Exemple #1
0
 public Arguments()
 {
     BasePredictors = new[]
     {
         ComponentFactoryUtils.CreateFromFunction(env => new LinearSvmTrainer(env, LabelColumnName, FeatureColumnName))
     };
 }
Exemple #2
0
 public Arguments()
 {
     BasePredictors = new[]
     {
         ComponentFactoryUtils.CreateFromFunction(env => new OnlineGradientDescentTrainer(env, LabelColumnName, FeatureColumnName))
     };
 }
        public void TestDeterministicSweeperAsyncCancellation()
        {
            var random = new Random(42);
            var env    = new MLContext(42);
            var args   = new DeterministicSweeperAsync.Arguments();

            args.BatchSize  = 5;
            args.Relaxation = 1;

            args.Sweeper = ComponentFactoryUtils.CreateFromFunction(
                environ => new KdoSweeper(environ,
                                          new KdoSweeper.Arguments()
            {
                SweptParameters = new IComponentFactory <INumericValueGenerator>[] {
                    ComponentFactoryUtils.CreateFromFunction(
                        t => new FloatValueGenerator(new FloatParamArguments()
                    {
                        Name = "foo", Min = 1, Max = 5
                    })),
                    ComponentFactoryUtils.CreateFromFunction(
                        t => new LongValueGenerator(new LongParamArguments()
                    {
                        Name = "bar", Min = 1, Max = 1000, LogBase = true
                    }))
                }
            }));

            var sweeper = new DeterministicSweeperAsync(env, args);

            int sweeps       = 20;
            var tasks        = new List <Task <ParameterSetWithId> >();
            int numCompleted = 0;

            for (int i = 0; i < sweeps; i++)
            {
                var task = sweeper.Propose();
                if (i < args.BatchSize - args.Relaxation)
                {
                    Assert.True(task.IsCompleted);
                    sweeper.Update(task.Result.Id, new RunResult(task.Result.ParameterSet, random.NextDouble(), true));
                    numCompleted++;
                }
                else
                {
                    tasks.Add(task);
                }
            }
            // Cancel after the first barrier and check if the number of registered actions
            // is indeed 2 * batchSize.
            sweeper.Cancel();
            Task.WaitAll(tasks.ToArray());
            foreach (var task in tasks)
            {
                if (task.Result != null)
                {
                    numCompleted++;
                }
            }
            Assert.Equal(args.BatchSize + args.BatchSize, numCompleted);
        }
        public void TestSmacSweeper()
        {
            RunMTAThread(() =>
            {
                var random = new Random(42);
                using (var env = new ConsoleEnvironment(42))
                {
                    int maxInitSweeps = 5;
                    var args          = new SmacSweeper.Arguments()
                    {
                        NumberInitialPopulation = 20,
                        SweptParameters         = new IComponentFactory <INumericValueGenerator>[] {
                            ComponentFactoryUtils.CreateFromFunction(
                                environ => new FloatValueGenerator(new FloatParamArguments()
                            {
                                Name = "foo", Min = 1, Max = 5
                            })),
                            ComponentFactoryUtils.CreateFromFunction(
                                environ => new LongValueGenerator(new LongParamArguments()
                            {
                                Name = "bar", Min = 1, Max = 100, LogBase = true
                            }))
                        }
                    };

                    var sweeper = new SmacSweeper(env, args);
                    var results = new List <IRunResult>();
                    var sweeps  = sweeper.ProposeSweeps(maxInitSweeps, results);
                    Assert.Equal(Math.Min(args.NumberInitialPopulation, maxInitSweeps), sweeps.Length);

                    for (int i = 1; i < 10; i++)
                    {
                        foreach (var parameterSet in sweeps)
                        {
                            foreach (var parameterValue in parameterSet)
                            {
                                if (parameterValue.Name == "foo")
                                {
                                    var val = float.Parse(parameterValue.ValueText, CultureInfo.InvariantCulture);
                                    Assert.InRange(val, 1, 5);
                                }
                                else if (parameterValue.Name == "bar")
                                {
                                    var val = long.Parse(parameterValue.ValueText);
                                    Assert.InRange(val, 1, 1000);
                                }
                                else
                                {
                                    Assert.True(false, "Wrong parameter");
                                }
                            }
                            results.Add(new RunResult(parameterSet, random.NextDouble(), true));
                        }

                        sweeps = sweeper.ProposeSweeps(5, results);
                    }
                    Assert.Equal(5, sweeps.Length);
                }
            });
        }
 public Arguments()
 {
     BasePredictors = new[]
     {
         ComponentFactoryUtils.CreateFromFunction(
             env => new MulticlassLogisticRegression(env, new MulticlassLogisticRegression.Arguments()))
     };
 }
 public Arguments()
 {
     BasePredictors = new[]
     {
         ComponentFactoryUtils.CreateFromFunction(
             env => new OnlineGradientDescentTrainer(env, new OnlineGradientDescentTrainer.Arguments()))
     };
 }
 public Arguments()
 {
     BasePredictors = new[]
     {
         ComponentFactoryUtils.CreateFromFunction(
             env => new MulticlassLogisticRegression(env, LabelColumn, FeatureColumn))
     };
 }
 public Arguments()
 {
     BasePredictors = new[]
     {
         ComponentFactoryUtils.CreateFromFunction(
             env => new LinearSvm(env, new LinearSvm.Arguments()))
     };
 }
        public void TestSmacSweeper()
        {
            var       random        = new Random(42);
            var       env           = new MLContext(42);
            const int maxInitSweeps = 5;
            var       args          = new SmacSweeper.Options()
            {
                NumberInitialPopulation = 20,
                SweptParameters         = new IComponentFactory <INumericValueGenerator>[] {
                    ComponentFactoryUtils.CreateFromFunction(
                        environ => new FloatValueGenerator(new FloatParamOptions()
                    {
                        Name = "foo", Min = 1, Max = 5
                    })),
                    ComponentFactoryUtils.CreateFromFunction(
                        environ => new LongValueGenerator(new LongParamOptions()
                    {
                        Name = "bar", Min = 1, Max = 100, LogBase = true
                    }))
                }
            };

            var sweeper = new SmacSweeper(env, args);
            var results = new List <IRunResult>();
            var sweeps  = sweeper.ProposeSweeps(maxInitSweeps, results);

            Assert.Equal(Math.Min(args.NumberInitialPopulation, maxInitSweeps), sweeps.Length);

            for (int i = 1; i < 10; i++)
            {
                foreach (var parameterSet in sweeps)
                {
                    foreach (var parameterValue in parameterSet)
                    {
                        if (parameterValue.Name == "foo")
                        {
                            var val = float.Parse(parameterValue.ValueText, CultureInfo.InvariantCulture);
                            Assert.InRange(val, 1, 5);
                        }
                        else if (parameterValue.Name == "bar")
                        {
                            var val = long.Parse(parameterValue.ValueText);
                            Assert.InRange(val, 1, 1000);
                        }
                        else
                        {
                            Assert.True(false, "Wrong parameter");
                        }
                    }
                    results.Add(new RunResult(parameterSet, random.NextDouble(), true));
                }

                sweeps = sweeper.ProposeSweeps(5, results);
            }
            // Because only unique configurations are considered, the number asked for may exceed the number actually returned.
            Assert.True(sweeps.Length <= 5);
        }
