public virtual void RunCoref(Document document)
        {
            IDictionary <Pair <int, int>, bool> mentionPairs = CorefUtils.GetUnlabeledMentionPairs(document);

            if (mentionPairs.Count == 0)
            {
                return;
            }
            Compressor <string>         compressor           = new Compressor <string>();
            DocumentExamples            examples             = extractor.Extract(0, document, mentionPairs, compressor);
            ICounter <Pair <int, int> > classificationScores = new ClassicCounter <Pair <int, int> >();
            ICounter <Pair <int, int> > rankingScores        = new ClassicCounter <Pair <int, int> >();
            ICounter <int> anaphoricityScores = new ClassicCounter <int>();

            foreach (Example example in examples.examples)
            {
                CorefUtils.CheckForInterrupt();
                Pair <int, int> mentionPair = new Pair <int, int>(example.mentionId1, example.mentionId2);
                classificationScores.IncrementCount(mentionPair, classificationModel.Predict(example, examples.mentionFeatures, compressor));
                rankingScores.IncrementCount(mentionPair, rankingModel.Predict(example, examples.mentionFeatures, compressor));
                if (!anaphoricityScores.ContainsKey(example.mentionId2))
                {
                    anaphoricityScores.IncrementCount(example.mentionId2, anaphoricityModel.Predict(new Example(example, false), examples.mentionFeatures, compressor));
                }
            }
            ClustererDataLoader.ClustererDoc doc = new ClustererDataLoader.ClustererDoc(0, classificationScores, rankingScores, anaphoricityScores, mentionPairs, null, document.predictedMentionsByID.Stream().Collect(Collectors.ToMap(null, null)));
            foreach (Pair <int, int> mentionPair_1 in clusterer.GetClusterMerges(doc))
            {
                CorefUtils.MergeCoreferenceClusters(mentionPair_1, document);
            }
        }
Exemplo n.º 2
0
        public static void WriteScores(IList <Pair <Example, IDictionary <int, CompressedFeatureVector> > > examples, Compressor <string> compressor, PairwiseModel model, PrintWriter writer, IDictionary <int, ICounter <Pair <int, int> > > scores)
        {
            int i = 0;

            foreach (Pair <Example, IDictionary <int, CompressedFeatureVector> > pair in examples)
            {
                if (i++ % 10000 == 0)
                {
                    Redwood.Log("scoref-train", string.Format("On test example %d/%d = %.2f%%", i, examples.Count, 100.0 * i / examples.Count));
                }
                Example example = pair.first;
                IDictionary <int, CompressedFeatureVector> mentionFeatures = pair.second;
                double p = model.Predict(example, mentionFeatures, compressor);
                writer.Println(example.docId + " " + example.mentionId1 + "," + example.mentionId2 + " " + p + " " + example.label);
                ICounter <Pair <int, int> > docScores = scores[example.docId];
                if (docScores == null)
                {
                    docScores             = new ClassicCounter <Pair <int, int> >();
                    scores[example.docId] = docScores;
                }
                docScores.IncrementCount(new Pair <int, int>(example.mentionId1, example.mentionId2), p);
            }
        }
