public void New_Evaluation() { var dataPath = GetDataPath(SentimentDataPath); var testDataPath = GetDataPath(SentimentTestPath); using (var env = new LocalEnvironment(seed: 1, conc: 1)) { var reader = new TextLoader(env, MakeSentimentTextLoaderArgs()); // Pipeline. var pipeline = new TextLoader(env, MakeSentimentTextLoaderArgs()) .Append(new TextTransform(env, "SentimentText", "Features")) .Append(new LinearClassificationTrainer(env, new LinearClassificationTrainer.Arguments { NumThreads = 1 }, "Features", "Label")); // Train. var readerModel = pipeline.Fit(new MultiFileSource(dataPath)); // Evaluate on the test set. var dataEval = readerModel.Read(new MultiFileSource(testDataPath)); var evaluator = new MyBinaryClassifierEvaluator(env, new BinaryClassifierEvaluator.Arguments() { }); var metrics = evaluator.Evaluate(dataEval); } }
public BinaryCrossValidationMetrics CrossValidate(IDataView trainData, IEstimator <ITransformer> estimator) { var models = new ITransformer[NumFolds]; var metrics = new BinaryClassificationMetrics[NumFolds]; if (StratificationColumn == null) { StratificationColumn = "StratificationColumn"; var random = new GenerateNumberTransform(_env, trainData, StratificationColumn); trainData = random; } else { throw new NotImplementedException(); } var evaluator = new MyBinaryClassifierEvaluator(_env, new BinaryClassifierEvaluator.Arguments() { }); for (int fold = 0; fold < NumFolds; fold++) { var trainFilter = new RangeFilter(_env, new RangeFilter.Arguments() { Column = StratificationColumn, Min = (Double)fold / NumFolds, Max = (Double)(fold + 1) / NumFolds, Complement = true }, trainData); var testFilter = new RangeFilter(_env, new RangeFilter.Arguments() { Column = StratificationColumn, Min = (Double)fold / NumFolds, Max = (Double)(fold + 1) / NumFolds, Complement = false }, trainData); models[fold] = estimator.Fit(trainFilter); var scoredTest = models[fold].Transform(testFilter); metrics[fold] = evaluator.Evaluate(scoredTest, labelColumn: LabelColumn, probabilityColumn: "Probability"); } return(new BinaryCrossValidationMetrics(models, metrics)); }