Exemple #1
0
 public Arguments()
 {
     BasePredictors = new[]
     {
         ComponentFactoryUtils.CreateFromFunction(
             env => {
             // Note that this illustrates a fundamnetal problem with the mixture of `ITrainer` and `ITrainerEstimator`
             // present in this class. The options to the estimator have no way of being communicated to the `ITrainer`
             // implementation, so there is a fundamnetal disconnect if someone chooses to ever use the *estimator* with
             // non-default column names. Unfortuantely no method of resolving this temporary strikes me as being any
             // less laborious than the proper fix, which is that this "meta" component should itself be a trainer
             // estimator, as opposed to a regular trainer.
             var trainerEstimator = new MulticlassLogisticRegression(env, LabelColumn, FeatureColumn);
             return(TrainerUtils.MapTrainerEstimatorToTrainer <MulticlassLogisticRegression,
                                                               MulticlassLogisticRegressionModelParameters, MulticlassLogisticRegressionModelParameters>(env, trainerEstimator));
         })
     };
 }
        public void TestRandomSweeper()
        {
            var env  = new MLContext(42);
            var args = new SweeperBase.ArgumentsBase()
            {
                SweptParameters = new[] {
                    ComponentFactoryUtils.CreateFromFunction(
                        environ => new LongValueGenerator(new LongParamArguments()
                    {
                        Name = "foo", Min = 10, Max = 20
                    })),
                    ComponentFactoryUtils.CreateFromFunction(
                        environ => new LongValueGenerator(new LongParamArguments()
                    {
                        Name = "bar", Min = 100, Max = 200
                    }))
                }
            };

            var sweeper     = new UniformRandomSweeper(env, args);
            var initialList = sweeper.ProposeSweeps(5, new List <RunResult>());

            Assert.Equal(5, initialList.Length);
            foreach (var parameterSet in initialList)
            {
                foreach (var parameterValue in parameterSet)
                {
                    if (parameterValue.Name == "foo")
                    {
                        var val = long.Parse(parameterValue.ValueText);
                        Assert.InRange(val, 10, 20);
                    }
                    else if (parameterValue.Name == "bar")
                    {
                        var val = long.Parse(parameterValue.ValueText);
                        Assert.InRange(val, 100, 200);
                    }
                    else
                    {
                        Assert.True(false, "Wrong parameter");
                    }
                }
            }
        }
        public void Metacomponents()
        {
            using (var env = new LocalEnvironment())
            {
                var loader  = TextLoader.ReadFile(env, MakeIrisTextLoaderArgs(), new MultiFileSource(GetDataPath(TestDatasets.irisData.trainFilename)));
                var term    = TermTransform.Create(env, loader, "Label");
                var concat  = new ConcatTransform(env, "Features", "SepalLength", "SepalWidth", "PetalLength", "PetalWidth").Transform(term);
                var trainer = new Ova(env, new Ova.Arguments
                {
                    PredictorType = ComponentFactoryUtils.CreateFromFunction(
                        e => new AveragedPerceptronTrainer(env, new AveragedPerceptronTrainer.Arguments()))
                });

                IDataView trainData  = trainer.Info.WantCaching ? (IDataView) new CacheDataView(env, concat, prefetch: null) : concat;
                var       trainRoles = new RoleMappedData(trainData, label: "Label", feature: "Features");

                // Auto-normalization.
                NormalizeTransform.CreateIfNeeded(env, ref trainRoles, trainer);
                var predictor = trainer.Train(new TrainContext(trainRoles));
            }
        }
Exemple #4
0
        public void Metacomponents()
        {
            var dataPath = GetDataPath(IrisDataPath);

            using (var env = new TlcEnvironment())
            {
                var loader  = new TextLoader(env, MakeIrisTextLoaderArgs(), new MultiFileSource(dataPath));
                var term    = new TermTransform(env, loader, "Label");
                var concat  = new ConcatTransform(env, term, "Features", "SepalLength", "SepalWidth", "PetalLength", "PetalWidth");
                var trainer = new Ova(env, new Ova.Arguments
                {
                    PredictorType = ComponentFactoryUtils.CreateFromFunction(
                        e => new FastTreeBinaryClassificationTrainer(e, new FastTreeBinaryClassificationTrainer.Arguments()))
                });

                IDataView trainData  = trainer.Info.WantCaching ? (IDataView) new CacheDataView(env, concat, prefetch: null) : concat;
                var       trainRoles = new RoleMappedData(trainData, label: "Label", feature: "Features");

                // Auto-normalization.
                NormalizeTransform.CreateIfNeeded(env, ref trainRoles, trainer);
                var predictor = trainer.Train(new TrainContext(trainRoles));
            }
        }
        void Metacomponents()
        {
            var dataPath = GetDataPath(IrisDataPath);

            using (var env = new TlcEnvironment())
            {
                var loader  = new TextLoader(env, MakeIrisTextLoaderArgs(), new MultiFileSource(dataPath));
                var term    = new TermTransform(env, loader, "Label");
                var concat  = new ConcatTransform(env, term, "Features", "SepalLength", "SepalWidth", "PetalLength", "PetalWidth");
                var trainer = new Ova(env, new Ova.Arguments
                {
                    PredictorType = ComponentFactoryUtils.CreateFromFunction(
                        (e) => new FastTreeBinaryClassificationTrainer(e, new FastTreeBinaryClassificationTrainer.Arguments()))
                });

                IDataView trainData  = trainer.Info.WantCaching ? (IDataView) new CacheDataView(env, concat, prefetch: null) : concat;
                var       trainRoles = new RoleMappedData(trainData, label: "Label", feature: "Features");

                // Auto-normalization.
                NormalizeTransform.CreateIfNeeded(env, ref trainRoles, trainer);
                var predictor = trainer.Train(new Runtime.TrainContext(trainRoles));

                var scoreRoles = new RoleMappedData(concat, label: "Label", feature: "Features");
                IDataScorerTransform scorer = ScoreUtils.GetScorer(predictor, scoreRoles, env, trainRoles.Schema);

                var keyToValue = new KeyToValueTransform(env, scorer, "PredictedLabel");
                var model      = env.CreatePredictionEngine <IrisData, IrisPrediction>(keyToValue);

                var testLoader = new TextLoader(env, MakeIrisTextLoaderArgs(), new MultiFileSource(dataPath));
                var testData   = testLoader.AsEnumerable <IrisData>(env, false);
                foreach (var input in testData.Take(20))
                {
                    var prediction = model.Predict(input);
                    Assert.True(prediction.PredictedLabel == input.Label);
                }
            }
        }
