public static CommonOutputs.TransformOutput Generate(IHostEnvironment env, GenerateNumberTransform.Arguments input)
        {
            var h  = EntryPointUtils.CheckArgsAndCreateHost(env, "GenerateNumber", input);
            var xf = new GenerateNumberTransform(h, input, input.Data);

            return(new CommonOutputs.TransformOutput()
            {
                Model = new TransformModel(h, xf, input.Data),
                OutputData = xf
            });
        }
Exemple #2
0
        private IDataView WrapPerInstance(RoleMappedData perInst)
        {
            var idv = perInst.Data;

            // Make a list of column names that Maml outputs as part of the per-instance data view, and then wrap
            // the per-instance data computed by the evaluator in a SelectColumnsTransform.
            var cols       = new List <(string Source, string Name)>();
            var colsToKeep = new List <string>();

            // If perInst is the result of cross-validation and contains a fold Id column, include it.
            int foldCol;

            if (perInst.Schema.Schema.TryGetColumnIndex(MetricKinds.ColumnNames.FoldIndex, out foldCol))
            {
                colsToKeep.Add(MetricKinds.ColumnNames.FoldIndex);
            }

            // Maml always outputs a name column, if it doesn't exist add a GenerateNumberTransform.
            if (perInst.Schema.Name == null)
            {
                var args = new GenerateNumberTransform.Arguments();
                args.Column = new[] { new GenerateNumberTransform.Column()
                                      {
                                          Name = "Instance"
                                      } };
                args.UseCounter = true;
                idv             = new GenerateNumberTransform(Host, args, idv);
                colsToKeep.Add("Instance");
            }
            else
            {
                cols.Add((perInst.Schema.Name.Name, "Instance"));
                colsToKeep.Add("Instance");
            }

            // Maml outputs the weight column if it exists.
            if (perInst.Schema.Weight != null)
            {
                colsToKeep.Add(perInst.Schema.Weight.Name);
            }

            // Get the other columns from the evaluator.
            foreach (var col in GetPerInstanceColumnsToSave(perInst.Schema))
            {
                colsToKeep.Add(col);
            }

            idv = new ColumnsCopyingTransformer(Host, cols.ToArray()).Transform(idv);
            idv = ColumnSelectingTransformer.CreateKeep(Host, idv, colsToKeep.ToArray());
            return(GetPerInstanceMetricsCore(idv, perInst.Schema));
        }
        private string GetSplitColumn(IChannel ch, IDataView input, ref IDataView output)
        {
            // The stratification column and/or group column, if they exist at all, must be present at this point.
            var schema = input.Schema;

            output = input;
            // If no stratification column was specified, but we have a group column of type Single, Double or
            // Key (contiguous) use it.
            string stratificationColumn = null;

            if (!string.IsNullOrWhiteSpace(Args.StratificationColumn))
            {
                stratificationColumn = Args.StratificationColumn;
            }
            else
            {
                string group = TrainUtils.MatchNameOrDefaultOrNull(ch, schema, nameof(Args.GroupColumn), Args.GroupColumn, DefaultColumnNames.GroupId);
                int    index;
                if (group != null && schema.TryGetColumnIndex(group, out index))
                {
                    // Check if group column key type with known cardinality.
                    var type = schema.GetColumnType(index);
                    if (type.KeyCount > 0)
                    {
                        stratificationColumn = group;
                    }
                }
            }

            if (string.IsNullOrEmpty(stratificationColumn))
            {
                stratificationColumn = "StratificationColumn";
                int tmp;
                int inc = 0;
                while (input.Schema.TryGetColumnIndex(stratificationColumn, out tmp))
                {
                    stratificationColumn = string.Format("StratificationColumn_{0:000}", ++inc);
                }
                var keyGenArgs = new GenerateNumberTransform.Arguments();
                var col        = new GenerateNumberTransform.Column();
                col.Name          = stratificationColumn;
                keyGenArgs.Column = new[] { col };
                output            = new GenerateNumberTransform(Host, keyGenArgs, input);
            }
            else
            {
                int col;
                if (!input.Schema.TryGetColumnIndex(stratificationColumn, out col))
                {
                    throw ch.ExceptUserArg(nameof(Arguments.StratificationColumn), "Column '{0}' does not exist", stratificationColumn);
                }
                var type = input.Schema.GetColumnType(col);
                if (!RangeFilter.IsValidRangeFilterColumnType(ch, type))
                {
                    ch.Info("Hashing the stratification column");
                    var origStratCol = stratificationColumn;
                    int tmp;
                    int inc = 0;
                    while (input.Schema.TryGetColumnIndex(stratificationColumn, out tmp))
                    {
                        stratificationColumn = string.Format("{0}_{1:000}", origStratCol, ++inc);
                    }
                    var hashargs = new HashTransform.Arguments();
                    hashargs.Column = new[] { new HashTransform.Column {
                                                  Source = origStratCol, Name = stratificationColumn
                                              } };
                    hashargs.HashBits = 30;
                    output            = new HashTransform(Host, hashargs, input);
                }
            }

            return(stratificationColumn);
        }
        private void RunCore(IChannel ch, string cmd)
        {
            Host.AssertValue(ch);

            IPredictor inputPredictor = null;

            if (Args.ContinueTrain && !TrainUtils.TryLoadPredictor(ch, Host, Args.InputModelFile, out inputPredictor))
            {
                ch.Warning("No input model file specified or model file did not contain a predictor. The model state cannot be initialized.");
            }

            ch.Trace("Constructing data pipeline");
            IDataLoader loader = CreateRawLoader();

            // If the per-instance results are requested and there is no name column, add a GenerateNumberTransform.
            var preXf = Args.PreTransform;

            if (!string.IsNullOrEmpty(Args.OutputDataFile))
            {
                string name = TrainUtils.MatchNameOrDefaultOrNull(ch, loader.Schema, nameof(Args.NameColumn), Args.NameColumn, DefaultColumnNames.Name);
                if (name == null)
                {
                    var args = new GenerateNumberTransform.Arguments();
                    args.Column = new[] { new GenerateNumberTransform.Column()
                                          {
                                              Name = DefaultColumnNames.Name
                                          }, };
                    args.UseCounter = true;
                    var options = CmdParser.GetSettings(ch, args, new GenerateNumberTransform.Arguments());
                    preXf = preXf.Concat(
                        new[]
                    {
                        new KeyValuePair <string, SubComponent <IDataTransform, SignatureDataTransform> >(
                            "", new SubComponent <IDataTransform, SignatureDataTransform>(
                                GenerateNumberTransform.LoadName, options))
                    }).ToArray();
                }
            }
            loader = CompositeDataLoader.Create(Host, loader, preXf);

            ch.Trace("Binding label and features columns");

            IDataView pipe = loader;
            var       stratificationColumn = GetSplitColumn(ch, loader, ref pipe);
            var       scorer    = Args.Scorer;
            var       evaluator = Args.Evaluator;

            Func <IDataView> validDataCreator = null;

            if (Args.ValidationFile != null)
            {
                validDataCreator =
                    () =>
                {
                    // Fork the command.
                    var impl = new CrossValidationCommand(this);
                    return(impl.CreateRawLoader(dataFile: Args.ValidationFile));
                };
            }

            FoldHelper fold = new FoldHelper(Host, RegistrationName, pipe, stratificationColumn,
                                             Args, CreateRoleMappedData, ApplyAllTransformsToData, scorer, evaluator,
                                             validDataCreator, ApplyAllTransformsToData, inputPredictor, cmd, loader, !string.IsNullOrEmpty(Args.OutputDataFile));
            var tasks = fold.GetCrossValidationTasks();

            if (!evaluator.IsGood())
            {
                evaluator = EvaluateUtils.GetEvaluatorType(ch, tasks[0].Result.ScoreSchema);
            }
            var eval = evaluator.CreateInstance(Host);

            // Print confusion matrix and fold results for each fold.
            for (int i = 0; i < tasks.Length; i++)
            {
                var dict = tasks[i].Result.Metrics;
                MetricWriter.PrintWarnings(ch, dict);
                eval.PrintFoldResults(ch, dict);
            }

            // Print the overall results.
            if (!TryGetOverallMetrics(tasks.Select(t => t.Result.Metrics).ToArray(), out var overallList))
            {
                throw ch.Except("No overall metrics found");
            }

            var overall = eval.GetOverallResults(overallList.ToArray());

            MetricWriter.PrintOverallMetrics(Host, ch, Args.SummaryFilename, overall, Args.NumFolds);
            eval.PrintAdditionalMetrics(ch, tasks.Select(t => t.Result.Metrics).ToArray());
            Dictionary <string, IDataView>[] metricValues = tasks.Select(t => t.Result.Metrics).ToArray();
            SendTelemetryMetric(metricValues);

            // Save the per-instance results.
            if (!string.IsNullOrWhiteSpace(Args.OutputDataFile))
            {
                var perInstance = EvaluateUtils.ConcatenatePerInstanceDataViews(Host, eval, Args.CollateMetrics,
                                                                                Args.OutputExampleFoldIndex, tasks.Select(t => t.Result.PerInstanceResults).ToArray(), out var variableSizeVectorColumnNames);
                if (variableSizeVectorColumnNames.Length > 0)
                {
                    ch.Warning("Detected columns of variable length: {0}. Consider setting collateMetrics- for meaningful per-Folds results.",
                               string.Join(", ", variableSizeVectorColumnNames));
                }
                if (Args.CollateMetrics)
                {
                    ch.Assert(perInstance.Length == 1);
                    MetricWriter.SavePerInstance(Host, ch, Args.OutputDataFile, perInstance[0]);
                }
                else
                {
                    int i = 0;
                    foreach (var idv in perInstance)
                    {
                        MetricWriter.SavePerInstance(Host, ch, ConstructPerFoldName(Args.OutputDataFile, i), idv);
                        i++;
                    }
                }
            }
        }
        private void RunCore(IChannel ch, string cmd)
        {
            Host.AssertValue(ch);

            IPredictor inputPredictor = null;

            if (Args.ContinueTrain && !TrainUtils.TryLoadPredictor(ch, Host, Args.InputModelFile, out inputPredictor))
            {
                ch.Warning("No input model file specified or model file did not contain a predictor. The model state cannot be initialized.");
            }

            ch.Trace("Constructing data pipeline");
            IDataLoader loader = CreateRawLoader();

            // If the per-instance results are requested and there is no name column, add a GenerateNumberTransform.
            var preXf = Args.PreTransform;

            if (!string.IsNullOrEmpty(Args.OutputDataFile))
            {
                string name = TrainUtils.MatchNameOrDefaultOrNull(ch, loader.Schema, nameof(Args.NameColumn), Args.NameColumn, DefaultColumnNames.Name);
                if (name == null)
                {
                    var args = new GenerateNumberTransform.Arguments();
                    args.Column = new[] { new GenerateNumberTransform.Column()
                                          {
                                              Name = DefaultColumnNames.Name
                                          }, };
                    args.UseCounter = true;
                    var options = CmdParser.GetSettings(ch, args, new GenerateNumberTransform.Arguments());
                    preXf = preXf.Concat(
                        new[]
                    {
                        new KeyValuePair <string, SubComponent <IDataTransform, SignatureDataTransform> >(
                            "", new SubComponent <IDataTransform, SignatureDataTransform>(
                                GenerateNumberTransform.LoadName, options))
                    }).ToArray();
                }
            }
            loader = CompositeDataLoader.Create(Host, loader, preXf);

            ch.Trace("Binding label and features columns");

            IDataView pipe = loader;
            var       stratificationColumn = GetSplitColumn(ch, loader, ref pipe);
            var       scorer    = Args.Scorer;
            var       evaluator = Args.Evaluator;

            Func <IDataView> validDataCreator = null;

            if (Args.ValidationFile != null)
            {
                validDataCreator =
                    () =>
                {
                    // Fork the command.
                    var impl = new CrossValidationCommand(this);
                    return(impl.CreateRawLoader(dataFile: Args.ValidationFile));
                };
            }

            FoldHelper fold = new FoldHelper(Host, RegistrationName, pipe, stratificationColumn,
                                             Args, CreateRoleMappedData, ApplyAllTransformsToData, scorer, evaluator,
                                             validDataCreator, ApplyAllTransformsToData, inputPredictor, cmd, loader, !string.IsNullOrEmpty(Args.OutputDataFile));
            var tasks = fold.GetCrossValidationTasks();

            if (!evaluator.IsGood())
            {
                evaluator = EvaluateUtils.GetEvaluatorType(ch, tasks[0].Result.ScoreSchema);
            }
            var eval = evaluator.CreateInstance(Host);

            // Print confusion matrix and fold results for each fold.
            for (int i = 0; i < tasks.Length; i++)
            {
                var dict = tasks[i].Result.Metrics;
                MetricWriter.PrintWarnings(ch, dict);
                eval.PrintFoldResults(ch, dict);
            }

            // Print the overall results.
            eval.PrintOverallResults(ch, Args.SummaryFilename, tasks.Select(t => t.Result.Metrics).ToArray());
            Dictionary <string, IDataView>[] metricValues = tasks.Select(t => t.Result.Metrics).ToArray();
            SendTelemetryMetric(metricValues);

            // Save the per-instance results.
            if (!string.IsNullOrWhiteSpace(Args.OutputDataFile))
            {
                Func <Task <FoldHelper.FoldResult>, int, IDataView> getPerInstance =
                    (task, i) =>
                {
                    if (!Args.OutputExampleFoldIndex)
                    {
                        return(task.Result.PerInstanceResults);
                    }

                    // If the fold index is requested, add a column containing it. We use the first column in the data view
                    // as an input column to the LambdaColumnMapper, because it must have an input.
                    var inputColName = task.Result.PerInstanceResults.Schema.GetColumnName(0);
                    var inputColType = task.Result.PerInstanceResults.Schema.GetColumnType(0);
                    return(Utils.MarshalInvoke(EvaluateUtils.AddKeyColumn <int>, inputColType.RawType, Host,
                                               task.Result.PerInstanceResults, inputColName, MetricKinds.ColumnNames.FoldIndex,
                                               inputColType, Args.NumFolds, i + 1, "FoldIndex", default(ValueGetter <VBuffer <DvText> >)));
                };

                var foldDataViews = tasks.Select(getPerInstance).ToArray();
                if (Args.CollateMetrics)
                {
                    var perInst = AppendPerInstanceDataViews(foldDataViews, ch);
                    MetricWriter.SavePerInstance(Host, ch, Args.OutputDataFile, perInst);
                }
                else
                {
                    int i = 0;
                    foreach (var idv in foldDataViews)
                    {
                        MetricWriter.SavePerInstance(Host, ch, ConstructPerFoldName(Args.OutputDataFile, i), idv);
                        i++;
                    }
                }
            }
        }