public static CommonOutputs.MacroOutput <Output> PipelineSweep(
            IHostEnvironment env,
            Arguments input,
            EntryPointNode node)
        {
            env.Check(input.StateArguments != null || input.State is AutoInference.AutoMlMlState,
                      "Must have a valid AutoML State, or pass arguments to create one.");
            env.Check(input.BatchSize > 0, "Batch size must be > 0.");

            // Get the user-defined column roles (if any)
            var dataRoles = GetDataRoles(env, input);

            // If no current state, create object and set data.
            if (input.State == null)
            {
                input.State = input.StateArguments?.CreateComponent(env);

                if (input.State is AutoInference.AutoMlMlState inState)
                {
                    inState.SetTrainTestData(input.TrainingData, input.TestingData);
                }
                else
                {
                    throw env.Except($"Incompatible type. Expecting type {typeof(AutoInference.AutoMlMlState)}, received type {input.State?.GetType()}.");
                }

                var result = node.AddNewVariable("State", input.State);
                node.Context.AddInputVariable(result.Item2, typeof(IMlState));
            }
            var autoMlState = (AutoInference.AutoMlMlState)input.State;

            // The indicators are just so the macro knows those pipelines need to
            // be run before performing next expansion. If we add them as inputs
            // to the next iteration, the next iteration cannot run until they have
            // their values set. Thus, indicators are needed.
            var pipelineIndicators = new List <Var <IDataView> >();

            var expNodes = new List <EntryPointNode>();

            // Keep versions of the training and testing var names
            var training = new Var <IDataView> {
                VarName = node.GetInputVariable("TrainingData").VariableName
            };
            var testing = new Var <IDataView> {
                VarName = node.GetInputVariable("TestingData").VariableName
            };
            var amlsVarObj =
                new Var <IMlState>()
            {
                VarName = node.GetInputVariable(nameof(input.State)).VariableName
            };

            // Make sure search space is defined. If not, infer,
            // with default number of transform levels.
            if (!autoMlState.IsSearchSpaceDefined())
            {
                autoMlState.InferSearchSpace(numTransformLevels: 1, dataRoles);
            }

            // Extract performance summaries and assign to previous candidate pipelines.
            foreach (var pipeline in autoMlState.BatchCandidates)
            {
                if (node.Context.TryGetVariable(ExperimentUtils.GenerateOverallMetricVarName(pipeline.UniqueId), out var v) &&
                    node.Context.TryGetVariable(AutoMlUtils.GenerateOverallTrainingMetricVarName(pipeline.UniqueId), out var v2))
                {
                    pipeline.PerformanceSummary = AutoMlUtils.ExtractRunSummary(env, (IDataView)v.Value, autoMlState.Metric.Name, (IDataView)v2.Value);
                    autoMlState.AddEvaluated(pipeline);
                }
            }

            node.OutputMap.TryGetValue("Results", out string outDvName);
            var outDvVar = new Var <IDataView>()
            {
                VarName = outDvName
            };

            node.OutputMap.TryGetValue("State", out string outStateName);
            var outStateVar = new Var <IMlState>()
            {
                VarName = outStateName
            };

            // Get next set of candidates.
            var candidatePipelines = autoMlState.GetNextCandidates(input.BatchSize);

            // Check if termination condition was met, i.e. no more candidates were returned.
            // If so, end expansion and add a node to extract the sweep result.
            if (candidatePipelines == null || candidatePipelines.Length == 0)
            {
                // Add a node to extract the sweep result.
                var resultSubgraph = new Experiment(env);
                var resultNode     = new Microsoft.ML.Legacy.Models.SweepResultExtractor()
                {
                    State = amlsVarObj
                };
                var resultOutput = new Legacy.Models.SweepResultExtractor.Output()
                {
                    State = outStateVar, Results = outDvVar
                };
                resultSubgraph.Add(resultNode, resultOutput);
                var resultSubgraphNodes = EntryPointNode.ValidateNodes(env, node.Context, resultSubgraph.GetNodes());
                expNodes.AddRange(resultSubgraphNodes);
                return(new CommonOutputs.MacroOutput <Output>()
                {
                    Nodes = expNodes
                });
            }

            // Prep all returned candidates
            foreach (var p in candidatePipelines)
            {
                // Add train test experiments to current graph for candidate pipeline
                var subgraph        = new Experiment(env);
                var trainTestOutput = p.AddAsTrainTest(training, testing, autoMlState.TrainerKind, subgraph, true);

                // Change variable name to reference pipeline ID in output map, context and entrypoint output.
                var uniqueName         = ExperimentUtils.GenerateOverallMetricVarName(p.UniqueId);
                var uniqueNameTraining = AutoMlUtils.GenerateOverallTrainingMetricVarName(p.UniqueId);
                var sgNode             = EntryPointNode.ValidateNodes(env, node.Context,
                                                                      new JArray(subgraph.GetNodes().Last())).Last();
                sgNode.RenameOutputVariable(trainTestOutput.OverallMetrics.VarName, uniqueName, cascadeChanges: true);
                sgNode.RenameOutputVariable(trainTestOutput.TrainingOverallMetrics.VarName, uniqueNameTraining, cascadeChanges: true);
                trainTestOutput.OverallMetrics.VarName         = uniqueName;
                trainTestOutput.TrainingOverallMetrics.VarName = uniqueNameTraining;
                expNodes.Add(sgNode);

                // Store indicators, to pass to next iteration of macro.
                pipelineIndicators.Add(trainTestOutput.OverallMetrics);
            }

            // Add recursive macro node
            var macroSubgraph = new Experiment(env);
            var macroNode     = new Legacy.Models.PipelineSweeper()
            {
                BatchSize        = input.BatchSize,
                CandidateOutputs = new ArrayVar <IDataView>(pipelineIndicators.ToArray()),
                TrainingData     = training,
                TestingData      = testing,
                State            = amlsVarObj
            };
            var output = new Legacy.Models.PipelineSweeper.Output()
            {
                Results = outDvVar, State = outStateVar
            };

            macroSubgraph.Add(macroNode, output);

            var subgraphNodes = EntryPointNode.ValidateNodes(env, node.Context, macroSubgraph.GetNodes());

            expNodes.AddRange(subgraphNodes);

            return(new CommonOutputs.MacroOutput <Output>()
            {
                Nodes = expNodes
            });
        }
