public void TestSimpleTrainExperiment() { var dataPath = GetDataPath("adult.tiny.with-schema.txt"); var env = new MLContext(); var experiment = env.CreateExperiment(); var importInput = new Legacy.Data.TextLoader(dataPath); var importOutput = experiment.Add(importInput); var catInput = new Legacy.Transforms.CategoricalOneHotVectorizer { Data = importOutput.Data }; catInput.AddColumn("Categories"); var catOutput = experiment.Add(catInput); var concatInput = new Legacy.Transforms.ColumnConcatenator { Data = catOutput.OutputData }; concatInput.AddColumn("Features", "Categories", "NumericFeatures"); var concatOutput = experiment.Add(concatInput); var sdcaInput = new Legacy.Trainers.StochasticDualCoordinateAscentBinaryClassifier { TrainingData = concatOutput.OutputData, LossFunction = new HingeLossSDCAClassificationLossFunction() { Margin = 1.1f }, NumThreads = 1, Shuffle = false }; var sdcaOutput = experiment.Add(sdcaInput); var scoreInput = new Legacy.Transforms.DatasetScorer { Data = concatOutput.OutputData, PredictorModel = sdcaOutput.PredictorModel }; var scoreOutput = experiment.Add(scoreInput); var evalInput = new Legacy.Models.BinaryClassificationEvaluator { Data = scoreOutput.ScoredData }; var evalOutput = experiment.Add(evalInput); experiment.Compile(); experiment.SetInput(importInput.InputFile, new SimpleFileHandle(env, dataPath, false, false)); experiment.Run(); var data = experiment.GetOutput(evalOutput.OverallMetrics); var schema = data.Schema; var b = schema.TryGetColumnIndex("AUC", out int aucCol); Assert.True(b); using (var cursor = data.GetRowCursor(col => col == aucCol)) { var getter = cursor.GetGetter <double>(aucCol); b = cursor.MoveNext(); Assert.True(b); double auc = 0; getter(ref auc); Assert.Equal(0.93, auc, 2); b = cursor.MoveNext(); Assert.False(b); } }
public static CommonOutputs.MacroOutput <Output> TrainTestBinary( IHostEnvironment env, Arguments input, EntryPointNode node) { // Parse the subgraph. var subGraphRunContext = new RunContext(env); var subGraphNodes = EntryPointNode.ValidateNodes(env, subGraphRunContext, input.Nodes); // Change the subgraph to use the training data as input. var varName = input.Inputs.Data.VarName; EntryPointVariable variable; if (!subGraphRunContext.TryGetVariable(varName, out variable)) { throw env.Except($"Invalid variable name '{varName}'."); } var trainingVar = node.GetInputVariable("TrainingData"); foreach (var subGraphNode in subGraphNodes) { subGraphNode.RenameInputVariable(variable.Name, trainingVar); } subGraphRunContext.RemoveVariable(variable); // Change the subgraph to use the model variable as output. varName = input.Outputs.Model.VarName; if (!subGraphRunContext.TryGetVariable(varName, out variable)) { throw env.Except($"Invalid variable name '{varName}'."); } string outputVarName = node.GetOutputVariableName("PredictorModel"); foreach (var subGraphNode in subGraphNodes) { subGraphNode.RenameOutputVariable(variable.Name, outputVarName); } subGraphRunContext.RemoveVariable(variable); // Move the variables from the subcontext to the main context. node.Context.AddContextVariables(subGraphRunContext); // Change all the subgraph nodes to use the main context. foreach (var subGraphNode in subGraphNodes) { subGraphNode.SetContext(node.Context); } // Add the scoring node. var testingVar = node.GetInputVariable("TestingData"); var exp = new Experiment(env); var scoreNode = new Legacy.Transforms.DatasetScorer(); scoreNode.Data.VarName = testingVar.ToJson(); scoreNode.PredictorModel.VarName = outputVarName; var scoreNodeOutput = exp.Add(scoreNode); subGraphNodes.AddRange(EntryPointNode.ValidateNodes(env, node.Context, exp.GetNodes())); // Add the evaluator node. exp.Reset(); var evalNode = new Legacy.Models.BinaryClassificationEvaluator(); evalNode.Data.VarName = scoreNodeOutput.ScoredData.VarName; var evalOutput = new Legacy.Models.BinaryClassificationEvaluator.Output(); string outVariableName; if (node.OutputMap.TryGetValue("Warnings", out outVariableName)) { evalOutput.Warnings.VarName = outVariableName; } if (node.OutputMap.TryGetValue("OverallMetrics", out outVariableName)) { evalOutput.OverallMetrics.VarName = outVariableName; } if (node.OutputMap.TryGetValue("PerInstanceMetrics", out outVariableName)) { evalOutput.PerInstanceMetrics.VarName = outVariableName; } if (node.OutputMap.TryGetValue("ConfusionMatrix", out outVariableName)) { evalOutput.ConfusionMatrix.VarName = outVariableName; } exp.Add(evalNode, evalOutput); subGraphNodes.AddRange(EntryPointNode.ValidateNodes(env, node.Context, exp.GetNodes())); var stageId = Guid.NewGuid().ToString("N"); foreach (var subGraphNode in subGraphNodes) { subGraphNode.StageId = stageId; } return(new CommonOutputs.MacroOutput <Output>() { Nodes = subGraphNodes }); }