コード例 #1
0
        public override void Train(TaggedWord tw, int loc, double weight)
        {
            if (useGT)
            {
                unknownGTTrainer.Train(tw, weight);
            }
            // scan data
            string word      = tw.Word();
            string subString = model.GetSignature(word, loc);
            ILabel tag       = new Tag(tw.Tag());

            if (!c.Contains(tag))
            {
                c[tag] = new ClassicCounter <string>();
            }
            c[tag].IncrementCount(subString, weight);
            tc.IncrementCount(tag, weight);
            seenEnd.Add(subString);
            string        tagStr = tw.Tag();
            IntTaggedWord iW     = new IntTaggedWord(word, IntTaggedWord.Any, wordIndex, tagIndex);

            seenCounter.IncrementCount(iW, weight);
            if (treesRead > indexToStartUnkCounting)
            {
                // start doing this once some way through trees;
                // treesRead is 1 based counting
                if (seenCounter.GetCount(iW) < 2)
                {
                    IntTaggedWord iT = new IntTaggedWord(IntTaggedWord.Any, tagStr, wordIndex, tagIndex);
                    unSeenCounter.IncrementCount(iT, weight);
                    unSeenCounter.IncrementCount(UnknownWordModelTrainerConstants.NullItw, weight);
                }
            }
        }
コード例 #2
0
        public static void Main(string[] args)
        {
            Edu.Stanford.Nlp.Classify.RVFDataset <string, string> data = new Edu.Stanford.Nlp.Classify.RVFDataset <string, string>();
            ClassicCounter <string> c1 = new ClassicCounter <string>();

            c1.IncrementCount("fever", 3.5);
            c1.IncrementCount("cough", 1.1);
            c1.IncrementCount("congestion", 4.2);
            ClassicCounter <string> c2 = new ClassicCounter <string>();

            c2.IncrementCount("fever", 1.5);
            c2.IncrementCount("cough", 2.1);
            c2.IncrementCount("nausea", 3.2);
            ClassicCounter <string> c3 = new ClassicCounter <string>();

            c3.IncrementCount("cough", 2.5);
            c3.IncrementCount("congestion", 3.2);
            data.Add(new RVFDatum <string, string>(c1, "cold"));
            data.Add(new RVFDatum <string, string>(c2, "flu"));
            data.Add(new RVFDatum <string, string>(c3, "cold"));
            data.SummaryStatistics();
            LinearClassifierFactory <string, string> factory = new LinearClassifierFactory <string, string>();

            factory.UseQuasiNewton();
            LinearClassifier <string, string> c  = factory.TrainClassifier(data);
            ClassicCounter <string>           c4 = new ClassicCounter <string>();

            c4.IncrementCount("cough", 2.3);
            c4.IncrementCount("fever", 1.3);
            RVFDatum <string, string> datum = new RVFDatum <string, string>(c4);

            c.JustificationOf((IDatum <string, string>)datum);
        }
コード例 #3
0
        public virtual void Train(TaggedWord tw, int loc, double weight)
        {
            uwModelTrainer.Train(tw, loc, weight);
            IntTaggedWord iTW = new IntTaggedWord(tw.Word(), tw.Tag(), wordIndex, tagIndex);

            seenCounter.IncrementCount(iTW, weight);
            IntTaggedWord iT = new IntTaggedWord(nullWord, iTW.tag);

            seenCounter.IncrementCount(iT, weight);
            IntTaggedWord iW = new IntTaggedWord(iTW.word, nullTag);

            seenCounter.IncrementCount(iW, weight);
            IntTaggedWord i = new IntTaggedWord(nullWord, nullTag);

            seenCounter.IncrementCount(i, weight);
            // rules.add(iTW);
            tags.Add(iT);
            words.Add(iW);
            string            tag     = tw.Tag();
            string            baseTag = op.Langpack().BasicCategory(tag);
            ICounter <string> counts  = baseTagCounts[baseTag];

            if (counts == null)
            {
                counts = new ClassicCounter <string>();
                baseTagCounts[baseTag] = counts;
            }
            counts.IncrementCount(tag, weight);
        }
コード例 #4
0
        public virtual ICounter <string> GetTopSpeakers(IList <Sieve.MentionData> closestMentions, IList <Sieve.MentionData> closestMentionsBackward, Person.Gender gender, ICoreMap quote, bool overrideGender)
        {
            ICounter <string> topSpeakerInRange               = new ClassicCounter <string>();
            ICounter <string> topSpeakerInRangeIgnoreGender   = new ClassicCounter <string>();
            ICollection <Sieve.MentionData> backwardsMentions = new HashSet <Sieve.MentionData>(closestMentionsBackward);

            foreach (Sieve.MentionData mention in closestMentions)
            {
                double weight = backwardsMentions.Contains(mention) ? BackwardWeight : ForwardWeight;
                if (mention.type.Equals(Name))
                {
                    if (!characterMap.Keys.Contains(mention.text))
                    {
                        continue;
                    }
                    Person p = characterMap[mention.text][0];
                    if ((gender == Person.Gender.Male && p.gender == Person.Gender.Male) || (gender == Person.Gender.Female && p.gender == Person.Gender.Female) || (gender == Person.Gender.Unk))
                    {
                        topSpeakerInRange.IncrementCount(p.name, weight);
                    }
                    topSpeakerInRangeIgnoreGender.IncrementCount(p.name, weight);
                    if (closestMentions.Count == 128 && closestMentionsBackward.Count == 94)
                    {
                        System.Console.Out.WriteLine(p.name + " " + weight + " name");
                    }
                }
                else
                {
                    if (mention.type.Equals(Pronoun))
                    {
                        int    charBeginKey = doc.Get(typeof(CoreAnnotations.TokensAnnotation))[mention.begin].BeginPosition();
                        Person p            = DoCoreference(charBeginKey, quote);
                        if (p != null)
                        {
                            if ((gender == Person.Gender.Male && p.gender == Person.Gender.Male) || (gender == Person.Gender.Female && p.gender == Person.Gender.Female) || (gender == Person.Gender.Unk))
                            {
                                topSpeakerInRange.IncrementCount(p.name, weight);
                            }
                            topSpeakerInRangeIgnoreGender.IncrementCount(p.name, weight);
                            if (closestMentions.Count == 128 && closestMentionsBackward.Count == 94)
                            {
                                System.Console.Out.WriteLine(p.name + " " + weight + " pronoun");
                            }
                        }
                    }
                }
            }
            if (topSpeakerInRange.Size() > 0)
            {
                return(topSpeakerInRange);
            }
            else
            {
                if (gender != Person.Gender.Unk && !overrideGender)
                {
                    return(topSpeakerInRange);
                }
            }
            return(topSpeakerInRangeIgnoreGender);
        }
