Exemplo n.º 1
0
        // REVIEW: Need to change the help command to use the provided host environment for output,
        // instead of assuming the console.
        public HelpCommand(IHostEnvironment env, Arguments args)
        {
            Contracts.CheckValue(env, nameof(env));
            env.CheckValue(args, nameof(args));

            _env       = env;
            _component = args.Component;
            if (string.IsNullOrWhiteSpace(_component))
            {
                _component = null;
            }

            _kind = args.Kind;
            if (string.IsNullOrWhiteSpace(_kind))
            {
                _kind = null;
            }

            _listKinds     = args.ListKinds;
            _allComponents = args.AllComponents;

            _extraAssemblies = args.ExtraAssemblies;

            if (args.Generator != null)
            {
                _generator = args.Generator.CreateComponent(_env, "maml.exe ? " + CmdParser.GetSettings(env, args, new Arguments()));
            }
        }
Exemplo n.º 2
0
        private static IDataView MakeScorer(IHostEnvironment env, ISchema schema, string featureColumn, TModel model, BinaryClassifierScorer.Arguments args)
        {
            var settings = $"Binary{{{CmdParser.GetSettings(env, args, new BinaryClassifierScorer.Arguments())}}}";

            var scorerFactorySettings = CmdParser.CreateComponentFactory(
                typeof(IComponentFactory <IDataView, ISchemaBoundMapper, RoleMappedSchema, IDataScorerTransform>),
                typeof(SignatureDataScorer),
                settings);

            var bindable = ScoreUtils.GetSchemaBindableMapper(env, model, scorerFactorySettings: scorerFactorySettings);
            var edv      = new EmptyDataView(env, schema);
            var data     = new RoleMappedData(edv, "Label", featureColumn, opt: true);

            return(new BinaryClassifierScorer(env, args, data.Data, bindable.Bind(env, data.Schema), data.Schema));
        }
        public override void Run()
        {
            string command = "Test";
            using (var ch = Host.Start(command))
            using (var server = InitServer(ch))
            {
                var settings = CmdParser.GetSettings(Host, ImplOptions, new Arguments());
                ch.Info("maml.exe {0} {1}", command, settings);

                SendTelemetry(Host);
                using (new TimerScope(Host, ch))
                {
                    RunCore(ch);
                }
            }
        }
Exemplo n.º 4
0
        public override void Run()
        {
            using (var ch = Host.Start(LoadName))
                using (var server = InitServer(ch))
                {
                    var    settings = CmdParser.GetSettings(Host, ImplOptions, new Arguments());
                    string cmd      = string.Format("maml.exe {0} {1}", LoadName, settings);
                    ch.Info(cmd);

                    SendTelemetry(Host);

                    using (new TimerScope(Host, ch))
                    {
                        RunCore(ch, cmd);
                    }
                }
        }
Exemplo n.º 5
0
        public override void Run()
        {
            string command = "Train";

            using (var ch = Host.Start(command))
                using (var server = InitServer(ch))
                {
                    var    settings = CmdParser.GetSettings(Host, Args, new Arguments());
                    string cmd      = string.Format("maml.exe {0} {1}", command, settings);
                    ch.Info(cmd);

                    SendTelemetry(Host);

                    using (new TimerScope(Host, ch))
                    {
                        RunCore(ch, cmd);
                    }
                }
        }
Exemplo n.º 6
0
        /// <summary>
        /// Process a script to be parsed (from the input resource).
        /// </summary>
        private static void Process(IndentingTextWriter wrt, string text, ArgsBase defaults)
        {
            var env = new TlcEnvironment(seed: 42);

            using (wrt.Nest())
            {
                var args1 = defaults.Clone();
                using (wrt.Nest())
                {
                    if (!CmdParser.ParseArguments(env, text, args1, s => wrt.WriteLine("*** {0}", s)))
                    {
                        wrt.WriteLine("*** Failed!");
                    }
                }
                string str1 = args1.ToString();
                wrt.WriteLine("ToString: {0}", str1);
                string settings1 = CmdParser.GetSettings(env, args1, defaults, SettingsFlags.None);
                wrt.WriteLine("Settings: {0}", settings1);

                var args2 = defaults.Clone();
                using (wrt.Nest())
                {
                    if (!CmdParser.ParseArguments(env, settings1, args2, s => wrt.WriteLine("*** BUG: {0}", s)))
                    {
                        wrt.WriteLine("*** BUG: parsing result of GetSettings failed!");
                    }
                }
                string str2 = args2.ToString();
                if (str1 != str2)
                {
                    wrt.WriteLine("*** BUG: ToString Mismatch: {0}", str2);
                }
                string settings2 = CmdParser.GetSettings(env, args2, defaults, SettingsFlags.None);
                if (settings1 != settings2)
                {
                    wrt.WriteLine("*** BUG: Settings Mismatch: {0}", settings2);
                }
            }
        }