Exemple #10
0
        /// <summary>
        /// Determine the default scorer for a schema bound mapper. This looks for text-valued ScoreColumnKind
        /// metadata on the first column of the mapper. If that text is found and maps to a scorer loadable class,
        /// that component is used. Otherwise, the GenericScorer is used.
        /// </summary>
        /// <param name="environment">The host environment.</param>.
        /// <param name="mapper">The schema bound mapper to get the default scorer.</param>.
        /// <param name="suffix">An optional suffix to append to the default column names.</param>
        public static TScorerFactory GetScorerComponent(
            IHostEnvironment environment,
            ISchemaBoundMapper mapper,
            string suffix = null)
        {
            Contracts.CheckValue(environment, nameof(environment));
            Contracts.AssertValue(mapper);

            ComponentCatalog.LoadableClassInfo info = null;
            ReadOnlyMemory <char> scoreKind         = default;

            if (mapper.OutputSchema.Count > 0 &&
                mapper.OutputSchema.TryGetMetadata(TextType.Instance, MetadataUtils.Kinds.ScoreColumnKind, 0, ref scoreKind) &&
                !scoreKind.IsEmpty)
            {
                var loadName = scoreKind.ToString();
                info = environment.ComponentCatalog.GetLoadableClassInfo <SignatureDataScorer>(loadName);
                if (info == null || !typeof(IDataScorerTransform).IsAssignableFrom(info.Type))
                {
                    info = null;
                }
            }

            Func <IHostEnvironment, IDataView, ISchemaBoundMapper, RoleMappedSchema, IDataScorerTransform> factoryFunc;

            if (info == null)
            {
                factoryFunc = (env, data, innerMapper, trainSchema) =>
                              new GenericScorer(
                    env,
                    new GenericScorer.Arguments()
                {
                    Suffix = suffix
                },
                    data,
                    innerMapper,
                    trainSchema);
            }
            else
            {
                factoryFunc = (env, data, innerMapper, trainSchema) =>
                {
                    object args = info.CreateArguments();
                    if (args is ScorerArgumentsBase scorerArgs)
                    {
                        scorerArgs.Suffix = suffix;
                    }
                    return((IDataScorerTransform)info.CreateInstance(
                               env,
                               args,
                               new object[] { data, innerMapper, trainSchema }));
                };
            }

            return(ComponentFactoryUtils.CreateFromFunction(factoryFunc));
        }
 public Arguments()
 {
     // REVIEW: Perhaps we can have a better non-parametetric learner.
     BasePredictorType = ComponentFactoryUtils.CreateFromFunction(
         env => new Ova(env, new Ova.Arguments()
     {
         PredictorType = ComponentFactoryUtils.CreateFromFunction(
             e => new FastTreeBinaryClassificationTrainer(e, DefaultColumnNames.Label, DefaultColumnNames.Features))
     }));
 }
 public Arguments()
 {
     // REVIEW: Perhaps we can have a better non-parametetric learner.
     BasePredictorType = ComponentFactoryUtils.CreateFromFunction(
         env => new Ova(env, new Ova.Arguments()
     {
         PredictorType = ComponentFactoryUtils.CreateFromFunction(
             e => new AveragedPerceptronTrainer(e, new AveragedPerceptronTrainer.Arguments()))
     }));
 }
        public void TestNelderMeadSweeperWithDefaultFirstBatchSweeper()
        {
            var random = new Random(42);
            var env    = new MLContext(42);
            var param  = new IComponentFactory <INumericValueGenerator>[] {
                ComponentFactoryUtils.CreateFromFunction(
                    environ => new FloatValueGenerator(new FloatParamArguments()
                {
                    Name = "foo", Min = 1, Max = 5
                })),
                ComponentFactoryUtils.CreateFromFunction(
                    environ => new LongValueGenerator(new LongParamArguments()
                {
                    Name = "bar", Min = 1, Max = 1000, LogBase = true
                }))
            };

            var args = new NelderMeadSweeper.Arguments();

            args.SweptParameters = param;
            var sweeper = new NelderMeadSweeper(env, args);
            var sweeps  = sweeper.ProposeSweeps(5, new List <RunResult>());

            Assert.Equal(3, sweeps.Length);

            var results = new List <IRunResult>();

            for (int i = 1; i < 10; i++)
            {
                foreach (var parameterSet in sweeps)
                {
                    foreach (var parameterValue in parameterSet)
                    {
                        if (parameterValue.Name == "foo")
                        {
                            var val = float.Parse(parameterValue.ValueText, CultureInfo.InvariantCulture);
                            Assert.InRange(val, 1, 5);
                        }
                        else if (parameterValue.Name == "bar")
                        {
                            var val = long.Parse(parameterValue.ValueText);
                            Assert.InRange(val, 1, 1000);
                        }
                        else
                        {
                            Assert.True(false, "Wrong parameter");
                        }
                    }
                    results.Add(new RunResult(parameterSet, random.NextDouble(), true));
                }

                sweeps = sweeper.ProposeSweeps(5, results);
            }
            Assert.True(sweeps == null || sweeps.Length <= 5);
        }