コード例 #5
0
        /// <summary>Featurize a given sentence.</summary>
        /// <param name="sentence">The sentence to featurize.</param>
        /// <returns>A counter encoding the featurized sentence.</returns>
        private static ICounter <string> Featurize(ICoreMap sentence)
        {
            ClassicCounter <string> features = new ClassicCounter <string>();
            string lastLemma = "^";

            foreach (CoreLabel token in sentence.Get(typeof(CoreAnnotations.TokensAnnotation)))
            {
                string lemma = token.Lemma().ToLower();
                if (number.Matcher(lemma).Matches())
                {
                    features.IncrementCount("**num**");
                }
                else
                {
                    features.IncrementCount(lemma);
                }
                if (alpha.Matcher(lemma).Matches())
                {
                    features.IncrementCount(lastLemma + "__" + lemma);
                    lastLemma = lemma;
                }
            }
            features.IncrementCount(lastLemma + "__$");
            return(features);
        }
コード例 #6
0
        private static ICounter <string> GetFeatures(ClustererDataLoader.ClustererDoc doc, Clusterer.Cluster c1, Clusterer.Cluster c2, Clusterer.GlobalFeatures gf)
        {
            Clusterer.MergeKey      key      = new Clusterer.MergeKey(c1, c2, gf.currentIndex);
            CompressedFeatureVector cfv      = featuresCache[key];
            ICounter <string>       features = cfv == null ? null : compressor.Uncompress(cfv);

            if (features != null)
            {
                featuresCacheHits += isTraining;
                return(features);
            }
            featuresCacheMisses += isTraining;
            features             = new ClassicCounter <string>();
            if (gf.anaphorSeen)
            {
                features.IncrementCount("anaphorSeen");
            }
            features.IncrementCount("docSize", gf.docSize);
            features.IncrementCount("percentComplete", gf.currentIndex / (double)gf.size);
            features.IncrementCount("bias", 1.0);
            int earliest1 = EarliestMention(c1, doc);
            int earliest2 = EarliestMention(c2, doc);

            if (doc.mentionIndices[earliest1] > doc.mentionIndices[earliest2])
            {
                int tmp = earliest1;
                earliest1 = earliest2;
                earliest2 = tmp;
            }
            features.IncrementCount("anaphoricity", doc.anaphoricityScores.GetCount(earliest2));
            if (c1.mentions.Count == 1 && c2.mentions.Count == 1)
            {
                Pair <int, int> mentionPair = new Pair <int, int>(c1.mentions[0], c2.mentions[0]);
                features.AddAll(AddSuffix(GetFeatures(doc, mentionPair, doc.classificationScores), "-classification"));
                features.AddAll(AddSuffix(GetFeatures(doc, mentionPair, doc.rankingScores), "-ranking"));
                features = AddSuffix(features, "-single");
            }
            else
            {
                IList <Pair <int, int> > between = new List <Pair <int, int> >();
                foreach (int m1 in c1.mentions)
                {
                    foreach (int m2 in c2.mentions)
                    {
                        between.Add(new Pair <int, int>(m1, m2));
                    }
                }
                features.AddAll(AddSuffix(GetFeatures(doc, between, doc.classificationScores), "-classification"));
                features.AddAll(AddSuffix(GetFeatures(doc, between, doc.rankingScores), "-ranking"));
            }
            featuresCache[key] = compressor.Compress(features);
            return(features);
        }