Exemplo n.º 7
0
        private string CreateLoaderArguments(ISchema schema, ValueWriter[] pipes, bool hasHeader, IChannel ch)
        {
            StringBuilder sb = new StringBuilder();

            if (hasHeader)
            {
                sb.Append("header+ ");
            }
            sb.AppendFormat("sep={0}", SeparatorCharToString(_sepChar));

            // This variable indicates the start index of each column.
            // If null, it means the index cannot be determined.
            int?index = 0;

            for (int i = 0; i < pipes.Length; i++)
            {
                int    src  = pipes[i].Source;
                string name = schema.GetColumnName(src);
                var    type = schema.GetColumnType(src);

                var column = GetColumn(name, type, index);
                sb.Append(" col=");
                if (!column.TryUnparse(sb))
                {
                    var settings = CmdParser.GetSettings(ch, column, new TextLoader.Column());
                    CmdQuoter.QuoteValue(settings, sb, true);
                }
                if (type.IsVector && !type.IsKnownSizeVector && i != pipes.Length - 1)
                {
                    ch.Warning("Column '{0}' is variable length, so it must be the last, or the file will be unreadable. Consider switching to binary format or use xf=Choose to make '{0}' the last column.", name);
                    index = null;
                }

                index += type.ValueCount;
            }

            return(sb.ToString());
        }