Exemple #14
0
 public Arguments()
 {
     BasePredictors = new[]
     {
         ComponentFactoryUtils.CreateFromFunction(
             env => {
             var trainerEstimator = new LinearSvmTrainer(env);
             return(TrainerUtils.MapTrainerEstimatorToTrainer <LinearSvmTrainer,
                                                               LinearBinaryModelParameters, LinearBinaryModelParameters>(env, trainerEstimator));
         })
     };
 }
 public Arguments()
 {
     BasePredictors = new[]
     {
         ComponentFactoryUtils.CreateFromFunction(
             env => {
             var trainerEstimator = new OnlineGradientDescentTrainer(env);
             return(TrainerUtils.MapTrainerEstimatorToTrainer <OnlineGradientDescentTrainer,
                                                               LinearRegressionModelParameters, LinearRegressionModelParameters>(env, trainerEstimator));
         })
     };
 }
Exemple #16
0
 public Arguments()
 {
     BasePredictors = new[]
     {
         // Note that this illustrates a fundamental problem with the mixture of `ITrainer` and `ITrainerEstimator`
         // present in this class. The options to the estimator have no way of being communicated to the `ITrainer`
         // implementation, so there is a fundamental disconnect if someone chooses to ever use the *estimator* with
         // non-default column names. Unfortunately no method of resolving this temporary strikes me as being any
         // less laborious than the proper fix, which is that this "meta" component should itself be a trainer
         // estimator, as opposed to a regular trainer.
         ComponentFactoryUtils.CreateFromFunction(env => new LbfgsMaximumEntropyMulticlassTrainer(env, LabelColumnName, FeatureColumnName))
     };
 }
        public void TestRandomSweeper()
        {
            using (var env = new ConsoleEnvironment(42))
            {
                var args = new SweeperBase.ArgumentsBase()
                {
                    SweptParameters = new[] {
                        ComponentFactoryUtils.CreateFromFunction(
                            environ => new LongValueGenerator(new LongParamArguments()
                        {
                            Name = "foo", Min = 10, Max = 20
                        })),
                        ComponentFactoryUtils.CreateFromFunction(
                            environ => new LongValueGenerator(new LongParamArguments()
                        {
                            Name = "bar", Min = 100, Max = 200
                        }))
                    }
                };

                var sweeper     = new UniformRandomSweeper(env, args);
                var initialList = sweeper.ProposeSweeps(5, new List <RunResult>());
                Assert.Equal(5, initialList.Length);
                foreach (var parameterSet in initialList)
                {
                    foreach (var parameterValue in parameterSet)
                    {
                        if (parameterValue.Name == "foo")
                        {
                            var val = long.Parse(parameterValue.ValueText);
                            Assert.InRange(val, 10, 20);
                        }
                        else if (parameterValue.Name == "bar")
                        {
                            var val = long.Parse(parameterValue.ValueText);
                            Assert.InRange(val, 100, 200);
                        }
                        else
                        {
                            Assert.True(false, "Wrong parameter");
                        }
                    }
                }
            }
        }
Exemple #18
0
 public Arguments()
 {
     BasePredictors = new[]
     {
         ComponentFactoryUtils.CreateFromFunction(
             env => {
             // Note that this illustrates a fundamnetal problem with the mixture of `ITrainer` and `ITrainerEstimator`
             // present in this class. The options to the estimator have no way of being communicated to the `ITrainer`
             // implementation, so there is a fundamnetal disconnect if someone chooses to ever use the *estimator* with
             // non-default column names. Unfortuantely no method of resolving this temporary strikes me as being any
             // less laborious than the proper fix, which is that this "meta" component should itself be a trainer
             // estimator, as opposed to a regular trainer.
             var trainerEstimator = new MulticlassLogisticRegression(env, LabelColumn, FeatureColumn);
             return(TrainerUtils.MapTrainerEstimatorToTrainer <MulticlassLogisticRegression,
                                                               MulticlassLogisticRegressionModelParameters, MulticlassLogisticRegressionModelParameters>(env, trainerEstimator));
         })
     };
 }
        public void Metacomponents()
        {
            using (var env = new LocalEnvironment())
            {
                var loader  = TextLoader.ReadFile(env, MakeIrisTextLoaderArgs(), new MultiFileSource(GetDataPath(TestDatasets.irisData.trainFilename)));
                var term    = TermTransform.Create(env, loader, "Label");
                var concat  = new ConcatTransform(env, "Features", "SepalLength", "SepalWidth", "PetalLength", "PetalWidth").Transform(term);
                var trainer = new Ova(env, new Ova.Arguments
                {
                    PredictorType = ComponentFactoryUtils.CreateFromFunction(
                        e => new AveragedPerceptronTrainer(env, new AveragedPerceptronTrainer.Arguments()))
                });

                IDataView trainData  = trainer.Info.WantCaching ? (IDataView) new CacheDataView(env, concat, prefetch: null) : concat;
                var       trainRoles = new RoleMappedData(trainData, label: "Label", feature: "Features");

                // Auto-normalization.
                NormalizeTransform.CreateIfNeeded(env, ref trainRoles, trainer);
                var predictor = trainer.Train(new TrainContext(trainRoles));
            }
        }