コード例 #7
0
        /// <summary>
        /// Converts the svm_light weight Counter (which uses feature indices) into a weight Counter
        /// using the actual features and labels.
        /// </summary>
        /// <remarks>
        /// Converts the svm_light weight Counter (which uses feature indices) into a weight Counter
        /// using the actual features and labels.  Because this is svm_light, and not svm_struct, the
        /// weights for the +1 class (which correspond to labelIndex.get(0)) and the -1 class
        /// (which correspond to labelIndex.get(1)) are just the negation of one another.
        /// </remarks>
        private ClassicCounter <Pair <F, L> > ConvertSVMLightWeights(ClassicCounter <int> weights, IIndex <F> featureIndex, IIndex <L> labelIndex)
        {
            ClassicCounter <Pair <F, L> > newWeights = new ClassicCounter <Pair <F, L> >();

            foreach (int i in weights.KeySet())
            {
                F      f = featureIndex.Get(i - 1);
                double w = weights.GetCount(i);
                // the first guy in the labelIndex was the +1 class and the second guy
                // was the -1 class
                newWeights.IncrementCount(new Pair <F, L>(f, labelIndex.Get(0)), w);
                newWeights.IncrementCount(new Pair <F, L>(f, labelIndex.Get(1)), -w);
            }
            return(newWeights);
        }
        public override void PrintResults(PrintWriter pw, IList <ICoreMap> goldStandard, IList <ICoreMap> extractorOutput)
        {
            ResultsPrinter.Align(goldStandard, extractorOutput);
            // the mention factory cannot be null here
            System.Diagnostics.Debug.Assert(relationMentionFactory != null, "ERROR: RelationExtractorResultsPrinter.relationMentionFactory cannot be null in printResults!");
            // Count predicted-actual relation type pairs
            ICounter <Pair <string, string> > results    = new ClassicCounter <Pair <string, string> >();
            ClassicCounter <string>           labelCount = new ClassicCounter <string>();

            // TODO: assumes binary relations
            for (int goldSentenceIndex = 0; goldSentenceIndex < goldStandard.Count; goldSentenceIndex++)
            {
                foreach (RelationMention goldRelation in AnnotationUtils.GetAllRelations(relationMentionFactory, goldStandard[goldSentenceIndex], createUnrelatedRelations))
                {
                    ICoreMap extractorSentence = extractorOutput[goldSentenceIndex];
                    IList <RelationMention> extractorRelations = AnnotationUtils.GetRelations(relationMentionFactory, extractorSentence, goldRelation.GetArg(0), goldRelation.GetArg(1));
                    labelCount.IncrementCount(goldRelation.GetType());
                    foreach (RelationMention extractorRelation in extractorRelations)
                    {
                        results.IncrementCount(new Pair <string, string>(extractorRelation.GetType(), goldRelation.GetType()));
                    }
                }
            }
            PrintResultsInternal(pw, results, labelCount);
        }
コード例 #9
0
            public virtual void ComputeFinalValues()
            {
                double denom = (double)numTrees;

                meanDepth           = depth2 / denom;
                meanLength          = length2 / denom;
                meanBreadth         = breadth2 / denom;
                meanConstituents    = phrasalBranchingNum2.TotalCount() / denom;
                meanBranchingFactor = phrasalBranching2.TotalCount() / phrasalBranchingNum2.TotalCount();
                //Compute *actual* stddev (we iterate over the whole population)
                foreach (int d in depths)
                {
                    stddevDepth += Math.Pow(d - meanDepth, 2);
                }
                stddevDepth = Math.Sqrt(stddevDepth / denom);
                foreach (int l in lengths)
                {
                    stddevLength += Math.Pow(l - meanLength, 2);
                }
                stddevLength = Math.Sqrt(stddevLength / denom);
                foreach (int b in breadths)
                {
                    stddevBreadth += Math.Pow(b - meanBreadth, 2);
                }
                stddevBreadth        = Math.Sqrt(stddevBreadth / denom);
                meanBranchingByLabel = new ClassicCounter <string>();
                foreach (string label in phrasalBranching2.KeySet())
                {
                    double mean = phrasalBranching2.GetCount(label) / phrasalBranchingNum2.GetCount(label);
                    meanBranchingByLabel.IncrementCount(label, mean);
                }
                oovWords = Generics.NewHashSet(words.KeySet());
                oovWords.RemoveAll(trainVocab);
                OOVRate = (double)oovWords.Count / (double)words.KeySet().Count;
            }
コード例 #10
0
        /// <exception cref="System.IO.IOException"/>
        private void WriteObject(ObjectOutputStream stream)
        {
            //    log.info("\nBefore compression:");
            //    log.info("arg size: " + argCounter.size() + "  total: " + argCounter.totalCount());
            //    log.info("stop size: " + stopCounter.size() + "  total: " + stopCounter.totalCount());
            ClassicCounter <IntDependency> fullArgCounter = argCounter;

            argCounter = new ClassicCounter <IntDependency>();
            foreach (IntDependency dependency in fullArgCounter.KeySet())
            {
                if (dependency.head != wildTW && dependency.arg != wildTW && dependency.head.word != -1 && dependency.arg.word != -1)
                {
                    argCounter.IncrementCount(dependency, fullArgCounter.GetCount(dependency));
                }
            }
            ClassicCounter <IntDependency> fullStopCounter = stopCounter;

            stopCounter = new ClassicCounter <IntDependency>();
            foreach (IntDependency dependency_1 in fullStopCounter.KeySet())
            {
                if (dependency_1.head.word != -1)
                {
                    stopCounter.IncrementCount(dependency_1, fullStopCounter.GetCount(dependency_1));
                }
            }
            //    log.info("After compression:");
            //    log.info("arg size: " + argCounter.size() + "  total: " + argCounter.totalCount());
            //    log.info("stop size: " + stopCounter.size() + "  total: " + stopCounter.totalCount());
            stream.DefaultWriteObject();
            argCounter  = fullArgCounter;
            stopCounter = fullStopCounter;
        }
コード例 #11
0
            public override Pair <double, double> GetScore(IList <IList <int> > clusters, IDictionary <int, IList <int> > mentionToGold)
            {
                double num = 0;
                int    dem = 0;

                foreach (IList <int> c in clusters)
                {
                    if (c.Count == 1)
                    {
                        continue;
                    }
                    ICounter <IList <int> > goldCounts = new ClassicCounter <IList <int> >();
                    double correct = 0;
                    foreach (int m in c)
                    {
                        IList <int> goldCluster = mentionToGold[m];
                        if (goldCluster != null)
                        {
                            goldCounts.IncrementCount(goldCluster);
                        }
                    }
                    foreach (KeyValuePair <IList <int>, double> e in goldCounts.EntrySet())
                    {
                        if (e.Key.Count != 1)
                        {
                            correct += e.Value * e.Value;
                        }
                    }
                    num += correct / c.Count;
                    dem += c.Count;
                }
                return(new Pair <double, double>(num, (double)dem));
            }
