コード例 #1
0
        /// <exception cref="System.Exception"/>
        public static void TrainClassification(PairwiseModel model, bool anaphoricityModel)
        {
            int numTrainingExamples = model.GetNumTrainingExamples();

            Redwood.Log("scoref-train", "Reading compression...");
            Compressor <string> compressor = IOUtils.ReadObjectFromFile(StatisticalCorefTrainer.compressorFile);

            Redwood.Log("scoref-train", "Reading train data...");
            IList <DocumentExamples> trainDocuments = IOUtils.ReadObjectFromFile(StatisticalCorefTrainer.extractedFeaturesFile);

            Redwood.Log("scoref-train", "Building train set...");
            IList <Pair <Example, IDictionary <int, CompressedFeatureVector> > > allExamples = anaphoricityModel ? GetAnaphoricityExamples(trainDocuments) : GetExamples(trainDocuments);

            Redwood.Log("scoref-train", "Training...");
            Random random       = new Random(0);
            int    i            = 0;
            bool   stopTraining = false;

            while (!stopTraining)
            {
                Java.Util.Collections.Shuffle(allExamples, random);
                foreach (Pair <Example, IDictionary <int, CompressedFeatureVector> > pair in allExamples)
                {
                    if (i++ > numTrainingExamples)
                    {
                        stopTraining = true;
                        break;
                    }
                    if (i % 10000 == 0)
                    {
                        Redwood.Log("scoref-train", string.Format("On train example %d/%d = %.2f%%", i, numTrainingExamples, 100.0 * i / numTrainingExamples));
                    }
                    model.Learn(pair.first, pair.second, compressor);
                }
            }
            Redwood.Log("scoref-train", "Writing models...");
            model.WriteModel();
        }
コード例 #2
0
        public virtual ICounter <string> GetFeatures(Example example, IDictionary <int, CompressedFeatureVector> mentionFeatures, Compressor <string> compressor)
        {
            ICounter <string> features     = new ClassicCounter <string>();
            ICounter <string> pairFeatures = new ClassicCounter <string>();
            ICounter <string> features1    = new ClassicCounter <string>();
            ICounter <string> features2    = compressor.Uncompress(mentionFeatures[example.mentionId2]);

            if (!example.IsNewLink())
            {
                System.Diagnostics.Debug.Assert((!anaphoricityClassifier));
                pairFeatures = compressor.Uncompress(example.pairwiseFeatures);
                features1    = compressor.Uncompress(mentionFeatures[example.mentionId1]);
            }
            else
            {
                features2.IncrementCount("bias");
            }
            if (!disallowedPrefixes.IsEmpty())
            {
                features1    = FilterOut(features1, disallowedPrefixes);
                features2    = FilterOut(features2, disallowedPrefixes);
                pairFeatures = FilterOut(pairFeatures, disallowedPrefixes);
            }
            IList <string> ids1 = example.IsNewLink() ? new List <string>() : Identifiers(features1, example.mentionType1);
            IList <string> ids2 = Identifiers(features2, example.mentionType2);

            features.AddAll(pairFeatures);
            foreach (string id1 in ids1)
            {
                foreach (string id2 in ids2)
                {
                    if (pairConjunctions.Contains(MetaFeatureExtractor.PairConjunction.First))
                    {
                        features.AddAll(GetConjunction(pairFeatures, "_m1=" + id1));
                    }
                    if (pairConjunctions.Contains(MetaFeatureExtractor.PairConjunction.Last))
                    {
                        features.AddAll(GetConjunction(pairFeatures, "_m2=" + id2));
                    }
                    if (pairConjunctions.Contains(MetaFeatureExtractor.PairConjunction.Both))
                    {
                        features.AddAll(GetConjunction(pairFeatures, "_ms=" + id1 + "_" + id2));
                    }
                    if (singleConjunctions.Contains(MetaFeatureExtractor.SingleConjunction.Index))
                    {
                        features.AddAll(GetConjunction(features1, "_1"));
                        features.AddAll(GetConjunction(features2, "_2"));
                    }
                    if (singleConjunctions.Contains(MetaFeatureExtractor.SingleConjunction.IndexCurrent))
                    {
                        features.AddAll(GetConjunction(features1, "_1" + "_m=" + id1));
                        features.AddAll(GetConjunction(features2, "_2" + "_m=" + id2));
                    }
                    if (singleConjunctions.Contains(MetaFeatureExtractor.SingleConjunction.IndexLast))
                    {
                        features.AddAll(GetConjunction(features1, "_1" + "_m2=" + id2));
                        features.AddAll(GetConjunction(features2, "_2" + "_m2=" + id2));
                    }
                    if (singleConjunctions.Contains(MetaFeatureExtractor.SingleConjunction.IndexOther))
                    {
                        features.AddAll(GetConjunction(features1, "_1" + "_m=" + id2));
                        features.AddAll(GetConjunction(features2, "_2" + "_m=" + id1));
                    }
                    if (singleConjunctions.Contains(MetaFeatureExtractor.SingleConjunction.IndexBoth))
                    {
                        features.AddAll(GetConjunction(features1, "_1" + "_ms=" + id1 + "_" + id2));
                        features.AddAll(GetConjunction(features2, "_2" + "_ms=" + id1 + "_" + id2));
                    }
                }
            }
            if (example.IsNewLink())
            {
                features.AddAll(features2);
                features.AddAll(GetConjunction(features2, "_m=" + ids2[0]));
                ICounter <string> newFeatures = new ClassicCounter <string>();
                foreach (KeyValuePair <string, double> e in features.EntrySet())
                {
                    newFeatures.IncrementCount(e.Key + "_NEW", e.Value);
                }
                features = newFeatures;
            }
            return(features);
        }
