Пример #1
0
        public void TestOldSavingAndLoading()
        {
            var data = new[] { new TestClass()
                               {
                                   A = 1, B = 2, C = 3,
                               }, new TestClass()
                               {
                                   A = 4, B = 5, C = 6
                               } };
            var dataView = ComponentCreation.CreateDataView(Env, data);
            var pipe     = new CategoricalEstimator(Env, new[] {
                new CategoricalEstimator.ColumnInfo("A", "TermA"),
                new CategoricalEstimator.ColumnInfo("B", "TermB"),
                new CategoricalEstimator.ColumnInfo("C", "TermC")
            });
            var result      = pipe.Fit(dataView).Transform(dataView);
            var resultRoles = new RoleMappedData(result);

            using (var ms = new MemoryStream())
            {
                TrainUtils.SaveModel(Env, Env.Start("saving"), ms, null, resultRoles);
                ms.Position = 0;
                var loadedView = ModelFileUtils.LoadTransforms(Env, dataView, ms);
            }
        }
Пример #2
0
        public void TrainOnAutoGeneratedData()
        {
            // Create a new environment for ML.NET operations. It can be used for exception tracking and logging,
            // as well as the source of randomness.
            var env = new LocalEnvironment();

            // Step one: read the data as an IDataView.
            // Let's assume that 'GetChurnData()' fetches and returns the training data from somewhere.
            IEnumerable <CustomerChurnInfo> churnData = GetChurnInfo();

            // Turn the data into the ML.NET data view.
            // We can use CreateDataView or CreateStreamingDataView, depending on whether 'churnData' is an IList,
            // or merely an IEnumerable.
            var trainData = env.CreateStreamingDataView(churnData);

            // Now note that 'trainData' is just an IDataView, so we face a choice here: either declare the static type
            // and proceed in the statically typed fashion, or keep dynamic types and build a dynamic pipeline.
            // We demonstrate both below.

            // We know that this is a binary classification task, so we create a binary classification context: it will give us the algorithms
            // we need, as well as the evaluation procedure.
            var classification = new BinaryClassificationContext(env);

            // Build the learning pipeline.
            // In our case, we will one-hot encode the demographic category, and concatenate that with the number of visits.
            // We apply our FastTree binary classifier to predict the 'HasChurned' label.

            var dynamicLearningPipeline = new CategoricalEstimator(env, "DemographicCategory")
                                          .Append(new ConcatEstimator(env, "Features", "DemographicCategory", "LastVisits"))
                                          .Append(new FastTreeBinaryClassificationTrainer(env, "HasChurned", "Features", numTrees: 20));

            var dynamicModel = dynamicLearningPipeline.Fit(trainData);

            // Build the same learning pipeline, but statically typed.
            // First, transition to the statically-typed data view.
            var staticData = trainData.AssertStatic(env, c => (
                                                        HasChurned: c.Bool.Scalar,
                                                        DemographicCategory: c.Text.Scalar,
                                                        LastVisits: c.R4.Vector));

            // Build the pipeline, same as the one above.
            var staticLearningPipeline = staticData.MakeNewEstimator()
                                         .Append(r => (
                                                     r.HasChurned,
                                                     Features: r.DemographicCategory.OneHotEncoding().ConcatWith(r.LastVisits)))
                                         .Append(r => classification.Trainers.FastTree(r.HasChurned, r.Features, numTrees: 20));

            var staticModel = staticLearningPipeline.Fit(staticData);

            // Note that dynamicModel should be the same as staticModel.AsDynamic (give or take random variance from
            // the training procedure).

            var qualityMetrics = classification.Evaluate(dynamicModel.Transform(trainData), "HasChurned");
        }
Пример #3
0
        public void CategoricalWorkout()
        {
            var data = new[] { new TestClass()
                               {
                                   A = 1, B = 2, C = 3,
                               }, new TestClass()
                               {
                                   A = 4, B = 5, C = 6
                               } };

            var dataView = ComponentCreation.CreateDataView(Env, data);
            var pipe     = new CategoricalEstimator(Env, new[] {
                new CategoricalEstimator.ColumnInfo("A", "CatA", CategoricalTransform.OutputKind.Bag),
                new CategoricalEstimator.ColumnInfo("A", "CatB", CategoricalTransform.OutputKind.Bin),
                new CategoricalEstimator.ColumnInfo("A", "CatC", CategoricalTransform.OutputKind.Ind),
                new CategoricalEstimator.ColumnInfo("A", "CatD", CategoricalTransform.OutputKind.Key),
            });

            TestEstimatorCore(pipe, dataView);
            Done();
        }