Exemple #20
0
        public void Metacomponents()
        {
            var dataPath = GetDataPath(IrisDataPath);

            using (var env = new TlcEnvironment())
            {
                var loader  = new TextLoader(env, MakeIrisTextLoaderArgs(), new MultiFileSource(dataPath));
                var term    = new TermTransform(env, loader, "Label");
                var concat  = new ConcatTransform(env, term, "Features", "SepalLength", "SepalWidth", "PetalLength", "PetalWidth");
                var trainer = new Ova(env, new Ova.Arguments
                {
                    PredictorType = ComponentFactoryUtils.CreateFromFunction(
                        e => new FastTreeBinaryClassificationTrainer(e, new FastTreeBinaryClassificationTrainer.Arguments()))
                });

                IDataView trainData  = trainer.Info.WantCaching ? (IDataView) new CacheDataView(env, concat, prefetch: null) : concat;
                var       trainRoles = new RoleMappedData(trainData, label: "Label", feature: "Features");

                // Auto-normalization.
                NormalizeTransform.CreateIfNeeded(env, ref trainRoles, trainer);
                var predictor = trainer.Train(new TrainContext(trainRoles));
            }
        }
        void Metacomponents()
        {
            var dataPath = GetDataPath(IrisDataPath);

            using (var env = new TlcEnvironment())
            {
                var loader  = new TextLoader(env, MakeIrisTextLoaderArgs(), new MultiFileSource(dataPath));
                var term    = new TermTransform(env, loader, "Label");
                var concat  = new ConcatTransform(env, term, "Features", "SepalLength", "SepalWidth", "PetalLength", "PetalWidth");
                var trainer = new Ova(env, new Ova.Arguments
                {
                    PredictorType = ComponentFactoryUtils.CreateFromFunction(
                        (e) => new FastTreeBinaryClassificationTrainer(e, new FastTreeBinaryClassificationTrainer.Arguments()))
                });

                IDataView trainData  = trainer.Info.WantCaching ? (IDataView) new CacheDataView(env, concat, prefetch: null) : concat;
                var       trainRoles = new RoleMappedData(trainData, label: "Label", feature: "Features");

                // Auto-normalization.
                NormalizeTransform.CreateIfNeeded(env, ref trainRoles, trainer);
                var predictor = trainer.Train(new Runtime.TrainContext(trainRoles));

                var scoreRoles = new RoleMappedData(concat, label: "Label", feature: "Features");
                IDataScorerTransform scorer = ScoreUtils.GetScorer(predictor, scoreRoles, env, trainRoles.Schema);

                var keyToValue = new KeyToValueTransform(env, scorer, "PredictedLabel");
                var model      = env.CreatePredictionEngine <IrisData, IrisPrediction>(keyToValue);

                var testLoader = new TextLoader(env, MakeIrisTextLoaderArgs(), new MultiFileSource(dataPath));
                var testData   = testLoader.AsEnumerable <IrisData>(env, false);
                foreach (var input in testData.Take(20))
                {
                    var prediction = model.Predict(input);
                    Assert.True(prediction.PredictedLabel == input.Label);
                }
            }
        }