コード例 #3
0
        public virtual DocumentExamples Extract(int id, Document document, IDictionary <Pair <int, int>, bool> labeledPairs, Compressor <string> compressor)
        {
            IList <Mention> mentionsList = CorefUtils.GetSortedMentions(document);
            IDictionary <int, IList <Mention> > mentionsByHeadIndex = new Dictionary <int, IList <Mention> >();

            foreach (Mention m in mentionsList)
            {
                IList <Mention> withIndex = mentionsByHeadIndex[m.headIndex];
                if (withIndex == null)
                {
                    withIndex = new List <Mention>();
                    mentionsByHeadIndex[m.headIndex] = withIndex;
                }
                withIndex.Add(m);
            }
            IDictionary <int, Mention> mentions          = document.predictedMentionsByID;
            IList <Example>            examples          = new List <Example>();
            ICollection <int>          mentionsToExtract = new HashSet <int>();

            foreach (KeyValuePair <Pair <int, int>, bool> pair in labeledPairs)
            {
                Mention m1 = mentions[pair.Key.first];
                Mention m2 = mentions[pair.Key.second];
                mentionsToExtract.Add(m1.mentionID);
                mentionsToExtract.Add(m2.mentionID);
                CompressedFeatureVector features = compressor.Compress(GetFeatures(document, m1, m2));
                examples.Add(new Example(id, m1, m2, pair.Value ? 1.0 : 0.0, features));
            }
            IDictionary <int, CompressedFeatureVector> mentionFeatures = new Dictionary <int, CompressedFeatureVector>();

            foreach (int mentionID in mentionsToExtract)
            {
                mentionFeatures[mentionID] = compressor.Compress(GetFeatures(document, document.predictedMentionsByID[mentionID], mentionsByHeadIndex));
            }
            return(new DocumentExamples(id, examples, mentionFeatures));
        }
コード例 #4
0
 public FeatureExtractor(Properties props, Dictionaries dictionaries, Compressor <string> compressor, string wordCountsPath)
     : this(props, dictionaries, compressor, LoadVocabulary(wordCountsPath))
 {
 }
コード例 #5
0
 public FeatureExtractor(Properties props, Dictionaries dictionaries, Compressor <string> compressor)
     : this(props, dictionaries, compressor, StatisticalCorefTrainer.wordCountsFile)
 {
 }
コード例 #6
0
        public static void WriteScores(IList <Pair <Example, IDictionary <int, CompressedFeatureVector> > > examples, Compressor <string> compressor, PairwiseModel model, PrintWriter writer, IDictionary <int, ICounter <Pair <int, int> > > scores)
        {
            int i = 0;

            foreach (Pair <Example, IDictionary <int, CompressedFeatureVector> > pair in examples)
            {
                if (i++ % 10000 == 0)
                {
                    Redwood.Log("scoref-train", string.Format("On test example %d/%d = %.2f%%", i, examples.Count, 100.0 * i / examples.Count));
                }
                Example example = pair.first;
                IDictionary <int, CompressedFeatureVector> mentionFeatures = pair.second;
                double p = model.Predict(example, mentionFeatures, compressor);
                writer.Println(example.docId + " " + example.mentionId1 + "," + example.mentionId2 + " " + p + " " + example.label);
                ICounter <Pair <int, int> > docScores = scores[example.docId];
                if (docScores == null)
                {
                    docScores             = new ClassicCounter <Pair <int, int> >();
                    scores[example.docId] = docScores;
                }
                docScores.IncrementCount(new Pair <int, int>(example.mentionId1, example.mentionId2), p);
            }
        }