コード例 #12
0
        public virtual RVFDatum <L, F> ScaleDatumGaussian(RVFDatum <L, F> datum)
        {
            // scale this dataset before scaling the datum
            if (means == null || stdevs == null)
            {
                ScaleFeaturesGaussian();
            }
            ICounter <F> scaledFeatures = new ClassicCounter <F>();

            foreach (F feature in datum.AsFeatures())
            {
                int fID = this.featureIndex.IndexOf(feature);
                if (fID >= 0)
                {
                    double oldVal = datum.AsFeaturesCounter().GetCount(feature);
                    double newVal;
                    if (stdevs[fID] != 0)
                    {
                        newVal = (oldVal - means[fID]) / stdevs[fID];
                    }
                    else
                    {
                        newVal = oldVal;
                    }
                    scaledFeatures.IncrementCount(feature, newVal);
                }
            }
            return(new RVFDatum <L, F>(scaledFeatures, datum.Label()));
        }
コード例 #13
0
        private static void ModifyUsingCoreNLPNER(Annotation doc)
        {
            Properties ann = new Properties();

            ann.SetProperty("annotators", "pos, lemma, ner");
            StanfordCoreNLP pipeline = new StanfordCoreNLP(ann, false);

            pipeline.Annotate(doc);
            foreach (ICoreMap sentence in doc.Get(typeof(CoreAnnotations.SentencesAnnotation)))
            {
                IList <EntityMention> entities = sentence.Get(typeof(MachineReadingAnnotations.EntityMentionsAnnotation));
                if (entities != null)
                {
                    IList <CoreLabel> tokens = sentence.Get(typeof(CoreAnnotations.TokensAnnotation));
                    foreach (EntityMention en in entities)
                    {
                        //System.out.println("old ner tag for " + en.getExtentString() + " was " + en.getType());
                        Span s = en.GetExtent();
                        ICounter <string> allNertagforSpan = new ClassicCounter <string>();
                        for (int i = s.Start(); i < s.End(); i++)
                        {
                            allNertagforSpan.IncrementCount(tokens[i].Ner());
                        }
                        string entityNertag = Counters.Argmax(allNertagforSpan);
                        en.SetType(entityNertag);
                    }
                }
            }
        }
コード例 #14
0
        public virtual void RunCoref(Document document)
        {
            IDictionary <Pair <int, int>, bool> mentionPairs = CorefUtils.GetUnlabeledMentionPairs(document);

            if (mentionPairs.Count == 0)
            {
                return;
            }
            Compressor <string>         compressor           = new Compressor <string>();
            DocumentExamples            examples             = extractor.Extract(0, document, mentionPairs, compressor);
            ICounter <Pair <int, int> > classificationScores = new ClassicCounter <Pair <int, int> >();
            ICounter <Pair <int, int> > rankingScores        = new ClassicCounter <Pair <int, int> >();
            ICounter <int> anaphoricityScores = new ClassicCounter <int>();

            foreach (Example example in examples.examples)
            {
                CorefUtils.CheckForInterrupt();
                Pair <int, int> mentionPair = new Pair <int, int>(example.mentionId1, example.mentionId2);
                classificationScores.IncrementCount(mentionPair, classificationModel.Predict(example, examples.mentionFeatures, compressor));
                rankingScores.IncrementCount(mentionPair, rankingModel.Predict(example, examples.mentionFeatures, compressor));
                if (!anaphoricityScores.ContainsKey(example.mentionId2))
                {
                    anaphoricityScores.IncrementCount(example.mentionId2, anaphoricityModel.Predict(new Example(example, false), examples.mentionFeatures, compressor));
                }
            }
            ClustererDataLoader.ClustererDoc doc = new ClustererDataLoader.ClustererDoc(0, classificationScores, rankingScores, anaphoricityScores, mentionPairs, null, document.predictedMentionsByID.Stream().Collect(Collectors.ToMap(null, null)));
            foreach (Pair <int, int> mentionPair_1 in clusterer.GetClusterMerges(doc))
            {
                CorefUtils.MergeCoreferenceClusters(mentionPair_1, document);
            }
        }
コード例 #15
0
        public virtual ICounter <L> ProbabilityOf(IDatum <L, F> example)
        {
            // calculate the feature indices and feature values
            int[]    featureIndices = LogisticUtils.IndicesOf(example.AsFeatures(), featureIndex);
            double[] featureValues;
            if (example is RVFDatum <object, object> )
            {
                ICollection <double> featureValuesCollection = ((RVFDatum <object, object>)example).AsFeaturesCounter().Values();
                featureValues = LogisticUtils.ConvertToArray(featureValuesCollection);
            }
            else
            {
                featureValues = new double[example.AsFeatures().Count];
                Arrays.Fill(featureValues, 1.0);
            }
            // calculate probability of each class
            ICounter <L> result     = new ClassicCounter <L>();
            int          numClasses = labelIndex.Size();

            double[] sigmoids = LogisticUtils.CalculateSigmoids(weights, featureIndices, featureValues);
            for (int c = 0; c < numClasses; c++)
            {
                L label = labelIndex.Get(c);
                result.IncrementCount(label, sigmoids[c]);
            }
            return(result);
        }