Beispiel #2
0
        /// <summary>
        /// Extract all values of one column of the data view in a form of an <see cref="IEnumerable{T}"/>.
        /// </summary>
        /// <typeparam name="T">The type of the values. This must match the actual column type.</typeparam>
        /// <param name="data">The data view to get the column from.</param>
        /// <param name="env">The current host environment.</param>
        /// <param name="columnName">The name of the column to extract.</param>
        public static IEnumerable <T> GetColumn <T>(this IDataView data, IHostEnvironment env, string columnName)
        {
            Contracts.CheckValue(env, nameof(env));
            env.CheckValue(data, nameof(data));
            env.CheckNonEmpty(columnName, nameof(columnName));

            if (!data.Schema.TryGetColumnIndex(columnName, out int col))
            {
                throw env.ExceptSchemaMismatch(nameof(columnName), "input", columnName);
            }

            // There are two decisions that we make here:
            // - Is the T an array type?
            //     - If yes, we need to map VBuffer to array and densify.
            //     - If no, this is not needed.
            // - Does T (or item type of T if it's an array) equal to the data view type?
            //     - If this is the same type, we can map directly.
            //     - Otherwise, we need a conversion delegate.

            var colType = data.Schema[col].Type;

            if (colType.RawType == typeof(T))
            {
                // Direct mapping is possible.
                return(GetColumnDirect <T>(data, col));
            }
            else if (typeof(T) == typeof(string) && colType.IsText)
            {
                // Special case of ROM<char> to string conversion.
                Delegate convert = (Func <ReadOnlyMemory <char>, string>)((ReadOnlyMemory <char> txt) => txt.ToString());
                Func <IDataView, int, Func <int, T>, IEnumerable <T> > del = GetColumnConvert;
                var meth = del.Method.GetGenericMethodDefinition().MakeGenericMethod(typeof(T), colType.RawType);
                return((IEnumerable <T>)(meth.Invoke(null, new object[] { data, col, convert })));
            }
            else if (typeof(T).IsArray)
            {
                // Output is an array type.
                if (!colType.IsVector)
                {
                    throw env.ExceptSchemaMismatch(nameof(columnName), "input", columnName, "vector", "scalar");
                }
                var elementType = typeof(T).GetElementType();
                if (elementType == colType.ItemType.RawType)
                {
                    // Direct mapping of items.
                    Func <IDataView, int, IEnumerable <int[]> > del = GetColumnArrayDirect <int>;
                    var meth = del.Method.GetGenericMethodDefinition().MakeGenericMethod(elementType);
                    return((IEnumerable <T>)meth.Invoke(null, new object[] { data, col }));
                }
                else if (elementType == typeof(string) && colType.ItemType.IsText)
                {
                    // Conversion of DvText items to string items.
                    Delegate convert = (Func <ReadOnlyMemory <char>, string>)((ReadOnlyMemory <char> txt) => txt.ToString());
                    Func <IDataView, int, Func <int, long>, IEnumerable <long[]> > del = GetColumnArrayConvert;
                    var meth = del.Method.GetGenericMethodDefinition().MakeGenericMethod(elementType, colType.ItemType.RawType);
                    return((IEnumerable <T>)meth.Invoke(null, new object[] { data, col, convert }));
                }
                // Fall through to the failure.
            }
            throw env.Except($"Could not map a data view column '{columnName}' of type {colType} to {typeof(T)}.");
        }