Exemplo n.º 8
0
 public string ToStringParameter(IHostEnvironment env)
 {
     return($" p=dp{{{CmdParser.GetSettings(env, _args, new DiscreteParamArguments())}}}");
 }
        private void RunCore(IChannel ch, string cmd)
        {
            Host.AssertValue(ch);

            IPredictor inputPredictor = null;

            if (Args.ContinueTrain && !TrainUtils.TryLoadPredictor(ch, Host, Args.InputModelFile, out inputPredictor))
            {
                ch.Warning("No input model file specified or model file did not contain a predictor. The model state cannot be initialized.");
            }

            ch.Trace("Constructing data pipeline");
            IDataLoader loader = CreateRawLoader();

            // If the per-instance results are requested and there is no name column, add a GenerateNumberTransform.
            var preXf = Args.PreTransform;

            if (!string.IsNullOrEmpty(Args.OutputDataFile))
            {
                string name = TrainUtils.MatchNameOrDefaultOrNull(ch, loader.Schema, nameof(Args.NameColumn), Args.NameColumn, DefaultColumnNames.Name);
                if (name == null)
                {
                    var args = new GenerateNumberTransform.Arguments();
                    args.Column = new[] { new GenerateNumberTransform.Column()
                                          {
                                              Name = DefaultColumnNames.Name
                                          }, };
                    args.UseCounter = true;
                    var options = CmdParser.GetSettings(ch, args, new GenerateNumberTransform.Arguments());
                    preXf = preXf.Concat(
                        new[]
                    {
                        new KeyValuePair <string, SubComponent <IDataTransform, SignatureDataTransform> >(
                            "", new SubComponent <IDataTransform, SignatureDataTransform>(
                                GenerateNumberTransform.LoadName, options))
                    }).ToArray();
                }
            }
            loader = CompositeDataLoader.Create(Host, loader, preXf);

            ch.Trace("Binding label and features columns");

            IDataView pipe = loader;
            var       stratificationColumn = GetSplitColumn(ch, loader, ref pipe);
            var       scorer    = Args.Scorer;
            var       evaluator = Args.Evaluator;

            Func <IDataView> validDataCreator = null;

            if (Args.ValidationFile != null)
            {
                validDataCreator =
                    () =>
                {
                    // Fork the command.
                    var impl = new CrossValidationCommand(this);
                    return(impl.CreateRawLoader(dataFile: Args.ValidationFile));
                };
            }

            FoldHelper fold = new FoldHelper(Host, RegistrationName, pipe, stratificationColumn,
                                             Args, CreateRoleMappedData, ApplyAllTransformsToData, scorer, evaluator,
                                             validDataCreator, ApplyAllTransformsToData, inputPredictor, cmd, loader, !string.IsNullOrEmpty(Args.OutputDataFile));
            var tasks = fold.GetCrossValidationTasks();

            if (!evaluator.IsGood())
            {
                evaluator = EvaluateUtils.GetEvaluatorType(ch, tasks[0].Result.ScoreSchema);
            }
            var eval = evaluator.CreateInstance(Host);

            // Print confusion matrix and fold results for each fold.
            for (int i = 0; i < tasks.Length; i++)
            {
                var dict = tasks[i].Result.Metrics;
                MetricWriter.PrintWarnings(ch, dict);
                eval.PrintFoldResults(ch, dict);
            }

            // Print the overall results.
            if (!TryGetOverallMetrics(tasks.Select(t => t.Result.Metrics).ToArray(), out var overallList))
            {
                throw ch.Except("No overall metrics found");
            }

            var overall = eval.GetOverallResults(overallList.ToArray());

            MetricWriter.PrintOverallMetrics(Host, ch, Args.SummaryFilename, overall, Args.NumFolds);
            eval.PrintAdditionalMetrics(ch, tasks.Select(t => t.Result.Metrics).ToArray());
            Dictionary <string, IDataView>[] metricValues = tasks.Select(t => t.Result.Metrics).ToArray();
            SendTelemetryMetric(metricValues);

            // Save the per-instance results.
            if (!string.IsNullOrWhiteSpace(Args.OutputDataFile))
            {
                var perInstance = EvaluateUtils.ConcatenatePerInstanceDataViews(Host, eval, Args.CollateMetrics,
                                                                                Args.OutputExampleFoldIndex, tasks.Select(t => t.Result.PerInstanceResults).ToArray(), out var variableSizeVectorColumnNames);
                if (variableSizeVectorColumnNames.Length > 0)
                {
                    ch.Warning("Detected columns of variable length: {0}. Consider setting collateMetrics- for meaningful per-Folds results.",
                               string.Join(", ", variableSizeVectorColumnNames));
                }
                if (Args.CollateMetrics)
                {
                    ch.Assert(perInstance.Length == 1);
                    MetricWriter.SavePerInstance(Host, ch, Args.OutputDataFile, perInstance[0]);
                }
                else
                {
                    int i = 0;
                    foreach (var idv in perInstance)
                    {
                        MetricWriter.SavePerInstance(Host, ch, ConstructPerFoldName(Args.OutputDataFile, i), idv);
                        i++;
                    }
                }
            }
        }