コード例 #16
0
        /// <summary>Method to convert features from counts to L1-normalized TFIDF based features</summary>
        /// <param name="datum">with a collection of features.</param>
        /// <param name="featureDocCounts">a counter of doc-count for each feature.</param>
        /// <returns>RVFDatum with l1-normalized tf-idf features.</returns>
        public virtual RVFDatum <L, F> GetL1NormalizedTFIDFDatum(IDatum <L, F> datum, ICounter <F> featureDocCounts)
        {
            ICounter <F> tfidfFeatures = new ClassicCounter <F>();

            foreach (F feature in datum.AsFeatures())
            {
                if (featureDocCounts.ContainsKey(feature))
                {
                    tfidfFeatures.IncrementCount(feature, 1.0);
                }
            }
            double l1norm = 0;

            foreach (F feature_1 in tfidfFeatures.KeySet())
            {
                double idf = Math.Log(((double)(this.Size() + 1)) / (featureDocCounts.GetCount(feature_1) + 0.5));
                double tf  = tfidfFeatures.GetCount(feature_1);
                tfidfFeatures.SetCount(feature_1, tf * idf);
                l1norm += tf * idf;
            }
            foreach (F feature_2 in tfidfFeatures.KeySet())
            {
                double tfidf = tfidfFeatures.GetCount(feature_2);
                tfidfFeatures.SetCount(feature_2, tfidf / l1norm);
            }
            RVFDatum <L, F> rvfDatum = new RVFDatum <L, F>(tfidfFeatures, datum.Label());

            return(rvfDatum);
        }
コード例 #17
0
        public virtual ClassicCounter <L> ScoresOf(RVFDatum <L, F> example)
        {
            ClassicCounter <L> scores = new ClassicCounter <L>();

            Counters.AddInPlace(scores, priors);
            if (addZeroValued)
            {
                Counters.AddInPlace(scores, priorZero);
            }
            foreach (L l in labels)
            {
                double       score    = 0.0;
                ICounter <F> features = example.AsFeaturesCounter();
                foreach (F f in features.KeySet())
                {
                    int value = (int)features.GetCount(f);
                    score += Weight(l, f, int.Parse(value));
                    if (addZeroValued)
                    {
                        score -= Weight(l, f, zero);
                    }
                }
                scores.IncrementCount(l, score);
            }
            return(scores);
        }
コード例 #18
0
        private NaiveBayesClassifier <L, F> TrainClassifier(int[][] data, int[] labels, int numFeatures, int numClasses, IIndex <L> labelIndex, IIndex <F> featureIndex)
        {
            ICollection <L> labelSet = Generics.NewHashSet();

            NaiveBayesClassifierFactory.NBWeights nbWeights = TrainWeights(data, labels, numFeatures, numClasses);
            ICounter <L> priors = new ClassicCounter <L>();

            double[] pr = nbWeights.priors;
            for (int i = 0; i < pr.Length; i++)
            {
                priors.IncrementCount(labelIndex.Get(i), pr[i]);
                labelSet.Add(labelIndex.Get(i));
            }
            ICounter <Pair <Pair <L, F>, Number> > weightsCounter = new ClassicCounter <Pair <Pair <L, F>, Number> >();

            double[][][] wts = nbWeights.weights;
            for (int c = 0; c < numClasses; c++)
            {
                L label = labelIndex.Get(c);
                for (int f = 0; f < numFeatures; f++)
                {
                    F           feature = featureIndex.Get(f);
                    Pair <L, F> p       = new Pair <L, F>(label, feature);
                    for (int val = 0; val < wts[c][f].Length; val++)
                    {
                        Pair <Pair <L, F>, Number> key = new Pair <Pair <L, F>, Number>(p, int.Parse(val));
                        weightsCounter.IncrementCount(key, wts[c][f][val]);
                    }
                }
            }
            return(new NaiveBayesClassifier <L, F>(weightsCounter, priors, labelSet));
        }
コード例 #19
0
        /// <summary>Trains this lexicon on the Collection of trees.</summary>
        public override void Train(TaggedWord tw, int loc, double weight)
        {
            IntTaggedWord iTW = new IntTaggedWord(tw.Word(), tw.Tag(), wordIndex, tagIndex);
            IntTaggedWord iT  = new IntTaggedWord(UnknownWordModelTrainerConstants.nullWord, iTW.tag);
            IntTaggedWord iW  = new IntTaggedWord(iTW.word, UnknownWordModelTrainerConstants.nullTag);

            seenCounter.IncrementCount(iW, weight);
            IntTaggedWord i = UnknownWordModelTrainerConstants.NullItw;

            if (treesRead > indexToStartUnkCounting)
            {
                // start doing this once some way through trees;
                // treesRead is 1 based counting
                if (seenCounter.GetCount(iW) < 2)
                {
                    // it's an entirely unknown word
                    int           s   = model.GetSignatureIndex(iTW.word, loc, wordIndex.Get(iTW.word));
                    IntTaggedWord iTS = new IntTaggedWord(s, iTW.tag);
                    IntTaggedWord iS  = new IntTaggedWord(s, UnknownWordModelTrainerConstants.nullTag);
                    unSeenCounter.IncrementCount(iTS, weight);
                    unSeenCounter.IncrementCount(iT, weight);
                    unSeenCounter.IncrementCount(iS, weight);
                    unSeenCounter.IncrementCount(i, weight);
                }
            }
        }