Пример #4
0
        public ParameterMixingCalibratedPredictor TrainKMeansAndLR()
        {
            using (var env = new ConsoleEnvironment(seed: 1))
            {
                // Pipeline
                var loader = TextLoader.ReadFile(env,
                                                 new TextLoader.Arguments()
                {
                    HasHeader = true,
                    Separator = ",",
                    Column    = new[] {
                        new TextLoader.Column("Label", DataKind.R4, 14),
                        new TextLoader.Column("CatFeatures", DataKind.TX,
                                              new [] {
                            new TextLoader.Range()
                            {
                                Min = 1, Max = 1
                            },
                            new TextLoader.Range()
                            {
                                Min = 3, Max = 3
                            },
                            new TextLoader.Range()
                            {
                                Min = 5, Max = 9
                            },
                            new TextLoader.Range()
                            {
                                Min = 13, Max = 13
                            }
                        }),
                        new TextLoader.Column("NumFeatures", DataKind.R4,
                                              new [] {
                            new TextLoader.Range()
                            {
                                Min = 0, Max = 0
                            },
                            new TextLoader.Range()
                            {
                                Min = 2, Max = 2
                            },
                            new TextLoader.Range()
                            {
                                Min = 4, Max = 4
                            },
                            new TextLoader.Range()
                            {
                                Min = 10, Max = 12
                            }
                        })
                    }
                }, new MultiFileSource(_dataPath));

                IDataView trans = new CategoricalEstimator(env, "CatFeatures").Fit(loader).Transform(loader);

                trans = NormalizeTransform.CreateMinMaxNormalizer(env, trans, "NumFeatures");
                trans = new ConcatTransform(env, "Features", "NumFeatures", "CatFeatures").Transform(trans);
                trans = TrainAndScoreTransform.Create(env, new TrainAndScoreTransform.Arguments
                {
                    Trainer = ComponentFactoryUtils.CreateFromFunction(host =>
                                                                       new KMeansPlusPlusTrainer(host, "Features", advancedSettings: s =>
                    {
                        s.K = 100;
                    })),
                    FeatureColumn = "Features"
                }, trans);
                trans = new ConcatTransform(env, "Features", "Features", "Score").Transform(trans);

                // Train
                var trainer    = new LogisticRegression(env, "Features", "Label", advancedSettings: args => { args.EnforceNonNegativity = true; args.OptTol = 1e-3f; });
                var trainRoles = new RoleMappedData(trans, label: "Label", feature: "Features");
                return(trainer.Train(trainRoles));
            }
        }
Пример #5
0
        public void TestMetadataPropagation()
        {
            var data = new[] {
                new TestMeta()
                {
                    A = new string[2] {
                        "A", "B"
                    }, B = "C", C = new int[2] {
                        3, 5
                    }, D = 6, E = new float[2] {
                        1.0f, 2.0f
                    }, F = 1.0f, G = new string[2] {
                        "A", "D"
                    }, H = "D"
                },
                new TestMeta()
                {
                    A = new string[2] {
                        "A", "B"
                    }, B = "C", C = new int[2] {
                        5, 3
                    }, D = 1, E = new float[2] {
                        3.0f, 4.0f
                    }, F = -1.0f, G = new string[2] {
                        "E", "A"
                    }, H = "E"
                },
                new TestMeta()
                {
                    A = new string[2] {
                        "A", "B"
                    }, B = "C", C = new int[2] {
                        3, 5
                    }, D = 6, E = new float[2] {
                        5.0f, 6.0f
                    }, F = 1.0f, G = new string[2] {
                        "D", "E"
                    }, H = "D"
                }
            };


            var dataView = ComponentCreation.CreateDataView(Env, data);
            var pipe     = new CategoricalEstimator(Env,
                                                    new CategoricalEstimator.ColumnInfo("A", "CatA", CategoricalTransform.OutputKind.Bag),
                                                    new CategoricalEstimator.ColumnInfo("B", "CatB", CategoricalTransform.OutputKind.Bag),
                                                    new CategoricalEstimator.ColumnInfo("C", "CatC", CategoricalTransform.OutputKind.Bag),
                                                    new CategoricalEstimator.ColumnInfo("D", "CatD", CategoricalTransform.OutputKind.Bag),
                                                    new CategoricalEstimator.ColumnInfo("E", "CatE", CategoricalTransform.OutputKind.Ind),
                                                    new CategoricalEstimator.ColumnInfo("F", "CatF", CategoricalTransform.OutputKind.Ind),
                                                    new CategoricalEstimator.ColumnInfo("G", "CatG", CategoricalTransform.OutputKind.Key),
                                                    new CategoricalEstimator.ColumnInfo("H", "CatH", CategoricalTransform.OutputKind.Key),
                                                    new CategoricalEstimator.ColumnInfo("A", "CatI", CategoricalTransform.OutputKind.Bin),
                                                    new CategoricalEstimator.ColumnInfo("B", "CatJ", CategoricalTransform.OutputKind.Bin),
                                                    new CategoricalEstimator.ColumnInfo("C", "CatK", CategoricalTransform.OutputKind.Bin),
                                                    new CategoricalEstimator.ColumnInfo("D", "CatL", CategoricalTransform.OutputKind.Bin));


            var result = pipe.Fit(dataView).Transform(dataView);

            ValidateMetadata(result);
            Done();
        }