Exemple #6
0
 public Arguments()
 {
     BasePredictorType = ComponentFactoryUtils.CreateFromFunction(
         env => new FastTreeRegressionTrainer(env, DefaultColumnNames.Label, DefaultColumnNames.Features));
 }
        public void TestRandomGridSweeper()
        {
            var env  = new MLContext(42);
            var args = new RandomGridSweeper.Arguments()
            {
                SweptParameters = new[] {
                    ComponentFactoryUtils.CreateFromFunction(
                        environ => new LongValueGenerator(new LongParamArguments()
                    {
                        Name = "foo", Min = 10, Max = 20, NumSteps = 3
                    })),
                    ComponentFactoryUtils.CreateFromFunction(
                        environ => new LongValueGenerator(new LongParamArguments()
                    {
                        Name = "bar", Min = 100, Max = 10000, LogBase = true, StepSize = 10
                    }))
                }
            };
            var sweeper     = new RandomGridSweeper(env, args);
            var initialList = sweeper.ProposeSweeps(5, new List <RunResult>());

            Assert.Equal(5, initialList.Length);
            var gridPoint = new bool[3][] {
                new bool[3],
                new bool[3],
                new bool[3]
            };
            int i = 0;
            int j = 0;

            foreach (var parameterSet in initialList)
            {
                foreach (var parameterValue in parameterSet)
                {
                    if (parameterValue.Name == "foo")
                    {
                        var val = long.Parse(parameterValue.ValueText);
                        Assert.True(val == 10 || val == 15 || val == 20);
                        i = (val == 10) ? 0 : (val == 15) ? 1 : 2;
                    }
                    else if (parameterValue.Name == "bar")
                    {
                        var val = long.Parse(parameterValue.ValueText);
                        Assert.True(val == 100 || val == 1000 || val == 10000);
                        j = (val == 100) ? 0 : (val == 1000) ? 1 : 2;
                    }
                    else
                    {
                        Assert.True(false, "Wrong parameter");
                    }
                }
                Assert.False(gridPoint[i][j]);
                gridPoint[i][j] = true;
            }

            var nextList = sweeper.ProposeSweeps(5, initialList.Select(p => new RunResult(p)));

            Assert.Equal(4, nextList.Length);
            foreach (var parameterSet in nextList)
            {
                foreach (var parameterValue in parameterSet)
                {
                    if (parameterValue.Name == "foo")
                    {
                        var val = long.Parse(parameterValue.ValueText);
                        Assert.True(val == 10 || val == 15 || val == 20);
                        i = (val == 10) ? 0 : (val == 15) ? 1 : 2;
                    }
                    else if (parameterValue.Name == "bar")
                    {
                        var val = long.Parse(parameterValue.ValueText);
                        Assert.True(val == 100 || val == 1000 || val == 10000);
                        j = (val == 100) ? 0 : (val == 1000) ? 1 : 2;
                    }
                    else
                    {
                        Assert.True(false, "Wrong parameter");
                    }
                }
                Assert.False(gridPoint[i][j]);
                gridPoint[i][j] = true;
            }

            gridPoint = new bool[3][] {
                new bool[3],
                new bool[3],
                new bool[3]
            };
            var lastList = sweeper.ProposeSweeps(10, null);

            Assert.Equal(9, lastList.Length);
            foreach (var parameterSet in lastList)
            {
                foreach (var parameterValue in parameterSet)
                {
                    if (parameterValue.Name == "foo")
                    {
                        var val = long.Parse(parameterValue.ValueText);
                        Assert.True(val == 10 || val == 15 || val == 20);
                        i = (val == 10) ? 0 : (val == 15) ? 1 : 2;
                    }
                    else if (parameterValue.Name == "bar")
                    {
                        var val = long.Parse(parameterValue.ValueText);
                        Assert.True(val == 100 || val == 1000 || val == 10000);
                        j = (val == 100) ? 0 : (val == 1000) ? 1 : 2;
                    }
                    else
                    {
                        Assert.True(false, "Wrong parameter");
                    }
                }
                Assert.False(gridPoint[i][j]);
                gridPoint[i][j] = true;
            }
            Assert.True(gridPoint.All(bArray => bArray.All(b => b)));
        }
        public void TestNelderMeadSweeper()
        {
            var random = new Random(42);
            var env    = new MLContext(42);
            var param  = new IComponentFactory <INumericValueGenerator>[] {
                ComponentFactoryUtils.CreateFromFunction(
                    environ => new FloatValueGenerator(new FloatParamArguments()
                {
                    Name = "foo", Min = 1, Max = 5
                })),
                ComponentFactoryUtils.CreateFromFunction(
                    environ => new LongValueGenerator(new LongParamArguments()
                {
                    Name = "bar", Min = 1, Max = 1000, LogBase = true
                }))
            };

            var args = new NelderMeadSweeper.Arguments()
            {
                SweptParameters   = param,
                FirstBatchSweeper = ComponentFactoryUtils.CreateFromFunction <IValueGenerator[], ISweeper>(
                    (environ, firstBatchArgs) =>
                {
                    return(new RandomGridSweeper(environ, new RandomGridSweeper.Arguments()
                    {
                        SweptParameters = param
                    }));
                }
                    )
            };
            var sweeper = new NelderMeadSweeper(env, args);
            var sweeps  = sweeper.ProposeSweeps(5, new List <RunResult>());

            Assert.Equal(3, sweeps.Length);

            var results = new List <IRunResult>();

            for (int i = 1; i < 10; i++)
            {
                foreach (var parameterSet in sweeps)
                {
                    foreach (var parameterValue in parameterSet)
                    {
                        if (parameterValue.Name == "foo")
                        {
                            var val = float.Parse(parameterValue.ValueText, CultureInfo.InvariantCulture);
                            Assert.InRange(val, 1, 5);
                        }
                        else if (parameterValue.Name == "bar")
                        {
                            var val = long.Parse(parameterValue.ValueText);
                            Assert.InRange(val, 1, 1000);
                        }
                        else
                        {
                            Assert.True(false, "Wrong parameter");
                        }
                    }
                    results.Add(new RunResult(parameterSet, random.NextDouble(), true));
                }

                sweeps = sweeper.ProposeSweeps(5, results);
            }
            Assert.True(sweeps.Length <= 5);
        }
        public void TestDeterministicSweeperAsyncParallel()
        {
            var       random    = new Random(42);
            var       env       = new MLContext(42);
            const int batchSize = 5;
            const int sweeps    = 20;
            var       paramSets = new List <ParameterSet>();
            var       args      = new DeterministicSweeperAsync.Arguments();

            args.BatchSize  = batchSize;
            args.Relaxation = batchSize - 2;

            args.Sweeper = ComponentFactoryUtils.CreateFromFunction(
                environ => new SmacSweeper(environ,
                                           new SmacSweeper.Arguments()
            {
                SweptParameters = new IComponentFactory <INumericValueGenerator>[] {
                    ComponentFactoryUtils.CreateFromFunction(
                        t => new FloatValueGenerator(new FloatParamArguments()
                    {
                        Name = "foo", Min = 1, Max = 5
                    })),
                    ComponentFactoryUtils.CreateFromFunction(
                        t => new LongValueGenerator(new LongParamArguments()
                    {
                        Name = "bar", Min = 1, Max = 1000, LogBase = true
                    }))
                }
            }));

            var sweeper = new DeterministicSweeperAsync(env, args);

            var mlock   = new object();
            var options = new ParallelOptions();

            options.MaxDegreeOfParallelism = 4;

            // Sleep randomly to simulate doing work.
            int[] sleeps = new int[sweeps];
            for (int i = 0; i < sleeps.Length; i++)
            {
                sleeps[i] = random.Next(10, 100);
            }
            var r = Parallel.For(0, sweeps, options, (int i) =>
            {
                var task = sweeper.Propose();
                task.Wait();
                Assert.Equal(TaskStatus.RanToCompletion, task.Status);
                var paramWithId = task.Result;
                if (paramWithId == null)
                {
                    return;
                }
                Thread.Sleep(sleeps[i]);
                var result = new RunResult(paramWithId.ParameterSet, 0.42, true);
                sweeper.Update(paramWithId.Id, result);
                lock (mlock)
                    paramSets.Add(paramWithId.ParameterSet);
            });

            Assert.True(paramSets.Count <= sweeps);
            CheckAsyncSweeperResult(paramSets);
        }
        public async Task TestNelderMeadSweeperAsync()
        {
            var       random    = new Random(42);
            var       env       = new MLContext(42);
            const int batchSize = 5;
            const int sweeps    = 40;
            var       paramSets = new List <ParameterSet>();
            var       args      = new DeterministicSweeperAsync.Arguments();

            args.BatchSize  = batchSize;
            args.Relaxation = 0;

            args.Sweeper = ComponentFactoryUtils.CreateFromFunction(
                environ =>
            {
                var param = new IComponentFactory <INumericValueGenerator>[] {
                    ComponentFactoryUtils.CreateFromFunction(
                        innerEnviron => new FloatValueGenerator(new FloatParamArguments()
                    {
                        Name = "foo", Min = 1, Max = 5
                    })),
                    ComponentFactoryUtils.CreateFromFunction(
                        innerEnviron => new LongValueGenerator(new LongParamArguments()
                    {
                        Name = "bar", Min = 1, Max = 1000, LogBase = true
                    }))
                };

                var nelderMeadSweeperArgs = new NelderMeadSweeper.Arguments()
                {
                    SweptParameters   = param,
                    FirstBatchSweeper = ComponentFactoryUtils.CreateFromFunction <IValueGenerator[], ISweeper>(
                        (firstBatchSweeperEnviron, firstBatchSweeperArgs) =>
                        new RandomGridSweeper(environ, new RandomGridSweeper.Arguments()
                    {
                        SweptParameters = param
                    }))
                };

                return(new NelderMeadSweeper(environ, nelderMeadSweeperArgs));
            }
                );

            var sweeper = new DeterministicSweeperAsync(env, args);
            var mlock   = new object();

            double[] metrics = new double[sweeps];
            for (int i = 0; i < metrics.Length; i++)
            {
                metrics[i] = random.NextDouble();
            }

            for (int i = 0; i < sweeps; i++)
            {
                var paramWithId = await sweeper.Propose();

                if (paramWithId == null)
                {
                    return;
                }
                var result = new RunResult(paramWithId.ParameterSet, metrics[i], true);
                sweeper.Update(paramWithId.Id, result);
                lock (mlock)
                    paramSets.Add(paramWithId.ParameterSet);
            }
            Assert.True(paramSets.Count <= sweeps);
            CheckAsyncSweeperResult(paramSets);
        }
        public void TestSimpleSweeperAsync()
        {
            var       random  = new Random(42);
            var       env     = new MLContext(42);
            const int sweeps  = 100;
            var       sweeper = new SimpleAsyncSweeper(env, new SweeperBase.ArgumentsBase
            {
                SweptParameters = new IComponentFactory <IValueGenerator>[] {
                    ComponentFactoryUtils.CreateFromFunction(
                        environ => new FloatValueGenerator(new FloatParamArguments()
                    {
                        Name = "foo", Min = 1, Max = 5
                    })),
                    ComponentFactoryUtils.CreateFromFunction(
                        environ => new LongValueGenerator(new LongParamArguments()
                    {
                        Name = "bar", Min = 1, Max = 1000, LogBase = true
                    }))
                }
            });

            var paramSets = new List <ParameterSet>();

            for (int i = 0; i < sweeps; i++)
            {
                var task = sweeper.Propose();
                Assert.True(task.IsCompleted);
                paramSets.Add(task.Result.ParameterSet);
                var result = new RunResult(task.Result.ParameterSet, random.NextDouble(), true);
                sweeper.Update(task.Result.Id, result);
            }
            Assert.Equal(sweeps, paramSets.Count);
            CheckAsyncSweeperResult(paramSets);

            // Test consumption without ever calling Update.
            var gridArgs = new RandomGridSweeper.Arguments();

            gridArgs.SweptParameters = new IComponentFactory <INumericValueGenerator>[] {
                ComponentFactoryUtils.CreateFromFunction(
                    environ => new FloatValueGenerator(new FloatParamArguments()
                {
                    Name = "foo", Min = 1, Max = 5
                })),
                ComponentFactoryUtils.CreateFromFunction(
                    environ => new LongValueGenerator(new LongParamArguments()
                {
                    Name = "bar", Min = 1, Max = 100, LogBase = true
                }))
            };
            var gridSweeper = new SimpleAsyncSweeper(env, gridArgs);

            paramSets.Clear();
            for (int i = 0; i < sweeps; i++)
            {
                var task = gridSweeper.Propose();
                Assert.True(task.IsCompleted);
                paramSets.Add(task.Result.ParameterSet);
            }
            Assert.Equal(sweeps, paramSets.Count);
            CheckAsyncSweeperResult(paramSets);
        }
        public void TestDeterministicSweeperAsync()
        {
            var random = new Random(42);
            var env    = new MLContext(42);
            var args   = new DeterministicSweeperAsync.Arguments();

            args.BatchSize  = 5;
            args.Relaxation = args.BatchSize - 1;

            args.Sweeper = ComponentFactoryUtils.CreateFromFunction(
                environ => new SmacSweeper(environ,
                                           new SmacSweeper.Arguments()
            {
                SweptParameters = new IComponentFactory <INumericValueGenerator>[] {
                    ComponentFactoryUtils.CreateFromFunction(
                        t => new FloatValueGenerator(new FloatParamArguments()
                    {
                        Name = "foo", Min = 1, Max = 5
                    })),
                    ComponentFactoryUtils.CreateFromFunction(
                        t => new LongValueGenerator(new LongParamArguments()
                    {
                        Name = "bar", Min = 1, Max = 1000, LogBase = true
                    }))
                }
            }));

            var sweeper = new DeterministicSweeperAsync(env, args);

            // Test single-threaded consumption.
            int sweeps    = 10;
            var paramSets = new List <ParameterSet>();

            for (int i = 0; i < sweeps; i++)
            {
                var task = sweeper.Propose();
                Assert.True(task.IsCompleted);
                paramSets.Add(task.Result.ParameterSet);
                var result = new RunResult(task.Result.ParameterSet, random.NextDouble(), true);
                sweeper.Update(task.Result.Id, result);
            }
            Assert.Equal(sweeps, paramSets.Count);
            CheckAsyncSweeperResult(paramSets);

            // Create two batches and test if the 2nd batch is executed after the synchronization barrier is reached.
            object mlock = new object();
            var    tasks = new Task <ParameterSetWithId> [sweeps];

            args.Relaxation = args.Relaxation - 1;
            sweeper         = new DeterministicSweeperAsync(env, args);
            paramSets.Clear();
            var results = new List <KeyValuePair <int, IRunResult> >();

            for (int i = 0; i < args.BatchSize; i++)
            {
                var task = sweeper.Propose();
                Assert.True(task.IsCompleted);
                tasks[i] = task;
                if (task.Result == null)
                {
                    continue;
                }
                results.Add(new KeyValuePair <int, IRunResult>(task.Result.Id, new RunResult(task.Result.ParameterSet, 0.42, true)));
            }
            // Register consumers for the 2nd batch. Those consumers will await until at least one run
            // in the previous batch has been posted to the sweeper.
            for (int i = args.BatchSize; i < 2 * args.BatchSize; i++)
            {
                var task = sweeper.Propose();
                Assert.False(task.IsCompleted);
                tasks[i] = task;
            }
            // Call update to unblock the 2nd batch.
            foreach (var run in results)
            {
                sweeper.Update(run.Key, run.Value);
            }

            Task.WaitAll(tasks);
            tasks.All(t => t.IsCompleted);
        }