コード例 #20
0
 public virtual void FinishTraining()
 {
     // testing: get some stats here
     log.Info("Total tokens: " + tokens);
     log.Info("Total WordTag types: " + wtCount.KeySet().Count);
     log.Info("Total tag types: " + tagCount.KeySet().Count);
     log.Info("Total word types: " + seenWords.Count);
     /* find # of once-seen words for each tag */
     foreach (Pair <string, string> wt in wtCount.KeySet())
     {
         if (wtCount.GetCount(wt) == 1)
         {
             r1.IncrementCount(wt.Second());
         }
     }
     /* find # of unseen words for each tag */
     foreach (string tag in tagCount.KeySet())
     {
         foreach (string word in seenWords)
         {
             Pair <string, string> wt_1 = new Pair <string, string>(word, tag);
             if (!(wtCount.KeySet().Contains(wt_1)))
             {
                 r0.IncrementCount(tag);
             }
         }
     }
     /* set unseen word probability for each tag */
     foreach (string tag_1 in tagCount.KeySet())
     {
         float logprob = (float)Math.Log(r1.GetCount(tag_1) / (tagCount.GetCount(tag_1) * r0.GetCount(tag_1)));
         unknownGT[tag_1] = float.ValueOf(logprob);
     }
 }
        private void PrintResultsInternal(PrintWriter pw, ICounter <Pair <string, string> > results, ClassicCounter <string> labelCount)
        {
            ClassicCounter <string> correct         = new ClassicCounter <string>();
            ClassicCounter <string> predictionCount = new ClassicCounter <string>();
            bool countGoldLabels = false;

            if (labelCount == null)
            {
                labelCount      = new ClassicCounter <string>();
                countGoldLabels = true;
            }
            foreach (Pair <string, string> predictedActual in results.KeySet())
            {
                string predicted = predictedActual.first;
                string actual    = predictedActual.second;
                if (predicted.Equals(actual))
                {
                    correct.IncrementCount(actual, results.GetCount(predictedActual));
                }
                predictionCount.IncrementCount(predicted, results.GetCount(predictedActual));
                if (countGoldLabels)
                {
                    labelCount.IncrementCount(actual, results.GetCount(predictedActual));
                }
            }
            DecimalFormat formatter = new DecimalFormat();

            formatter.SetMaximumFractionDigits(1);
            formatter.SetMinimumFractionDigits(1);
            double totalCount     = 0;
            double totalCorrect   = 0;
            double totalPredicted = 0;

            pw.Println("Label\tCorrect\tPredict\tActual\tPrecn\tRecall\tF");
            IList <string> labels = new List <string>(labelCount.KeySet());

            labels.Sort();
            foreach (string label in labels)
            {
                double numcorrect = correct.GetCount(label);
                double predicted  = predictionCount.GetCount(label);
                double trueCount  = labelCount.GetCount(label);
                double precision  = (predicted > 0) ? (numcorrect / predicted) : 0;
                double recall     = numcorrect / trueCount;
                double f          = (precision + recall > 0) ? 2 * precision * recall / (precision + recall) : 0.0;
                pw.Println(StringUtils.PadOrTrim(label, MaxLabelLength) + "\t" + numcorrect + "\t" + predicted + "\t" + trueCount + "\t" + formatter.Format(precision * 100) + "\t" + formatter.Format(100 * recall) + "\t" + formatter.Format(100 * f));
                if (!RelationMention.IsUnrelatedLabel(label))
                {
                    totalCount     += trueCount;
                    totalCorrect   += numcorrect;
                    totalPredicted += predicted;
                }
            }
            double precision_1 = (totalPredicted > 0) ? (totalCorrect / totalPredicted) : 0;
            double recall_1    = totalCorrect / totalCount;
            double f_1         = (totalPredicted > 0 && totalCorrect > 0) ? 2 * precision_1 * recall_1 / (precision_1 + recall_1) : 0.0;

            pw.Println("Total\t" + totalCorrect + "\t" + totalPredicted + "\t" + totalCount + "\t" + formatter.Format(100 * precision_1) + "\t" + formatter.Format(100 * recall_1) + "\t" + formatter.Format(100 * f_1));
        }
コード例 #22
0
 protected internal override void TallyInternalNode(Tree lt, double weight)
 {
     if (lt.Children().Length == 1)
     {
         UnaryRule ur = new UnaryRule(stateIndex.AddToIndex(lt.Label().Value()), stateIndex.AddToIndex(lt.Children()[0].Label().Value()));
         symbolCounter.IncrementCount(stateIndex.Get(ur.parent), weight);
         unaryRuleCounter.IncrementCount(ur, weight);
         unaryRules.Add(ur);
     }
     else
     {
         BinaryRule br = new BinaryRule(stateIndex.AddToIndex(lt.Label().Value()), stateIndex.AddToIndex(lt.Children()[0].Label().Value()), stateIndex.AddToIndex(lt.Children()[1].Label().Value()));
         symbolCounter.IncrementCount(stateIndex.Get(br.parent), weight);
         binaryRuleCounter.IncrementCount(br, weight);
         binaryRules.Add(br);
     }
 }