Beispiel #3
0
        private static void RunGraphCore(EnvironmentBlock *penv, IHostEnvironment env, string graphStr, int cdata, DataSourceBlock **ppdata)
        {
            Contracts.AssertValue(env);

            var    args = new RunGraphArgs();
            string err  = null;

            if (!CmdParser.ParseArguments(env, graphStr, args, e => err = err ?? e))
            {
                throw env.Except(err);
            }

            int?maxThreadsAllowed = Math.Min(args.parallel > 0 ? args.parallel.Value : penv->maxThreadsAllowed, penv->maxThreadsAllowed);

            maxThreadsAllowed = penv->maxThreadsAllowed > 0 ? maxThreadsAllowed : args.parallel;
            var host = env.Register("RunGraph", args.randomSeed, null);

            JObject graph;

            try
            {
                graph = JObject.Parse(args.graph);
            }
            catch (JsonReaderException ex)
            {
                throw host.Except(ex, "Failed to parse experiment graph: {0}", ex.Message);
            }

            var runner = new GraphRunner(host, graph["nodes"] as JArray);

            var dvNative = new IDataView[cdata];

            try
            {
                for (int i = 0; i < cdata; i++)
                {
                    dvNative[i] = new NativeDataView(host, ppdata[i]);
                }

                // Setting inputs.
                var jInputs = graph["inputs"] as JObject;
                if (graph["inputs"] != null && jInputs == null)
                {
                    throw host.Except("Unexpected value for 'inputs': {0}", graph["inputs"]);
                }
                int iDv = 0;
                if (jInputs != null)
                {
                    foreach (var kvp in jInputs)
                    {
                        var pathValue = kvp.Value as JValue;
                        if (pathValue == null)
                        {
                            throw host.Except("Invalid value for input: {0}", kvp.Value);
                        }

                        var path    = pathValue.Value <string>();
                        var varName = kvp.Key;
                        var type    = runner.GetPortDataKind(varName);

                        switch (type)
                        {
                        case TlcModule.DataKind.FileHandle:
                            var fh = new SimpleFileHandle(host, path, false, false);
                            runner.SetInput(varName, fh);
                            break;

                        case TlcModule.DataKind.DataView:
                            IDataView dv;
                            if (!string.IsNullOrWhiteSpace(path))
                            {
                                var extension = Path.GetExtension(path);
                                if (extension == ".txt")
                                {
                                    dv = TextLoader.LoadFile(host, new TextLoader.Options(), new MultiFileSource(path));
                                }

                                else
                                {
                                    dv = new BinaryLoader(host, new BinaryLoader.Arguments(), path);
                                }
                            }
                            else
                            {
                                Contracts.Assert(iDv < dvNative.Length);
                                // prefetch all columns
                                dv = dvNative[iDv++];
                                var prefetch = new int[dv.Schema.Count];
                                for (int i = 0; i < prefetch.Length; i++)
                                {
                                    prefetch[i] = i;
                                }
                                dv = new CacheDataView(host, dv, prefetch);
                            }
                            runner.SetInput(varName, dv);
                            break;

                        case TlcModule.DataKind.PredictorModel:
                            PredictorModel pm;
                            if (!string.IsNullOrWhiteSpace(path))
                            {
                                using (var fs = File.OpenRead(path))
                                    pm = new PredictorModelImpl(host, fs);
                            }
                            else
                            {
                                throw host.Except("Model must be loaded from a file");
                            }
                            runner.SetInput(varName, pm);
                            break;

                        case TlcModule.DataKind.TransformModel:
                            TransformModel tm;
                            if (!string.IsNullOrWhiteSpace(path))
                            {
                                using (var fs = File.OpenRead(path))
                                    tm = new TransformModelImpl(host, fs);
                            }
                            else
                            {
                                throw host.Except("Model must be loaded from a file");
                            }
                            runner.SetInput(varName, tm);
                            break;

                        default:
                            throw host.Except("Port type {0} not supported", type);
                        }
                    }
                }
                runner.RunAll();

                // Reading outputs.
                using (var ch = host.Start("Reading outputs"))
                {
                    var jOutputs = graph["outputs"] as JObject;
                    if (jOutputs != null)
                    {
                        foreach (var kvp in jOutputs)
                        {
                            var pathValue = kvp.Value as JValue;
                            if (pathValue == null)
                            {
                                throw host.Except("Invalid value for input: {0}", kvp.Value);
                            }
                            var path    = pathValue.Value <string>();
                            var varName = kvp.Key;
                            var type    = runner.GetPortDataKind(varName);

                            switch (type)
                            {
                            case TlcModule.DataKind.FileHandle:
                                var fh = runner.GetOutput <IFileHandle>(varName);
                                throw host.ExceptNotSupp("File handle outputs not yet supported.");

                            case TlcModule.DataKind.DataView:
                                var idv = runner.GetOutput <IDataView>(varName);
                                if (!string.IsNullOrWhiteSpace(path))
                                {
                                    SaveIdvToFile(idv, path, host);
                                }
                                else
                                {
                                    var infos = ProcessColumns(ref idv, args.maxSlots, host);
                                    SendViewToNative(ch, penv, idv, infos);
                                }
                                break;

                            case TlcModule.DataKind.PredictorModel:
                                var pm = runner.GetOutput <PredictorModel>(varName);
                                if (!string.IsNullOrWhiteSpace(path))
                                {
                                    SavePredictorModelToFile(pm, path, host);
                                }
                                else
                                {
                                    throw host.Except("Returning in-memory models is not supported");
                                }
                                break;

                            case TlcModule.DataKind.TransformModel:
                                var tm = runner.GetOutput <TransformModel>(varName);
                                if (!string.IsNullOrWhiteSpace(path))
                                {
                                    using (var fs = File.OpenWrite(path))
                                        tm.Save(host, fs);
                                }
                                else
                                {
                                    throw host.Except("Returning in-memory models is not supported");
                                }
                                break;

                            case TlcModule.DataKind.Array:
                                var objArray = runner.GetOutput <object[]>(varName);
                                if (objArray is PredictorModel[])
                                {
                                    var modelArray = (PredictorModel[])objArray;
                                    // Save each model separately
                                    for (var i = 0; i < modelArray.Length; i++)
                                    {
                                        var modelPath = string.Format(CultureInfo.InvariantCulture, path, i);
                                        SavePredictorModelToFile(modelArray[i], modelPath, host);
                                    }
                                }
                                else
                                {
                                    throw host.Except("DataKind.Array type {0} not supported", objArray.First().GetType());
                                }
                                break;

                            default:
                                throw host.Except("Port type {0} not supported", type);
                            }
                        }
                    }
                }
            }
            finally
            {
                // The raw data view is disposable so it lets go of unmanaged raw pointers before we return.
                for (int i = 0; i < dvNative.Length; i++)
                {
                    var view = dvNative[i];
                    if (view == null)
                    {
                        continue;
                    }
                    host.Assert(view is IDisposable);
                    var disp = (IDisposable)dvNative[i];
                    disp.Dispose();
                }
            }
        }
        public static CommonOutputs.MacroOutput <Output> TrainTestBinary(
            IHostEnvironment env,
            Arguments input,
            EntryPointNode node)
        {
            // Parse the subgraph.
            var subGraphRunContext = new RunContext(env);
            var subGraphNodes      = EntryPointNode.ValidateNodes(env, subGraphRunContext, input.Nodes, node.Catalog);

            // Change the subgraph to use the training data as input.
            var varName = input.Inputs.Data.VarName;
            EntryPointVariable variable;

            if (!subGraphRunContext.TryGetVariable(varName, out variable))
            {
                throw env.Except($"Invalid variable name '{varName}'.");
            }
            var trainingVar = node.GetInputVariable("TrainingData");

            foreach (var subGraphNode in subGraphNodes)
            {
                subGraphNode.RenameInputVariable(variable.Name, trainingVar);
            }
            subGraphRunContext.RemoveVariable(variable);

            // Change the subgraph to use the model variable as output.
            varName = input.Outputs.Model.VarName;
            if (!subGraphRunContext.TryGetVariable(varName, out variable))
            {
                throw env.Except($"Invalid variable name '{varName}'.");
            }
            string outputVarName = node.GetOutputVariableName("PredictorModel");

            foreach (var subGraphNode in subGraphNodes)
            {
                subGraphNode.RenameOutputVariable(variable.Name, outputVarName);
            }
            subGraphRunContext.RemoveVariable(variable);

            // Move the variables from the subcontext to the main context.
            node.Context.AddContextVariables(subGraphRunContext);

            // Change all the subgraph nodes to use the main context.
            foreach (var subGraphNode in subGraphNodes)
            {
                subGraphNode.SetContext(node.Context);
            }

            // Add the scoring node.
            var testingVar = node.GetInputVariable("TestingData");
            var exp        = new Experiment(env);
            var scoreNode  = new Legacy.Transforms.DatasetScorer();

            scoreNode.Data.VarName           = testingVar.ToJson();
            scoreNode.PredictorModel.VarName = outputVarName;
            var scoreNodeOutput = exp.Add(scoreNode);

            subGraphNodes.AddRange(EntryPointNode.ValidateNodes(env, node.Context, exp.GetNodes(), node.Catalog));

            // Add the evaluator node.
            exp.Reset();
            var evalNode = new Legacy.Models.BinaryClassificationEvaluator();

            evalNode.Data.VarName = scoreNodeOutput.ScoredData.VarName;
            var    evalOutput = new Legacy.Models.BinaryClassificationEvaluator.Output();
            string outVariableName;

            if (node.OutputMap.TryGetValue("Warnings", out outVariableName))
            {
                evalOutput.Warnings.VarName = outVariableName;
            }
            if (node.OutputMap.TryGetValue("OverallMetrics", out outVariableName))
            {
                evalOutput.OverallMetrics.VarName = outVariableName;
            }
            if (node.OutputMap.TryGetValue("PerInstanceMetrics", out outVariableName))
            {
                evalOutput.PerInstanceMetrics.VarName = outVariableName;
            }
            if (node.OutputMap.TryGetValue("ConfusionMatrix", out outVariableName))
            {
                evalOutput.ConfusionMatrix.VarName = outVariableName;
            }
            exp.Add(evalNode, evalOutput);
            subGraphNodes.AddRange(EntryPointNode.ValidateNodes(env, node.Context, exp.GetNodes(), node.Catalog));

            var stageId = Guid.NewGuid().ToString("N");

            foreach (var subGraphNode in subGraphNodes)
            {
                subGraphNode.StageId = stageId;
            }

            return(new CommonOutputs.MacroOutput <Output>()
            {
                Nodes = subGraphNodes
            });
        }
