예제 #1
0
        protected internal virtual GeneralDataset <string, string> CreateDataset(Annotation corpus)
        {
            GeneralDataset <string, string> dataset = new RVFDataset <string, string>();

            foreach (ICoreMap sentence in corpus.Get(typeof(CoreAnnotations.SentencesAnnotation)))
            {
                foreach (RelationMention rel in AnnotationUtils.GetAllRelations(relationMentionFactory, sentence, createUnrelatedRelations))
                {
                    dataset.Add(CreateDatum(rel));
                }
            }
            dataset.ApplyFeatureCountThreshold(featureCountThreshold);
            return(dataset);
        }
        /// <exception cref="System.IO.IOException"/>
        /// <exception cref="System.TypeLoadException"/>
        public virtual ICounter <string> GetTopFeatures(IEnumerator <Pair <IDictionary <string, DataInstance>, File> > sentsf, double perSelectRand, double perSelectNeg, string externalFeatureWeightsFileLabel)
        {
            ICounter <string>           features = new ClassicCounter <string>();
            RVFDataset <string, string> dataset  = new RVFDataset <string, string>();
            Random r       = new Random(10);
            Random rneg    = new Random(10);
            int    numrand = 0;
            IList <Pair <string, int> > chosen = new List <Pair <string, int> >();

            while (sentsf.MoveNext())
            {
                Pair <IDictionary <string, DataInstance>, File> sents = sentsf.Current;
                numrand = this.Sample(sents.First(), r, rneg, perSelectNeg, perSelectRand, numrand, chosen, dataset);
            }

            /*if(batchProcessSents){
             * for(File f: sentFiles){
             * Map<String, List<CoreLabel>> sentsf = IOUtils.readObjectFromFile(f);
             * numrand = this.sample(sentsf, r, rneg, perSelectNeg, perSelectRand, numrand, chosen, dataset);
             * }
             * }else
             * numrand = this.sample(sents, r, rneg, perSelectNeg, perSelectRand, numrand, chosen, dataset);
             */
            System.Console.Out.WriteLine("num random chosen: " + numrand);
            System.Console.Out.WriteLine("Number of datums per label: " + dataset.NumDatumsPerLabel());
            LogisticClassifierFactory <string, string> logfactory = new LogisticClassifierFactory <string, string>();
            LogisticClassifier <string, string>        classifier = logfactory.TrainClassifier(dataset);
            ICounter <string> weights = classifier.WeightsAsCounter();

            if (!classifier.GetLabelForInternalPositiveClass().Equals(answerLabel))
            {
                weights = Counters.Scale(weights, -1);
            }
            if (thresholdWeight != null)
            {
                HashSet <string> removeKeys = new HashSet <string>();
                foreach (KeyValuePair <string, double> en in weights.EntrySet())
                {
                    if (Math.Abs(en.Value) <= thresholdWeight)
                    {
                        removeKeys.Add(en.Key);
                    }
                }
                Counters.RemoveKeys(weights, removeKeys);
                System.Console.Out.WriteLine("Removing " + removeKeys);
            }
            IOUtils.WriteStringToFile(Counters.ToSortedString(weights, weights.Size(), "%1$s:%2$f", "\n"), externalFeatureWeightsFileLabel, "utf8");
            // getDecisionTree(sents, chosen, weights, wekaOptions);
            return(features);
        }
        //goldList null if not training
        public static SupervisedSieveTraining.FeaturesData Featurize(SupervisedSieveTraining.SieveData sd, IList <XMLToAnnotation.GoldQuoteInfo> goldList, bool isTraining)
        {
            Annotation doc = sd.doc;

            sieve = new Sieve(doc, sd.characterMap, sd.pronounCorefMap, sd.animacyList);
            IList <ICoreMap>  quotes    = doc.Get(typeof(CoreAnnotations.QuotationsAnnotation));
            IList <ICoreMap>  sentences = doc.Get(typeof(CoreAnnotations.SentencesAnnotation));
            IList <CoreLabel> tokens    = doc.Get(typeof(CoreAnnotations.TokensAnnotation));
            IDictionary <int, IList <ICoreMap> > paragraphToQuotes = GetQuotesInParagraph(doc);
            GeneralDataset <string, string>      dataset           = new RVFDataset <string, string>();
            //necessary for 'ScoreBestMention'
            IDictionary <int, Pair <int, int> > mapQuoteToDataRange = new Dictionary <int, Pair <int, int> >();
            //maps quote to corresponding indices in the dataset
            IDictionary <int, Sieve.MentionData> mapDatumToMention = new Dictionary <int, Sieve.MentionData>();

            if (isTraining && goldList.Count != quotes.Count)
            {
                throw new Exception("Gold Quote List size doesn't match quote list size!");
            }
            for (int quoteIdx = 0; quoteIdx < quotes.Count; quoteIdx++)
            {
                int      initialSize = dataset.Size();
                ICoreMap quote       = quotes[quoteIdx];
                XMLToAnnotation.GoldQuoteInfo gold = null;
                if (isTraining)
                {
                    gold = goldList[quoteIdx];
                    if (gold.speaker == string.Empty)
                    {
                        continue;
                    }
                }
                ICoreMap        quoteFirstSentence = sentences[quote.Get(typeof(CoreAnnotations.SentenceBeginAnnotation))];
                Pair <int, int> quoteRun           = new Pair <int, int>(quote.Get(typeof(CoreAnnotations.TokenBeginAnnotation)), quote.Get(typeof(CoreAnnotations.TokenEndAnnotation)));
                //      int quoteChapter = quoteFirstSentence.get(ChapterAnnotator.ChapterAnnotation.class);
                int quoteParagraphIdx = quoteFirstSentence.Get(typeof(CoreAnnotations.ParagraphIndexAnnotation));
                //add mentions before quote up to the previous paragraph
                int rightValue = quoteRun.first - 1;
                int leftValue  = quoteRun.first - 1;
                //move left value to be the first token idx of the previous paragraph
                for (int sentIdx = quote.Get(typeof(CoreAnnotations.SentenceBeginAnnotation)); sentIdx >= 0; sentIdx--)
                {
                    ICoreMap sentence = sentences[sentIdx];
                    if (sentence.Get(typeof(CoreAnnotations.ParagraphIndexAnnotation)) == quoteParagraphIdx)
                    {
                        continue;
                    }
                    if (sentence.Get(typeof(CoreAnnotations.ParagraphIndexAnnotation)) == quoteParagraphIdx - 1)
                    {
                        //quoteParagraphIdx - 1 for this and prev
                        leftValue = sentence.Get(typeof(CoreAnnotations.TokenBeginAnnotation));
                    }
                    else
                    {
                        break;
                    }
                }
                IList <Sieve.MentionData> mentionsInPreviousParagraph = new List <Sieve.MentionData>();
                if (leftValue > -1 && rightValue > -1)
                {
                    mentionsInPreviousParagraph = EliminateDuplicates(sieve.FindClosestMentionsInSpanBackward(new Pair <int, int>(leftValue, rightValue)));
                }
                //mentions in next paragraph
                leftValue  = quoteRun.second + 1;
                rightValue = quoteRun.second + 1;
                for (int sentIdx_1 = quote.Get(typeof(CoreAnnotations.SentenceEndAnnotation)); sentIdx_1 < sentences.Count; sentIdx_1++)
                {
                    ICoreMap sentence = sentences[sentIdx_1];
                    //        if(sentence.get(CoreAnnotations.ParagraphIndexAnnotation.class) == quoteParagraphIdx) {
                    //          continue;
                    //        }
                    if (sentence.Get(typeof(CoreAnnotations.ParagraphIndexAnnotation)) == quoteParagraphIdx)
                    {
                        //quoteParagraphIdx + 1
                        rightValue = sentence.Get(typeof(CoreAnnotations.TokenEndAnnotation)) - 1;
                    }
                    else
                    {
                        break;
                    }
                }
                IList <Sieve.MentionData> mentionsInNextParagraph = new List <Sieve.MentionData>();
                if (leftValue < tokens.Count && rightValue < tokens.Count)
                {
                    mentionsInNextParagraph = sieve.FindClosestMentionsInSpanForward(new Pair <int, int>(leftValue, rightValue));
                }
                IList <Sieve.MentionData> candidateMentions = new List <Sieve.MentionData>();
                Sharpen.Collections.AddAll(candidateMentions, mentionsInPreviousParagraph);
                Sharpen.Collections.AddAll(candidateMentions, mentionsInNextParagraph);
                //      System.out.println(candidateMentions.size());
                int rankedDistance = 1;
                int numBackwards   = mentionsInPreviousParagraph.Count;
                foreach (Sieve.MentionData mention in candidateMentions)
                {
                    IList <CoreLabel> mentionCandidateTokens   = doc.Get(typeof(CoreAnnotations.TokensAnnotation)).SubList(mention.begin, mention.end + 1);
                    ICoreMap          mentionCandidateSentence = sentences[mentionCandidateTokens[0].SentIndex()];
                    //        if (mentionCandidateSentence.get(ChapterAnnotator.ChapterAnnotation.class) != quoteChapter) {
                    //          continue;
                    //        }
                    ICounter <string> features = new ClassicCounter <string>();
                    bool isLeft   = true;
                    int  distance = quoteRun.first - mention.end;
                    if (distance < 0)
                    {
                        isLeft   = false;
                        distance = mention.begin - quoteRun.second;
                    }
                    if (distance < 0)
                    {
                        continue;
                    }
                    //disregard mention-in-quote cases.
                    features.SetCount("wordDistance", distance);
                    IList <CoreLabel> betweenTokens;
                    if (isLeft)
                    {
                        betweenTokens = tokens.SubList(mention.end + 1, quoteRun.first);
                    }
                    else
                    {
                        betweenTokens = tokens.SubList(quoteRun.second + 1, mention.begin);
                    }
                    //Punctuation in between
                    foreach (CoreLabel token in betweenTokens)
                    {
                        if (punctuation.Contains(token.Word()))
                        {
                            features.SetCount("punctuationPresence:" + token.Word(), 1);
                        }
                    }
                    // number of mentions away
                    features.SetCount("rankedDistance", rankedDistance);
                    rankedDistance++;
                    if (rankedDistance == numBackwards)
                    {
                        //reset for the forward
                        rankedDistance = 1;
                    }
                    //        int quoteParagraphIdx = quoteFirstSentence.get(CoreAnnotations.ParagraphIndexAnnotation.class);
                    //third distance: # of paragraphs away
                    int      mentionParagraphIdx        = -1;
                    ICoreMap sentenceInMentionParagraph = null;
                    int      quoteParagraphBeginToken   = GetParagraphBeginToken(quoteFirstSentence, sentences);
                    int      quoteParagraphEndToken     = GetParagraphEndToken(quoteFirstSentence, sentences);
                    if (isLeft)
                    {
                        if (quoteParagraphBeginToken <= mention.begin && mention.end <= quoteParagraphEndToken)
                        {
                            features.SetCount("leftParagraphDistance", 0);
                            mentionParagraphIdx        = quoteParagraphIdx;
                            sentenceInMentionParagraph = quoteFirstSentence;
                        }
                        else
                        {
                            int      paragraphDistance = 1;
                            int      currParagraphIdx  = quoteParagraphIdx - paragraphDistance;
                            ICoreMap currSentence      = quoteFirstSentence;
                            int      currSentenceIdx   = currSentence.Get(typeof(CoreAnnotations.SentenceIndexAnnotation));
                            while (currParagraphIdx >= 0)
                            {
                                //              Paragraph prevParagraph = paragraphs.get(prevParagraphIndex);
                                //extract begin and end tokens of
                                while (currSentence.Get(typeof(CoreAnnotations.ParagraphIndexAnnotation)) != currParagraphIdx)
                                {
                                    currSentenceIdx--;
                                    currSentence = sentences[currSentenceIdx];
                                }
                                int prevParagraphBegin = GetParagraphBeginToken(currSentence, sentences);
                                int prevParagraphEnd   = GetParagraphEndToken(currSentence, sentences);
                                if (prevParagraphBegin <= mention.begin && mention.end <= prevParagraphEnd)
                                {
                                    mentionParagraphIdx        = currParagraphIdx;
                                    sentenceInMentionParagraph = currSentence;
                                    features.SetCount("leftParagraphDistance", paragraphDistance);
                                    if (paragraphDistance % 2 == 0)
                                    {
                                        features.SetCount("leftParagraphDistanceEven", 1);
                                    }
                                    break;
                                }
                                paragraphDistance++;
                                currParagraphIdx--;
                            }
                        }
                    }
                    else
                    {
                        //right
                        if (quoteParagraphBeginToken <= mention.begin && mention.end <= quoteParagraphEndToken)
                        {
                            features.SetCount("rightParagraphDistance", 0);
                            sentenceInMentionParagraph = quoteFirstSentence;
                            mentionParagraphIdx        = quoteParagraphIdx;
                        }
                        else
                        {
                            int      paragraphDistance  = 1;
                            int      nextParagraphIndex = quoteParagraphIdx + paragraphDistance;
                            ICoreMap currSentence       = quoteFirstSentence;
                            int      currSentenceIdx    = currSentence.Get(typeof(CoreAnnotations.SentenceIndexAnnotation));
                            while (currSentenceIdx < sentences.Count)
                            {
                                while (currSentence.Get(typeof(CoreAnnotations.ParagraphIndexAnnotation)) != nextParagraphIndex)
                                {
                                    currSentenceIdx++;
                                    currSentence = sentences[currSentenceIdx];
                                }
                                int nextParagraphBegin = GetParagraphBeginToken(currSentence, sentences);
                                int nextParagraphEnd   = GetParagraphEndToken(currSentence, sentences);
                                if (nextParagraphBegin <= mention.begin && mention.end <= nextParagraphEnd)
                                {
                                    sentenceInMentionParagraph = currSentence;
                                    features.SetCount("rightParagraphDistance", paragraphDistance);
                                    break;
                                }
                                paragraphDistance++;
                                nextParagraphIndex++;
                            }
                        }
                    }
                    //2. mention features
                    if (sentenceInMentionParagraph != null)
                    {
                        int mentionParagraphBegin = GetParagraphBeginToken(sentenceInMentionParagraph, sentences);
                        int mentionParagraphEnd   = GetParagraphEndToken(sentenceInMentionParagraph, sentences);
                        if (!(mentionParagraphBegin == quoteParagraphBeginToken && mentionParagraphEnd == quoteParagraphEndToken))
                        {
                            IList <ICoreMap> quotesInMentionParagraph = paragraphToQuotes.GetOrDefault(mentionParagraphIdx, new List <ICoreMap>());
                            Pair <List <string>, List <Pair <int, int> > > namesInMentionParagraph = sieve.ScanForNames(new Pair <int, int>(mentionParagraphBegin, mentionParagraphEnd));
                            features.SetCount("quotesInMentionParagraph", quotesInMentionParagraph.Count);
                            features.SetCount("wordsInMentionParagraph", mentionParagraphEnd - mentionParagraphBegin + 1);
                            features.SetCount("namesInMentionParagraph", namesInMentionParagraph.first.Count);
                            //mention ordering in paragraph it is in
                            for (int i = 0; i < namesInMentionParagraph.second.Count; i++)
                            {
                                if (ExtractQuotesUtil.RangeContains(new Pair <int, int>(mention.begin, mention.end), namesInMentionParagraph.second[i]))
                                {
                                    features.SetCount("orderInParagraph", i);
                                }
                            }
                            //if mention paragraph is all one quote
                            if (quotesInMentionParagraph.Count == 1)
                            {
                                ICoreMap qInMentionParagraph = quotesInMentionParagraph[0];
                                if (qInMentionParagraph.Get(typeof(CoreAnnotations.TokenBeginAnnotation)) == mentionParagraphBegin && qInMentionParagraph.Get(typeof(CoreAnnotations.TokenEndAnnotation)) - 1 == mentionParagraphEnd)
                                {
                                    features.SetCount("mentionParagraphIsInConversation", 1);
                                }
                                else
                                {
                                    features.SetCount("mentionParagraphIsInConversation", -1);
                                }
                            }
                            foreach (ICoreMap quoteIMP in quotesInMentionParagraph)
                            {
                                if (ExtractQuotesUtil.RangeContains(new Pair <int, int>(quoteIMP.Get(typeof(CoreAnnotations.TokenBeginAnnotation)), quoteIMP.Get(typeof(CoreAnnotations.TokenEndAnnotation)) - 1), new Pair <int, int>(mention.begin, mention.end)))
                                {
                                    features.SetCount("mentionInQuote", 1);
                                }
                            }
                            if (features.GetCount("mentionInQuote") != 1)
                            {
                                features.SetCount("mentionNotInQuote", 1);
                            }
                        }
                    }
                    // nearby word syntax types...make sure to check if there are previous or next words
                    // or there will be an array index crash
                    if (mention.begin > 0)
                    {
                        CoreLabel prevWord = tokens[mention.begin - 1];
                        features.SetCount("prevWordType:" + prevWord.Tag(), 1);
                        if (punctuationForFeatures.Contains(prevWord.Lemma()))
                        {
                            features.SetCount("prevWordPunct:" + prevWord.Lemma(), 1);
                        }
                    }
                    if (mention.end + 1 < tokens.Count)
                    {
                        CoreLabel nextWord = tokens[mention.end + 1];
                        features.SetCount("nextWordType:" + nextWord.Tag(), 1);
                        if (punctuationForFeatures.Contains(nextWord.Lemma()))
                        {
                            features.SetCount("nextWordPunct:" + nextWord.Lemma(), 1);
                        }
                    }
                    //                    features.setCount("prevAndNext:" + prevWord.tag()+ ";" + nextWord.tag(), 1);
                    //quote paragraph features
                    IList <ICoreMap> quotesInQuoteParagraph = paragraphToQuotes[quoteParagraphIdx];
                    features.SetCount("QuotesInQuoteParagraph", quotesInQuoteParagraph.Count);
                    features.SetCount("WordsInQuoteParagraph", quoteParagraphEndToken - quoteParagraphBeginToken + 1);
                    features.SetCount("NamesInQuoteParagraph", sieve.ScanForNames(new Pair <int, int>(quoteParagraphBeginToken, quoteParagraphEndToken)).first.Count);
                    //quote features
                    features.SetCount("quoteLength", quote.Get(typeof(CoreAnnotations.TokenEndAnnotation)) - quote.Get(typeof(CoreAnnotations.TokenBeginAnnotation)) + 1);
                    for (int i_1 = 0; i_1 < quotesInQuoteParagraph.Count; i_1++)
                    {
                        if (quotesInQuoteParagraph[i_1].Equals(quote))
                        {
                            features.SetCount("quotePosition", i_1 + 1);
                        }
                    }
                    if (features.GetCount("quotePosition") == 0)
                    {
                        throw new Exception("Check this (equality not working)");
                    }
                    Pair <List <string>, List <Pair <int, int> > > namesData = sieve.ScanForNames(quoteRun);
                    foreach (string name in namesData.first)
                    {
                        features.SetCount("charactersInQuote:" + sd.characterMap[name][0].name, 1);
                    }
                    //if quote encompasses entire paragraph
                    if (quote.Get(typeof(CoreAnnotations.TokenBeginAnnotation)) == quoteParagraphBeginToken && quote.Get(typeof(CoreAnnotations.TokenEndAnnotation)) == quoteParagraphEndToken)
                    {
                        features.SetCount("isImplicitSpeaker", 1);
                    }
                    else
                    {
                        features.SetCount("isImplicitSpeaker", -1);
                    }
                    //Vocative detection
                    if (mention.type.Equals("name"))
                    {
                        IList <Person> pList = sd.characterMap[sieve.TokenRangeToString(new Pair <int, int>(mention.begin, mention.end))];
                        Person         p     = null;
                        if (pList != null)
                        {
                            p = pList[0];
                        }
                        else
                        {
                            Pair <List <string>, List <Pair <int, int> > > scanForNamesResultPair = sieve.ScanForNames(new Pair <int, int>(mention.begin, mention.end));
                            if (scanForNamesResultPair.first.Count != 0)
                            {
                                string scanForNamesResultString = scanForNamesResultPair.first[0];
                                if (scanForNamesResultString != null && sd.characterMap.Contains(scanForNamesResultString))
                                {
                                    p = sd.characterMap[scanForNamesResultString][0];
                                }
                            }
                        }
                        if (p != null)
                        {
                            foreach (string name_1 in namesData.first)
                            {
                                if (p.aliases.Contains(name_1))
                                {
                                    features.SetCount("nameInQuote", 1);
                                }
                            }
                            if (quoteParagraphIdx > 0)
                            {
                                //            Paragraph prevParagraph = paragraphs.get(ex.paragraph_idx - 1);
                                IList <ICoreMap>         quotesInPrevParagraph = paragraphToQuotes.GetOrDefault(quoteParagraphIdx - 1, new List <ICoreMap>());
                                IList <Pair <int, int> > exclusionList         = new List <Pair <int, int> >();
                                foreach (ICoreMap quoteIPP in quotesInPrevParagraph)
                                {
                                    Pair <int, int> quoteRange = new Pair <int, int>(quoteIPP.Get(typeof(CoreAnnotations.TokenBeginAnnotation)), quoteIPP.Get(typeof(CoreAnnotations.TokenEndAnnotation)));
                                    exclusionList.Add(quoteRange);
                                    foreach (string name_2 in sieve.ScanForNames(quoteRange).first)
                                    {
                                        if (p.aliases.Contains(name_2))
                                        {
                                            features.SetCount("nameInPrevParagraphQuote", 1);
                                        }
                                    }
                                }
                                int      sentenceIdx             = quoteFirstSentence.Get(typeof(CoreAnnotations.SentenceIndexAnnotation));
                                ICoreMap sentenceInPrevParagraph = null;
                                for (int i = sentenceIdx - 1; i_1 >= 0; i_1--)
                                {
                                    ICoreMap currSentence = sentences[i_1];
                                    if (currSentence.Get(typeof(CoreAnnotations.ParagraphIndexAnnotation)) == quoteParagraphIdx - 1)
                                    {
                                        sentenceInPrevParagraph = currSentence;
                                        break;
                                    }
                                }
                                int prevParagraphBegin = GetParagraphBeginToken(sentenceInPrevParagraph, sentences);
                                int prevParagraphEnd   = GetParagraphEndToken(sentenceInPrevParagraph, sentences);
                                IList <Pair <int, int> > prevParagraphNonQuoteRuns = GetRangeExclusion(new Pair <int, int>(prevParagraphBegin, prevParagraphEnd), exclusionList);
                                foreach (Pair <int, int> nonQuoteRange in prevParagraphNonQuoteRuns)
                                {
                                    foreach (string name_2 in sieve.ScanForNames(nonQuoteRange).first)
                                    {
                                        if (p.aliases.Contains(name_2))
                                        {
                                            features.SetCount("nameInPrevParagraphNonQuote", 1);
                                        }
                                    }
                                }
                            }
                        }
                    }
                    if (isTraining)
                    {
                        if (QuoteAttributionUtils.RangeContains(new Pair <int, int>(gold.mentionStartTokenIndex, gold.mentionEndTokenIndex), new Pair <int, int>(mention.begin, mention.end)))
                        {
                            RVFDatum <string, string> datum = new RVFDatum <string, string>(features, "isMention");
                            datum.SetID(int.ToString(dataset.Size()));
                            mapDatumToMention[dataset.Size()] = mention;
                            dataset.Add(datum);
                        }
                        else
                        {
                            RVFDatum <string, string> datum = new RVFDatum <string, string>(features, "isNotMention");
                            datum.SetID(int.ToString(dataset.Size()));
                            dataset.Add(datum);
                            mapDatumToMention[dataset.Size()] = mention;
                        }
                    }
                    else
                    {
                        RVFDatum <string, string> datum = new RVFDatum <string, string>(features, "none");
                        datum.SetID(int.ToString(dataset.Size()));
                        mapDatumToMention[dataset.Size()] = mention;
                        dataset.Add(datum);
                    }
                }
                mapQuoteToDataRange[quoteIdx] = new Pair <int, int>(initialSize, dataset.Size() - 1);
            }
            return(new SupervisedSieveTraining.FeaturesData(mapQuoteToDataRange, mapDatumToMention, dataset));
        }
 // public void getDecisionTree(Map<String, List<CoreLabel>> sents,
 // List<Pair<String, Integer>> chosen, Counter<String> weights, String
 // wekaOptions) {
 // RVFDataset<String, String> dataset = new RVFDataset<String, String>();
 // for (Pair<String, Integer> d : chosen) {
 // CoreLabel l = sents.get(d.first).get(d.second());
 // String w = l.word();
 // Integer num = this.clusterIds.get(w);
 // if (num == null)
 // num = -1;
 // double wt = weights.getCount("Cluster-" + num);
 // String label;
 // if (l.get(answerClass).toString().equals(answerLabel))
 // label = answerLabel;
 // else
 // label = "O";
 // Counter<String> feat = new ClassicCounter<String>();
 // feat.setCount("DIST", wt);
 // dataset.add(new RVFDatum<String, String>(feat, label));
 // }
 // WekaDatumClassifierFactory wekaFactory = new
 // WekaDatumClassifierFactory("weka.classifiers.trees.J48", wekaOptions);
 // WekaDatumClassifier classifier = wekaFactory.trainClassifier(dataset);
 // Classifier cls = classifier.getClassifier();
 // J48 j48decisiontree = (J48) cls;
 // System.out.println(j48decisiontree.toSummaryString());
 // System.out.println(j48decisiontree.toString());
 //
 // }
 private int Sample(IDictionary <string, DataInstance> sents, Random r, Random rneg, double perSelectNeg, double perSelectRand, int numrand, IList <Pair <string, int> > chosen, RVFDataset <string, string> dataset)
 {
     foreach (KeyValuePair <string, DataInstance> en in sents)
     {
         CoreLabel[] sent = Sharpen.Collections.ToArray(en.Value.GetTokens(), new CoreLabel[0]);
         for (int i = 0; i < sent.Length; i++)
         {
             CoreLabel l          = sent[i];
             bool      chooseThis = false;
             if (l.Get(answerClass).Equals(answerLabel))
             {
                 chooseThis = true;
             }
             else
             {
                 if ((!l.Get(answerClass).Equals("O") || negativeWords.Contains(l.Word().ToLower())) && GetRandomBoolean(r, perSelectNeg))
                 {
                     chooseThis = true;
                 }
                 else
                 {
                     if (GetRandomBoolean(r, perSelectRand))
                     {
                         numrand++;
                         chooseThis = true;
                     }
                     else
                     {
                         chooseThis = false;
                     }
                 }
             }
             if (chooseThis)
             {
                 chosen.Add(new Pair(en.Key, i));
                 RVFDatum <string, string> d = GetDatum(sent, i);
                 dataset.Add(d, en.Key, int.ToString(i));
             }
         }
     }
     return(numrand);
 }