Exemple #22
0
        private void RunCore(IChannel ch, string cmd)
        {
            Host.AssertValue(ch);

            IPredictor inputPredictor = null;

            if (Args.ContinueTrain && !TrainUtils.TryLoadPredictor(ch, Host, Args.InputModelFile, out inputPredictor))
            {
                ch.Warning("No input model file specified or model file did not contain a predictor. The model state cannot be initialized.");
            }

            ch.Trace("Constructing data pipeline");
            IDataLoader loader = CreateRawLoader();

            // If the per-instance results are requested and there is no name column, add a GenerateNumberTransform.
            var preXf = Args.PreTransform;

            if (!string.IsNullOrEmpty(Args.OutputDataFile))
            {
                string name = TrainUtils.MatchNameOrDefaultOrNull(ch, loader.Schema, nameof(Args.NameColumn), Args.NameColumn, DefaultColumnNames.Name);
                if (name == null)
                {
                    preXf = preXf.Concat(
                        new[]
                    {
                        new KeyValuePair <string, IComponentFactory <IDataView, IDataTransform> >(
                            "", ComponentFactoryUtils.CreateFromFunction <IDataView, IDataTransform>(
                                (env, input) =>
                        {
                            var args    = new GenerateNumberTransform.Arguments();
                            args.Column = new[] { new GenerateNumberTransform.Column()
                                                  {
                                                      Name = DefaultColumnNames.Name
                                                  }, };
                            args.UseCounter = true;
                            return(new GenerateNumberTransform(env, args, input));
                        }))
                    }).ToArray();
                }
            }
            loader = CompositeDataLoader.Create(Host, loader, preXf);

            ch.Trace("Binding label and features columns");

            IDataView pipe = loader;
            var       stratificationColumn = GetSplitColumn(ch, loader, ref pipe);
            var       scorer    = Args.Scorer;
            var       evaluator = Args.Evaluator;

            Func <IDataView> validDataCreator = null;

            if (Args.ValidationFile != null)
            {
                validDataCreator =
                    () =>
                {
                    // Fork the command.
                    var impl = new CrossValidationCommand(this);
                    return(impl.CreateRawLoader(dataFile: Args.ValidationFile));
                };
            }

            FoldHelper fold = new FoldHelper(Host, RegistrationName, pipe, stratificationColumn,
                                             Args, CreateRoleMappedData, ApplyAllTransformsToData, scorer, evaluator,
                                             validDataCreator, ApplyAllTransformsToData, inputPredictor, cmd, loader, !string.IsNullOrEmpty(Args.OutputDataFile));
            var tasks = fold.GetCrossValidationTasks();

            var eval = evaluator?.CreateComponent(Host) ??
                       EvaluateUtils.GetEvaluator(Host, tasks[0].Result.ScoreSchema);

            // Print confusion matrix and fold results for each fold.
            for (int i = 0; i < tasks.Length; i++)
            {
                var dict = tasks[i].Result.Metrics;
                MetricWriter.PrintWarnings(ch, dict);
                eval.PrintFoldResults(ch, dict);
            }

            // Print the overall results.
            if (!TryGetOverallMetrics(tasks.Select(t => t.Result.Metrics).ToArray(), out var overallList))
            {
                throw ch.Except("No overall metrics found");
            }

            var overall = eval.GetOverallResults(overallList.ToArray());

            MetricWriter.PrintOverallMetrics(Host, ch, Args.SummaryFilename, overall, Args.NumFolds);
            eval.PrintAdditionalMetrics(ch, tasks.Select(t => t.Result.Metrics).ToArray());
            Dictionary <string, IDataView>[] metricValues = tasks.Select(t => t.Result.Metrics).ToArray();
            SendTelemetryMetric(metricValues);

            // Save the per-instance results.
            if (!string.IsNullOrWhiteSpace(Args.OutputDataFile))
            {
                var perInstance = EvaluateUtils.ConcatenatePerInstanceDataViews(Host, eval, Args.CollateMetrics,
                                                                                Args.OutputExampleFoldIndex, tasks.Select(t => t.Result.PerInstanceResults).ToArray(), out var variableSizeVectorColumnNames);
                if (variableSizeVectorColumnNames.Length > 0)
                {
                    ch.Warning("Detected columns of variable length: {0}. Consider setting collateMetrics- for meaningful per-Folds results.",
                               string.Join(", ", variableSizeVectorColumnNames));
                }
                if (Args.CollateMetrics)
                {
                    ch.Assert(perInstance.Length == 1);
                    MetricWriter.SavePerInstance(Host, ch, Args.OutputDataFile, perInstance[0]);
                }
                else
                {
                    int i = 0;
                    foreach (var idv in perInstance)
                    {
                        MetricWriter.SavePerInstance(Host, ch, ConstructPerFoldName(Args.OutputDataFile, i), idv);
                        i++;
                    }
                }
            }
        }
        // Factory method for SignatureDataTransform.
        private static IDataTransform Create(IHostEnvironment env, Arguments args, IDataView input)
        {
            Contracts.CheckValue(env, nameof(env));
            var host = env.Register("Tree Featurizer Transform");

            host.CheckValue(args, nameof(args));
            host.CheckValue(input, nameof(input));
            host.CheckUserArg(!string.IsNullOrWhiteSpace(args.TrainedModelFile) || args.Trainer != null, nameof(args.TrainedModelFile),
                              "Please specify either a trainer or an input model file.");
            host.CheckUserArg(!string.IsNullOrEmpty(args.FeatureColumn), nameof(args.FeatureColumn), "Transform needs an input features column");

            IDataTransform xf;

            using (var ch = host.Start("Create Tree Ensemble Scorer"))
            {
                var scorerArgs = new TreeEnsembleFeaturizerBindableMapper.Arguments()
                {
                    Suffix = args.Suffix
                };
                if (!string.IsNullOrWhiteSpace(args.TrainedModelFile))
                {
                    if (args.Trainer != null)
                    {
                        ch.Warning("Both an input model and a trainer were specified. Using the model file.");
                    }

                    ch.Trace("Loading model");
                    IPredictor predictor;
                    using (Stream strm = new FileStream(args.TrainedModelFile, FileMode.Open, FileAccess.Read))
                        using (var rep = RepositoryReader.Open(strm, ch))
                            ModelLoadContext.LoadModel <IPredictor, SignatureLoadModel>(host, out predictor, rep, ModelFileUtils.DirPredictor);

                    ch.Trace("Creating scorer");
                    var data = TrainAndScoreTransformer.CreateDataFromArgs(ch, input, args);
                    Contracts.Assert(data.Schema.Feature.HasValue);

                    // Make sure that the given predictor has the correct number of input features.
                    if (predictor is CalibratedPredictorBase)
                    {
                        predictor = ((CalibratedPredictorBase)predictor).SubPredictor;
                    }
                    // Predictor should be a TreeEnsembleModelParameters, which implements IValueMapper, so this should
                    // be non-null.
                    var vm = predictor as IValueMapper;
                    ch.CheckUserArg(vm != null, nameof(args.TrainedModelFile), "Predictor in model file does not have compatible type");
                    if (vm.InputType.VectorSize != data.Schema.Feature.Value.Type.VectorSize)
                    {
                        throw ch.ExceptUserArg(nameof(args.TrainedModelFile),
                                               "Predictor in model file expects {0} features, but data has {1} features",
                                               vm.InputType.VectorSize, data.Schema.Feature.Value.Type.VectorSize);
                    }

                    ISchemaBindableMapper bindable = new TreeEnsembleFeaturizerBindableMapper(env, scorerArgs, predictor);
                    var bound = bindable.Bind(env, data.Schema);
                    xf = new GenericScorer(env, scorerArgs, input, bound, data.Schema);
                }
                else
                {
                    ch.AssertValue(args.Trainer);

                    ch.Trace("Creating TrainAndScoreTransform");

                    var trainScoreArgs = new TrainAndScoreTransformer.Arguments();
                    args.CopyTo(trainScoreArgs);
                    trainScoreArgs.Trainer = args.Trainer;

                    trainScoreArgs.Scorer = ComponentFactoryUtils.CreateFromFunction <IDataView, ISchemaBoundMapper, RoleMappedSchema, IDataScorerTransform>(
                        (e, data, mapper, trainSchema) => Create(e, scorerArgs, data, mapper, trainSchema));

                    var mapperFactory = ComponentFactoryUtils.CreateFromFunction <IPredictor, ISchemaBindableMapper>(
                        (e, predictor) => new TreeEnsembleFeaturizerBindableMapper(e, scorerArgs, predictor));

                    var labelInput = AppendLabelTransform(host, ch, input, trainScoreArgs.LabelColumn, args.LabelPermutationSeed);
                    var scoreXf    = TrainAndScoreTransformer.Create(host, trainScoreArgs, labelInput, mapperFactory);

                    if (input == labelInput)
                    {
                        return(scoreXf);
                    }
                    return((IDataTransform)ApplyTransformUtils.ApplyAllTransformsToData(host, scoreXf, input, labelInput));
                }
            }
            return(xf);
        }