Beispiel #5
0
        public static CommonOutputs.MacroOutput <Output> TrainTest(
            IHostEnvironment env,
            Arguments input,
            EntryPointNode node)
        {
            // Create default pipeline ID if one not given.
            input.PipelineId = input.PipelineId ?? Guid.NewGuid().ToString("N");

            // Parse the subgraph.
            var subGraphRunContext = new RunContext(env);
            var subGraphNodes      = EntryPointNode.ValidateNodes(env, subGraphRunContext, input.Nodes, label: input.LabelColumn,
                                                                  input.GroupColumn.IsExplicit ? input.GroupColumn.Value : null,
                                                                  input.WeightColumn.IsExplicit ? input.WeightColumn.Value : null,
                                                                  input.NameColumn.IsExplicit ? input.NameColumn.Value : null);

            // Change the subgraph to use the training data as input.
            var             varName = input.Inputs.Data.VarName;
            VariableBinding transformModelVarName = null;

            if (input.TransformModel != null)
            {
                transformModelVarName = node.GetInputVariable(nameof(input.TransformModel));
            }

            if (!subGraphRunContext.TryGetVariable(varName, out var dataVariable))
            {
                throw env.Except($"Invalid variable name '{varName}'.");
            }
            var trainingVar = node.GetInputVariable(nameof(input.TrainingData));

            foreach (var subGraphNode in subGraphNodes)
            {
                subGraphNode.RenameInputVariable(dataVariable.Name, trainingVar);
            }
            subGraphRunContext.RemoveVariable(dataVariable);

            // Change the subgraph to use the model variable as output.
            varName = input.Outputs.PredictorModel.VarName;
            if (!subGraphRunContext.TryGetVariable(varName, out dataVariable))
            {
                throw env.Except($"Invalid variable name '{varName}'.");
            }

            string predictorModelVarName = node.GetOutputVariableName(nameof(Output.PredictorModel));

            foreach (var subGraphNode in subGraphNodes)
            {
                subGraphNode.RenameOutputVariable(dataVariable.Name, predictorModelVarName);
            }
            subGraphRunContext.RemoveVariable(dataVariable);

            // Move the variables from the subcontext to the main context.
            node.Context.AddContextVariables(subGraphRunContext);

            // Change all the subgraph nodes to use the main context.
            foreach (var subGraphNode in subGraphNodes)
            {
                subGraphNode.SetContext(node.Context);
            }

            // Testing using test data set
            var testingVar = node.GetInputVariable(nameof(input.TestingData));
            //var exp = new Experiment(env);

            Dictionary <string, List <ParameterBinding> >  inputBindingMap;
            Dictionary <ParameterBinding, VariableBinding> inputMap;
            ParameterBinding            paramBinding;
            Dictionary <string, string> outputMap;

            //combine the predictor model with any potential transfrom model passed from the outer graph
            if (transformModelVarName != null && transformModelVarName.VariableName != null)
            {
                var combineArgs = new ModelOperations.SimplePredictorModelInput();
                inputBindingMap = new Dictionary <string, List <ParameterBinding> >();
                inputMap        = new Dictionary <ParameterBinding, VariableBinding>();

                var inputTransformModel = new SimpleVariableBinding(transformModelVarName.VariableName);
                var inputPredictorModel = new SimpleVariableBinding(predictorModelVarName);
                paramBinding = new SimpleParameterBinding(nameof(combineArgs.TransformModel));
                inputBindingMap.Add(nameof(combineArgs.TransformModel), new List <ParameterBinding>()
                {
                    paramBinding
                });
                inputMap.Add(paramBinding, inputTransformModel);
                paramBinding = new SimpleParameterBinding(nameof(combineArgs.PredictorModel));
                inputBindingMap.Add(nameof(combineArgs.PredictorModel), new List <ParameterBinding>()
                {
                    paramBinding
                });
                inputMap.Add(paramBinding, inputPredictorModel);
                outputMap = new Dictionary <string, string>();

                var combineNodeOutputPredictorModel = new Var <PredictorModel>();
                predictorModelVarName = combineNodeOutputPredictorModel.VarName;
                outputMap.Add(nameof(ModelOperations.PredictorModelOutput.PredictorModel), combineNodeOutputPredictorModel.VarName);
                EntryPointNode combineNode = EntryPointNode.Create(env, "Transforms.TwoHeterogeneousModelCombiner", combineArgs,
                                                                   node.Context, inputBindingMap, inputMap, outputMap);
                subGraphNodes.Add(combineNode);
            }

            // Add the scoring node for testing.
            var args = new ScoreModel.Input();

            inputBindingMap = new Dictionary <string, List <ParameterBinding> >();
            inputMap        = new Dictionary <ParameterBinding, VariableBinding>();
            paramBinding    = new SimpleParameterBinding(nameof(args.Data));
            inputBindingMap.Add(nameof(args.Data), new List <ParameterBinding>()
            {
                paramBinding
            });
            inputMap.Add(paramBinding, testingVar);
            var scoreNodeInputPredictorModel = new SimpleVariableBinding(predictorModelVarName);

            paramBinding = new SimpleParameterBinding(nameof(args.PredictorModel));
            inputBindingMap.Add(nameof(args.PredictorModel), new List <ParameterBinding>()
            {
                paramBinding
            });
            inputMap.Add(paramBinding, scoreNodeInputPredictorModel);

            var scoreNodeOutputScoredData       = new Var <IDataView>();
            var scoreNodeOutputScoringTransform = new Var <TransformModel>();

            outputMap = new Dictionary <string, string>();
            outputMap.Add(nameof(ScoreModel.Output.ScoredData), scoreNodeOutputScoredData.VarName);
            outputMap.Add(nameof(ScoreModel.Output.ScoringTransform), scoreNodeOutputScoringTransform.VarName);

            EntryPointNode scoreNode = EntryPointNode.Create(env, "Transforms.DatasetScorer", args,
                                                             node.Context, inputBindingMap, inputMap, outputMap);

            subGraphNodes.Add(scoreNode);
            var evalDataVarName = scoreNodeOutputScoredData.VarName;

            // REVIEW: add similar support for FeatureColumn.
            var settings = new MacroUtils.EvaluatorSettings
            {
                LabelColumn  = input.LabelColumn,
                WeightColumn = input.WeightColumn.IsExplicit ? input.WeightColumn.Value : null,
                GroupColumn  = input.GroupColumn.IsExplicit ? input.GroupColumn.Value : null,
                NameColumn   = input.NameColumn.IsExplicit ? input.NameColumn.Value : null
            };

            if (input.IncludeTrainingMetrics)
            {
                string evalTrainingDataVarName;
                args            = new ScoreModel.Input();
                inputBindingMap = new Dictionary <string, List <ParameterBinding> >();
                inputMap        = new Dictionary <ParameterBinding, VariableBinding>();
                paramBinding    = new SimpleParameterBinding(nameof(args.Data));
                inputBindingMap.Add(nameof(args.Data), new List <ParameterBinding>()
                {
                    paramBinding
                });
                inputMap.Add(paramBinding, trainingVar);
                scoreNodeInputPredictorModel = new SimpleVariableBinding(predictorModelVarName);
                paramBinding = new SimpleParameterBinding(nameof(args.PredictorModel));
                inputBindingMap.Add(nameof(args.PredictorModel), new List <ParameterBinding>()
                {
                    paramBinding
                });
                inputMap.Add(paramBinding, scoreNodeInputPredictorModel);

                scoreNodeOutputScoredData       = new Var <IDataView>();
                scoreNodeOutputScoringTransform = new Var <TransformModel>();
                outputMap = new Dictionary <string, string>();
                outputMap.Add(nameof(ScoreModel.Output.ScoredData), scoreNodeOutputScoredData.VarName);
                outputMap.Add(nameof(ScoreModel.Output.ScoringTransform), scoreNodeOutputScoringTransform.VarName);

                scoreNode = EntryPointNode.Create(env, "Transforms.DatasetScorer", args,
                                                  node.Context, inputBindingMap, inputMap, outputMap);
                subGraphNodes.Add(scoreNode);
                evalTrainingDataVarName = scoreNodeOutputScoredData.VarName;

                // Add the evaluator node for training.
                var evalTrainingArgs = MacroUtils.GetEvaluatorArgs(input.Kind, out var evalTrainingEntryPointName, settings);
                inputBindingMap = new Dictionary <string, List <ParameterBinding> >();
                inputMap        = new Dictionary <ParameterBinding, VariableBinding>();
                var evalTrainingNodeInputData = new SimpleVariableBinding(evalTrainingDataVarName);
                paramBinding = new SimpleParameterBinding(nameof(evalTrainingArgs.Data));
                inputBindingMap.Add(nameof(evalTrainingArgs.Data), new List <ParameterBinding>()
                {
                    paramBinding
                });
                inputMap.Add(paramBinding, evalTrainingNodeInputData);

                outputMap = new Dictionary <string, string>();
                if (node.OutputMap.TryGetValue(nameof(Output.TrainingWarnings), out var outTrainingVariableName))
                {
                    outputMap.Add(nameof(CommonOutputs.ClassificationEvaluateOutput.Warnings), outTrainingVariableName);
                }
                if (node.OutputMap.TryGetValue(nameof(Output.TrainingOverallMetrics), out outTrainingVariableName))
                {
                    outputMap.Add(nameof(CommonOutputs.ClassificationEvaluateOutput.OverallMetrics), outTrainingVariableName);
                }
                if (node.OutputMap.TryGetValue(nameof(Output.TrainingPerInstanceMetrics), out outTrainingVariableName))
                {
                    outputMap.Add(nameof(CommonOutputs.ClassificationEvaluateOutput.PerInstanceMetrics), outTrainingVariableName);
                }
                if (node.OutputMap.TryGetValue(nameof(Output.TrainingConfusionMatrix), out outTrainingVariableName))
                {
                    outputMap.Add(nameof(CommonOutputs.ClassificationEvaluateOutput.ConfusionMatrix), outTrainingVariableName);
                }
                EntryPointNode evalTrainingNode = EntryPointNode.Create(env, evalTrainingEntryPointName, evalTrainingArgs, node.Context, inputBindingMap, inputMap, outputMap);
                subGraphNodes.Add(evalTrainingNode);
            }

            // Add the evaluator node for testing.
            var evalArgs = MacroUtils.GetEvaluatorArgs(input.Kind, out var evalEntryPointName, settings);

            inputBindingMap = new Dictionary <string, List <ParameterBinding> >();
            inputMap        = new Dictionary <ParameterBinding, VariableBinding>();
            var evalNodeInputData = new SimpleVariableBinding(evalDataVarName);

            paramBinding = new SimpleParameterBinding(nameof(evalArgs.Data));
            inputBindingMap.Add(nameof(evalArgs.Data), new List <ParameterBinding>()
            {
                paramBinding
            });
            inputMap.Add(paramBinding, evalNodeInputData);

            outputMap = new Dictionary <string, string>();
            if (node.OutputMap.TryGetValue(nameof(Output.Warnings), out var outVariableName))
            {
                outputMap.Add(nameof(CommonOutputs.ClassificationEvaluateOutput.Warnings), outVariableName);
            }
            if (node.OutputMap.TryGetValue(nameof(Output.OverallMetrics), out outVariableName))
            {
                outputMap.Add(nameof(CommonOutputs.ClassificationEvaluateOutput.OverallMetrics), outVariableName);
            }
            if (node.OutputMap.TryGetValue(nameof(Output.PerInstanceMetrics), out outVariableName))
            {
                outputMap.Add(nameof(CommonOutputs.ClassificationEvaluateOutput.PerInstanceMetrics), outVariableName);
            }
            if (node.OutputMap.TryGetValue(nameof(Output.ConfusionMatrix), out outVariableName))
            {
                outputMap.Add(nameof(CommonOutputs.ClassificationEvaluateOutput.ConfusionMatrix), outVariableName);
            }
            EntryPointNode evalNode = EntryPointNode.Create(env, evalEntryPointName, evalArgs, node.Context, inputBindingMap, inputMap, outputMap);

            subGraphNodes.Add(evalNode);

            // Marks as an atomic unit that can be run in
            // a distributed fashion.
            foreach (var subGraphNode in subGraphNodes)
            {
                subGraphNode.StageId = input.PipelineId;
            }

            return(new CommonOutputs.MacroOutput <Output>()
            {
                Nodes = subGraphNodes
            });
        }
        static IDataTransform Create(IHostEnvironment env, Arguments args, IDataView input, out IDataView sourceCtx)
        {
            sourceCtx = input;
            env.CheckValue(args.tag, "Tag cannot be empty.");
            if (TagHelper.EnumerateTaggedView(true, input).Where(c => c.Item1 == args.tag).Any())
            {
                throw env.Except("Tag '{0}' is already used.", args.tag);
            }
            env.CheckValue(args.selectTag, "Selected tag cannot be empty.");

            if (string.IsNullOrEmpty(args.filename))
            {
                var selected = TagHelper.EnumerateTaggedView(true, input).Where(c => c.Item1 == args.selectTag);
                if (!selected.Any())
                {
                    throw env.Except("Unable to find a view to select with tag '{0}'. Did you forget to specify a filename?", args.selectTag);
                }
                var first = selected.First();
                if (selected.Skip(1).Any())
                {
                    throw env.Except("Tag '{0}' is ambiguous, {1} views were found.", args.selectTag, selected.Count());
                }
                var tagged = input as ITaggedDataView;
                if (tagged == null)
                {
                    var ag = new TagViewTransform.Arguments {
                        tag = args.tag
                    };
                    tagged = new TagViewTransform(env, ag, input);
                }
                first.Item2.AddRange(new[] { new Tuple <string, ITaggedDataView>(args.tag, tagged) });
                tagged.AddRange(new[] { new Tuple <string, ITaggedDataView>(args.selectTag, first.Item2) });
#if (DEBUG_TIP)
                long count = DataViewUtils.ComputeRowCount(tagged);
                if (count == 0)
                {
                    throw env.Except("Replaced view is empty.");
                }
                count = DataViewUtils.ComputeRowCount(first.Item2);
                if (count == 0)
                {
                    throw env.Except("Selected view is empty.");
                }
#endif
                var tr = first.Item2 as IDataTransform;
                env.AssertValue(tr);
                return(tr);
            }
            else
            {
                if (!File.Exists(args.filename))
                {
                    throw env.Except("Unable to find file '{0}'.", args.filename);
                }
                var selected = TagHelper.EnumerateTaggedView(true, input).Where(c => c.Item1 == args.selectTag);
                if (selected.Any())
                {
                    throw env.Except("Tag '{0}' was already given. It cannot be assigned to the new file.", args.selectTag);
                }
                var loaderArgs   = new BinaryLoader.Arguments();
                var file         = new MultiFileSource(args.filename);
                var loadSettings = ScikitSubComponent <IDataLoader, SignatureDataLoader> .AsSubComponent(args.loaderSettings);

                IDataView loader = loadSettings.CreateInstance(env, file);

                var ag = new TagViewTransform.Arguments {
                    tag = args.selectTag
                };
                var newInput = new TagViewTransform(env, ag, loader);
                var tagged   = input as ITaggedDataView;
                if (tagged == null)
                {
                    ag = new TagViewTransform.Arguments {
                        tag = args.tag
                    };
                    tagged = new TagViewTransform(env, ag, input);
                }

                newInput.AddRange(new[] { new Tuple <string, ITaggedDataView>(args.tag, tagged) });
                tagged.AddRange(new[] { new Tuple <string, ITaggedDataView>(args.selectTag, newInput) });

                var schema = loader.Schema;
                if (schema.Count == 0)
                {
                    throw env.Except("The loaded view '{0}' is empty (empty schema).", args.filename);
                }
                return(newInput);
            }
        }
        IDataTransform Create(IHostEnvironment env, Arguments args, IDataView input, out IDataView sourceCtx)
        {
            sourceCtx = input;
            Contracts.CheckValue(env, "env");
            env.CheckValue(args, "args");
            env.CheckValue(input, "input");
            env.CheckValue(args.tag, "tag is empty");
            env.CheckValue(args.trainer, "trainer",
                           "Trainer cannot be null. If your model is already trained, please use ScoreTransform instead.");

            var views = TagHelper.EnumerateTaggedView(true, input).Where(c => c.Item1 == args.tag);

            if (views.Any())
            {
                throw env.Except("Tag '{0}' is already used.", args.tag);
            }

            var host = env.Register("TagTrainOrScoreTransform");

            using (var ch = host.Start("Train"))
            {
                ch.Trace("Constructing trainer");
                var trainerSett = ScikitSubComponent <ITrainer, SignatureTrainer> .AsSubComponent(args.trainer);

                ITrainer trainer    = trainerSett.CreateInstance(host);
                var      customCols = TrainUtils.CheckAndGenerateCustomColumns(env, args.CustomColumn);

                string             feat;
                string             group;
                var                data       = CreateDataFromArgs(_host, ch, new OpaqueDataView(input), args, out feat, out group);
                ICalibratorTrainer calibrator = args.calibrator == null
                                    ? null
                                    : ScikitSubComponent <ICalibratorTrainer, SignatureCalibrator> .AsSubComponent(args.calibrator).CreateInstance(host);

                var nameTrainer = args.trainer.ToString().Replace("{", "").Replace("}", "").Replace(" ", "").Replace("=", "").Replace("+", "Y").Replace("-", "N");
                var extTrainer  = new ExtendedTrainer(trainer, nameTrainer);
                _predictor = extTrainer.Train(host, ch, data, null, calibrator, args.maxCalibrationExamples);

                if (!string.IsNullOrEmpty(args.outputModel))
                {
                    ch.Info("Saving model into '{0}'", args.outputModel);
                    using (var fs = File.Create(args.outputModel))
                        TrainUtils.SaveModel(env, ch, fs, _predictor, data);
                    ch.Info("Done.");
                }

                if (_cali != null)
                {
                    throw ch.ExceptNotImpl("Calibrator is not implemented yet.");
                }

                ch.Trace("Scoring");
                if (_args.scorer != null)
                {
                    var mapper   = new SchemaBindablePredictorWrapper(_predictor);
                    var roles    = new RoleMappedSchema(input.Schema, null, feat, group: group);
                    var bound    = (mapper as ISchemaBindableMapper).Bind(_host, roles);
                    var scorPars = ScikitSubComponent <IDataScorerTransform, SignatureDataScorer> .AsSubComponent(_args.scorer);

                    _scorer = scorPars.CreateInstance(_host, input, bound, roles);
                }
                else
                {
                    _scorer = PredictorHelper.CreateDefaultScorer(_host, input, feat, group, _predictor);
                }

                ch.Info("Tagging with tag '{0}'.", args.tag);

                var ar = new TagViewTransform.Arguments {
                    tag = args.tag
                };
                var res = new TagViewTransform(env, ar, _scorer, _predictor);
                return(res);
            }
        }