コード例 #7
0
        /// <exception cref="System.Exception"/>
        public static void TrainRanking(PairwiseModel model)
        {
            Redwood.Log("scoref-train", "Reading compression...");
            Compressor <string> compressor = IOUtils.ReadObjectFromFile(StatisticalCorefTrainer.compressorFile);

            Redwood.Log("scoref-train", "Reading train data...");
            IList <DocumentExamples> trainDocuments = IOUtils.ReadObjectFromFile(StatisticalCorefTrainer.extractedFeaturesFile);

            Redwood.Log("scoref-train", "Training...");
            for (int i = 0; i < model.GetNumEpochs(); i++)
            {
                Java.Util.Collections.Shuffle(trainDocuments);
                int j = 0;
                foreach (DocumentExamples doc in trainDocuments)
                {
                    j++;
                    Redwood.Log("scoref-train", "On epoch: " + i + " / " + model.GetNumEpochs() + ", document: " + j + " / " + trainDocuments.Count);
                    IDictionary <int, IList <Example> > mentionToPotentialAntecedents = new Dictionary <int, IList <Example> >();
                    foreach (Example e in doc.examples)
                    {
                        int             mention = e.mentionId2;
                        IList <Example> potentialAntecedents = mentionToPotentialAntecedents[mention];
                        if (potentialAntecedents == null)
                        {
                            potentialAntecedents = new List <Example>();
                            mentionToPotentialAntecedents[mention] = potentialAntecedents;
                        }
                        potentialAntecedents.Add(e);
                    }
                    IList <IList <Example> > examples = new List <IList <Example> >(mentionToPotentialAntecedents.Values);
                    Java.Util.Collections.Shuffle(examples);
                    foreach (IList <Example> es in examples)
                    {
                        if (es.Count == 0)
                        {
                            continue;
                        }
                        if (model is MaxMarginMentionRanker)
                        {
                            MaxMarginMentionRanker ranker = (MaxMarginMentionRanker)model;
                            bool noAntecedent             = es.Stream().AllMatch(null);
                            es.Add(new Example(es[0], noAntecedent));
                            double  maxPositiveScore   = -double.MaxValue;
                            Example maxScoringPositive = null;
                            foreach (Example e_1 in es)
                            {
                                double score = model.Predict(e_1, doc.mentionFeatures, compressor);
                                if (e_1.label == 1)
                                {
                                    System.Diagnostics.Debug.Assert((!noAntecedent ^ e_1.IsNewLink()));
                                    if (score > maxPositiveScore)
                                    {
                                        maxPositiveScore   = score;
                                        maxScoringPositive = e_1;
                                    }
                                }
                            }
                            System.Diagnostics.Debug.Assert((maxScoringPositive != null));
                            double  maxNegativeScore   = -double.MaxValue;
                            Example maxScoringNegative = null;
                            MaxMarginMentionRanker.ErrorType maxScoringEt = null;
                            foreach (Example e_2 in es)
                            {
                                double score = model.Predict(e_2, doc.mentionFeatures, compressor);
                                if (e_2.label != 1)
                                {
                                    System.Diagnostics.Debug.Assert((!(noAntecedent && e_2.IsNewLink())));
                                    MaxMarginMentionRanker.ErrorType et = MaxMarginMentionRanker.ErrorType.Wl;
                                    if (noAntecedent && !e_2.IsNewLink())
                                    {
                                        et = MaxMarginMentionRanker.ErrorType.Fl;
                                    }
                                    else
                                    {
                                        if (!noAntecedent && e_2.IsNewLink())
                                        {
                                            if (e_2.mentionType2 == Dictionaries.MentionType.Pronominal)
                                            {
                                                et = MaxMarginMentionRanker.ErrorType.FnPron;
                                            }
                                            else
                                            {
                                                et = MaxMarginMentionRanker.ErrorType.Fn;
                                            }
                                        }
                                    }
                                    if (ranker.multiplicativeCost)
                                    {
                                        score = ranker.costs[et.id] * (1 - maxPositiveScore + score);
                                    }
                                    else
                                    {
                                        score += ranker.costs[et.id];
                                    }
                                    if (score > maxNegativeScore)
                                    {
                                        maxNegativeScore   = score;
                                        maxScoringNegative = e_2;
                                        maxScoringEt       = et;
                                    }
                                }
                            }
                            System.Diagnostics.Debug.Assert((maxScoringNegative != null));
                            ranker.Learn(maxScoringPositive, maxScoringNegative, doc.mentionFeatures, compressor, maxScoringEt);
                        }
                        else
                        {
                            double  maxPositiveScore   = -double.MaxValue;
                            double  maxNegativeScore   = -double.MaxValue;
                            Example maxScoringPositive = null;
                            Example maxScoringNegative = null;
                            foreach (Example e_1 in es)
                            {
                                double score = model.Predict(e_1, doc.mentionFeatures, compressor);
                                if (e_1.label == 1)
                                {
                                    if (score > maxPositiveScore)
                                    {
                                        maxPositiveScore   = score;
                                        maxScoringPositive = e_1;
                                    }
                                }
                                else
                                {
                                    if (score > maxNegativeScore)
                                    {
                                        maxNegativeScore   = score;
                                        maxScoringNegative = e_1;
                                    }
                                }
                            }
                            model.Learn(maxScoringPositive, maxScoringNegative, doc.mentionFeatures, compressor, 1);
                        }
                    }
                }
            }
            Redwood.Log("scoref-train", "Writing models...");
            model.WriteModel();
        }
