예제 #1
0
        public void TestSimpleTrainExperiment()
        {
            var dataPath   = GetDataPath("adult.tiny.with-schema.txt");
            var env        = new MLContext();
            var experiment = env.CreateExperiment();

            var importInput  = new Legacy.Data.TextLoader(dataPath);
            var importOutput = experiment.Add(importInput);

            var catInput = new Legacy.Transforms.CategoricalOneHotVectorizer
            {
                Data = importOutput.Data
            };

            catInput.AddColumn("Categories");
            var catOutput = experiment.Add(catInput);

            var concatInput = new Legacy.Transforms.ColumnConcatenator
            {
                Data = catOutput.OutputData
            };

            concatInput.AddColumn("Features", "Categories", "NumericFeatures");
            var concatOutput = experiment.Add(concatInput);

            var sdcaInput = new Legacy.Trainers.StochasticDualCoordinateAscentBinaryClassifier
            {
                TrainingData = concatOutput.OutputData,
                LossFunction = new HingeLossSDCAClassificationLossFunction()
                {
                    Margin = 1.1f
                },
                NumThreads = 1,
                Shuffle    = false
            };
            var sdcaOutput = experiment.Add(sdcaInput);

            var scoreInput = new Legacy.Transforms.DatasetScorer
            {
                Data           = concatOutput.OutputData,
                PredictorModel = sdcaOutput.PredictorModel
            };
            var scoreOutput = experiment.Add(scoreInput);

            var evalInput = new Legacy.Models.BinaryClassificationEvaluator
            {
                Data = scoreOutput.ScoredData
            };
            var evalOutput = experiment.Add(evalInput);

            experiment.Compile();
            experiment.SetInput(importInput.InputFile, new SimpleFileHandle(env, dataPath, false, false));
            experiment.Run();
            var data = experiment.GetOutput(evalOutput.OverallMetrics);

            var schema = data.Schema;
            var b      = schema.TryGetColumnIndex("AUC", out int aucCol);

            Assert.True(b);
            using (var cursor = data.GetRowCursor(col => col == aucCol))
            {
                var getter = cursor.GetGetter <double>(aucCol);
                b = cursor.MoveNext();
                Assert.True(b);
                double auc = 0;
                getter(ref auc);
                Assert.Equal(0.93, auc, 2);
                b = cursor.MoveNext();
                Assert.False(b);
            }
        }
        public static CommonOutputs.MacroOutput <Output> TrainTestBinary(
            IHostEnvironment env,
            Arguments input,
            EntryPointNode node)
        {
            // Parse the subgraph.
            var subGraphRunContext = new RunContext(env);
            var subGraphNodes      = EntryPointNode.ValidateNodes(env, subGraphRunContext, input.Nodes);

            // Change the subgraph to use the training data as input.
            var varName = input.Inputs.Data.VarName;
            EntryPointVariable variable;

            if (!subGraphRunContext.TryGetVariable(varName, out variable))
            {
                throw env.Except($"Invalid variable name '{varName}'.");
            }
            var trainingVar = node.GetInputVariable("TrainingData");

            foreach (var subGraphNode in subGraphNodes)
            {
                subGraphNode.RenameInputVariable(variable.Name, trainingVar);
            }
            subGraphRunContext.RemoveVariable(variable);

            // Change the subgraph to use the model variable as output.
            varName = input.Outputs.Model.VarName;
            if (!subGraphRunContext.TryGetVariable(varName, out variable))
            {
                throw env.Except($"Invalid variable name '{varName}'.");
            }
            string outputVarName = node.GetOutputVariableName("PredictorModel");

            foreach (var subGraphNode in subGraphNodes)
            {
                subGraphNode.RenameOutputVariable(variable.Name, outputVarName);
            }
            subGraphRunContext.RemoveVariable(variable);

            // Move the variables from the subcontext to the main context.
            node.Context.AddContextVariables(subGraphRunContext);

            // Change all the subgraph nodes to use the main context.
            foreach (var subGraphNode in subGraphNodes)
            {
                subGraphNode.SetContext(node.Context);
            }

            // Add the scoring node.
            var testingVar = node.GetInputVariable("TestingData");
            var exp        = new Experiment(env);
            var scoreNode  = new Legacy.Transforms.DatasetScorer();

            scoreNode.Data.VarName           = testingVar.ToJson();
            scoreNode.PredictorModel.VarName = outputVarName;
            var scoreNodeOutput = exp.Add(scoreNode);

            subGraphNodes.AddRange(EntryPointNode.ValidateNodes(env, node.Context, exp.GetNodes()));

            // Add the evaluator node.
            exp.Reset();
            var evalNode = new Legacy.Models.BinaryClassificationEvaluator();

            evalNode.Data.VarName = scoreNodeOutput.ScoredData.VarName;
            var    evalOutput = new Legacy.Models.BinaryClassificationEvaluator.Output();
            string outVariableName;

            if (node.OutputMap.TryGetValue("Warnings", out outVariableName))
            {
                evalOutput.Warnings.VarName = outVariableName;
            }
            if (node.OutputMap.TryGetValue("OverallMetrics", out outVariableName))
            {
                evalOutput.OverallMetrics.VarName = outVariableName;
            }
            if (node.OutputMap.TryGetValue("PerInstanceMetrics", out outVariableName))
            {
                evalOutput.PerInstanceMetrics.VarName = outVariableName;
            }
            if (node.OutputMap.TryGetValue("ConfusionMatrix", out outVariableName))
            {
                evalOutput.ConfusionMatrix.VarName = outVariableName;
            }
            exp.Add(evalNode, evalOutput);
            subGraphNodes.AddRange(EntryPointNode.ValidateNodes(env, node.Context, exp.GetNodes()));

            var stageId = Guid.NewGuid().ToString("N");

            foreach (var subGraphNode in subGraphNodes)
            {
                subGraphNode.StageId = stageId;
            }

            return(new CommonOutputs.MacroOutput <Output>()
            {
                Nodes = subGraphNodes
            });
        }