Beispiel #8
0
        public static CommonOutputs.MacroOutput <Output> TrainTest(
            IHostEnvironment env,
            Arguments input,
            EntryPointNode node)
        {
            // Create default pipeline ID if one not given.
            input.PipelineId = input.PipelineId ?? Guid.NewGuid().ToString("N");

            // Parse the subgraph.
            var subGraphRunContext = new RunContext(env);
            var subGraphNodes      = EntryPointNode.ValidateNodes(env, subGraphRunContext, input.Nodes, node.Catalog);

            // Change the subgraph to use the training data as input.
            var             varName = input.Inputs.Data.VarName;
            VariableBinding transformModelVarName = null;

            if (input.TransformModel != null)
            {
                transformModelVarName = node.GetInputVariable(nameof(input.TransformModel));
            }

            if (!subGraphRunContext.TryGetVariable(varName, out var dataVariable))
            {
                throw env.Except($"Invalid variable name '{varName}'.");
            }
            var trainingVar = node.GetInputVariable(nameof(input.TrainingData));

            foreach (var subGraphNode in subGraphNodes)
            {
                subGraphNode.RenameInputVariable(dataVariable.Name, trainingVar);
            }
            subGraphRunContext.RemoveVariable(dataVariable);

            // Change the subgraph to use the model variable as output.
            varName = input.Outputs.PredictorModel == null ? input.Outputs.TransformModel.VarName : input.Outputs.PredictorModel.VarName;
            if (!subGraphRunContext.TryGetVariable(varName, out dataVariable))
            {
                throw env.Except($"Invalid variable name '{varName}'.");
            }

            string outputVarName = input.Outputs.PredictorModel == null?node.GetOutputVariableName(nameof(Output.TransformModel)) :
                                       node.GetOutputVariableName(nameof(Output.PredictorModel));

            foreach (var subGraphNode in subGraphNodes)
            {
                subGraphNode.RenameOutputVariable(dataVariable.Name, outputVarName);
            }
            subGraphRunContext.RemoveVariable(dataVariable);

            // Move the variables from the subcontext to the main context.
            node.Context.AddContextVariables(subGraphRunContext);

            // Change all the subgraph nodes to use the main context.
            foreach (var subGraphNode in subGraphNodes)
            {
                subGraphNode.SetContext(node.Context);
            }

            // Testing using test data set
            var testingVar = node.GetInputVariable(nameof(input.TestingData));
            var exp        = new Experiment(env);

            DatasetScorer.Output scoreNodeOutput = null;
            ML.Models.DatasetTransformer.Output datasetTransformNodeOutput = null;
            if (input.Outputs.PredictorModel == null)
            {
                //combine the predictor model with any potential transfrom model passed from the outer graph
                if (transformModelVarName != null && transformModelVarName.VariableName != null)
                {
                    var modelCombine = new ML.Transforms.ModelCombiner
                    {
                        Models = new ArrayVar <ITransformModel>(
                            new Var <ITransformModel>[] {
                            new Var <ITransformModel> {
                                VarName = transformModelVarName.VariableName
                            },
                            new Var <ITransformModel> {
                                VarName = outputVarName
                            }
                        }
                            )
                    };

                    var modelCombineOutput = exp.Add(modelCombine);
                    outputVarName = modelCombineOutput.OutputModel.VarName;
                }

                var datasetTransformerNode = new Models.DatasetTransformer
                {
                    Data           = { VarName = testingVar.ToJson() },
                    TransformModel = { VarName = outputVarName }
                };

                datasetTransformNodeOutput = exp.Add(datasetTransformerNode);
            }
            else
            {
                //combine the predictor model with any potential transfrom model passed from the outer graph
                if (transformModelVarName != null && transformModelVarName.VariableName != null)
                {
                    var modelCombine = new TwoHeterogeneousModelCombiner
                    {
                        TransformModel = { VarName = transformModelVarName.VariableName },
                        PredictorModel = { VarName = outputVarName }
                    };

                    var modelCombineOutput = exp.Add(modelCombine);
                    outputVarName = modelCombineOutput.PredictorModel.VarName;
                }

                // Add the scoring node for testing.
                var scoreNode = new DatasetScorer
                {
                    Data           = { VarName = testingVar.ToJson() },
                    PredictorModel = { VarName = outputVarName }
                };

                scoreNodeOutput = exp.Add(scoreNode);
            }

            subGraphNodes.AddRange(EntryPointNode.ValidateNodes(env, node.Context, exp.GetNodes(), node.Catalog));

            // Do not double-add previous nodes.
            exp.Reset();

            // REVIEW: we need to extract the proper label column name here to pass to the evaluators.
            // This is where you would add code to do it.
            var settings = new MacroUtils.EvaluatorSettings
            {
                LabelColumn = DefaultColumnNames.Label
            };

            string outVariableName;

            if (input.IncludeTrainingMetrics)
            {
                DatasetScorer.Output scoreNodeTrainingOutput = null;
                ML.Models.DatasetTransformer.Output datasetTransformNodeTrainingOutput = null;
                if (input.Outputs.PredictorModel == null)
                {
                    var datasetTransformerNode = new Models.DatasetTransformer
                    {
                        Data           = { VarName = testingVar.ToJson() },
                        TransformModel = { VarName = outputVarName }
                    };

                    datasetTransformNodeTrainingOutput = exp.Add(datasetTransformerNode);
                }
                else
                {
                    // Add the scoring node for training.
                    var scoreNodeTraining = new DatasetScorer
                    {
                        Data           = { VarName = trainingVar.ToJson() },
                        PredictorModel = { VarName = outputVarName }
                    };
                    scoreNodeTrainingOutput = exp.Add(scoreNodeTraining);
                }

                subGraphNodes.AddRange(EntryPointNode.ValidateNodes(env, node.Context, exp.GetNodes(), node.Catalog));

                // Do not double-add previous nodes.
                exp.Reset();

                // Add the evaluator node for training.
                var evalInputOutputTraining = MacroUtils.GetEvaluatorInputOutput(input.Kind, settings);
                var evalNodeTraining        = evalInputOutputTraining.Item1;
                var evalOutputTraining      = evalInputOutputTraining.Item2;
                evalNodeTraining.Data.VarName = input.Outputs.PredictorModel == null ? datasetTransformNodeTrainingOutput.OutputData.VarName :
                                                scoreNodeTrainingOutput.ScoredData.VarName;

                if (node.OutputMap.TryGetValue(nameof(Output.TrainingWarnings), out outVariableName))
                {
                    evalOutputTraining.Warnings.VarName = outVariableName;
                }
                if (node.OutputMap.TryGetValue(nameof(Output.TrainingOverallMetrics), out outVariableName))
                {
                    evalOutputTraining.OverallMetrics.VarName = outVariableName;
                }
                if (node.OutputMap.TryGetValue(nameof(Output.TrainingPerInstanceMetrics), out outVariableName))
                {
                    evalOutputTraining.PerInstanceMetrics.VarName = outVariableName;
                }
                if (node.OutputMap.TryGetValue(nameof(Output.TrainingConfusionMatrix), out outVariableName) &&
                    evalOutputTraining is CommonOutputs.IClassificationEvaluatorOutput eoTraining)
                {
                    eoTraining.ConfusionMatrix.VarName = outVariableName;
                }

                exp.Add(evalNodeTraining, evalOutputTraining);
                subGraphNodes.AddRange(EntryPointNode.ValidateNodes(env, node.Context, exp.GetNodes(), node.Catalog));
            }

            // Do not double-add previous nodes.
            exp.Reset();

            // Add the evaluator node for testing.
            var evalInputOutput = MacroUtils.GetEvaluatorInputOutput(input.Kind, settings);
            var evalNode        = evalInputOutput.Item1;
            var evalOutput      = evalInputOutput.Item2;

            evalNode.Data.VarName = input.Outputs.PredictorModel == null ? datasetTransformNodeOutput.OutputData.VarName : scoreNodeOutput.ScoredData.VarName;

            if (node.OutputMap.TryGetValue(nameof(Output.Warnings), out outVariableName))
            {
                evalOutput.Warnings.VarName = outVariableName;
            }
            if (node.OutputMap.TryGetValue(nameof(Output.OverallMetrics), out outVariableName))
            {
                evalOutput.OverallMetrics.VarName = outVariableName;
            }
            if (node.OutputMap.TryGetValue(nameof(Output.PerInstanceMetrics), out outVariableName))
            {
                evalOutput.PerInstanceMetrics.VarName = outVariableName;
            }
            if (node.OutputMap.TryGetValue(nameof(Output.ConfusionMatrix), out outVariableName) &&
                evalOutput is CommonOutputs.IClassificationEvaluatorOutput eo)
            {
                eo.ConfusionMatrix.VarName = outVariableName;
            }

            exp.Add(evalNode, evalOutput);
            subGraphNodes.AddRange(EntryPointNode.ValidateNodes(env, node.Context, exp.GetNodes(), node.Catalog));

            // Marks as an atomic unit that can be run in
            // a distributed fashion.
            foreach (var subGraphNode in subGraphNodes)
            {
                subGraphNode.StageId = input.PipelineId;
            }

            return(new CommonOutputs.MacroOutput <Output>()
            {
                Nodes = subGraphNodes
            });
        }
