public static CommonOutputs.MacroOutput <Output> CrossValidate(
            IHostEnvironment env,
            Arguments input,
            EntryPointNode node)
        {
            env.CheckValue(input, nameof(input));

            // This will be the final resulting list of nodes that is returned from the macro.
            var subGraphNodes = new List <EntryPointNode>();

            //the input transform model
            VariableBinding transformModelVarName = null;

            if (input.TransformModel != null)
            {
                transformModelVarName = node.GetInputVariable(nameof(input.TransformModel));
            }

            // Split the input data into folds.
            var splitArgs = new CVSplit.Input();

            splitArgs.NumFolds             = input.NumFolds;
            splitArgs.StratificationColumn = input.StratificationColumn;
            var inputBindingMap           = new Dictionary <string, List <ParameterBinding> >();
            var inputMap                  = new Dictionary <ParameterBinding, VariableBinding>();
            var inputData                 = node.GetInputVariable(nameof(splitArgs.Data));
            ParameterBinding paramBinding = new SimpleParameterBinding(nameof(splitArgs.Data));

            inputBindingMap.Add(nameof(splitArgs.Data), new List <ParameterBinding>()
            {
                paramBinding
            });
            inputMap.Add(paramBinding, inputData);
            var outputMap            = new Dictionary <string, string>();
            var splitOutputTrainData = new ArrayVar <IDataView>();
            var splitOutputTestData  = new ArrayVar <IDataView>();

            outputMap.Add(nameof(CVSplit.Output.TrainData), splitOutputTrainData.VarName);
            outputMap.Add(nameof(CVSplit.Output.TestData), splitOutputTestData.VarName);
            var splitNode = EntryPointNode.Create(env, "Models.CrossValidatorDatasetSplitter", splitArgs,
                                                  node.Context, inputBindingMap, inputMap, outputMap);

            subGraphNodes.Add(splitNode);

            var predModelVars           = new Var <PredictorModel> [input.NumFolds];
            var inputTransformModelVars = new Var <PredictorModel> [input.NumFolds];
            var warningsVars            = new Var <IDataView> [input.NumFolds];
            var overallMetricsVars      = new Var <IDataView> [input.NumFolds];
            var instanceMetricsVars     = new Var <IDataView> [input.NumFolds];
            var confusionMatrixVars     = new Var <IDataView> [input.NumFolds];

            // Instantiate the subgraph for each fold.
            for (int k = 0; k < input.NumFolds; k++)
            {
                // Parse the nodes in input.Nodes into a temporary run context.
                var context = new RunContext(env);
                var graph   = EntryPointNode.ValidateNodes(env, context, input.Nodes);

                // Rename all the variables such that they don't conflict with the ones in the outer run context.
                var mapping = new Dictionary <string, string>();
                foreach (var entryPointNode in graph)
                {
                    entryPointNode.RenameAllVariables(mapping);
                }

                // Instantiate a TrainTest entry point for this fold.
                var args = new TrainTestMacro.Arguments
                {
                    Nodes          = new JArray(graph.Select(n => n.ToJson()).ToArray()),
                    TransformModel = null,
                    LabelColumn    = input.LabelColumn,
                    GroupColumn    = input.GroupColumn,
                    WeightColumn   = input.WeightColumn,
                    NameColumn     = input.NameColumn
                };

                if (transformModelVarName != null)
                {
                    args.TransformModel = new Var <TransformModel> {
                        VarName = transformModelVarName.VariableName
                    }
                }
                ;

                args.Inputs.Data = new Var <IDataView>
                {
                    VarName = mapping[input.Inputs.Data.VarName]
                };
                args.Outputs.PredictorModel = new Var <PredictorModel>
                {
                    VarName = mapping[input.Outputs.PredictorModel.VarName]
                };

                // Set train/test trainer kind to match.
                args.Kind = input.Kind;

                // Set the input bindings for the TrainTest entry point.
                inputBindingMap = new Dictionary <string, List <ParameterBinding> >();
                inputMap        = new Dictionary <ParameterBinding, VariableBinding>();
                var trainingData = new SimpleParameterBinding(nameof(args.TrainingData));
                inputBindingMap.Add(nameof(args.TrainingData), new List <ParameterBinding> {
                    trainingData
                });
                inputMap.Add(trainingData, new ArrayIndexVariableBinding(splitOutputTrainData.VarName, k));
                var testingData = new SimpleParameterBinding(nameof(args.TestingData));
                inputBindingMap.Add(nameof(args.TestingData), new List <ParameterBinding> {
                    testingData
                });
                inputMap.Add(testingData, new ArrayIndexVariableBinding(splitOutputTestData.VarName, k));
                outputMap = new Dictionary <string, string>();
                var transformModelVar = new Var <TransformModel>();
                var predModelVar      = new Var <PredictorModel>();
                outputMap.Add(nameof(TrainTestMacro.Output.PredictorModel), predModelVar.VarName);
                predModelVars[k] = predModelVar;
                if (transformModelVarName != null && transformModelVarName.VariableName != null)
                {
                    var combineModelsArgs = new ModelOperations.SimplePredictorModelInput();
                    inputBindingMap = new Dictionary <string, List <ParameterBinding> >();
                    inputMap        = new Dictionary <ParameterBinding, VariableBinding>();

                    var inputTransformModel = new SimpleVariableBinding(transformModelVarName.VariableName);
                    var inputPredictorModel = new SimpleVariableBinding(predModelVar.VarName);
                    paramBinding = new SimpleParameterBinding(nameof(combineModelsArgs.TransformModel));
                    inputBindingMap.Add(nameof(combineModelsArgs.TransformModel), new List <ParameterBinding>()
                    {
                        paramBinding
                    });
                    inputMap.Add(paramBinding, inputTransformModel);
                    paramBinding = new SimpleParameterBinding(nameof(combineModelsArgs.PredictorModel));
                    inputBindingMap.Add(nameof(combineModelsArgs.PredictorModel), new List <ParameterBinding>()
                    {
                        paramBinding
                    });
                    inputMap.Add(paramBinding, inputPredictorModel);
                    outputMap = new Dictionary <string, string>();

                    var combineNodeOutputPredictorModel = new Var <PredictorModel>();
                    predModelVars[k] = combineNodeOutputPredictorModel;
                    outputMap.Add(nameof(ModelOperations.PredictorModelOutput.PredictorModel), combineNodeOutputPredictorModel.VarName);
                    EntryPointNode combineNode = EntryPointNode.Create(env, "Transforms.TwoHeterogeneousModelCombiner", combineModelsArgs,
                                                                       node.Context, inputBindingMap, inputMap, outputMap);
                    subGraphNodes.Add(combineNode);
                }

                var warningVar = new Var <IDataView>();
                outputMap.Add(nameof(TrainTestMacro.Output.Warnings), warningVar.VarName);
                warningsVars[k] = warningVar;
                var overallMetric = new Var <IDataView>();
                outputMap.Add(nameof(TrainTestMacro.Output.OverallMetrics), overallMetric.VarName);
                overallMetricsVars[k] = overallMetric;
                var instanceMetric = new Var <IDataView>();
                outputMap.Add(nameof(TrainTestMacro.Output.PerInstanceMetrics), instanceMetric.VarName);
                instanceMetricsVars[k] = instanceMetric;
                var confusionMatrix = new Var <IDataView>();
                outputMap.Add(nameof(TrainTestMacro.Output.ConfusionMatrix), confusionMatrix.VarName);
                confusionMatrixVars[k] = confusionMatrix;
                const string trainTestEvaluatorMacroEntryPoint = "Models.TrainTestEvaluator";
                subGraphNodes.Add(EntryPointNode.Create(env, trainTestEvaluatorMacroEntryPoint, args, node.Context, inputBindingMap, inputMap, outputMap));
            }

            // Convert the predictor models to an array of predictor models.
            MacroUtils.ConvertIPredictorModelsToArray(env, node.Context, subGraphNodes, predModelVars, node.GetOutputVariableName(nameof(Output.PredictorModel)));

            // Convert the warnings, overall, per instance and confusion matrix data views into an array.
            var warningsArrayVar = new ArrayVar <IDataView>();
            var overallArrayVar  = new ArrayVar <IDataView>();
            var instanceArrayVar = new ArrayVar <IDataView>();
            ArrayVar <IDataView> confusionMatrixArrayVar = null;

            MacroUtils.ConvertIdataViewsToArray(env, node.Context, subGraphNodes, warningsVars, warningsArrayVar.VarName);
            MacroUtils.ConvertIdataViewsToArray(env, node.Context, subGraphNodes, overallMetricsVars, overallArrayVar.VarName);
            MacroUtils.ConvertIdataViewsToArray(env, node.Context, subGraphNodes, instanceMetricsVars, instanceArrayVar.VarName);
            if (input.Kind == MacroUtils.TrainerKinds.SignatureBinaryClassifierTrainer ||
                input.Kind == MacroUtils.TrainerKinds.SignatureMultiClassClassifierTrainer)
            {
                confusionMatrixArrayVar = new ArrayVar <IDataView>();
                MacroUtils.ConvertIdataViewsToArray(env, node.Context, subGraphNodes, confusionMatrixVars, confusionMatrixArrayVar.VarName);
            }

            var combineArgs = new CombineMetricsInput();

            combineArgs.Kind         = input.Kind;
            combineArgs.LabelColumn  = input.LabelColumn;
            combineArgs.WeightColumn = input.WeightColumn;
            combineArgs.GroupColumn  = input.GroupColumn;
            combineArgs.NameColumn   = input.NameColumn;

            // Set the input bindings for the CombineMetrics entry point.
            var combineInputBindingMap = new Dictionary <string, List <ParameterBinding> >();
            var combineInputMap        = new Dictionary <ParameterBinding, VariableBinding>();

            var warningsArray = new SimpleParameterBinding(nameof(combineArgs.Warnings));

            combineInputBindingMap.Add(nameof(combineArgs.Warnings), new List <ParameterBinding> {
                warningsArray
            });
            combineInputMap.Add(warningsArray, new SimpleVariableBinding(warningsArrayVar.VarName));
            var overallArray = new SimpleParameterBinding(nameof(combineArgs.OverallMetrics));

            combineInputBindingMap.Add(nameof(combineArgs.OverallMetrics), new List <ParameterBinding> {
                overallArray
            });
            combineInputMap.Add(overallArray, new SimpleVariableBinding(overallArrayVar.VarName));
            var combinePerInstArray = new SimpleParameterBinding(nameof(combineArgs.PerInstanceMetrics));

            combineInputBindingMap.Add(nameof(combineArgs.PerInstanceMetrics), new List <ParameterBinding> {
                combinePerInstArray
            });
            combineInputMap.Add(combinePerInstArray, new SimpleVariableBinding(instanceArrayVar.VarName));
            if (confusionMatrixArrayVar != null)
            {
                var combineConfArray = new SimpleParameterBinding(nameof(combineArgs.ConfusionMatrix));
                combineInputBindingMap.Add(nameof(combineArgs.ConfusionMatrix), new List <ParameterBinding> {
                    combineConfArray
                });
                combineInputMap.Add(combineConfArray, new SimpleVariableBinding(confusionMatrixArrayVar.VarName));
            }

            var combineOutputMap  = new Dictionary <string, string>();
            var combineWarningVar = new Var <IDataView>();

            combineWarningVar.VarName = node.GetOutputVariableName(nameof(Output.Warnings));
            combineOutputMap.Add(nameof(Output.Warnings), combineWarningVar.VarName);
            var combineOverallMetric = new Var <IDataView>();

            combineOverallMetric.VarName = node.GetOutputVariableName(nameof(Output.OverallMetrics));
            combineOutputMap.Add(nameof(Output.OverallMetrics), combineOverallMetric.VarName);
            var combineInstanceMetric = new Var <IDataView>();

            combineInstanceMetric.VarName = node.GetOutputVariableName(nameof(Output.PerInstanceMetrics));
            combineOutputMap.Add(nameof(Output.PerInstanceMetrics), combineInstanceMetric.VarName);
            if (confusionMatrixArrayVar != null)
            {
                var combineConfusionMatrix = new Var <IDataView>();
                combineConfusionMatrix.VarName = node.GetOutputVariableName(nameof(Output.ConfusionMatrix));
                combineOutputMap.Add(nameof(TrainTestMacro.Output.ConfusionMatrix), combineConfusionMatrix.VarName);
            }
            var combineMetricsNode = EntryPointNode.Create(env, "Models.CrossValidationResultsCombiner",
                                                           combineArgs, node.Context, combineInputBindingMap, combineInputMap, combineOutputMap);

            subGraphNodes.Add(combineMetricsNode);
            return(new CommonOutputs.MacroOutput <Output>()
            {
                Nodes = subGraphNodes
            });
        }
