Exemplo n.º 1
0
        public virtual double Predict(Example example, IDictionary <int, CompressedFeatureVector> mentionFeatures, Compressor <string> compressor)
        {
            ICounter <string> features = meta.GetFeatures(example, mentionFeatures, compressor);

            return(classifier.Label(features));
        }
        /// <exception cref="System.Exception"/>
        public static void TrainRanking(PairwiseModel model)
        {
            Redwood.Log("scoref-train", "Reading compression...");
            Compressor <string> compressor = IOUtils.ReadObjectFromFile(StatisticalCorefTrainer.compressorFile);

            Redwood.Log("scoref-train", "Reading train data...");
            IList <DocumentExamples> trainDocuments = IOUtils.ReadObjectFromFile(StatisticalCorefTrainer.extractedFeaturesFile);

            Redwood.Log("scoref-train", "Training...");
            for (int i = 0; i < model.GetNumEpochs(); i++)
            {
                Java.Util.Collections.Shuffle(trainDocuments);
                int j = 0;
                foreach (DocumentExamples doc in trainDocuments)
                {
                    j++;
                    Redwood.Log("scoref-train", "On epoch: " + i + " / " + model.GetNumEpochs() + ", document: " + j + " / " + trainDocuments.Count);
                    IDictionary <int, IList <Example> > mentionToPotentialAntecedents = new Dictionary <int, IList <Example> >();
                    foreach (Example e in doc.examples)
                    {
                        int             mention = e.mentionId2;
                        IList <Example> potentialAntecedents = mentionToPotentialAntecedents[mention];
                        if (potentialAntecedents == null)
                        {
                            potentialAntecedents = new List <Example>();
                            mentionToPotentialAntecedents[mention] = potentialAntecedents;
                        }
                        potentialAntecedents.Add(e);
                    }
                    IList <IList <Example> > examples = new List <IList <Example> >(mentionToPotentialAntecedents.Values);
                    Java.Util.Collections.Shuffle(examples);
                    foreach (IList <Example> es in examples)
                    {
                        if (es.Count == 0)
                        {
                            continue;
                        }
                        if (model is MaxMarginMentionRanker)
                        {
                            MaxMarginMentionRanker ranker = (MaxMarginMentionRanker)model;
                            bool noAntecedent             = es.Stream().AllMatch(null);
                            es.Add(new Example(es[0], noAntecedent));
                            double  maxPositiveScore   = -double.MaxValue;
                            Example maxScoringPositive = null;
                            foreach (Example e_1 in es)
                            {
                                double score = model.Predict(e_1, doc.mentionFeatures, compressor);
                                if (e_1.label == 1)
                                {
                                    System.Diagnostics.Debug.Assert((!noAntecedent ^ e_1.IsNewLink()));
                                    if (score > maxPositiveScore)
                                    {
                                        maxPositiveScore   = score;
                                        maxScoringPositive = e_1;
                                    }
                                }
                            }
                            System.Diagnostics.Debug.Assert((maxScoringPositive != null));
                            double  maxNegativeScore   = -double.MaxValue;
                            Example maxScoringNegative = null;
                            MaxMarginMentionRanker.ErrorType maxScoringEt = null;
                            foreach (Example e_2 in es)
                            {
                                double score = model.Predict(e_2, doc.mentionFeatures, compressor);
                                if (e_2.label != 1)
                                {
                                    System.Diagnostics.Debug.Assert((!(noAntecedent && e_2.IsNewLink())));
                                    MaxMarginMentionRanker.ErrorType et = MaxMarginMentionRanker.ErrorType.Wl;
                                    if (noAntecedent && !e_2.IsNewLink())
                                    {
                                        et = MaxMarginMentionRanker.ErrorType.Fl;
                                    }
                                    else
                                    {
                                        if (!noAntecedent && e_2.IsNewLink())
                                        {
                                            if (e_2.mentionType2 == Dictionaries.MentionType.Pronominal)
                                            {
                                                et = MaxMarginMentionRanker.ErrorType.FnPron;
                                            }
                                            else
                                            {
                                                et = MaxMarginMentionRanker.ErrorType.Fn;
                                            }
                                        }
                                    }
                                    if (ranker.multiplicativeCost)
                                    {
                                        score = ranker.costs[et.id] * (1 - maxPositiveScore + score);
                                    }
                                    else
                                    {
                                        score += ranker.costs[et.id];
                                    }
                                    if (score > maxNegativeScore)
                                    {
                                        maxNegativeScore   = score;
                                        maxScoringNegative = e_2;
                                        maxScoringEt       = et;
                                    }
                                }
                            }
                            System.Diagnostics.Debug.Assert((maxScoringNegative != null));
                            ranker.Learn(maxScoringPositive, maxScoringNegative, doc.mentionFeatures, compressor, maxScoringEt);
                        }
                        else
                        {
                            double  maxPositiveScore   = -double.MaxValue;
                            double  maxNegativeScore   = -double.MaxValue;
                            Example maxScoringPositive = null;
                            Example maxScoringNegative = null;
                            foreach (Example e_1 in es)
                            {
                                double score = model.Predict(e_1, doc.mentionFeatures, compressor);
                                if (e_1.label == 1)
                                {
                                    if (score > maxPositiveScore)
                                    {
                                        maxPositiveScore   = score;
                                        maxScoringPositive = e_1;
                                    }
                                }
                                else
                                {
                                    if (score > maxNegativeScore)
                                    {
                                        maxNegativeScore   = score;
                                        maxScoringNegative = e_1;
                                    }
                                }
                            }
                            model.Learn(maxScoringPositive, maxScoringNegative, doc.mentionFeatures, compressor, 1);
                        }
                    }
                }
            }
            Redwood.Log("scoref-train", "Writing models...");
            model.WriteModel();
        }