Beispiel #9
0
        public static ScikitOnnxContext ToOnnx(IDataTransform trans, ref string[] inputs, ref string[] outputs,
                                               string name             = null, string producer = "Scikit.ML",
                                               long version            = 0, string domain      = "onnx.ai.ml",
                                               OnnxVersion onnxVersion = OnnxVersion.Stable,
                                               IDataView[] begin       = null, IHostEnvironment host = null)
        {
            if (host == null)
            {
                using (var env = new DelegateEnvironment())
                    return(ToOnnx(trans, ref inputs, ref outputs, name, producer, version, domain, onnxVersion, begin, env));
            }
            if (name == null)
            {
                name = trans.GetType().Name;
            }

            GuessOutputs(trans, ref outputs);
            GuessInputs(trans, ref inputs, begin);

            if (inputs == null || inputs.Length == 0)
            {
                throw host.Except("Inputs cannot be empty.");
            }
            if (outputs == null || outputs.Length == 0)
            {
                throw host.Except("Outputs cannot be empty.");
            }

            var assembly    = System.Reflection.Assembly.GetExecutingAssembly();
            var versionInfo = System.Diagnostics.FileVersionInfo.GetVersionInfo(assembly.Location);
            var ctx         = new ScikitOnnxContext(host, name, producer, versionInfo.FileVersion,
                                                    version, domain, onnxVersion);

            var hasin        = new HashSet <string>(inputs);
            var uniqueVars   = EnumerateVariables(trans, begin).ToArray();
            var mapInputType = new Dictionary <string, ColumnType>();

            foreach (var it in uniqueVars.Where(c => c.position == TagHelper.GraphPositionEnum.first))
            {
                mapInputType[it.variableName] = it.variableType;
            }

            foreach (var col in inputs)
            {
                ctx.AddInputVariable(mapInputType[col], col);
            }

            var views      = TagHelper.EnumerateAllViews(trans, begin);
            var transforms = views.Where(c => (c.Item1 as IDataTransform) != null)
                             .Select(c => c.Item1)
                             .ToArray();

            foreach (var tr in transforms.Reverse())
            {
                var tron = tr as ICanSaveOnnx;
                if (tron == null)
                {
                    throw host.ExceptNotSupp($"Transform {tr.GetType()} cannot be saved in Onnx format.");
                }
                if (!tron.CanSaveOnnx(ctx))
                {
                    throw host.ExceptNotSupp($"Transform {tr.GetType()} cannot be saved in ONNX format.");
                }
                var tron2 = tron as ISaveAsOnnx;
                if (!tron2.CanSaveOnnx(ctx))
                {
                    throw host.ExceptNotSupp($"Transform {tr.GetType()} does not implement SaveAsOnnx.");
                }
                tron2.SaveAsOnnx(ctx);
            }

            var mapOuputType = new Dictionary <string, ColumnType>();

            foreach (var it in uniqueVars.Where(c => c.position == TagHelper.GraphPositionEnum.last))
            {
                mapOuputType[it.variableName] = it.variableType;
            }

            foreach (var col in outputs)
            {
                var variableName     = ctx.TryGetVariableName(col);
                var trueVariableName = ctx.AddIntermediateVariable(null, col, true);
                ctx.CreateNode("Identity", variableName, trueVariableName, ctx.GetNodeName("Identity"), "");
                ctx.AddOutputVariable(mapOuputType[col], trueVariableName);
            }

            return(ctx);
        }