コード例 #23
0
        private string Convert(string @in, bool unicodeToBuckwalter)
        {
            StringTokenizer st     = new StringTokenizer(@in);
            StringBuilder   result = new StringBuilder(@in.Length);

            while (st.HasMoreTokens())
            {
                string token = st.NextToken();
                for (int i = 0; i < token.Length; i++)
                {
                    if (ATBTreeUtils.reservedWords.Contains(token))
                    {
                        result.Append(token);
                        break;
                    }
                    char inCh  = char.ValueOf(token[i]);
                    char outCh = null;
                    if (unicodeToBuckwalter)
                    {
                        outCh = (PassAsciiInUnicode && inCh < 127) ? inCh : u2bMap[inCh];
                    }
                    else
                    {
                        if ((SuppressDigitMappingInB2a && char.IsDigit((char)inCh)) || (SuppressPuncMappingInB2a && latinPunc.Matcher(inCh.ToString()).Matches()))
                        {
                            outCh = inCh;
                        }
                        else
                        {
                            outCh = b2uMap[inCh];
                        }
                    }
                    if (outCh == null)
                    {
                        if (Debug)
                        {
                            string key = inCh + "[U+" + StringUtils.PadLeft(int.ToString(inCh, 16).ToUpper(), 4, '0') + ']';
                            unmappable.IncrementCount(key);
                        }
                        result.Append(inCh);
                    }
                    else
                    {
                        // pass through char
                        if (outputUnicodeValues)
                        {
                            result.Append("\\u").Append(StringUtils.PadLeft(int.ToString(inCh, 16).ToUpper(), 4, '0'));
                        }
                        else
                        {
                            result.Append(outCh);
                        }
                    }
                }
                result.Append(" ");
            }
            return(result.ToString().Trim());
        }
コード例 #24
0
        public virtual void ClassifyMentions(IList <IList <Mention> > predictedMentions, Dictionaries dict, Properties props)
        {
            ICollection <string> neStrings = Generics.NewHashSet();

            foreach (IList <Mention> predictedMention in predictedMentions)
            {
                foreach (Mention m in predictedMention)
                {
                    string ne = m.headWord.Ner();
                    if (ne.Equals("O"))
                    {
                        continue;
                    }
                    foreach (CoreLabel cl in m.originalSpan)
                    {
                        if (!cl.Ner().Equals(ne))
                        {
                            continue;
                        }
                    }
                    neStrings.Add(m.LowercaseNormalizedSpanString());
                }
            }
            foreach (IList <Mention> predicts in predictedMentions)
            {
                IDictionary <int, ICollection <Mention> > headPositions = Generics.NewHashMap();
                foreach (Mention p in predicts)
                {
                    if (!headPositions.Contains(p.headIndex))
                    {
                        headPositions[p.headIndex] = Generics.NewHashSet();
                    }
                    headPositions[p.headIndex].Add(p);
                }
                ICollection <Mention> remove = Generics.NewHashSet();
                foreach (int hPos in headPositions.Keys)
                {
                    ICollection <Mention> shares = headPositions[hPos];
                    if (shares.Count > 1)
                    {
                        ICounter <Mention> probs = new ClassicCounter <Mention>();
                        foreach (Mention p_1 in shares)
                        {
                            double trueProb = ProbabilityOf(p_1, shares, neStrings, dict, props);
                            probs.IncrementCount(p_1, trueProb);
                        }
                        // add to remove
                        Mention keep = Counters.Argmax(probs, null);
                        probs.Remove(keep);
                        Sharpen.Collections.AddAll(remove, probs.KeySet());
                    }
                }
                foreach (Mention r in remove)
                {
                    predicts.Remove(r);
                }
            }
        }
コード例 #25
0
        /// <summary>Reads in a model file in svm light format.</summary>
        /// <remarks>
        /// Reads in a model file in svm light format.  It needs to know if its multiclass or not
        /// because it affects the number of header lines.  Maybe there is another way to tell and we
        /// can remove this flag?
        /// </remarks>
        private static Pair <double, ClassicCounter <int> > ReadModel(File modelFile, bool multiclass)
        {
            int modelLineCount = 0;

            try
            {
                int            numLinesToSkip = multiclass ? 13 : 10;
                string         stopToken      = "#";
                BufferedReader @in            = new BufferedReader(new FileReader(modelFile));
                for (int i = 0; i < numLinesToSkip; i++)
                {
                    @in.ReadLine();
                    modelLineCount++;
                }
                IList <Pair <double, ClassicCounter <int> > > supportVectors = new List <Pair <double, ClassicCounter <int> > >();
                // Read Threshold
                string thresholdLine = @in.ReadLine();
                modelLineCount++;
                string[] pieces    = thresholdLine.Split("\\s+");
                double   threshold = double.Parse(pieces[0]);
                // Read Support Vectors
                while (@in.Ready())
                {
                    string svLine = @in.ReadLine();
                    modelLineCount++;
                    pieces = svLine.Split("\\s+");
                    // First Element is the alpha_i * y_i
                    double alpha = double.Parse(pieces[0]);
                    ClassicCounter <int> supportVector = new ClassicCounter <int>();
                    for (int i_1 = 1; i_1 < pieces.Length; ++i_1)
                    {
                        string piece = pieces[i_1];
                        if (piece.Equals(stopToken))
                        {
                            break;
                        }
                        // Each in featureIndex:num class
                        string[] indexNum     = piece.Split(":");
                        string   featureIndex = indexNum[0];
                        // mihai: we may see "qid" as indexNum[0]. just skip this piece. this is the block id useful only for reranking, which we don't do here.
                        if (!featureIndex.Equals("qid"))
                        {
                            double count = double.Parse(indexNum[1]);
                            supportVector.IncrementCount(int.Parse(featureIndex), count);
                        }
                    }
                    supportVectors.Add(new Pair <double, ClassicCounter <int> >(alpha, supportVector));
                }
                @in.Close();
                return(new Pair <double, ClassicCounter <int> >(threshold, GetWeights(supportVectors)));
            }
            catch (Exception e)
            {
                Sharpen.Runtime.PrintStackTrace(e);
                throw new Exception("Error reading SVM model (line " + modelLineCount + " in file " + modelFile.GetAbsolutePath() + ")");
            }
        }
