protected override IEnumerable <SuggestedRecipe> ApplyCore(Type predictorType, TransformInference.SuggestedTransform[] transforms) { SuggestedRecipe.SuggestedLearner learner = new SuggestedRecipe.SuggestedLearner(); if (predictorType == typeof(SignatureMultiClassClassifierTrainer)) { learner.LoadableClassInfo = ComponentCatalog.GetLoadableClassInfo <SignatureTrainer>(Learners.SdcaMultiClassTrainer.LoadNameValue); } else { learner.LoadableClassInfo = ComponentCatalog.GetLoadableClassInfo <SignatureTrainer>(Learners.LinearClassificationTrainer.LoadNameValue); var epInput = new Trainers.StochasticDualCoordinateAscentBinaryClassifier(); learner.PipelineNode = new TrainerPipelineNode(epInput); } learner.Settings = ""; yield return(new SuggestedRecipe(ToString(), transforms, new[] { learner })); }
public void TestOvaMacro() { var dataPath = GetDataPath(@"iris.txt"); using (var env = new TlcEnvironment(42)) { // Specify subgraph for OVA var subGraph = env.CreateExperiment(); var learnerInput = new Trainers.StochasticDualCoordinateAscentBinaryClassifier { NumThreads = 1 }; var learnerOutput = subGraph.Add(learnerInput); // Create pipeline with OVA and multiclass scoring. var experiment = env.CreateExperiment(); var importInput = new ML.Data.TextLoader(dataPath); importInput.Arguments.Column = new TextLoaderColumn[] { new TextLoaderColumn { Name = "Label", Source = new[] { new TextLoaderRange(0) } }, new TextLoaderColumn { Name = "Features", Source = new[] { new TextLoaderRange(1, 4) } } }; var importOutput = experiment.Add(importInput); var oneVersusAll = new Models.OneVersusAll { TrainingData = importOutput.Data, Nodes = subGraph, UseProbabilities = true, }; var ovaOutput = experiment.Add(oneVersusAll); var scoreInput = new ML.Transforms.DatasetScorer { Data = importOutput.Data, PredictorModel = ovaOutput.PredictorModel }; var scoreOutput = experiment.Add(scoreInput); var evalInput = new ML.Models.ClassificationEvaluator { Data = scoreOutput.ScoredData }; var evalOutput = experiment.Add(evalInput); experiment.Compile(); experiment.SetInput(importInput.InputFile, new SimpleFileHandle(env, dataPath, false, false)); experiment.Run(); var data = experiment.GetOutput(evalOutput.OverallMetrics); var schema = data.Schema; var b = schema.TryGetColumnIndex(MultiClassClassifierEvaluator.AccuracyMacro, out int accCol); Assert.True(b); using (var cursor = data.GetRowCursor(col => col == accCol)) { var getter = cursor.GetGetter <double>(accCol); b = cursor.MoveNext(); Assert.True(b); double acc = 0; getter(ref acc); Assert.Equal(0.96, acc, 2); b = cursor.MoveNext(); Assert.False(b); } } }