Beispiel #10
0
        public JArray GetNodes()
        {
            JObject json;

            try
            {
                json = JObject.Parse($"{{'nodes': [{string.Join(",", _jsonNodes)}]}}");
            }
            catch (JsonReaderException ex)
            {
                throw _env.Except(ex, "Failed to parse experiment graph: {0}", ex.Message);
            }

            return(json["nodes"] as JArray);
        }
        protected virtual IDataTransform Create(IHostEnvironment env, Arguments args, IDataView input, out IDataView sourceCtx, IPredictor overwritePredictor)
        {
            sourceCtx = input;
            Contracts.CheckValue(env, "env");
            env.CheckValue(args, "args");
            env.CheckValue(input, "input");

            IPredictor predictor;

            if (overwritePredictor == null)
            {
                throw env.Except("No defined predictor.");
            }
            else
            {
                predictor = overwritePredictor;
            }

            // The function is returning something and modifying a member of the class. Not very fancy.
            _predictor = predictor;

            string feat = TrainUtils.MatchNameOrDefaultOrNull(env, input.Schema,
                                                              "featureColumn", args.featureColumn, DefaultColumnNames.Features);
            int index = SchemaHelper.GetColumnIndex(input.Schema, feat);
            var type  = input.Schema[index].Type;

            if (!type.IsVector() || type.AsVector().ItemType().RawKind() != DataKind.R4)
            {
                throw env.Except("Features must a vector of floats");
            }

            if (args.useProb)
            {
                var valueMapper = predictor as IValueMapperDist;
                if (valueMapper == null)
                {
                    throw env.Except("Predictor must be a IValueMapper.");
                }
                var output = valueMapper.DistType;
                if (output.IsVector())
                {
                    return(CreateTransformValueMapperDist <VBuffer <float>, VBuffer <float>, VBuffer <float> >(valueMapper, feat, args.outputColumn));
                }
                else
                {
                    return(CreateTransformValueMapperDist <VBuffer <float>, VBuffer <float>, float>(valueMapper, feat, args.outputColumn));
                }
            }
            else
            {
                var valueMapper = predictor as IValueMapper;
                if (valueMapper == null)
                {
                    throw env.Except("Predictor must be a IValueMapper.");
                }
                var output = valueMapper.OutputType;
                if (output.IsVector())
                {
                    return(CreateTransformValueMapper <VBuffer <float>, VBuffer <float> >(valueMapper, feat, args.outputColumn));
                }
                else
                {
                    return(CreateTransformValueMapper <VBuffer <float>, float>(valueMapper, feat, args.outputColumn));
                }
            }
        }