protected override IEnumerable <SuggestedRecipe> ApplyCore(Type predictorType, TransformInference.SuggestedTransform[] transforms) { SuggestedRecipe.SuggestedLearner learner = new SuggestedRecipe.SuggestedLearner(); if (predictorType == typeof(SignatureMultiClassClassifierTrainer)) { learner.LoadableClassInfo = ComponentCatalog.GetLoadableClassInfo <SignatureTrainer>("OVA"); learner.Settings = "p=AveragedPerceptron{iter=10}"; } else { learner.LoadableClassInfo = ComponentCatalog.GetLoadableClassInfo <SignatureTrainer>(Learners.AveragedPerceptronTrainer.LoadNameValue); learner.Settings = "iter=10"; var epInput = new Legacy.Trainers.AveragedPerceptronBinaryClassifier { NumIterations = 10 }; learner.PipelineNode = new TrainerPipelineNode(epInput); } yield return (new SuggestedRecipe(ToString(), transforms, new[] { learner }, Int32.MaxValue)); }
public void TestOvaMacroWithUncalibratedLearner() { var dataPath = GetDataPath(@"iris.txt"); var env = new MLContext(42); // Specify subgraph for OVA var subGraph = env.CreateExperiment(); var learnerInput = new Legacy.Trainers.AveragedPerceptronBinaryClassifier { Shuffle = false }; var learnerOutput = subGraph.Add(learnerInput); // Create pipeline with OVA and multiclass scoring. var experiment = env.CreateExperiment(); var importInput = new Legacy.Data.TextLoader(dataPath); importInput.Arguments.Column = new TextLoaderColumn[] { new TextLoaderColumn { Name = "Label", Source = new[] { new TextLoaderRange(0) } }, new TextLoaderColumn { Name = "Features", Source = new[] { new TextLoaderRange(1, 4) } } }; var importOutput = experiment.Add(importInput); var oneVersusAll = new Legacy.Models.OneVersusAll { TrainingData = importOutput.Data, Nodes = subGraph, UseProbabilities = true, }; var ovaOutput = experiment.Add(oneVersusAll); var scoreInput = new Legacy.Transforms.DatasetScorer { Data = importOutput.Data, PredictorModel = ovaOutput.PredictorModel }; var scoreOutput = experiment.Add(scoreInput); var evalInput = new Legacy.Models.ClassificationEvaluator { Data = scoreOutput.ScoredData }; var evalOutput = experiment.Add(evalInput); experiment.Compile(); experiment.SetInput(importInput.InputFile, new SimpleFileHandle(env, dataPath, false, false)); experiment.Run(); var data = experiment.GetOutput(evalOutput.OverallMetrics); var schema = data.Schema; var b = schema.TryGetColumnIndex(MultiClassClassifierEvaluator.AccuracyMacro, out int accCol); Assert.True(b); using (var cursor = data.GetRowCursor(col => col == accCol)) { var getter = cursor.GetGetter <double>(accCol); b = cursor.MoveNext(); Assert.True(b); double acc = 0; getter(ref acc); Assert.Equal(0.71, acc, 2); b = cursor.MoveNext(); Assert.False(b); } }