示例#2
0
        public static CommonOutputs.MacroOutput <Output> OneVersusAll(
            IHostEnvironment env,
            Arguments input,
            EntryPointNode node)
        {
            Contracts.CheckValue(env, nameof(env));
            env.CheckValue(input, nameof(input));
            env.Assert(input.Nodes.Count > 0);

            var numClasses    = GetNumberOfClasses(env, input, out var label);
            var predModelVars = new Var <PredictorModel> [numClasses];

            // This will be the final resulting list of nodes that is returned from the macro.
            var macroNodes = new List <EntryPointNode>();

            // Instantiate the subgraph for each label value.
            for (int k = 0; k < numClasses; k++)
            {
                predModelVars[k] = ProcessClass(env, macroNodes, k, label, input, node);
            }

            // Convert the predictor models to an array of predictor models.
            var modelsArray = new Var <PredictorModel[]>();

            MacroUtils.ConvertIPredictorModelsToArray(env, node.Context, macroNodes, predModelVars, modelsArray.VarName);

            // Use OVA model combiner to combine these models into one.
            // Takes in array of models that are binary predictor models and
            // produces single multiclass predictor model.
            var combineArgs = new ModelOperations.CombineOvaPredictorModelsInput();

            combineArgs.Caching           = input.Caching;
            combineArgs.FeatureColumn     = input.FeatureColumn;
            combineArgs.LabelColumn       = input.LabelColumn;
            combineArgs.NormalizeFeatures = input.NormalizeFeatures;
            combineArgs.UseProbabilities  = input.UseProbabilities;

            var inputBindingMap            = new Dictionary <string, List <ParameterBinding> >();
            var inputMap                   = new Dictionary <ParameterBinding, VariableBinding>();
            var combineNodeModelArrayInput = new SimpleVariableBinding(modelsArray.VarName);
            var paramBinding               = new SimpleParameterBinding(nameof(combineArgs.ModelArray));

            inputBindingMap.Add(nameof(combineArgs.ModelArray), new List <ParameterBinding>()
            {
                paramBinding
            });
            inputMap.Add(paramBinding, combineNodeModelArrayInput);
            paramBinding = new SimpleParameterBinding(nameof(combineArgs.TrainingData));
            inputBindingMap.Add(nameof(combineArgs.TrainingData), new List <ParameterBinding>()
            {
                paramBinding
            });
            inputMap.Add(paramBinding, node.GetInputVariable(nameof(input.TrainingData)));

            var outputMap = new Dictionary <string, string>();

            outputMap.Add(nameof(Output.PredictorModel), node.GetOutputVariableName(nameof(Output.PredictorModel)));
            var combineModelsNode = EntryPointNode.Create(env, "Models.OvaModelCombiner",
                                                          combineArgs, node.Context, inputBindingMap, inputMap, outputMap);

            macroNodes.Add(combineModelsNode);

            return(new CommonOutputs.MacroOutput <Output>()
            {
                Nodes = macroNodes
            });
        }