예제 #5
0
        /// <summary>Train a sentiment model from a set of data.</summary>
        /// <param name="data">The data to train the model from.</param>
        /// <param name="modelLocation">
        /// An optional location to save the model.
        /// Note that this stream will be closed in this method,
        /// and should not be written to thereafter.
        /// </param>
        /// <returns>A sentiment classifier, ready to use.</returns>
        public static SimpleSentiment Train(IStream <SimpleSentiment.SentimentDatum> data, Optional <OutputStream> modelLocation)
        {
            // Some useful variables configuring how we train
            bool   useL1 = true;
            double sigma = 1.0;
            int    featureCountThreshold = 5;

            // Featurize the data
            Redwood.Util.ForceTrack("Featurizing");
            RVFDataset <SentimentClass, string> dataset = new RVFDataset <SentimentClass, string>();
            AtomicInteger             datasize          = new AtomicInteger(0);
            ICounter <SentimentClass> distribution      = new ClassicCounter <SentimentClass>();

            data.Unordered().Parallel().Map(null).ForEach(null);
            Redwood.Util.EndTrack("Featurizing");
            // Print label distribution
            Redwood.Util.StartTrack("Distribution");
            foreach (SentimentClass label in SentimentClass.Values())
            {
                Redwood.Util.Log(string.Format("%7d", (int)distribution.GetCount(label)) + "   " + label);
            }
            Redwood.Util.EndTrack("Distribution");
            // Train the classifier
            Redwood.Util.ForceTrack("Training");
            if (featureCountThreshold > 1)
            {
                dataset.ApplyFeatureCountThreshold(featureCountThreshold);
            }
            dataset.Randomize(42L);
            LinearClassifierFactory <SentimentClass, string> factory = new LinearClassifierFactory <SentimentClass, string>();

            factory.SetVerbose(true);
            try
            {
                factory.SetMinimizerCreator(null);
            }
            catch (Exception)
            {
            }
            factory.SetSigma(sigma);
            LinearClassifier <SentimentClass, string> classifier = factory.TrainClassifier(dataset);

            // Optionally save the model
            modelLocation.IfPresent(null);
            Redwood.Util.EndTrack("Training");
            // Evaluate the model
            Redwood.Util.ForceTrack("Evaluating");
            factory.SetVerbose(false);
            double sumAccuracy             = 0.0;
            ICounter <SentimentClass> sumP = new ClassicCounter <SentimentClass>();
            ICounter <SentimentClass> sumR = new ClassicCounter <SentimentClass>();
            int numFolds = 4;

            for (int fold = 0; fold < numFolds; ++fold)
            {
                Pair <GeneralDataset <SentimentClass, string>, GeneralDataset <SentimentClass, string> > trainTest = dataset.SplitOutFold(fold, numFolds);
                LinearClassifier <SentimentClass, string> foldClassifier = factory.TrainClassifierWithInitialWeights(trainTest.first, classifier);
                // convex objective, so this should be OK
                sumAccuracy += foldClassifier.EvaluateAccuracy(trainTest.second);
                foreach (SentimentClass label_1 in SentimentClass.Values())
                {
                    Pair <double, double> pr = foldClassifier.EvaluatePrecisionAndRecall(trainTest.second, label_1);
                    sumP.IncrementCount(label_1, pr.first);
                    sumP.IncrementCount(label_1, pr.second);
                }
            }
            DecimalFormat df = new DecimalFormat("0.000%");

            log.Info("----------");
            double aveAccuracy = sumAccuracy / ((double)numFolds);

            log.Info(string.Empty + numFolds + "-fold accuracy: " + df.Format(aveAccuracy));
            log.Info(string.Empty);
            foreach (SentimentClass label_2 in SentimentClass.Values())
            {
                double p = sumP.GetCount(label_2) / numFolds;
                double r = sumR.GetCount(label_2) / numFolds;
                log.Info(label_2 + " (P)  = " + df.Format(p));
                log.Info(label_2 + " (R)  = " + df.Format(r));
                log.Info(label_2 + " (F1) = " + df.Format(2 * p * r / (p + r)));
                log.Info(string.Empty);
            }
            log.Info("----------");
            Redwood.Util.EndTrack("Evaluating");
            // Return
            return(new SimpleSentiment(classifier));
        }