コード例 #8
0
        public virtual double Predict(Example example, IDictionary <int, CompressedFeatureVector> mentionFeatures, Compressor <string> compressor)
        {
            ICounter <string> features = meta.GetFeatures(example, mentionFeatures, compressor);

            return(classifier.Label(features));
        }
コード例 #9
0
        public virtual void Learn(Example correct, Example incorrect, IDictionary <int, CompressedFeatureVector> mentionFeatures, Compressor <string> compressor, double weight)
        {
            ICounter <string> cFeatures = null;
            ICounter <string> iFeatures = null;

            if (correct != null)
            {
                cFeatures = meta.GetFeatures(correct, mentionFeatures, compressor);
            }
            if (incorrect != null)
            {
                iFeatures = meta.GetFeatures(incorrect, mentionFeatures, compressor);
            }
            if (correct == null || incorrect == null)
            {
                if (singletonRatio != 0)
                {
                    if (correct != null)
                    {
                        classifier.Learn(cFeatures, 1.0, weight * singletonRatio);
                    }
                    if (incorrect != null)
                    {
                        classifier.Learn(iFeatures, -1.0, weight * singletonRatio);
                    }
                }
            }
            else
            {
                classifier.Learn(cFeatures, 1.0, weight);
                classifier.Learn(iFeatures, -1.0, weight);
            }
        }
コード例 #10
0
        public virtual void Learn(Example example, IDictionary <int, CompressedFeatureVector> mentionFeatures, Compressor <string> compressor, double weight)
        {
            ICounter <string> features = meta.GetFeatures(example, mentionFeatures, compressor);

            classifier.Learn(features, example.label == 1.0 ? 1.0 : -1.0, weight);
        }
コード例 #11
0
        public virtual void Learn(Example correct, Example incorrect, IDictionary <int, CompressedFeatureVector> mentionFeatures, Compressor <string> compressor, MaxMarginMentionRanker.ErrorType errorType)
        {
            ICounter <string> cFeatures = meta.GetFeatures(correct, mentionFeatures, compressor);
            ICounter <string> iFeatures = meta.GetFeatures(incorrect, mentionFeatures, compressor);

            foreach (KeyValuePair <string, double> e in cFeatures.EntrySet())
            {
                iFeatures.DecrementCount(e.Key, e.Value);
            }
            if (multiplicativeCost)
            {
                classifier.Learn(iFeatures, 1.0, costs[errorType.id], loss);
            }
            else
            {
                classifier.Learn(iFeatures, 1.0, 1.0, losses[errorType.id]);
            }
        }