public virtual RVFDatum <L, F> ScaleDatumGaussian(RVFDatum <L, F> datum) { // scale this dataset before scaling the datum if (means == null || stdevs == null) { ScaleFeaturesGaussian(); } ICounter <F> scaledFeatures = new ClassicCounter <F>(); foreach (F feature in datum.AsFeatures()) { int fID = this.featureIndex.IndexOf(feature); if (fID >= 0) { double oldVal = datum.AsFeaturesCounter().GetCount(feature); double newVal; if (stdevs[fID] != 0) { newVal = (oldVal - means[fID]) / stdevs[fID]; } else { newVal = oldVal; } scaledFeatures.IncrementCount(feature, newVal); } } return(new RVFDatum <L, F>(scaledFeatures, datum.Label())); }
public virtual ClassicCounter <L> ScoresOf(RVFDatum <L, F> example) { ClassicCounter <L> scores = new ClassicCounter <L>(); Counters.AddInPlace(scores, priors); if (addZeroValued) { Counters.AddInPlace(scores, priorZero); } foreach (L l in labels) { double score = 0.0; ICounter <F> features = example.AsFeaturesCounter(); foreach (F f in features.KeySet()) { int value = (int)features.GetCount(f); score += Weight(l, f, int.Parse(value)); if (addZeroValued) { score -= Weight(l, f, zero); } } scores.IncrementCount(l, score); } return(scores); }
/// <summary>Method to convert features from counts to L1-normalized TFIDF based features</summary> /// <param name="datum">with a collection of features.</param> /// <param name="featureDocCounts">a counter of doc-count for each feature.</param> /// <returns>RVFDatum with l1-normalized tf-idf features.</returns> public virtual RVFDatum <L, F> GetL1NormalizedTFIDFDatum(IDatum <L, F> datum, ICounter <F> featureDocCounts) { ICounter <F> tfidfFeatures = new ClassicCounter <F>(); foreach (F feature in datum.AsFeatures()) { if (featureDocCounts.ContainsKey(feature)) { tfidfFeatures.IncrementCount(feature, 1.0); } } double l1norm = 0; foreach (F feature_1 in tfidfFeatures.KeySet()) { double idf = Math.Log(((double)(this.Size() + 1)) / (featureDocCounts.GetCount(feature_1) + 0.5)); double tf = tfidfFeatures.GetCount(feature_1); tfidfFeatures.SetCount(feature_1, tf * idf); l1norm += tf * idf; } foreach (F feature_2 in tfidfFeatures.KeySet()) { double tfidf = tfidfFeatures.GetCount(feature_2); tfidfFeatures.SetCount(feature_2, tfidf / l1norm); } RVFDatum <L, F> rvfDatum = new RVFDatum <L, F>(tfidfFeatures, datum.Label()); return(rvfDatum); }
public virtual RVFDatum <L, F> GetRVFDatumWithId(int index) { RVFDatum <L, F> datum = GetRVFDatum(index); datum.SetID(GetRVFDatumId(index)); return(datum); }
/// <summary>Get the sentiment of a sentence.</summary> /// <param name="sentence"> /// The sentence as a core map. /// POS tags and Lemmas are a prerequisite. /// See /// <see cref="Edu.Stanford.Nlp.Ling.CoreAnnotations.PartOfSpeechAnnotation"/> /// and /// <see cref="Edu.Stanford.Nlp.Ling.CoreAnnotations.LemmaAnnotation"/> /// . /// </param> /// <returns>The sentiment class of this sentence.</returns> public virtual SentimentClass Classify(ICoreMap sentence) { ICounter <string> features = Featurize(sentence); RVFDatum <SentimentClass, string> datum = new RVFDatum <SentimentClass, string>(features); return(impl.ClassOf(datum)); }
/// <summary>The examples are assumed to be a list of RFVDatum.</summary> /// <remarks> /// The examples are assumed to be a list of RFVDatum. /// The datums are assumed to not contain the zeroes and then they are added to each instance. /// </remarks> public virtual NaiveBayesClassifier <L, F> TrainClassifier(GeneralDataset <L, F> examples, ICollection <F> featureSet) { int numFeatures = featureSet.Count; int[][] data = new int[][] { }; int[] labels = new int[examples.Size()]; labelIndex = new HashIndex <L>(); featureIndex = new HashIndex <F>(); foreach (F feat in featureSet) { featureIndex.Add(feat); } for (int d = 0; d < examples.Size(); d++) { RVFDatum <L, F> datum = examples.GetRVFDatum(d); ICounter <F> c = datum.AsFeaturesCounter(); foreach (F feature in c.KeySet()) { int fNo = featureIndex.IndexOf(feature); int value = (int)c.GetCount(feature); data[d][fNo] = value; } labelIndex.Add(datum.Label()); labels[d] = labelIndex.IndexOf(datum.Label()); } int numClasses = labelIndex.Size(); return(TrainClassifier(data, labels, numFeatures, numClasses, labelIndex, featureIndex)); }
public static void Main(string[] args) { Edu.Stanford.Nlp.Classify.RVFDataset <string, string> data = new Edu.Stanford.Nlp.Classify.RVFDataset <string, string>(); ClassicCounter <string> c1 = new ClassicCounter <string>(); c1.IncrementCount("fever", 3.5); c1.IncrementCount("cough", 1.1); c1.IncrementCount("congestion", 4.2); ClassicCounter <string> c2 = new ClassicCounter <string>(); c2.IncrementCount("fever", 1.5); c2.IncrementCount("cough", 2.1); c2.IncrementCount("nausea", 3.2); ClassicCounter <string> c3 = new ClassicCounter <string>(); c3.IncrementCount("cough", 2.5); c3.IncrementCount("congestion", 3.2); data.Add(new RVFDatum <string, string>(c1, "cold")); data.Add(new RVFDatum <string, string>(c2, "flu")); data.Add(new RVFDatum <string, string>(c3, "cold")); data.SummaryStatistics(); LinearClassifierFactory <string, string> factory = new LinearClassifierFactory <string, string>(); factory.UseQuasiNewton(); LinearClassifier <string, string> c = factory.TrainClassifier(data); ClassicCounter <string> c4 = new ClassicCounter <string>(); c4.IncrementCount("cough", 2.3); c4.IncrementCount("fever", 1.3); RVFDatum <string, string> datum = new RVFDatum <string, string>(c4); c.JustificationOf((IDatum <string, string>)datum); }
public virtual Edu.Stanford.Nlp.Classify.RVFDataset <L, F> ScaleDatasetGaussian(Edu.Stanford.Nlp.Classify.RVFDataset <L, F> dataset) { Edu.Stanford.Nlp.Classify.RVFDataset <L, F> newDataset = new Edu.Stanford.Nlp.Classify.RVFDataset <L, F>(this.featureIndex, this.labelIndex); for (int i = 0; i < dataset.Size(); i++) { RVFDatum <L, F> datum = ((RVFDatum <L, F>)dataset.GetDatum(i)); newDataset.Add(ScaleDatumGaussian(datum)); } return(newDataset); }
private ICounter <L> ScoresOfRVFDatum(RVFDatum <L, F> example) { ICounter <F> features = example.AsFeaturesCounter(); double sum = ScoreOf(features); ICounter <L> c = new ClassicCounter <L>(); c.SetCount(classes[0], -sum); c.SetCount(classes[1], sum); return(c); }
/// <seealso cref="Classify(Edu.Stanford.Nlp.Util.ICoreMap)"/> public virtual SentimentClass Classify(string text) { Annotation ann = new Annotation(text); pipeline.Get().Annotate(ann); ICoreMap sentence = ann.Get(typeof(CoreAnnotations.SentencesAnnotation))[0]; ICounter <string> features = Featurize(sentence); RVFDatum <SentimentClass, string> datum = new RVFDatum <SentimentClass, string>(features); return(impl.ClassOf(datum)); }
public virtual void TestBackwardsCompatibility() { RVFDataset <string, string> dataset = new WeightedRVFDataset <string, string>(); RVFDatum <string, string> datum1 = NewRVFDatum(null, "a", "b", "a"); dataset.Add(datum1); RVFDatum <string, string> datum2 = NewRVFDatum(null, "a", "b", "a"); dataset.Add(datum2); NUnit.Framework.Assert.AreEqual(1.0f, ((WeightedRVFDataset <string, string>)dataset).GetWeights()[0], 1e-10); NUnit.Framework.Assert.AreEqual(1.0f, ((WeightedRVFDataset <string, string>)dataset).GetWeights()[1], 1e-10); }
public virtual void TestWeightingWorks() { WeightedRVFDataset <string, string> dataset = new WeightedRVFDataset <string, string>(); RVFDatum <string, string> datum1 = NewRVFDatum(null, "a", "b", "a"); dataset.Add(datum1, 42.0f); RVFDatum <string, string> datum2 = NewRVFDatum(null, "a", "b", "a"); dataset.Add(datum2, 7.3f); NUnit.Framework.Assert.AreEqual(42.0f, dataset.GetWeights()[0], 1e-10); NUnit.Framework.Assert.AreEqual(7.3f, dataset.GetWeights()[1], 1e-10); }
public virtual double ProbabilityOf(Mention p, ICollection <Mention> shares, ICollection <string> neStrings, Dictionaries dict, Properties props) { try { bool dummyLabel = false; RVFDatum <bool, string> datum = new RVFDatum <bool, string>(ExtractFeatures(p, shares, neStrings, dict, props), dummyLabel); return(rf.ProbabilityOfTrue(datum)); } catch (Exception e) { throw new Exception(e); } }
/// <summary>Method to convert this dataset to RVFDataset using L1-normalized TF-IDF features</summary> /// <returns>RVFDataset</returns> public virtual RVFDataset <L, F> GetL1NormalizedTFIDFDataset() { RVFDataset <L, F> rvfDataset = new RVFDataset <L, F>(this.Size(), this.featureIndex, this.labelIndex); ICounter <F> featureDocCounts = GetFeatureCounter(); for (int i = 0; i < this.Size(); i++) { IDatum <L, F> datum = this.GetDatum(i); RVFDatum <L, F> rvfDatum = GetL1NormalizedTFIDFDatum(datum, featureDocCounts); rvfDataset.Add(rvfDatum); } return(rvfDataset); }
/// <summary> /// Returns a counter for the log probability of each of the classes /// looking at the the sum of e^v for each count v, should be 1 /// Note: Uses SloppyMath.logSum which isn't exact but isn't as /// offensively slow as doing a series of exponentials /// </summary> public override ICounter <L> LogProbabilityOf(RVFDatum <L, F> example) { if (platt == null) { throw new NotSupportedException("If you want to ask for the probability, you must train a Platt model!"); } ICounter <L> scores = ScoresOf(example); scores.IncrementCount(null); ICounter <L> probs = platt.LogProbabilityOf(new RVFDatum <L, L>(scores)); //System.out.println(scores+" "+probs); return(probs); }
/// <summary>Builds a sigmoid model to turn the classifier outputs into probabilities.</summary> private LinearClassifier <L, L> FitSigmoid(SVMLightClassifier <L, F> classifier, GeneralDataset <L, F> dataset) { RVFDataset <L, L> plattDataset = new RVFDataset <L, L>(); for (int i = 0; i < dataset.Size(); i++) { RVFDatum <L, F> d = dataset.GetRVFDatum(i); ICounter <L> scores = classifier.ScoresOf((IDatum <L, F>)d); scores.IncrementCount(null); plattDataset.Add(new RVFDatum <L, L>(scores, d.Label())); } LinearClassifierFactory <L, L> factory = new LinearClassifierFactory <L, L>(); factory.SetPrior(new LogPrior(LogPrior.LogPriorType.Null)); return(factory.TrainClassifier(plattDataset)); }
public virtual void ScoreBestMentionNew(SupervisedSieveTraining.FeaturesData fd, Annotation doc) { IList <ICoreMap> quotes = doc.Get(typeof(CoreAnnotations.QuotationsAnnotation)); for (int i = 0; i < quotes.Count; i++) { ICoreMap quote = quotes[i]; if (quote.Get(typeof(QuoteAttributionAnnotator.MentionAnnotation)) != null) { continue; } double maxConfidence = 0; int maxDataIdx = -1; int goldDataIdx = -1; Pair <int, int> dataRange = fd.mapQuoteToDataRange[i]; if (dataRange == null) { continue; } else { for (int dataIdx = dataRange.first; dataIdx <= dataRange.second; dataIdx++) { RVFDatum <string, string> datum = fd.dataset.GetRVFDatum(dataIdx); double isMentionConfidence = quoteToMentionClassifier.ScoresOf(datum).GetCount("isMention"); if (isMentionConfidence > maxConfidence) { maxConfidence = isMentionConfidence; maxDataIdx = dataIdx; } } if (maxDataIdx != -1) { Sieve.MentionData mentionData = fd.mapDatumToMention[maxDataIdx]; if (mentionData.type.Equals("animate noun")) { continue; } quote.Set(typeof(QuoteAttributionAnnotator.MentionAnnotation), mentionData.text); quote.Set(typeof(QuoteAttributionAnnotator.MentionBeginAnnotation), mentionData.begin); quote.Set(typeof(QuoteAttributionAnnotator.MentionEndAnnotation), mentionData.end); quote.Set(typeof(QuoteAttributionAnnotator.MentionTypeAnnotation), mentionData.type); quote.Set(typeof(QuoteAttributionAnnotator.MentionSieveAnnotation), "supervised"); } } } }
public virtual float Accuracy(IEnumerator <RVFDatum <L, F> > exampleIterator) { int correct = 0; int total = 0; for (; exampleIterator.MoveNext();) { RVFDatum <L, F> next = exampleIterator.Current; L guess = ClassOf(next); if (guess.Equals(next.Label())) { correct++; } total++; } logger.Info("correct " + correct + " out of " + total); return(correct / (float)total); }
/// <summary> /// Score the given input, returning both the classification decision and the /// probability of that decision. /// </summary> /// <remarks> /// Score the given input, returning both the classification decision and the /// probability of that decision. /// Note that this method will not return a relation which does not type check. /// </remarks> /// <param name="input">The input to classify.</param> /// <returns>A pair with the relation we classified into, along with its confidence.</returns> public virtual Pair<string, double> Classify(KBPRelationExtractor.KBPInput input) { RVFDatum<string, string> datum = new RVFDatum<string, string>(Features(input)); ICounter<string> scores = classifier.ScoresOf(datum); Counters.ExpInPlace(scores); Counters.Normalize(scores); string best = Counters.Argmax(scores); // While it doesn't type check, continue going down the list. // NO_RELATION is always an option somewhere in there, so safe to keep going... while (!KBPRelationExtractorConstants.NoRelation.Equals(best) && scores.Size() > 1 && (!KBPRelationExtractor.RelationType.FromString(best).Get().validNamedEntityLabels.Contains(input.objectType) || KBPRelationExtractor.RelationType.FromString (best).Get().entityType != input.subjectType)) { scores.Remove(best); Counters.Normalize(scores); best = Counters.Argmax(scores); } return Pair.MakePair(best, scores.GetCount(best)); }
// Quick little sanity check public static void Main(string[] args) { ICollection <RVFDatum <string, string> > trainingInstances = new List <RVFDatum <string, string> >(); { ClassicCounter <string> f1 = new ClassicCounter <string>(); f1.SetCount("humidity", 5.0); f1.SetCount("temperature", 35.0); trainingInstances.Add(new RVFDatum <string, string>(f1, "rain")); } { ClassicCounter <string> f1 = new ClassicCounter <string>(); f1.SetCount("humidity", 4.0); f1.SetCount("temperature", 32.0); trainingInstances.Add(new RVFDatum <string, string>(f1, "rain")); } { ClassicCounter <string> f1 = new ClassicCounter <string>(); f1.SetCount("humidity", 6.0); f1.SetCount("temperature", 30.0); trainingInstances.Add(new RVFDatum <string, string>(f1, "rain")); } { ClassicCounter <string> f1 = new ClassicCounter <string>(); f1.SetCount("humidity", 2.0); f1.SetCount("temperature", 33.0); trainingInstances.Add(new RVFDatum <string, string>(f1, "dry")); } { ClassicCounter <string> f1 = new ClassicCounter <string>(); f1.SetCount("humidity", 1.0); f1.SetCount("temperature", 34.0); trainingInstances.Add(new RVFDatum <string, string>(f1, "dry")); } Edu.Stanford.Nlp.Classify.KNNClassifier <string, string> classifier = new KNNClassifierFactory <string, string>(3, false, true).Train(trainingInstances); { ClassicCounter <string> f1 = new ClassicCounter <string>(); f1.SetCount("humidity", 2.0); f1.SetCount("temperature", 33.0); RVFDatum <string, string> testVec = new RVFDatum <string, string>(f1); System.Console.Out.WriteLine(classifier.ScoresOf(testVec)); System.Console.Out.WriteLine(classifier.ClassOf(testVec)); } }
/// <summary>Read the data as a list of RVFDatum objects.</summary> /// <remarks>Read the data as a list of RVFDatum objects. For the test set we must reuse the indices from the training set</remarks> internal static List <RVFDatum <string, int> > ReadData(string filename, IDictionary <int, IIndex <string> > indices) { try { string sep = ", "; List <RVFDatum <string, int> > examples = new List <RVFDatum <string, int> >(); foreach (string line in ObjectBank.GetLineIterator(new File(filename))) { RVFDatum <string, int> next = ReadDatum(line, sep, indices); examples.Add(next); } return(examples); } catch (Exception e) { Sharpen.Runtime.PrintStackTrace(e); } return(null); }
// public void getDecisionTree(Map<String, List<CoreLabel>> sents, // List<Pair<String, Integer>> chosen, Counter<String> weights, String // wekaOptions) { // RVFDataset<String, String> dataset = new RVFDataset<String, String>(); // for (Pair<String, Integer> d : chosen) { // CoreLabel l = sents.get(d.first).get(d.second()); // String w = l.word(); // Integer num = this.clusterIds.get(w); // if (num == null) // num = -1; // double wt = weights.getCount("Cluster-" + num); // String label; // if (l.get(answerClass).toString().equals(answerLabel)) // label = answerLabel; // else // label = "O"; // Counter<String> feat = new ClassicCounter<String>(); // feat.setCount("DIST", wt); // dataset.add(new RVFDatum<String, String>(feat, label)); // } // WekaDatumClassifierFactory wekaFactory = new // WekaDatumClassifierFactory("weka.classifiers.trees.J48", wekaOptions); // WekaDatumClassifier classifier = wekaFactory.trainClassifier(dataset); // Classifier cls = classifier.getClassifier(); // J48 j48decisiontree = (J48) cls; // System.out.println(j48decisiontree.toSummaryString()); // System.out.println(j48decisiontree.toString()); // // } private int Sample(IDictionary <string, DataInstance> sents, Random r, Random rneg, double perSelectNeg, double perSelectRand, int numrand, IList <Pair <string, int> > chosen, RVFDataset <string, string> dataset) { foreach (KeyValuePair <string, DataInstance> en in sents) { CoreLabel[] sent = Sharpen.Collections.ToArray(en.Value.GetTokens(), new CoreLabel[0]); for (int i = 0; i < sent.Length; i++) { CoreLabel l = sent[i]; bool chooseThis = false; if (l.Get(answerClass).Equals(answerLabel)) { chooseThis = true; } else { if ((!l.Get(answerClass).Equals("O") || negativeWords.Contains(l.Word().ToLower())) && GetRandomBoolean(r, perSelectNeg)) { chooseThis = true; } else { if (GetRandomBoolean(r, perSelectRand)) { numrand++; chooseThis = true; } else { chooseThis = false; } } } if (chooseThis) { chosen.Add(new Pair(en.Key, i)); RVFDatum <string, string> d = GetDatum(sent, i); dataset.Add(d, en.Key, int.ToString(i)); } } } return(numrand); }
public virtual void TestCombiningDatasets() { RVFDatum <string, string> datum1 = NewRVFDatum(null, "a", "b", "a"); RVFDatum <string, string> datum2 = NewRVFDatum(null, "c", "c", "b"); RVFDataset <string, string> data1 = new RVFDataset <string, string>(); data1.Add(datum1); RVFDataset <string, string> data2 = new RVFDataset <string, string>(); data1.Add(datum2); RVFDataset <string, string> data = new RVFDataset <string, string>(); data.AddAll(data1); data.AddAll(data2); IEnumerator <RVFDatum <string, string> > iterator = data.GetEnumerator(); NUnit.Framework.Assert.AreEqual(datum1, iterator.Current); NUnit.Framework.Assert.AreEqual(datum2, iterator.Current); NUnit.Framework.Assert.IsFalse(iterator.MoveNext()); }
// todo: Fix javadoc, have unit tested /// <summary>Print SVM Light Format file.</summary> /// <remarks> /// Print SVM Light Format file. /// The following comments are no longer applicable because I am /// now printing out the exact labelID for each example. -Ramesh ([email protected]) 12/17/2009. /// If the Dataset has more than 2 classes, then it /// prints using the label index (+1) (for svm_struct). If it is 2 classes, then the labelIndex.get(0) /// is mapped to +1 and labelIndex.get(1) is mapped to -1 (for svm_light). /// </remarks> public virtual void PrintSVMLightFormat(PrintWriter pw) { //assumes each data item has a few features on, and sorts the feature keys while collecting the values in a counter // old comment: // the following code commented out by Ramesh ([email protected]) 12/17/2009. // why not simply print the exact id of the label instead of mapping to some values?? // new comment: // mihai: we NEED this, because svm_light has special conventions not supported by default by our labels, // e.g., in a multiclass setting it assumes that labels start at 1 whereas our labels start at 0 (08/31/2010) string[] labelMap = MakeSvmLabelMap(); for (int i = 0; i < size; i++) { RVFDatum <L, F> d = GetRVFDatum(i); ICounter <F> c = d.AsFeaturesCounter(); ClassicCounter <int> printC = new ClassicCounter <int>(); foreach (F f in c.KeySet()) { printC.SetCount(featureIndex.IndexOf(f), c.GetCount(f)); } int[] features = Sharpen.Collections.ToArray(printC.KeySet(), new int[printC.KeySet().Count]); Arrays.Sort(features); StringBuilder sb = new StringBuilder(); sb.Append(labelMap[labels[i]]).Append(' '); // sb.append(labels[i]).append(' '); // commented out by mihai: labels[i] breaks svm_light conventions! /* Old code: assumes that F is Integer.... * * for (int f: features) { * sb.append((f + 1)).append(":").append(c.getCount(f)).append(" "); * } */ //I think this is what was meant (using printC rather than c), but not sure // ~Sarah Spikes ([email protected]) foreach (int f_1 in features) { sb.Append((f_1 + 1)).Append(':').Append(printC.GetCount(f_1)).Append(' '); } pw.Println(sb.ToString()); } }
internal static RVFDatum <string, int> ReadDatum(string[] values, int classColumn, ICollection <int> skip, IDictionary <int, IIndex <string> > indices) { ClassicCounter <int> c = new ClassicCounter <int>(); RVFDatum <string, int> d = new RVFDatum <string, int>(c); int attrNo = 0; for (int index = 0; index < values.Length; index++) { if (index == classColumn) { d.SetLabel(values[index]); continue; } if (skip.Contains(int.Parse(index))) { continue; } int featKey = int.Parse(attrNo); IIndex <string> ind = indices[featKey]; if (ind == null) { ind = new HashIndex <string>(); indices[featKey] = ind; } // MG: condition on isLocked is useless, since add(E) contains such a condition: //if (!ind.isLocked()) { ind.Add(values[index]); //} int valInd = ind.IndexOf(values[index]); if (valInd == -1) { valInd = 0; logger.Info("unknown attribute value " + values[index] + " of attribute " + attrNo); } c.IncrementCount(featKey, valInd); attrNo++; } return(d); }
/// <summary> /// Given a set of vectors, and a mapping from each vector to its class label, /// generates the sets of instances used to perform classifications and returns /// the corresponding K-NN classifier. /// </summary> /// <remarks> /// Given a set of vectors, and a mapping from each vector to its class label, /// generates the sets of instances used to perform classifications and returns /// the corresponding K-NN classifier. /// NOTE: if l2NormalizeVectors is T, creates a copy and applies L2Normalize to it. /// </remarks> public virtual KNNClassifier <K, V> Train(ICollection <ICounter <V> > vectors, IDictionary <V, K> labelMap) { KNNClassifier <K, V> classifier = new KNNClassifier <K, V>(k, weightedVotes, l2NormalizeVectors); ICollection <RVFDatum <K, V> > instances = new List <RVFDatum <K, V> >(); foreach (ICounter <V> vector in vectors) { K label = labelMap[vector]; RVFDatum <K, V> datum; if (l2NormalizeVectors) { datum = new RVFDatum <K, V>(Counters.L2Normalize(new ClassicCounter <V>(vector)), label); } else { datum = new RVFDatum <K, V>(vector, label); } instances.Add(datum); } classifier.AddInstances(instances); return(classifier); }
/// <summary> /// Given a CollectionValued Map of vectors, treats outer key as label for each /// set of inner vectors. /// </summary> /// <remarks> /// Given a CollectionValued Map of vectors, treats outer key as label for each /// set of inner vectors. /// NOTE: if l2NormalizeVectors is T, creates a copy of each vector and applies /// l2Normalize to it. /// </remarks> public virtual KNNClassifier <K, V> Train(CollectionValuedMap <K, ICounter <V> > vecBag) { KNNClassifier <K, V> classifier = new KNNClassifier <K, V>(k, weightedVotes, l2NormalizeVectors); ICollection <RVFDatum <K, V> > instances = new List <RVFDatum <K, V> >(); foreach (K label in vecBag.Keys) { RVFDatum <K, V> datum; foreach (ICounter <V> vector in vecBag[label]) { if (l2NormalizeVectors) { datum = new RVFDatum <K, V>(Counters.L2Normalize(new ClassicCounter <V>(vector)), label); } else { datum = new RVFDatum <K, V>(vector, label); } instances.Add(datum); } } classifier.AddInstances(instances); return(classifier); }
/// <summary> /// Given an instance to classify, scores and returns /// score by class. /// </summary> /// <remarks> /// Given an instance to classify, scores and returns /// score by class. /// NOTE: supports only RVFDatums /// </remarks> public virtual ClassicCounter <K> ScoresOf(IDatum <K, V> datum) { if (datum is RVFDatum <object, object> ) { RVFDatum <K, V> vec = (RVFDatum <K, V>)datum; if (l2Normalize) { ClassicCounter <V> featVec = new ClassicCounter <V>(vec.AsFeaturesCounter()); Counters.Normalize(featVec); vec = new RVFDatum <K, V>(featVec); } ClassicCounter <ICounter <V> > scores = new ClassicCounter <ICounter <V> >(); foreach (ICounter <V> instance in instances.AllValues()) { scores.SetCount(instance, Counters.Cosine(vec.AsFeaturesCounter(), instance)); } // set entry, for given instance and score IList <ICounter <V> > sorted = Counters.ToSortedList(scores); ClassicCounter <K> classScores = new ClassicCounter <K>(); for (int i = 0; i < k && i < sorted.Count; i++) { K label = classLookup[sorted[i]]; double count = 1.0; if (weightedVotes) { count = scores.GetCount(sorted[i]); } classScores.IncrementCount(label, count); } return(classScores); } else { return(null); } }
public virtual L ClassOf(RVFDatum <L, F> example) { ICounter <L> scores = ScoresOf(example); return(Counters.Argmax(scores)); }
public virtual ClassicCounter <L> ScoresOf(IDatum <L, F> example) { RVFDatum <L, F> rvf = new RVFDatum <L, F>(example); return(ScoresOf(rvf)); }