Exemplo n.º 3
0
        public virtual void Learn(Example example, IDictionary <int, CompressedFeatureVector> mentionFeatures, Compressor <string> compressor, double weight)
        {
            ICounter <string> features = meta.GetFeatures(example, mentionFeatures, compressor);

            classifier.Learn(features, example.label == 1.0 ? 1.0 : -1.0, weight);
        }
        public virtual ICounter <string> GetFeatures(Example example, IDictionary <int, CompressedFeatureVector> mentionFeatures, Compressor <string> compressor)
        {
            ICounter <string> features     = new ClassicCounter <string>();
            ICounter <string> pairFeatures = new ClassicCounter <string>();
            ICounter <string> features1    = new ClassicCounter <string>();
            ICounter <string> features2    = compressor.Uncompress(mentionFeatures[example.mentionId2]);

            if (!example.IsNewLink())
            {
                System.Diagnostics.Debug.Assert((!anaphoricityClassifier));
                pairFeatures = compressor.Uncompress(example.pairwiseFeatures);
                features1    = compressor.Uncompress(mentionFeatures[example.mentionId1]);
            }
            else
            {
                features2.IncrementCount("bias");
            }
            if (!disallowedPrefixes.IsEmpty())
            {
                features1    = FilterOut(features1, disallowedPrefixes);
                features2    = FilterOut(features2, disallowedPrefixes);
                pairFeatures = FilterOut(pairFeatures, disallowedPrefixes);
            }
            IList <string> ids1 = example.IsNewLink() ? new List <string>() : Identifiers(features1, example.mentionType1);
            IList <string> ids2 = Identifiers(features2, example.mentionType2);

            features.AddAll(pairFeatures);
            foreach (string id1 in ids1)
            {
                foreach (string id2 in ids2)
                {
                    if (pairConjunctions.Contains(MetaFeatureExtractor.PairConjunction.First))
                    {
                        features.AddAll(GetConjunction(pairFeatures, "_m1=" + id1));
                    }
                    if (pairConjunctions.Contains(MetaFeatureExtractor.PairConjunction.Last))
                    {
                        features.AddAll(GetConjunction(pairFeatures, "_m2=" + id2));
                    }
                    if (pairConjunctions.Contains(MetaFeatureExtractor.PairConjunction.Both))
                    {
                        features.AddAll(GetConjunction(pairFeatures, "_ms=" + id1 + "_" + id2));
                    }
                    if (singleConjunctions.Contains(MetaFeatureExtractor.SingleConjunction.Index))
                    {
                        features.AddAll(GetConjunction(features1, "_1"));
                        features.AddAll(GetConjunction(features2, "_2"));
                    }
                    if (singleConjunctions.Contains(MetaFeatureExtractor.SingleConjunction.IndexCurrent))
                    {
                        features.AddAll(GetConjunction(features1, "_1" + "_m=" + id1));
                        features.AddAll(GetConjunction(features2, "_2" + "_m=" + id2));
                    }
                    if (singleConjunctions.Contains(MetaFeatureExtractor.SingleConjunction.IndexLast))
                    {
                        features.AddAll(GetConjunction(features1, "_1" + "_m2=" + id2));
                        features.AddAll(GetConjunction(features2, "_2" + "_m2=" + id2));
                    }
                    if (singleConjunctions.Contains(MetaFeatureExtractor.SingleConjunction.IndexOther))
                    {
                        features.AddAll(GetConjunction(features1, "_1" + "_m=" + id2));
                        features.AddAll(GetConjunction(features2, "_2" + "_m=" + id1));
                    }
                    if (singleConjunctions.Contains(MetaFeatureExtractor.SingleConjunction.IndexBoth))
                    {
                        features.AddAll(GetConjunction(features1, "_1" + "_ms=" + id1 + "_" + id2));
                        features.AddAll(GetConjunction(features2, "_2" + "_ms=" + id1 + "_" + id2));
                    }
                }
            }
            if (example.IsNewLink())
            {
                features.AddAll(features2);
                features.AddAll(GetConjunction(features2, "_m=" + ids2[0]));
                ICounter <string> newFeatures = new ClassicCounter <string>();
                foreach (KeyValuePair <string, double> e in features.EntrySet())
                {
                    newFeatures.IncrementCount(e.Key + "_NEW", e.Value);
                }
                features = newFeatures;
            }
            return(features);
        }