Exemple #24
0
        public ParameterMixingCalibratedPredictor TrainKMeansAndLR()
        {
            using (var env = new ConsoleEnvironment(seed: 1))
            {
                // Pipeline
                var loader = TextLoader.ReadFile(env,
                                                 new TextLoader.Arguments()
                {
                    HasHeader = true,
                    Separator = ",",
                    Column    = new[] {
                        new TextLoader.Column("Label", DataKind.R4, 14),
                        new TextLoader.Column("CatFeatures", DataKind.TX,
                                              new [] {
                            new TextLoader.Range()
                            {
                                Min = 1, Max = 1
                            },
                            new TextLoader.Range()
                            {
                                Min = 3, Max = 3
                            },
                            new TextLoader.Range()
                            {
                                Min = 5, Max = 9
                            },
                            new TextLoader.Range()
                            {
                                Min = 13, Max = 13
                            }
                        }),
                        new TextLoader.Column("NumFeatures", DataKind.R4,
                                              new [] {
                            new TextLoader.Range()
                            {
                                Min = 0, Max = 0
                            },
                            new TextLoader.Range()
                            {
                                Min = 2, Max = 2
                            },
                            new TextLoader.Range()
                            {
                                Min = 4, Max = 4
                            },
                            new TextLoader.Range()
                            {
                                Min = 10, Max = 12
                            }
                        })
                    }
                }, new MultiFileSource(_dataPath));

                IDataView trans = new CategoricalEstimator(env, "CatFeatures").Fit(loader).Transform(loader);

                trans = NormalizeTransform.CreateMinMaxNormalizer(env, trans, "NumFeatures");
                trans = new ConcatTransform(env, "Features", "NumFeatures", "CatFeatures").Transform(trans);
                trans = TrainAndScoreTransform.Create(env, new TrainAndScoreTransform.Arguments
                {
                    Trainer = ComponentFactoryUtils.CreateFromFunction(host =>
                                                                       new KMeansPlusPlusTrainer(host, "Features", advancedSettings: s =>
                    {
                        s.K = 100;
                    })),
                    FeatureColumn = "Features"
                }, trans);
                trans = new ConcatTransform(env, "Features", "Features", "Score").Transform(trans);

                // Train
                var trainer    = new LogisticRegression(env, "Features", "Label", advancedSettings: args => { args.EnforceNonNegativity = true; args.OptTol = 1e-3f; });
                var trainRoles = new RoleMappedData(trans, label: "Label", feature: "Features");
                return(trainer.Train(trainRoles));
            }
        }
        public void TestSimpleSweeperAsync()
        {
            var       random  = new Random(42);
            var       env     = new MLContext(42);
            const int sweeps  = 100;
            var       sweeper = new SimpleAsyncSweeper(env, new SweeperBase.ArgumentsBase
            {
                SweptParameters = new IComponentFactory <IValueGenerator>[] {
                    ComponentFactoryUtils.CreateFromFunction(
                        environ => new FloatValueGenerator(new FloatParamArguments()
                    {
                        Name = "foo", Min = 1, Max = 5
                    })),
                    ComponentFactoryUtils.CreateFromFunction(
                        environ => new LongValueGenerator(new LongParamArguments()
                    {
                        Name = "bar", Min = 1, Max = 1000, LogBase = true
                    }))
                }
            });

            var paramSets = new List <ParameterSet>();

            for (int i = 0; i < sweeps; i++)
            {
                var task = sweeper.Propose();
                Assert.True(task.IsCompleted);
                paramSets.Add(task.Result.ParameterSet);
                var result = new RunResult(task.Result.ParameterSet, random.NextDouble(), true);
                sweeper.Update(task.Result.Id, result);
            }
            Assert.Equal(sweeps, paramSets.Count);
            CheckAsyncSweeperResult(paramSets);

            // Test consumption without ever calling Update.
            var gridArgs = new RandomGridSweeper.Arguments();

            gridArgs.SweptParameters = new IComponentFactory <INumericValueGenerator>[] {
                ComponentFactoryUtils.CreateFromFunction(
                    environ => new FloatValueGenerator(new FloatParamArguments()
                {
                    Name = "foo", Min = 1, Max = 5
                })),
                ComponentFactoryUtils.CreateFromFunction(
                    environ => new LongValueGenerator(new LongParamArguments()
                {
                    Name = "bar", Min = 1, Max = 100, LogBase = true
                }))
            };
            var gridSweeper = new SimpleAsyncSweeper(env, gridArgs);

            paramSets.Clear();
            for (int i = 0; i < sweeps; i++)
            {
                var task = gridSweeper.Propose();
                Assert.True(task.IsCompleted);
                paramSets.Add(task.Result.ParameterSet);
            }
            Assert.Equal(sweeps, paramSets.Count);
            CheckAsyncSweeperResult(paramSets);
        }