Exemplo n.º 10
0
            // REVIEW: include Subcomponent array for testing once it is supported
            //[Argument(ArgumentType.Multiple)]
            //public SubComponent[] sub4 = new SubComponent[] { new SubComponent("sub4", "settings4"), new SubComponent("sub5", "settings5") };

            /// <summary>
            /// ToString is overrided by CmdParser.GetSettings which is of primary for this test
            /// </summary>
            /// <returns></returns>
            public string ToString(IExceptionContext ectx)
            {
                return(CmdParser.GetSettings(ectx, this, new SimpleArg(), SettingsFlags.None));
            }
        private void RunCore(IChannel ch, string cmd)
        {
            Host.AssertValue(ch);

            IPredictor inputPredictor = null;

            if (Args.ContinueTrain && !TrainUtils.TryLoadPredictor(ch, Host, Args.InputModelFile, out inputPredictor))
            {
                ch.Warning("No input model file specified or model file did not contain a predictor. The model state cannot be initialized.");
            }

            ch.Trace("Constructing data pipeline");
            IDataLoader loader = CreateRawLoader();

            // If the per-instance results are requested and there is no name column, add a GenerateNumberTransform.
            var preXf = Args.PreTransform;

            if (!string.IsNullOrEmpty(Args.OutputDataFile))
            {
                string name = TrainUtils.MatchNameOrDefaultOrNull(ch, loader.Schema, nameof(Args.NameColumn), Args.NameColumn, DefaultColumnNames.Name);
                if (name == null)
                {
                    var args = new GenerateNumberTransform.Arguments();
                    args.Column = new[] { new GenerateNumberTransform.Column()
                                          {
                                              Name = DefaultColumnNames.Name
                                          }, };
                    args.UseCounter = true;
                    var options = CmdParser.GetSettings(ch, args, new GenerateNumberTransform.Arguments());
                    preXf = preXf.Concat(
                        new[]
                    {
                        new KeyValuePair <string, SubComponent <IDataTransform, SignatureDataTransform> >(
                            "", new SubComponent <IDataTransform, SignatureDataTransform>(
                                GenerateNumberTransform.LoadName, options))
                    }).ToArray();
                }
            }
            loader = CompositeDataLoader.Create(Host, loader, preXf);

            ch.Trace("Binding label and features columns");

            IDataView pipe = loader;
            var       stratificationColumn = GetSplitColumn(ch, loader, ref pipe);
            var       scorer    = Args.Scorer;
            var       evaluator = Args.Evaluator;

            Func <IDataView> validDataCreator = null;

            if (Args.ValidationFile != null)
            {
                validDataCreator =
                    () =>
                {
                    // Fork the command.
                    var impl = new CrossValidationCommand(this);
                    return(impl.CreateRawLoader(dataFile: Args.ValidationFile));
                };
            }

            FoldHelper fold = new FoldHelper(Host, RegistrationName, pipe, stratificationColumn,
                                             Args, CreateRoleMappedData, ApplyAllTransformsToData, scorer, evaluator,
                                             validDataCreator, ApplyAllTransformsToData, inputPredictor, cmd, loader, !string.IsNullOrEmpty(Args.OutputDataFile));
            var tasks = fold.GetCrossValidationTasks();

            if (!evaluator.IsGood())
            {
                evaluator = EvaluateUtils.GetEvaluatorType(ch, tasks[0].Result.ScoreSchema);
            }
            var eval = evaluator.CreateInstance(Host);

            // Print confusion matrix and fold results for each fold.
            for (int i = 0; i < tasks.Length; i++)
            {
                var dict = tasks[i].Result.Metrics;
                MetricWriter.PrintWarnings(ch, dict);
                eval.PrintFoldResults(ch, dict);
            }

            // Print the overall results.
            eval.PrintOverallResults(ch, Args.SummaryFilename, tasks.Select(t => t.Result.Metrics).ToArray());
            Dictionary <string, IDataView>[] metricValues = tasks.Select(t => t.Result.Metrics).ToArray();
            SendTelemetryMetric(metricValues);

            // Save the per-instance results.
            if (!string.IsNullOrWhiteSpace(Args.OutputDataFile))
            {
                Func <Task <FoldHelper.FoldResult>, int, IDataView> getPerInstance =
                    (task, i) =>
                {
                    if (!Args.OutputExampleFoldIndex)
                    {
                        return(task.Result.PerInstanceResults);
                    }

                    // If the fold index is requested, add a column containing it. We use the first column in the data view
                    // as an input column to the LambdaColumnMapper, because it must have an input.
                    var inputColName = task.Result.PerInstanceResults.Schema.GetColumnName(0);
                    var inputColType = task.Result.PerInstanceResults.Schema.GetColumnType(0);
                    return(Utils.MarshalInvoke(EvaluateUtils.AddKeyColumn <int>, inputColType.RawType, Host,
                                               task.Result.PerInstanceResults, inputColName, MetricKinds.ColumnNames.FoldIndex,
                                               inputColType, Args.NumFolds, i + 1, "FoldIndex", default(ValueGetter <VBuffer <DvText> >)));
                };

                var foldDataViews = tasks.Select(getPerInstance).ToArray();
                if (Args.CollateMetrics)
                {
                    var perInst = AppendPerInstanceDataViews(foldDataViews, ch);
                    MetricWriter.SavePerInstance(Host, ch, Args.OutputDataFile, perInst);
                }
                else
                {
                    int i = 0;
                    foreach (var idv in foldDataViews)
                    {
                        MetricWriter.SavePerInstance(Host, ch, ConstructPerFoldName(Args.OutputDataFile, i), idv);
                        i++;
                    }
                }
            }
        }