Exemple #13
0
 public Arguments()
 {
     BasePredictorType = ComponentFactoryUtils.CreateFromFunction(
         env => new FastTreeBinaryClassificationTrainer(env, new FastTreeBinaryClassificationTrainer.Arguments()));
 }
        // This method is called if only a datafile is specified, without a loader/term and value columns.
        // It determines the type of the Value column and returns the appropriate TextLoader component factory.
        private static IComponentFactory <IMultiStreamSource, IDataLoader> GetLoaderFactory(string filename, bool keyValues, IHost host)
        {
            Contracts.AssertValue(host);

            // If the user specified non-key values, we define the value column to be numeric.
            if (!keyValues)
            {
                return(ComponentFactoryUtils.CreateFromFunction <IMultiStreamSource, IDataLoader>(
                           (env, files) => TextLoader.Create(
                               env,
                               new TextLoader.Arguments()
                {
                    Column = new[]
                    {
                        new TextLoader.Column("Term", DataKind.TX, 0),
                        new TextLoader.Column("Value", DataKind.Num, 1)
                    }
                },
                               files)));
            }

            // If the user specified key values, we scan the values to determine the range of the key type.
            ulong min = ulong.MaxValue;
            ulong max = ulong.MinValue;

            try
            {
                var  txtArgs = new TextLoader.Arguments();
                bool parsed  = CmdParser.ParseArguments(host, "col=Term:TX:0 col=Value:TX:1", txtArgs);
                host.Assert(parsed);
                var data = TextLoader.ReadFile(host, txtArgs, new MultiFileSource(filename));
                using (var cursor = data.GetRowCursor(c => true))
                {
                    var getTerm = cursor.GetGetter <ReadOnlyMemory <char> >(0);
                    var getVal  = cursor.GetGetter <ReadOnlyMemory <char> >(1);
                    ReadOnlyMemory <char> txt = default;

                    using (var ch = host.Start("Creating Text Lookup Loader"))
                    {
                        long countNonKeys = 0;
                        while (cursor.MoveNext())
                        {
                            getVal(ref txt);
                            ulong res;
                            // Try to parse the text as a key value between 1 and ulong.MaxValue. If this succeeds and res>0,
                            // we update max and min accordingly. If res==0 it means the value is missing, in which case we ignore it for
                            // computing max and min.
                            if (Runtime.Data.Conversion.Conversions.Instance.TryParseKey(in txt, 1, ulong.MaxValue, out res))
                            {
                                if (res < min && res != 0)
                                {
                                    min = res;
                                }
                                if (res > max)
                                {
                                    max = res;
                                }
                            }
                            // If parsing as key did not succeed, the value can still be 0, so we try parsing it as a ulong. If it succeeds,
                            // then the value is 0, and we update min accordingly.
                            else if (Runtime.Data.Conversion.Conversions.Instance.TryParse(in txt, out res))
                            {
                                ch.Assert(res == 0);
                                min = 0;
                            }
Exemple #15
0
        public ParameterMixingCalibratedPredictor TrainKMeansAndLR()
        {
            using (var env = new ConsoleEnvironment(seed: 1))
            {
                // Pipeline
                var loader = TextLoader.ReadFile(env,
                                                 new TextLoader.Arguments()
                {
                    HasHeader = true,
                    Separator = ",",
                    Column    = new[] {
                        new TextLoader.Column("Label", DataKind.R4, 14),
                        new TextLoader.Column("CatFeatures", DataKind.TX,
                                              new [] {
                            new TextLoader.Range()
                            {
                                Min = 1, Max = 1
                            },
                            new TextLoader.Range()
                            {
                                Min = 3, Max = 3
                            },
                            new TextLoader.Range()
                            {
                                Min = 5, Max = 9
                            },
                            new TextLoader.Range()
                            {
                                Min = 13, Max = 13
                            }
                        }),
                        new TextLoader.Column("NumFeatures", DataKind.R4,
                                              new [] {
                            new TextLoader.Range()
                            {
                                Min = 0, Max = 0
                            },
                            new TextLoader.Range()
                            {
                                Min = 2, Max = 2
                            },
                            new TextLoader.Range()
                            {
                                Min = 4, Max = 4
                            },
                            new TextLoader.Range()
                            {
                                Min = 10, Max = 12
                            }
                        })
                    }
                }, new MultiFileSource(_dataPath));

                IDataView trans = new CategoricalEstimator(env, "CatFeatures").Fit(loader).Transform(loader);

                trans = NormalizeTransform.CreateMinMaxNormalizer(env, trans, "NumFeatures");
                trans = new ConcatTransform(env, "Features", "NumFeatures", "CatFeatures").Transform(trans);
                trans = TrainAndScoreTransform.Create(env, new TrainAndScoreTransform.Arguments
                {
                    Trainer = ComponentFactoryUtils.CreateFromFunction(host =>
                                                                       new KMeansPlusPlusTrainer(host, "Features", advancedSettings: s =>
                    {
                        s.K = 100;
                    })),
                    FeatureColumn = "Features"
                }, trans);
                trans = new ConcatTransform(env, "Features", "Features", "Score").Transform(trans);

                // Train
                var trainer    = new LogisticRegression(env, "Features", "Label", advancedSettings: args => { args.EnforceNonNegativity = true; args.OptTol = 1e-3f; });
                var trainRoles = new RoleMappedData(trans, label: "Label", feature: "Features");
                return(trainer.Train(trainRoles));
            }
        }
        public void TestDeterministicSweeperAsyncCancellation()
        {
            var random = new Random(42);

            using (var env = new ConsoleEnvironment(42))
            {
                var args = new DeterministicSweeperAsync.Arguments();
                args.BatchSize  = 5;
                args.Relaxation = 1;

                args.Sweeper = ComponentFactoryUtils.CreateFromFunction(
                    environ => new KdoSweeper(environ,
                                              new KdoSweeper.Arguments()
                {
                    SweptParameters = new IComponentFactory <INumericValueGenerator>[] {
                        ComponentFactoryUtils.CreateFromFunction(
                            t => new FloatValueGenerator(new FloatParamArguments()
                        {
                            Name = "foo", Min = 1, Max = 5
                        })),
                        ComponentFactoryUtils.CreateFromFunction(
                            t => new LongValueGenerator(new LongParamArguments()
                        {
                            Name = "bar", Min = 1, Max = 1000, LogBase = true
                        }))
                    }
                }));

                var sweeper = new DeterministicSweeperAsync(env, args);

                int sweeps       = 20;
                var tasks        = new List <Task <ParameterSetWithId> >();
                int numCompleted = 0;
                for (int i = 0; i < sweeps; i++)
                {
                    var task = sweeper.Propose();
                    if (i < args.BatchSize - args.Relaxation)
                    {
                        Assert.True(task.IsCompleted);
                        sweeper.Update(task.Result.Id, new RunResult(task.Result.ParameterSet, random.NextDouble(), true));
                        numCompleted++;
                    }
                    else
                    {
                        tasks.Add(task);
                    }
                }
                // Cancel after the first barrier and check if the number of registered actions
                // is indeed 2 * batchSize.
                sweeper.Cancel();
                Task.WaitAll(tasks.ToArray());
                foreach (var task in tasks)
                {
                    if (task.Result != null)
                    {
                        numCompleted++;
                    }
                }
                Assert.Equal(args.BatchSize + args.BatchSize, numCompleted);
            }
        }
        // This method is called if only a datafile is specified, without a loader/term and value columns.
        // It determines the type of the Value column and returns the appropriate TextLoader component factory.
        private static IComponentFactory <IMultiStreamSource, IDataLoader> GetLoaderFactory(string filename, bool keyValues, IHost host)
        {
            Contracts.AssertValue(host);

            // If the user specified non-key values, we define the value column to be numeric.
            if (!keyValues)
            {
                return(ComponentFactoryUtils.CreateFromFunction <IMultiStreamSource, IDataLoader>(
                           (env, files) => TextLoader.Create(
                               env,
                               new TextLoader.Arguments()
                {
                    Column = new[]
                    {
                        new TextLoader.Column("Term", DataKind.TX, 0),
                        new TextLoader.Column("Value", DataKind.Num, 1)
                    }
                },
                               files)));
            }

            // If the user specified key values, we scan the values to determine the range of the key type.
            ulong min = ulong.MaxValue;
            ulong max = ulong.MinValue;

            try
            {
                var  txtArgs = new TextLoader.Arguments();
                bool parsed  = CmdParser.ParseArguments(host, "col=Term:TX:0 col=Value:TX:1", txtArgs);
                host.Assert(parsed);
                var data = TextLoader.ReadFile(host, txtArgs, new MultiFileSource(filename));
                using (var cursor = data.GetRowCursor(c => true))
                {
                    var    getTerm = cursor.GetGetter <DvText>(0);
                    var    getVal  = cursor.GetGetter <DvText>(1);
                    DvText txt     = default(DvText);

                    using (var ch = host.Start("Creating Text Lookup Loader"))
                    {
                        long countNonKeys = 0;
                        while (cursor.MoveNext())
                        {
                            getVal(ref txt);
                            ulong res;
                            // Try to parse the text as a key value between 1 and ulong.MaxValue. If this succeeds and res>0,
                            // we update max and min accordingly. If res==0 it means the value is missing, in which case we ignore it for
                            // computing max and min.
                            if (Conversions.Instance.TryParseKey(ref txt, 1, ulong.MaxValue, out res))
                            {
                                if (res < min && res != 0)
                                {
                                    min = res;
                                }
                                if (res > max)
                                {
                                    max = res;
                                }
                            }
                            // If parsing as key did not succeed, the value can still be 0, so we try parsing it as a ulong. If it succeeds,
                            // then the value is 0, and we update min accordingly.
                            else if (Conversions.Instance.TryParse(ref txt, out res))
                            {
                                ch.Assert(res == 0);
                                min = 0;
                            }
                            //If parsing as a ulong fails, we increment the counter for the non-key values.
                            else
                            {
                                var term = default(DvText);
                                getTerm(ref term);
                                if (countNonKeys < 5)
                                {
                                    ch.Warning("Term '{0}' in mapping file is mapped to non key value '{1}'", term, txt);
                                }
                                countNonKeys++;
                            }
                        }
                        if (countNonKeys > 0)
                        {
                            ch.Warning("Found {0} non key values in the file '{1}'", countNonKeys, filename);
                        }
                        if (min > max)
                        {
                            min = 0;
                            max = uint.MaxValue - 1;
                            ch.Warning("did not find any valid key values in the file '{0}'", filename);
                        }
                        else
                        {
                            ch.Info("Found key values in the range {0} to {1} in the file '{2}'", min, max, filename);
                        }
                        ch.Done();
                    }
                }
            }
            catch (Exception e)
            {
                throw host.Except(e, "Failed to parse the lookup file '{0}' in TermLookupTransform", filename);
            }

            TextLoader.Column valueColumn = new TextLoader.Column("Value", DataKind.U4, 1);
            if (max - min < (ulong)int.MaxValue)
            {
                valueColumn.KeyRange = new KeyRange(min, max);
            }
            else if (max - min < (ulong)uint.MaxValue)
            {
                valueColumn.KeyRange = new KeyRange(min);
            }
            else
            {
                valueColumn.Type     = DataKind.U8;
                valueColumn.KeyRange = new KeyRange(min);
            }

            return(ComponentFactoryUtils.CreateFromFunction <IMultiStreamSource, IDataLoader>(
                       (env, files) => TextLoader.Create(
                           env,
                           new TextLoader.Arguments()
            {
                Column = new[]
                {
                    new TextLoader.Column("Term", DataKind.TX, 0),
                    valueColumn
                }
            },
                           files)));
        }
        /// <summary>
        /// Method to convert set of sweepable hyperparameters into <see cref="IComponentFactory"/> instances used
        /// by the current smart hyperparameter sweepers.
        /// </summary>
        internal static IComponentFactory <IValueGenerator>[] ConvertToComponentFactories(TlcModule.SweepableParamAttribute[] hps)
        {
            var results = new IComponentFactory <IValueGenerator> [hps.Length];

            for (int i = 0; i < hps.Length; i++)
            {
                switch (hps[i])
                {
                case TlcModule.SweepableDiscreteParamAttribute dp:
                    results[i] = ComponentFactoryUtils.CreateFromFunction(env =>
                    {
                        var dpArgs = new DiscreteParamArguments()
                        {
                            Name   = dp.Name,
                            Values = dp.Options.Select(o => o.ToString()).ToArray()
                        };
                        return(new DiscreteValueGenerator(dpArgs));
                    });
                    break;

                case TlcModule.SweepableFloatParamAttribute fp:
                    results[i] = ComponentFactoryUtils.CreateFromFunction(env =>
                    {
                        var fpArgs = new FloatParamArguments()
                        {
                            Name    = fp.Name,
                            Min     = fp.Min,
                            Max     = fp.Max,
                            LogBase = fp.IsLogScale,
                        };
                        if (fp.NumSteps.HasValue)
                        {
                            fpArgs.NumSteps = fp.NumSteps.Value;
                        }
                        if (fp.StepSize.HasValue)
                        {
                            fpArgs.StepSize = fp.StepSize.Value;
                        }
                        return(new FloatValueGenerator(fpArgs));
                    });
                    break;

                case TlcModule.SweepableLongParamAttribute lp:
                    results[i] = ComponentFactoryUtils.CreateFromFunction(env =>
                    {
                        var lpArgs = new LongParamArguments()
                        {
                            Name    = lp.Name,
                            Min     = lp.Min,
                            Max     = lp.Max,
                            LogBase = lp.IsLogScale
                        };
                        if (lp.NumSteps.HasValue)
                        {
                            lpArgs.NumSteps = lp.NumSteps.Value;
                        }
                        if (lp.StepSize.HasValue)
                        {
                            lpArgs.StepSize = lp.StepSize.Value;
                        }
                        return(new LongValueGenerator(lpArgs));
                    });
                    break;
                }
            }
            return(results);
        }
        private static IPredictor TrainKMeansAndLRCore()
        {
            string dataPath = s_dataPath;

            using (var env = new TlcEnvironment(seed: 1))
            {
                // Pipeline
                var loader = TextLoader.ReadFile(env,
                                                 new TextLoader.Arguments()
                {
                    HasHeader = true,
                    Separator = ",",
                    Column    = new[] {
                        new TextLoader.Column("Label", DataKind.R4, 14),
                        new TextLoader.Column("CatFeatures", DataKind.TX,
                                              new [] {
                            new TextLoader.Range()
                            {
                                Min = 1, Max = 1
                            },
                            new TextLoader.Range()
                            {
                                Min = 3, Max = 3
                            },
                            new TextLoader.Range()
                            {
                                Min = 5, Max = 9
                            },
                            new TextLoader.Range()
                            {
                                Min = 13, Max = 13
                            }
                        }),
                        new TextLoader.Column("NumFeatures", DataKind.R4,
                                              new [] {
                            new TextLoader.Range()
                            {
                                Min = 0, Max = 0
                            },
                            new TextLoader.Range()
                            {
                                Min = 2, Max = 2
                            },
                            new TextLoader.Range()
                            {
                                Min = 4, Max = 4
                            },
                            new TextLoader.Range()
                            {
                                Min = 10, Max = 12
                            }
                        })
                    }
                }, new MultiFileSource(dataPath));

                IDataTransform trans = CategoricalTransform.Create(env, new CategoricalTransform.Arguments
                {
                    Column = new[]
                    {
                        new CategoricalTransform.Column {
                            Name = "CatFeatures", Source = "CatFeatures"
                        }
                    }
                }, loader);

                trans = NormalizeTransform.CreateMinMaxNormalizer(env, trans, "NumFeatures");
                trans = new ConcatTransform(env, trans, "Features", "NumFeatures", "CatFeatures");
                trans = TrainAndScoreTransform.Create(env, new TrainAndScoreTransform.Arguments
                {
                    Trainer = ComponentFactoryUtils.CreateFromFunction(host =>
                                                                       new KMeansPlusPlusTrainer(host, new KMeansPlusPlusTrainer.Arguments()
                    {
                        K = 100
                    })),
                    FeatureColumn = "Features"
                }, trans);
                trans = new ConcatTransform(env, trans, "Features", "Features", "Score");

                // Train
                var trainer = new LogisticRegression(env, new LogisticRegression.Arguments()
                {
                    EnforceNonNegativity = true, OptTol = 1e-3f
                });
                var trainRoles = new RoleMappedData(trans, label: "Label", feature: "Features");
                return(trainer.Train(trainRoles));
            }
        }
        // Factory method for SignatureDataTransform.
        private static IDataTransform Create(IHostEnvironment env, Arguments args, IDataView input)
        {
            Contracts.CheckValue(env, nameof(env));
            var host = env.Register("Tree Featurizer Transform");

            host.CheckValue(args, nameof(args));
            host.CheckValue(input, nameof(input));
            host.CheckUserArg(!string.IsNullOrWhiteSpace(args.TrainedModelFile) || args.Trainer != null, nameof(args.TrainedModelFile),
                              "Please specify either a trainer or an input model file.");
            host.CheckUserArg(!string.IsNullOrEmpty(args.FeatureColumn), nameof(args.FeatureColumn), "Transform needs an input features column");

            IDataTransform xf;

            using (var ch = host.Start("Create Tree Ensemble Scorer"))
            {
                var scorerArgs = new TreeEnsembleFeaturizerBindableMapper.Arguments()
                {
                    Suffix = args.Suffix
                };
                if (!string.IsNullOrWhiteSpace(args.TrainedModelFile))
                {
                    if (args.Trainer != null)
                    {
                        ch.Warning("Both an input model and a trainer were specified. Using the model file.");
                    }

                    ch.Trace("Loading model");
                    IPredictor predictor;
                    using (Stream strm = new FileStream(args.TrainedModelFile, FileMode.Open, FileAccess.Read))
                        using (var rep = RepositoryReader.Open(strm, ch))
                            ModelLoadContext.LoadModel <IPredictor, SignatureLoadModel>(host, out predictor, rep, ModelFileUtils.DirPredictor);

                    ch.Trace("Creating scorer");
                    var data = TrainAndScoreTransformer.CreateDataFromArgs(ch, input, args);
                    Contracts.Assert(data.Schema.Feature.HasValue);

                    // Make sure that the given predictor has the correct number of input features.
                    if (predictor is CalibratedPredictorBase)
                    {
                        predictor = ((CalibratedPredictorBase)predictor).SubPredictor;
                    }
                    // Predictor should be a TreeEnsembleModelParameters, which implements IValueMapper, so this should
                    // be non-null.
                    var vm = predictor as IValueMapper;
                    ch.CheckUserArg(vm != null, nameof(args.TrainedModelFile), "Predictor in model file does not have compatible type");
                    if (vm.InputType.VectorSize != data.Schema.Feature.Value.Type.VectorSize)
                    {
                        throw ch.ExceptUserArg(nameof(args.TrainedModelFile),
                                               "Predictor in model file expects {0} features, but data has {1} features",
                                               vm.InputType.VectorSize, data.Schema.Feature.Value.Type.VectorSize);
                    }

                    ISchemaBindableMapper bindable = new TreeEnsembleFeaturizerBindableMapper(env, scorerArgs, predictor);
                    var bound = bindable.Bind(env, data.Schema);
                    xf = new GenericScorer(env, scorerArgs, input, bound, data.Schema);
                }
                else
                {
                    ch.AssertValue(args.Trainer);

                    ch.Trace("Creating TrainAndScoreTransform");

                    var trainScoreArgs = new TrainAndScoreTransformer.Arguments();
                    args.CopyTo(trainScoreArgs);
                    trainScoreArgs.Trainer = args.Trainer;

                    trainScoreArgs.Scorer = ComponentFactoryUtils.CreateFromFunction <IDataView, ISchemaBoundMapper, RoleMappedSchema, IDataScorerTransform>(
                        (e, data, mapper, trainSchema) => Create(e, scorerArgs, data, mapper, trainSchema));

                    var mapperFactory = ComponentFactoryUtils.CreateFromFunction <IPredictor, ISchemaBindableMapper>(
                        (e, predictor) => new TreeEnsembleFeaturizerBindableMapper(e, scorerArgs, predictor));

                    var labelInput = AppendLabelTransform(host, ch, input, trainScoreArgs.LabelColumn, args.LabelPermutationSeed);
                    var scoreXf    = TrainAndScoreTransformer.Create(host, trainScoreArgs, labelInput, mapperFactory);

                    if (input == labelInput)
                    {
                        return(scoreXf);
                    }
                    return((IDataTransform)ApplyTransformUtils.ApplyAllTransformsToData(host, scoreXf, input, labelInput));
                }
            }
            return(xf);
        }
Exemple #21
0
        private void RunCore(IChannel ch, string cmd)
        {
            Host.AssertValue(ch);

            IPredictor inputPredictor = null;

            if (Args.ContinueTrain && !TrainUtils.TryLoadPredictor(ch, Host, Args.InputModelFile, out inputPredictor))
            {
                ch.Warning("No input model file specified or model file did not contain a predictor. The model state cannot be initialized.");
            }

            ch.Trace("Constructing data pipeline");
            IDataLoader loader = CreateRawLoader();

            // If the per-instance results are requested and there is no name column, add a GenerateNumberTransform.
            var preXf = Args.PreTransform;

            if (!string.IsNullOrEmpty(Args.OutputDataFile))
            {
                string name = TrainUtils.MatchNameOrDefaultOrNull(ch, loader.Schema, nameof(Args.NameColumn), Args.NameColumn, DefaultColumnNames.Name);
                if (name == null)
                {
                    preXf = preXf.Concat(
                        new[]
                    {
                        new KeyValuePair <string, IComponentFactory <IDataView, IDataTransform> >(
                            "", ComponentFactoryUtils.CreateFromFunction <IDataView, IDataTransform>(
                                (env, input) =>
                        {
                            var args    = new GenerateNumberTransform.Arguments();
                            args.Column = new[] { new GenerateNumberTransform.Column()
                                                  {
                                                      Name = DefaultColumnNames.Name
                                                  }, };
                            args.UseCounter = true;
                            return(new GenerateNumberTransform(env, args, input));
                        }))
                    }).ToArray();
                }
            }
            loader = CompositeDataLoader.Create(Host, loader, preXf);

            ch.Trace("Binding label and features columns");

            IDataView pipe = loader;
            var       stratificationColumn = GetSplitColumn(ch, loader, ref pipe);
            var       scorer    = Args.Scorer;
            var       evaluator = Args.Evaluator;

            Func <IDataView> validDataCreator = null;

            if (Args.ValidationFile != null)
            {
                validDataCreator =
                    () =>
                {
                    // Fork the command.
                    var impl = new CrossValidationCommand(this);
                    return(impl.CreateRawLoader(dataFile: Args.ValidationFile));
                };
            }

            FoldHelper fold = new FoldHelper(Host, RegistrationName, pipe, stratificationColumn,
                                             Args, CreateRoleMappedData, ApplyAllTransformsToData, scorer, evaluator,
                                             validDataCreator, ApplyAllTransformsToData, inputPredictor, cmd, loader, !string.IsNullOrEmpty(Args.OutputDataFile));
            var tasks = fold.GetCrossValidationTasks();

            var eval = evaluator?.CreateComponent(Host) ??
                       EvaluateUtils.GetEvaluator(Host, tasks[0].Result.ScoreSchema);

            // Print confusion matrix and fold results for each fold.
            for (int i = 0; i < tasks.Length; i++)
            {
                var dict = tasks[i].Result.Metrics;
                MetricWriter.PrintWarnings(ch, dict);
                eval.PrintFoldResults(ch, dict);
            }

            // Print the overall results.
            if (!TryGetOverallMetrics(tasks.Select(t => t.Result.Metrics).ToArray(), out var overallList))
            {
                throw ch.Except("No overall metrics found");
            }

            var overall = eval.GetOverallResults(overallList.ToArray());

            MetricWriter.PrintOverallMetrics(Host, ch, Args.SummaryFilename, overall, Args.NumFolds);
            eval.PrintAdditionalMetrics(ch, tasks.Select(t => t.Result.Metrics).ToArray());
            Dictionary <string, IDataView>[] metricValues = tasks.Select(t => t.Result.Metrics).ToArray();
            SendTelemetryMetric(metricValues);

            // Save the per-instance results.
            if (!string.IsNullOrWhiteSpace(Args.OutputDataFile))
            {
                var perInstance = EvaluateUtils.ConcatenatePerInstanceDataViews(Host, eval, Args.CollateMetrics,
                                                                                Args.OutputExampleFoldIndex, tasks.Select(t => t.Result.PerInstanceResults).ToArray(), out var variableSizeVectorColumnNames);
                if (variableSizeVectorColumnNames.Length > 0)
                {
                    ch.Warning("Detected columns of variable length: {0}. Consider setting collateMetrics- for meaningful per-Folds results.",
                               string.Join(", ", variableSizeVectorColumnNames));
                }
                if (Args.CollateMetrics)
                {
                    ch.Assert(perInstance.Length == 1);
                    MetricWriter.SavePerInstance(Host, ch, Args.OutputDataFile, perInstance[0]);
                }
                else
                {
                    int i = 0;
                    foreach (var idv in perInstance)
                    {
                        MetricWriter.SavePerInstance(Host, ch, ConstructPerFoldName(Args.OutputDataFile, i), idv);
                        i++;
                    }
                }
            }
        }