Exemple #26
0
 public Arguments()
 {
     BasePredictorType = ComponentFactoryUtils.CreateFromFunction(
         env => new FastTreeRegressionTrainer(env, DefaultColumnNames.Label, DefaultColumnNames.Features));
 }
        public void TestDeterministicSweeperAsync()
        {
            var random = new Random(42);
            var env    = new MLContext(42);
            var args   = new DeterministicSweeperAsync.Arguments();

            args.BatchSize  = 5;
            args.Relaxation = args.BatchSize - 1;

            args.Sweeper = ComponentFactoryUtils.CreateFromFunction(
                environ => new SmacSweeper(environ,
                                           new SmacSweeper.Arguments()
            {
                SweptParameters = new IComponentFactory <INumericValueGenerator>[] {
                    ComponentFactoryUtils.CreateFromFunction(
                        t => new FloatValueGenerator(new FloatParamArguments()
                    {
                        Name = "foo", Min = 1, Max = 5
                    })),
                    ComponentFactoryUtils.CreateFromFunction(
                        t => new LongValueGenerator(new LongParamArguments()
                    {
                        Name = "bar", Min = 1, Max = 1000, LogBase = true
                    }))
                }
            }));

            var sweeper = new DeterministicSweeperAsync(env, args);

            // Test single-threaded consumption.
            int sweeps    = 10;
            var paramSets = new List <ParameterSet>();

            for (int i = 0; i < sweeps; i++)
            {
                var task = sweeper.Propose();
                Assert.True(task.IsCompleted);
                paramSets.Add(task.Result.ParameterSet);
                var result = new RunResult(task.Result.ParameterSet, random.NextDouble(), true);
                sweeper.Update(task.Result.Id, result);
            }
            Assert.Equal(sweeps, paramSets.Count);
            CheckAsyncSweeperResult(paramSets);

            // Create two batches and test if the 2nd batch is executed after the synchronization barrier is reached.
            object mlock = new object();
            var    tasks = new Task <ParameterSetWithId> [sweeps];

            args.Relaxation = args.Relaxation - 1;
            sweeper         = new DeterministicSweeperAsync(env, args);
            paramSets.Clear();
            var results = new List <KeyValuePair <int, IRunResult> >();

            for (int i = 0; i < args.BatchSize; i++)
            {
                var task = sweeper.Propose();
                Assert.True(task.IsCompleted);
                tasks[i] = task;
                if (task.Result == null)
                {
                    continue;
                }
                results.Add(new KeyValuePair <int, IRunResult>(task.Result.Id, new RunResult(task.Result.ParameterSet, 0.42, true)));
            }
            // Register consumers for the 2nd batch. Those consumers will await until at least one run
            // in the previous batch has been posted to the sweeper.
            for (int i = args.BatchSize; i < 2 * args.BatchSize; i++)
            {
                var task = sweeper.Propose();
                Assert.False(task.IsCompleted);
                tasks[i] = task;
            }
            // Call update to unblock the 2nd batch.
            foreach (var run in results)
            {
                sweeper.Update(run.Key, run.Value);
            }

            Task.WaitAll(tasks);
            tasks.All(t => t.IsCompleted);
        }
        public void TestDeterministicSweeperAsyncParallel()
        {
            var       random    = new Random(42);
            var       env       = new MLContext(42);
            const int batchSize = 5;
            const int sweeps    = 20;
            var       paramSets = new List <ParameterSet>();
            var       args      = new DeterministicSweeperAsync.Arguments();

            args.BatchSize  = batchSize;
            args.Relaxation = batchSize - 2;

            args.Sweeper = ComponentFactoryUtils.CreateFromFunction(
                environ => new SmacSweeper(environ,
                                           new SmacSweeper.Arguments()
            {
                SweptParameters = new IComponentFactory <INumericValueGenerator>[] {
                    ComponentFactoryUtils.CreateFromFunction(
                        t => new FloatValueGenerator(new FloatParamArguments()
                    {
                        Name = "foo", Min = 1, Max = 5
                    })),
                    ComponentFactoryUtils.CreateFromFunction(
                        t => new LongValueGenerator(new LongParamArguments()
                    {
                        Name = "bar", Min = 1, Max = 1000, LogBase = true
                    }))
                }
            }));

            var sweeper = new DeterministicSweeperAsync(env, args);

            var mlock   = new object();
            var options = new ParallelOptions();

            options.MaxDegreeOfParallelism = 4;

            // Sleep randomly to simulate doing work.
            int[] sleeps = new int[sweeps];
            for (int i = 0; i < sleeps.Length; i++)
            {
                sleeps[i] = random.Next(10, 100);
            }
            var r = Parallel.For(0, sweeps, options, (int i) =>
            {
                var task = sweeper.Propose();
                task.Wait();
                Assert.Equal(TaskStatus.RanToCompletion, task.Status);
                var paramWithId = task.Result;
                if (paramWithId == null)
                {
                    return;
                }
                Thread.Sleep(sleeps[i]);
                var result = new RunResult(paramWithId.ParameterSet, 0.42, true);
                sweeper.Update(paramWithId.Id, result);
                lock (mlock)
                    paramSets.Add(paramWithId.ParameterSet);
            });

            Assert.True(paramSets.Count <= sweeps);
            CheckAsyncSweeperResult(paramSets);
        }
        public async Task TestNelderMeadSweeperAsync()
        {
            var       random    = new Random(42);
            var       env       = new MLContext(42);
            const int batchSize = 5;
            const int sweeps    = 40;
            var       paramSets = new List <ParameterSet>();
            var       args      = new DeterministicSweeperAsync.Arguments();

            args.BatchSize  = batchSize;
            args.Relaxation = 0;

            args.Sweeper = ComponentFactoryUtils.CreateFromFunction(
                environ =>
            {
                var param = new IComponentFactory <INumericValueGenerator>[] {
                    ComponentFactoryUtils.CreateFromFunction(
                        innerEnviron => new FloatValueGenerator(new FloatParamArguments()
                    {
                        Name = "foo", Min = 1, Max = 5
                    })),
                    ComponentFactoryUtils.CreateFromFunction(
                        innerEnviron => new LongValueGenerator(new LongParamArguments()
                    {
                        Name = "bar", Min = 1, Max = 1000, LogBase = true
                    }))
                };

                var nelderMeadSweeperArgs = new NelderMeadSweeper.Arguments()
                {
                    SweptParameters   = param,
                    FirstBatchSweeper = ComponentFactoryUtils.CreateFromFunction <IValueGenerator[], ISweeper>(
                        (firstBatchSweeperEnviron, firstBatchSweeperArgs) =>
                        new RandomGridSweeper(environ, new RandomGridSweeper.Arguments()
                    {
                        SweptParameters = param
                    }))
                };

                return(new NelderMeadSweeper(environ, nelderMeadSweeperArgs));
            }
                );

            var sweeper = new DeterministicSweeperAsync(env, args);
            var mlock   = new object();

            double[] metrics = new double[sweeps];
            for (int i = 0; i < metrics.Length; i++)
            {
                metrics[i] = random.NextDouble();
            }

            for (int i = 0; i < sweeps; i++)
            {
                var paramWithId = await sweeper.Propose();

                if (paramWithId == null)
                {
                    return;
                }
                var result = new RunResult(paramWithId.ParameterSet, metrics[i], true);
                sweeper.Update(paramWithId.Id, result);
                lock (mlock)
                    paramSets.Add(paramWithId.ParameterSet);
            }
            Assert.True(paramSets.Count <= sweeps);
            CheckAsyncSweeperResult(paramSets);
        }
        public void TestRandomGridSweeper()
        {
            var env  = new MLContext(42);
            var args = new RandomGridSweeper.Arguments()
            {
                SweptParameters = new[] {
                    ComponentFactoryUtils.CreateFromFunction(
                        environ => new LongValueGenerator(new LongParamArguments()
                    {
                        Name = "foo", Min = 10, Max = 20, NumSteps = 3
                    })),
                    ComponentFactoryUtils.CreateFromFunction(
                        environ => new LongValueGenerator(new LongParamArguments()
                    {
                        Name = "bar", Min = 100, Max = 10000, LogBase = true, StepSize = 10
                    }))
                }
            };
            var sweeper     = new RandomGridSweeper(env, args);
            var initialList = sweeper.ProposeSweeps(5, new List <RunResult>());

            Assert.Equal(5, initialList.Length);
            var gridPoint = new bool[3][] {
                new bool[3],
                new bool[3],
                new bool[3]
            };
            int i = 0;
            int j = 0;

            foreach (var parameterSet in initialList)
            {
                foreach (var parameterValue in parameterSet)
                {
                    if (parameterValue.Name == "foo")
                    {
                        var val = long.Parse(parameterValue.ValueText);
                        Assert.True(val == 10 || val == 15 || val == 20);
                        i = (val == 10) ? 0 : (val == 15) ? 1 : 2;
                    }
                    else if (parameterValue.Name == "bar")
                    {
                        var val = long.Parse(parameterValue.ValueText);
                        Assert.True(val == 100 || val == 1000 || val == 10000);
                        j = (val == 100) ? 0 : (val == 1000) ? 1 : 2;
                    }
                    else
                    {
                        Assert.True(false, "Wrong parameter");
                    }
                }
                Assert.False(gridPoint[i][j]);
                gridPoint[i][j] = true;
            }

            var nextList = sweeper.ProposeSweeps(5, initialList.Select(p => new RunResult(p)));

            Assert.Equal(4, nextList.Length);
            foreach (var parameterSet in nextList)
            {
                foreach (var parameterValue in parameterSet)
                {
                    if (parameterValue.Name == "foo")
                    {
                        var val = long.Parse(parameterValue.ValueText);
                        Assert.True(val == 10 || val == 15 || val == 20);
                        i = (val == 10) ? 0 : (val == 15) ? 1 : 2;
                    }
                    else if (parameterValue.Name == "bar")
                    {
                        var val = long.Parse(parameterValue.ValueText);
                        Assert.True(val == 100 || val == 1000 || val == 10000);
                        j = (val == 100) ? 0 : (val == 1000) ? 1 : 2;
                    }
                    else
                    {
                        Assert.True(false, "Wrong parameter");
                    }
                }
                Assert.False(gridPoint[i][j]);
                gridPoint[i][j] = true;
            }

            gridPoint = new bool[3][] {
                new bool[3],
                new bool[3],
                new bool[3]
            };
            var lastList = sweeper.ProposeSweeps(10, null);

            Assert.Equal(9, lastList.Length);
            foreach (var parameterSet in lastList)
            {
                foreach (var parameterValue in parameterSet)
                {
                    if (parameterValue.Name == "foo")
                    {
                        var val = long.Parse(parameterValue.ValueText);
                        Assert.True(val == 10 || val == 15 || val == 20);
                        i = (val == 10) ? 0 : (val == 15) ? 1 : 2;
                    }
                    else if (parameterValue.Name == "bar")
                    {
                        var val = long.Parse(parameterValue.ValueText);
                        Assert.True(val == 100 || val == 1000 || val == 10000);
                        j = (val == 100) ? 0 : (val == 1000) ? 1 : 2;
                    }
                    else
                    {
                        Assert.True(false, "Wrong parameter");
                    }
                }
                Assert.False(gridPoint[i][j]);
                gridPoint[i][j] = true;
            }
            Assert.True(gridPoint.All(bArray => bArray.All(b => b)));
        }