コード例 #26
0
        // TODO not called any more, but possibly useful as a reference
        /// <summary>
        /// This should be called after the classifier has been trained and
        /// parseAndTrain has been called to accumulate test set
        /// This will return precision,recall and F1 measure
        /// </summary>
        public virtual void RunTestSet(IList <IList <CoreLabel> > testSet)
        {
            ICounter <string> tp     = new ClassicCounter <string>();
            ICounter <string> fp     = new ClassicCounter <string>();
            ICounter <string> fn     = new ClassicCounter <string>();
            ICounter <string> actual = new ClassicCounter <string>();

            foreach (IList <CoreLabel> labels in testSet)
            {
                IList <CoreLabel> unannotatedLabels = new List <CoreLabel>();
                // create a new label without answer annotation
                foreach (CoreLabel label in labels)
                {
                    CoreLabel newLabel = new CoreLabel();
                    newLabel.Set(annotationForWord, label.Get(annotationForWord));
                    newLabel.Set(typeof(CoreAnnotations.PartOfSpeechAnnotation), label.Get(typeof(CoreAnnotations.PartOfSpeechAnnotation)));
                    unannotatedLabels.Add(newLabel);
                }
                IList <CoreLabel> annotatedLabels = this.classifier.Classify(unannotatedLabels);
                int ind = 0;
                foreach (CoreLabel expectedLabel in labels)
                {
                    CoreLabel annotatedLabel = annotatedLabels[ind];
                    string    answer         = annotatedLabel.Get(typeof(CoreAnnotations.AnswerAnnotation));
                    string    expectedAnswer = expectedLabel.Get(typeof(CoreAnnotations.AnswerAnnotation));
                    actual.IncrementCount(expectedAnswer);
                    // match only non background symbols
                    if (!SeqClassifierFlags.DefaultBackgroundSymbol.Equals(expectedAnswer) && expectedAnswer.Equals(answer))
                    {
                        // true positives
                        tp.IncrementCount(answer);
                        System.Console.Out.WriteLine("True Positive:" + annotatedLabel);
                    }
                    else
                    {
                        if (!SeqClassifierFlags.DefaultBackgroundSymbol.Equals(answer))
                        {
                            // false positives
                            fp.IncrementCount(answer);
                            System.Console.Out.WriteLine("False Positive:" + annotatedLabel);
                        }
                        else
                        {
                            if (!SeqClassifierFlags.DefaultBackgroundSymbol.Equals(expectedAnswer))
                            {
                                // false negatives
                                fn.IncrementCount(expectedAnswer);
                                System.Console.Out.WriteLine("False Negative:" + expectedLabel);
                            }
                        }
                    }
                    // else true negatives
                    ind++;
                }
            }
            actual.Remove(SeqClassifierFlags.DefaultBackgroundSymbol);
        }
コード例 #27
0
        /// <summary>Trains the first-character based unknown word model.</summary>
        /// <param name="tw">The word we are currently training on</param>
        /// <param name="loc">The position of that word</param>
        /// <param name="weight">The weight to give this word in terms of training</param>
        public override void Train(TaggedWord tw, int loc, double weight)
        {
            if (useGT)
            {
                unknownGTTrainer.Train(tw, weight);
            }
            string word  = tw.Word();
            ILabel tagL  = new Tag(tw.Tag());
            string first = Sharpen.Runtime.Substring(word, 0, 1);

            if (useUnicodeType)
            {
                char ch   = word[0];
                int  type = char.GetType(ch);
                if (type != char.OtherLetter)
                {
                    // standard Chinese characters are of type "OTHER_LETTER"!!
                    first = int.ToString(type);
                }
            }
            string tag = tw.Tag();

            if (!c.Contains(tagL))
            {
                c[tagL] = new ClassicCounter <string>();
            }
            c[tagL].IncrementCount(first, weight);
            tc.IncrementCount(tagL, weight);
            seenFirst.Add(first);
            IntTaggedWord iW = new IntTaggedWord(word, IntTaggedWord.Any, wordIndex, tagIndex);

            seenCounter.IncrementCount(iW, weight);
            if (treesRead > indexToStartUnkCounting)
            {
                // start doing this once some way through trees;
                // treesRead is 1 based counting
                if (seenCounter.GetCount(iW) < 2)
                {
                    IntTaggedWord iT = new IntTaggedWord(IntTaggedWord.Any, tag, wordIndex, tagIndex);
                    unSeenCounter.IncrementCount(iT, weight);
                    unSeenCounter.IncrementCount(iTotal, weight);
                }
            }
        }
コード例 #28
0
        private static ICounter <string> GetConjunction(ICounter <string> original, string suffix)
        {
            ICounter <string> conjuction = new ClassicCounter <string>();

            foreach (KeyValuePair <string, double> e in original.EntrySet())
            {
                conjuction.IncrementCount(e.Key + suffix, e.Value);
            }
            return(conjuction);
        }
コード例 #29
0
        private static ICounter <string> AddSuffix(ICounter <string> features, string suffix)
        {
            ICounter <string> withSuffix = new ClassicCounter <string>();

            foreach (KeyValuePair <string, double> e in features.EntrySet())
            {
                withSuffix.IncrementCount(e.Key + suffix, e.Value);
            }
            return(withSuffix);
        }
コード例 #30
0
        /// <summary>
        /// If markovOrder is zero, we always transition back to the start state
        /// If markovOrder is negative, we assume that it is infinite
        /// </summary>
        public static TransducerGraph CreateGraphFromPaths(IList paths, int markovOrder)
        {
            ClassicCounter pathCounter = new ClassicCounter();

            foreach (object o in paths)
            {
                pathCounter.IncrementCount(o);
            }
            return(CreateGraphFromPaths(pathCounter, markovOrder));
        }