public static CommonOutputs.MacroOutput <Output> CrossValidate( IHostEnvironment env, Arguments input, EntryPointNode node) { env.CheckValue(input, nameof(input)); // This will be the final resulting list of nodes that is returned from the macro. var subGraphNodes = new List <EntryPointNode>(); //the input transform model VariableBinding transformModelVarName = null; if (input.TransformModel != null) { transformModelVarName = node.GetInputVariable(nameof(input.TransformModel)); } // Split the input data into folds. var splitArgs = new CVSplit.Input(); splitArgs.NumFolds = input.NumFolds; splitArgs.StratificationColumn = input.StratificationColumn; var inputBindingMap = new Dictionary <string, List <ParameterBinding> >(); var inputMap = new Dictionary <ParameterBinding, VariableBinding>(); var inputData = node.GetInputVariable(nameof(splitArgs.Data)); ParameterBinding paramBinding = new SimpleParameterBinding(nameof(splitArgs.Data)); inputBindingMap.Add(nameof(splitArgs.Data), new List <ParameterBinding>() { paramBinding }); inputMap.Add(paramBinding, inputData); var outputMap = new Dictionary <string, string>(); var splitOutputTrainData = new ArrayVar <IDataView>(); var splitOutputTestData = new ArrayVar <IDataView>(); outputMap.Add(nameof(CVSplit.Output.TrainData), splitOutputTrainData.VarName); outputMap.Add(nameof(CVSplit.Output.TestData), splitOutputTestData.VarName); var splitNode = EntryPointNode.Create(env, "Models.CrossValidatorDatasetSplitter", splitArgs, node.Context, inputBindingMap, inputMap, outputMap); subGraphNodes.Add(splitNode); var predModelVars = new Var <PredictorModel> [input.NumFolds]; var inputTransformModelVars = new Var <PredictorModel> [input.NumFolds]; var warningsVars = new Var <IDataView> [input.NumFolds]; var overallMetricsVars = new Var <IDataView> [input.NumFolds]; var instanceMetricsVars = new Var <IDataView> [input.NumFolds]; var confusionMatrixVars = new Var <IDataView> [input.NumFolds]; // Instantiate the subgraph for each fold. for (int k = 0; k < input.NumFolds; k++) { // Parse the nodes in input.Nodes into a temporary run context. var context = new RunContext(env); var graph = EntryPointNode.ValidateNodes(env, context, input.Nodes); // Rename all the variables such that they don't conflict with the ones in the outer run context. var mapping = new Dictionary <string, string>(); foreach (var entryPointNode in graph) { entryPointNode.RenameAllVariables(mapping); } // Instantiate a TrainTest entry point for this fold. var args = new TrainTestMacro.Arguments { Nodes = new JArray(graph.Select(n => n.ToJson()).ToArray()), TransformModel = null, LabelColumn = input.LabelColumn, GroupColumn = input.GroupColumn, WeightColumn = input.WeightColumn, NameColumn = input.NameColumn }; if (transformModelVarName != null) { args.TransformModel = new Var <TransformModel> { VarName = transformModelVarName.VariableName } } ; args.Inputs.Data = new Var <IDataView> { VarName = mapping[input.Inputs.Data.VarName] }; args.Outputs.PredictorModel = new Var <PredictorModel> { VarName = mapping[input.Outputs.PredictorModel.VarName] }; // Set train/test trainer kind to match. args.Kind = input.Kind; // Set the input bindings for the TrainTest entry point. inputBindingMap = new Dictionary <string, List <ParameterBinding> >(); inputMap = new Dictionary <ParameterBinding, VariableBinding>(); var trainingData = new SimpleParameterBinding(nameof(args.TrainingData)); inputBindingMap.Add(nameof(args.TrainingData), new List <ParameterBinding> { trainingData }); inputMap.Add(trainingData, new ArrayIndexVariableBinding(splitOutputTrainData.VarName, k)); var testingData = new SimpleParameterBinding(nameof(args.TestingData)); inputBindingMap.Add(nameof(args.TestingData), new List <ParameterBinding> { testingData }); inputMap.Add(testingData, new ArrayIndexVariableBinding(splitOutputTestData.VarName, k)); outputMap = new Dictionary <string, string>(); var transformModelVar = new Var <TransformModel>(); var predModelVar = new Var <PredictorModel>(); outputMap.Add(nameof(TrainTestMacro.Output.PredictorModel), predModelVar.VarName); predModelVars[k] = predModelVar; if (transformModelVarName != null && transformModelVarName.VariableName != null) { var combineModelsArgs = new ModelOperations.SimplePredictorModelInput(); inputBindingMap = new Dictionary <string, List <ParameterBinding> >(); inputMap = new Dictionary <ParameterBinding, VariableBinding>(); var inputTransformModel = new SimpleVariableBinding(transformModelVarName.VariableName); var inputPredictorModel = new SimpleVariableBinding(predModelVar.VarName); paramBinding = new SimpleParameterBinding(nameof(combineModelsArgs.TransformModel)); inputBindingMap.Add(nameof(combineModelsArgs.TransformModel), new List <ParameterBinding>() { paramBinding }); inputMap.Add(paramBinding, inputTransformModel); paramBinding = new SimpleParameterBinding(nameof(combineModelsArgs.PredictorModel)); inputBindingMap.Add(nameof(combineModelsArgs.PredictorModel), new List <ParameterBinding>() { paramBinding }); inputMap.Add(paramBinding, inputPredictorModel); outputMap = new Dictionary <string, string>(); var combineNodeOutputPredictorModel = new Var <PredictorModel>(); predModelVars[k] = combineNodeOutputPredictorModel; outputMap.Add(nameof(ModelOperations.PredictorModelOutput.PredictorModel), combineNodeOutputPredictorModel.VarName); EntryPointNode combineNode = EntryPointNode.Create(env, "Transforms.TwoHeterogeneousModelCombiner", combineModelsArgs, node.Context, inputBindingMap, inputMap, outputMap); subGraphNodes.Add(combineNode); } var warningVar = new Var <IDataView>(); outputMap.Add(nameof(TrainTestMacro.Output.Warnings), warningVar.VarName); warningsVars[k] = warningVar; var overallMetric = new Var <IDataView>(); outputMap.Add(nameof(TrainTestMacro.Output.OverallMetrics), overallMetric.VarName); overallMetricsVars[k] = overallMetric; var instanceMetric = new Var <IDataView>(); outputMap.Add(nameof(TrainTestMacro.Output.PerInstanceMetrics), instanceMetric.VarName); instanceMetricsVars[k] = instanceMetric; var confusionMatrix = new Var <IDataView>(); outputMap.Add(nameof(TrainTestMacro.Output.ConfusionMatrix), confusionMatrix.VarName); confusionMatrixVars[k] = confusionMatrix; const string trainTestEvaluatorMacroEntryPoint = "Models.TrainTestEvaluator"; subGraphNodes.Add(EntryPointNode.Create(env, trainTestEvaluatorMacroEntryPoint, args, node.Context, inputBindingMap, inputMap, outputMap)); } // Convert the predictor models to an array of predictor models. MacroUtils.ConvertIPredictorModelsToArray(env, node.Context, subGraphNodes, predModelVars, node.GetOutputVariableName(nameof(Output.PredictorModel))); // Convert the warnings, overall, per instance and confusion matrix data views into an array. var warningsArrayVar = new ArrayVar <IDataView>(); var overallArrayVar = new ArrayVar <IDataView>(); var instanceArrayVar = new ArrayVar <IDataView>(); ArrayVar <IDataView> confusionMatrixArrayVar = null; MacroUtils.ConvertIdataViewsToArray(env, node.Context, subGraphNodes, warningsVars, warningsArrayVar.VarName); MacroUtils.ConvertIdataViewsToArray(env, node.Context, subGraphNodes, overallMetricsVars, overallArrayVar.VarName); MacroUtils.ConvertIdataViewsToArray(env, node.Context, subGraphNodes, instanceMetricsVars, instanceArrayVar.VarName); if (input.Kind == MacroUtils.TrainerKinds.SignatureBinaryClassifierTrainer || input.Kind == MacroUtils.TrainerKinds.SignatureMultiClassClassifierTrainer) { confusionMatrixArrayVar = new ArrayVar <IDataView>(); MacroUtils.ConvertIdataViewsToArray(env, node.Context, subGraphNodes, confusionMatrixVars, confusionMatrixArrayVar.VarName); } var combineArgs = new CombineMetricsInput(); combineArgs.Kind = input.Kind; combineArgs.LabelColumn = input.LabelColumn; combineArgs.WeightColumn = input.WeightColumn; combineArgs.GroupColumn = input.GroupColumn; combineArgs.NameColumn = input.NameColumn; // Set the input bindings for the CombineMetrics entry point. var combineInputBindingMap = new Dictionary <string, List <ParameterBinding> >(); var combineInputMap = new Dictionary <ParameterBinding, VariableBinding>(); var warningsArray = new SimpleParameterBinding(nameof(combineArgs.Warnings)); combineInputBindingMap.Add(nameof(combineArgs.Warnings), new List <ParameterBinding> { warningsArray }); combineInputMap.Add(warningsArray, new SimpleVariableBinding(warningsArrayVar.VarName)); var overallArray = new SimpleParameterBinding(nameof(combineArgs.OverallMetrics)); combineInputBindingMap.Add(nameof(combineArgs.OverallMetrics), new List <ParameterBinding> { overallArray }); combineInputMap.Add(overallArray, new SimpleVariableBinding(overallArrayVar.VarName)); var combinePerInstArray = new SimpleParameterBinding(nameof(combineArgs.PerInstanceMetrics)); combineInputBindingMap.Add(nameof(combineArgs.PerInstanceMetrics), new List <ParameterBinding> { combinePerInstArray }); combineInputMap.Add(combinePerInstArray, new SimpleVariableBinding(instanceArrayVar.VarName)); if (confusionMatrixArrayVar != null) { var combineConfArray = new SimpleParameterBinding(nameof(combineArgs.ConfusionMatrix)); combineInputBindingMap.Add(nameof(combineArgs.ConfusionMatrix), new List <ParameterBinding> { combineConfArray }); combineInputMap.Add(combineConfArray, new SimpleVariableBinding(confusionMatrixArrayVar.VarName)); } var combineOutputMap = new Dictionary <string, string>(); var combineWarningVar = new Var <IDataView>(); combineWarningVar.VarName = node.GetOutputVariableName(nameof(Output.Warnings)); combineOutputMap.Add(nameof(Output.Warnings), combineWarningVar.VarName); var combineOverallMetric = new Var <IDataView>(); combineOverallMetric.VarName = node.GetOutputVariableName(nameof(Output.OverallMetrics)); combineOutputMap.Add(nameof(Output.OverallMetrics), combineOverallMetric.VarName); var combineInstanceMetric = new Var <IDataView>(); combineInstanceMetric.VarName = node.GetOutputVariableName(nameof(Output.PerInstanceMetrics)); combineOutputMap.Add(nameof(Output.PerInstanceMetrics), combineInstanceMetric.VarName); if (confusionMatrixArrayVar != null) { var combineConfusionMatrix = new Var <IDataView>(); combineConfusionMatrix.VarName = node.GetOutputVariableName(nameof(Output.ConfusionMatrix)); combineOutputMap.Add(nameof(TrainTestMacro.Output.ConfusionMatrix), combineConfusionMatrix.VarName); } var combineMetricsNode = EntryPointNode.Create(env, "Models.CrossValidationResultsCombiner", combineArgs, node.Context, combineInputBindingMap, combineInputMap, combineOutputMap); subGraphNodes.Add(combineMetricsNode); return(new CommonOutputs.MacroOutput <Output>() { Nodes = subGraphNodes }); }
public static CommonOutputs.MacroOutput <Output> OneVersusAll( IHostEnvironment env, Arguments input, EntryPointNode node) { Contracts.CheckValue(env, nameof(env)); env.CheckValue(input, nameof(input)); env.Assert(input.Nodes.Count > 0); var numClasses = GetNumberOfClasses(env, input, out var label); var predModelVars = new Var <PredictorModel> [numClasses]; // This will be the final resulting list of nodes that is returned from the macro. var macroNodes = new List <EntryPointNode>(); // Instantiate the subgraph for each label value. for (int k = 0; k < numClasses; k++) { predModelVars[k] = ProcessClass(env, macroNodes, k, label, input, node); } // Convert the predictor models to an array of predictor models. var modelsArray = new Var <PredictorModel[]>(); MacroUtils.ConvertIPredictorModelsToArray(env, node.Context, macroNodes, predModelVars, modelsArray.VarName); // Use OVA model combiner to combine these models into one. // Takes in array of models that are binary predictor models and // produces single multiclass predictor model. var combineArgs = new ModelOperations.CombineOvaPredictorModelsInput(); combineArgs.Caching = input.Caching; combineArgs.FeatureColumn = input.FeatureColumn; combineArgs.LabelColumn = input.LabelColumn; combineArgs.NormalizeFeatures = input.NormalizeFeatures; combineArgs.UseProbabilities = input.UseProbabilities; var inputBindingMap = new Dictionary <string, List <ParameterBinding> >(); var inputMap = new Dictionary <ParameterBinding, VariableBinding>(); var combineNodeModelArrayInput = new SimpleVariableBinding(modelsArray.VarName); var paramBinding = new SimpleParameterBinding(nameof(combineArgs.ModelArray)); inputBindingMap.Add(nameof(combineArgs.ModelArray), new List <ParameterBinding>() { paramBinding }); inputMap.Add(paramBinding, combineNodeModelArrayInput); paramBinding = new SimpleParameterBinding(nameof(combineArgs.TrainingData)); inputBindingMap.Add(nameof(combineArgs.TrainingData), new List <ParameterBinding>() { paramBinding }); inputMap.Add(paramBinding, node.GetInputVariable(nameof(input.TrainingData))); var outputMap = new Dictionary <string, string>(); outputMap.Add(nameof(Output.PredictorModel), node.GetOutputVariableName(nameof(Output.PredictorModel))); var combineModelsNode = EntryPointNode.Create(env, "Models.OvaModelCombiner", combineArgs, node.Context, inputBindingMap, inputMap, outputMap); macroNodes.Add(combineModelsNode); return(new CommonOutputs.MacroOutput <Output>() { Nodes = macroNodes }); }