public void Train(ILabeledExampleCollection <LblT, ExT> dataset) { Preconditions.CheckNotNull(dataset); var trainDataset = new LabeledDataset <LblT, ExT>(dataset); for (int i = 0; i < mInnerModels.Length; i++) { if (mInnerModels[i] == null) { mInnerModels[i] = CreateModel(i); } mInnerModels[i].Train(GetTrainSet(i, mInnerModels[i], trainDataset)); } foreach (LabeledExample <LblT, ExT> le in trainDataset) { LabeledExample <LblT, ExT> le_ = le; double[] scores = GetPredictionScores(mInnerModels.Select(m => m.Predict(le_.Example)).ToArray()).ToArray(); mTagDistrTable.AddCount(le.Label, scores); } mTagDistrTable.Calculate(); IsTrained = true; }
private void UpdateDistrTable() { if (mExampleScores == null || mBinWidth == 0) { mTagDistrTable = null; return; } mTagDistrTable = new EnumTagDistrTable <SentimentLabel>(2, mBinWidth, -5, 5, SentimentLabel.Exclude) { CalcDistrFunc = (tagCounts, values, tag) => ((double)tagCounts[tag] + 1) / (tagCounts.Values.Sum() + tagCounts.Count) }; foreach (ExampleScore es in mExampleScores) { mTagDistrTable.AddCount(es.Label, es.PosScore, es.NegScore); } mTagDistrTable.Calculate(); }
public override void Train(ILabeledExampleCollection <SentimentLabel, SparseVector <double> > dataset) { Preconditions.CheckNotNull(dataset); Preconditions.CheckArgumentRange(TagDistrTable == null || TagDistrTable.NumOfDimensions == 2); mBinModel = CreateModel(); mBinModel.Train(new LabeledDataset <SentimentLabel, SparseVector <double> >(dataset.Where(le => le.Label != SentimentLabel.Neutral))); TagDistrTable = new EnumTagDistrTable <SentimentLabel>(1, BinWidth, -5, 5, SentimentLabel.Exclude) { CalcDistrFunc = (tagCounts, values, tag) => ((double)tagCounts[tag] + 1) / (tagCounts.Values.Sum() + tagCounts.Count) // use Laplace formula }; foreach (LabeledExample <SentimentLabel, SparseVector <double> > le in dataset) { Prediction <SentimentLabel> prediction = mBinModel.Predict(le.Example); TagDistrTable.AddCount(le.Label, prediction.BestClassLabel == SentimentLabel.Positive ? prediction.BestScore : -prediction.BestScore); } TagDistrTable.Calculate(); IsTrained = true; }