Exemplo n.º 3
0
        /// <exception cref="System.Exception"/>
        public static void TrainRanking(PairwiseModel model)
        {
            Redwood.Log("scoref-train", "Reading compression...");
            Compressor <string> compressor = IOUtils.ReadObjectFromFile(StatisticalCorefTrainer.compressorFile);

            Redwood.Log("scoref-train", "Reading train data...");
            IList <DocumentExamples> trainDocuments = IOUtils.ReadObjectFromFile(StatisticalCorefTrainer.extractedFeaturesFile);

            Redwood.Log("scoref-train", "Training...");
            for (int i = 0; i < model.GetNumEpochs(); i++)
            {
                Java.Util.Collections.Shuffle(trainDocuments);
                int j = 0;
                foreach (DocumentExamples doc in trainDocuments)
                {
                    j++;
                    Redwood.Log("scoref-train", "On epoch: " + i + " / " + model.GetNumEpochs() + ", document: " + j + " / " + trainDocuments.Count);
                    IDictionary <int, IList <Example> > mentionToPotentialAntecedents = new Dictionary <int, IList <Example> >();
                    foreach (Example e in doc.examples)
                    {
                        int             mention = e.mentionId2;
                        IList <Example> potentialAntecedents = mentionToPotentialAntecedents[mention];
                        if (potentialAntecedents == null)
                        {
                            potentialAntecedents = new List <Example>();
                            mentionToPotentialAntecedents[mention] = potentialAntecedents;
                        }
                        potentialAntecedents.Add(e);
                    }
                    IList <IList <Example> > examples = new List <IList <Example> >(mentionToPotentialAntecedents.Values);
                    Java.Util.Collections.Shuffle(examples);
                    foreach (IList <Example> es in examples)
                    {
                        if (es.Count == 0)
                        {
                            continue;
                        }
                        if (model is MaxMarginMentionRanker)
                        {
                            MaxMarginMentionRanker ranker = (MaxMarginMentionRanker)model;
                            bool noAntecedent             = es.Stream().AllMatch(null);
                            es.Add(new Example(es[0], noAntecedent));
                            double  maxPositiveScore   = -double.MaxValue;
                            Example maxScoringPositive = null;
                            foreach (Example e_1 in es)
                            {
                                double score = model.Predict(e_1, doc.mentionFeatures, compressor);
                                if (e_1.label == 1)
                                {
                                    System.Diagnostics.Debug.Assert((!noAntecedent ^ e_1.IsNewLink()));
                                    if (score > maxPositiveScore)
                                    {
                                        maxPositiveScore   = score;
                                        maxScoringPositive = e_1;
                                    }
                                }
                            }
                            System.Diagnostics.Debug.Assert((maxScoringPositive != null));
                            double  maxNegativeScore   = -double.MaxValue;
                            Example maxScoringNegative = null;
                            MaxMarginMentionRanker.ErrorType maxScoringEt = null;
                            foreach (Example e_2 in es)
                            {
                                double score = model.Predict(e_2, doc.mentionFeatures, compressor);
                                if (e_2.label != 1)
                                {
                                    System.Diagnostics.Debug.Assert((!(noAntecedent && e_2.IsNewLink())));
                                    MaxMarginMentionRanker.ErrorType et = MaxMarginMentionRanker.ErrorType.Wl;
                                    if (noAntecedent && !e_2.IsNewLink())
                                    {
                                        et = MaxMarginMentionRanker.ErrorType.Fl;
                                    }
                                    else
                                    {
                                        if (!noAntecedent && e_2.IsNewLink())
                                        {
                                            if (e_2.mentionType2 == Dictionaries.MentionType.Pronominal)
                                            {
                                                et = MaxMarginMentionRanker.ErrorType.FnPron;
                                            }
                                            else
                                            {
                                                et = MaxMarginMentionRanker.ErrorType.Fn;
                                            }
                                        }
                                    }
                                    if (ranker.multiplicativeCost)
                                    {
                                        score = ranker.costs[et.id] * (1 - maxPositiveScore + score);
                                    }
                                    else
                                    {
                                        score += ranker.costs[et.id];
                                    }
                                    if (score > maxNegativeScore)
                                    {
                                        maxNegativeScore   = score;
                                        maxScoringNegative = e_2;
                                        maxScoringEt       = et;
                                    }
                                }
                            }
                            System.Diagnostics.Debug.Assert((maxScoringNegative != null));
                            ranker.Learn(maxScoringPositive, maxScoringNegative, doc.mentionFeatures, compressor, maxScoringEt);
                        }
                        else
                        {
                            double  maxPositiveScore   = -double.MaxValue;
                            double  maxNegativeScore   = -double.MaxValue;
                            Example maxScoringPositive = null;
                            Example maxScoringNegative = null;
                            foreach (Example e_1 in es)
                            {
                                double score = model.Predict(e_1, doc.mentionFeatures, compressor);
                                if (e_1.label == 1)
                                {
                                    if (score > maxPositiveScore)
                                    {
                                        maxPositiveScore   = score;
                                        maxScoringPositive = e_1;
                                    }
                                }
                                else
                                {
                                    if (score > maxNegativeScore)
                                    {
                                        maxNegativeScore   = score;
                                        maxScoringNegative = e_1;
                                    }
                                }
                            }
                            model.Learn(maxScoringPositive, maxScoringNegative, doc.mentionFeatures, compressor, 1);
                        }
                    }
                }
            }
            Redwood.Log("scoref-train", "Writing models...");
            model.WriteModel();
        }
        public virtual void RunCoref(Document document)
        {
            Compressor <string> compressor = new Compressor <string>();

            if (Thread.Interrupted())
            {
                // Allow interrupting
                throw new RuntimeInterruptedException();
            }
            IDictionary <Pair <int, int>, bool> pairs = new Dictionary <Pair <int, int>, bool>();

            foreach (KeyValuePair <int, IList <int> > e in CorefUtils.HeuristicFilter(CorefUtils.GetSortedMentions(document), maxMentionDistance, maxMentionDistanceWithStringMatch))
            {
                foreach (int m1 in e.Value)
                {
                    pairs[new Pair <int, int>(m1, e.Key)] = true;
                }
            }
            DocumentExamples            examples       = extractor.Extract(0, document, pairs, compressor);
            ICounter <Pair <int, int> > pairwiseScores = new ClassicCounter <Pair <int, int> >();

            foreach (Example mentionPair in examples.examples)
            {
                if (Thread.Interrupted())
                {
                    // Allow interrupting
                    throw new RuntimeInterruptedException();
                }
                pairwiseScores.IncrementCount(new Pair <int, int>(mentionPair.mentionId1, mentionPair.mentionId2), classifier.Predict(mentionPair, examples.mentionFeatures, compressor));
            }
            IList <Pair <int, int> > mentionPairs = new List <Pair <int, int> >(pairwiseScores.KeySet());

            mentionPairs.Sort(null);
            ICollection <int> seenAnaphors = new HashSet <int>();

            foreach (Pair <int, int> pair in mentionPairs)
            {
                if (seenAnaphors.Contains(pair.second))
                {
                    continue;
                }
                if (Thread.Interrupted())
                {
                    // Allow interrupting
                    throw new RuntimeInterruptedException();
                }
                seenAnaphors.Add(pair.second);
                Dictionaries.MentionType mt1 = document.predictedMentionsByID[pair.first].mentionType;
                Dictionaries.MentionType mt2 = document.predictedMentionsByID[pair.second].mentionType;
                if (pairwiseScores.GetCount(pair) > thresholds[new Pair <bool, bool>(mt1 == Dictionaries.MentionType.Pronominal, mt2 == Dictionaries.MentionType.Pronominal)])
                {
                    CorefUtils.MergeCoreferenceClusters(pair, document);
                }
            }
        }