Exemplo n.º 12
0
 public string ToStringParameter(IHostEnvironment env)
 {
     return($" p=fp{{{CmdParser.GetSettings(env, _options, new FloatParamOptions())}}}");
 }
            // REVIEW: include Subcomponent array for testing once it is supported
            //[Argument(ArgumentType.Multiple)]
            //public SubComponent[] sub4 = new SubComponent[] { new SubComponent("sub4", "settings4"), new SubComponent("sub5", "settings5") };

            /// <summary>
            /// ToString is overrided by CmdParser.GetSettings which is of primary for this test
            /// </summary>
            /// <returns></returns>
            public string ToString(IHostEnvironment env)
            {
                return(CmdParser.GetSettings(env, this, new SimpleArg(), SettingsFlags.None));
            }
        public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataView input)
        {
            Contracts.CheckValue(env, nameof(env));
            var host = env.Register("Tree Featurizer Transform");

            host.CheckValue(args, nameof(args));
            host.CheckValue(input, nameof(input));
            host.CheckUserArg(!string.IsNullOrWhiteSpace(args.TrainedModelFile) || args.Trainer.IsGood(), nameof(args.TrainedModelFile),
                              "Please specify either a trainer or an input model file.");
            host.CheckUserArg(!string.IsNullOrEmpty(args.FeatureColumn), nameof(args.FeatureColumn), "Transform needs an input features column");

            IDataTransform xf;

            using (var ch = host.Start("Create Tree Ensemble Scorer"))
            {
                var scorerArgs = new TreeEnsembleFeaturizerBindableMapper.Arguments()
                {
                    Suffix = args.Suffix
                };
                if (!string.IsNullOrWhiteSpace(args.TrainedModelFile))
                {
                    if (args.Trainer.IsGood())
                    {
                        ch.Warning("Both an input model and a trainer were specified. Using the model file.");
                    }

                    ch.Trace("Loading model");
                    IPredictor predictor;
                    using (Stream strm = new FileStream(args.TrainedModelFile, FileMode.Open, FileAccess.Read))
                        using (var rep = RepositoryReader.Open(strm, ch))
                            ModelLoadContext.LoadModel <IPredictor, SignatureLoadModel>(host, out predictor, rep, ModelFileUtils.DirPredictor);

                    ch.Trace("Creating scorer");
                    var data = TrainAndScoreTransform.CreateDataFromArgs(ch, input, args);

                    // Make sure that the given predictor has the correct number of input features.
                    if (predictor is CalibratedPredictorBase)
                    {
                        predictor = ((CalibratedPredictorBase)predictor).SubPredictor;
                    }
                    // Predictor should be a FastTreePredictionWrapper, which implements IValueMapper, so this should
                    // be non-null.
                    var vm = predictor as IValueMapper;
                    ch.CheckUserArg(vm != null, nameof(args.TrainedModelFile), "Predictor in model file does not have compatible type");
                    if (vm.InputType.VectorSize != data.Schema.Feature.Type.VectorSize)
                    {
                        throw ch.ExceptUserArg(nameof(args.TrainedModelFile),
                                               "Predictor in model file expects {0} features, but data has {1} features",
                                               vm.InputType.VectorSize, data.Schema.Feature.Type.VectorSize);
                    }

                    var bindable = new TreeEnsembleFeaturizerBindableMapper(env, scorerArgs, predictor);
                    var bound    = bindable.Bind(env, data.Schema);
                    xf = new GenericScorer(env, scorerArgs, input, bound, data.Schema);
                }
                else
                {
                    ch.Assert(args.Trainer.IsGood());

                    ch.Trace("Creating TrainAndScoreTransform");
                    string scorerSettings = CmdParser.GetSettings(ch, scorerArgs,
                                                                  new TreeEnsembleFeaturizerBindableMapper.Arguments());
                    var scorer =
                        new SubComponent <IDataScorerTransform, SignatureDataScorer>(
                            TreeEnsembleFeaturizerBindableMapper.LoadNameShort, scorerSettings);

                    var trainScoreArgs = new TrainAndScoreTransform.Arguments();
                    args.CopyTo(trainScoreArgs);
                    trainScoreArgs.Trainer = new SubComponent <ITrainer, SignatureTrainer>(args.Trainer.Kind,
                                                                                           args.Trainer.Settings);

                    var labelInput = AppendLabelTransform(host, ch, input, trainScoreArgs.LabelColumn, args.LabelPermutationSeed);
                    trainScoreArgs.Scorer = scorer;
                    var scoreXf = TrainAndScoreTransform.Create(host, trainScoreArgs, labelInput);
                    if (input == labelInput)
                    {
                        return(scoreXf);
                    }
                    return((IDataTransform)ApplyTransformUtils.ApplyAllTransformsToData(host, scoreXf, input, labelInput));
                }

                ch.Done();
            }
            return(xf);
        }