Exemple #1
0
        void DecomposableTrainAndPredict()
        {
            using (var env = new LocalEnvironment()
                             .AddStandardComponents()) // ScoreUtils.GetScorer requires scorers to be registered in the ComponentCatalog
            {
                var loader  = TextLoader.ReadFile(env, MakeIrisTextLoaderArgs(), new MultiFileSource(GetDataPath(TestDatasets.irisData.trainFilename)));
                var term    = TermTransform.Create(env, loader, "Label");
                var concat  = new ConcatTransform(env, "Features", "SepalLength", "SepalWidth", "PetalLength", "PetalWidth").Transform(term);
                var trainer = new SdcaMultiClassTrainer(env, "Features", "Label", advancedSettings: (s) => { s.MaxIterations = 100; s.Shuffle = true; s.NumThreads = 1; });

                IDataView trainData  = trainer.Info.WantCaching ? (IDataView) new CacheDataView(env, concat, prefetch: null) : concat;
                var       trainRoles = new RoleMappedData(trainData, label: "Label", feature: "Features");

                // Auto-normalization.
                NormalizeTransform.CreateIfNeeded(env, ref trainRoles, trainer);
                var predictor = trainer.Train(new Runtime.TrainContext(trainRoles));

                var scoreRoles = new RoleMappedData(concat, label: "Label", feature: "Features");
                IDataScorerTransform scorer = ScoreUtils.GetScorer(predictor, scoreRoles, env, trainRoles.Schema);

                // Cut out term transform from pipeline.
                var newScorer  = ApplyTransformUtils.ApplyAllTransformsToData(env, scorer, loader, term);
                var keyToValue = new KeyToValueTransform(env, "PredictedLabel").Transform(newScorer);
                var model      = env.CreatePredictionEngine <IrisDataNoLabel, IrisPrediction>(keyToValue);

                var testData = loader.AsEnumerable <IrisDataNoLabel>(env, false);
                foreach (var input in testData.Take(20))
                {
                    var prediction = model.Predict(input);
                    Assert.True(prediction.PredictedLabel == "Iris-setosa");
                }
            }
        }
        public void ConcatWithAliases()
        {
            string dataPath = GetDataPath("adult.test");

            var source = new MultiFileSource(dataPath);
            var loader = new TextLoader(Env, new TextLoader.Arguments
            {
                Column = new[] {
                    new TextLoader.Column("float1", DataKind.R4, 0),
                    new TextLoader.Column("float4", DataKind.R4, new[] { new TextLoader.Range(0), new TextLoader.Range(2), new TextLoader.Range(4), new TextLoader.Range(10) }),
                    new TextLoader.Column("vfloat", DataKind.R4, new[] { new TextLoader.Range(0), new TextLoader.Range(2), new TextLoader.Range(4), new TextLoader.Range(10, null)
                                                                         {
                                                                             AutoEnd = false, VariableEnd = true
                                                                         } })
                },
                Separator = ",",
                HasHeader = true
            }, new MultiFileSource(dataPath));
            var data = loader.Read(source);

            ColumnType GetType(Schema schema, string name)
            {
                Assert.True(schema.TryGetColumnIndex(name, out int cIdx), $"Could not find '{name}'");
                return(schema.GetColumnType(cIdx));
            }

            data = TakeFilter.Create(Env, data, 10);

            var concater = new ConcatTransform(Env,
                                               new ConcatTransform.ColumnInfo("f2", new[] { ("float1", "FLOAT1"), ("float1", "FLOAT2") }),
        public ITransformer Fit(IDataView input)
        {
            var h = _host;

            h.CheckValue(input, nameof(input));

            var tparams = new TransformApplierParams(this);

            string[]      textCols       = _inputColumns;
            string[]      wordTokCols    = null;
            string[]      charTokCols    = null;
            string        wordFeatureCol = null;
            string        charFeatureCol = null;
            List <string> tempCols       = new List <string>();
            IDataView     view           = input;

            if (tparams.NeedInitialSourceColumnConcatTransform && textCols.Length > 1)
            {
                var srcCols = textCols;
                textCols = new[] { GenerateColumnName(input.Schema, OutputColumn, "InitialConcat") };
                tempCols.Add(textCols[0]);
                view = new ConcatTransform(h, textCols[0], srcCols).Transform(view);
            }

            if (tparams.NeedsNormalizeTransform)
            {
                var      xfCols  = new (string input, string output)[textCols.Length];
Exemple #4
0
        public TransformWrapper Fit(IDataView input)
        {
            var xf    = new ConcatTransform(_env, input, _name, _source);
            var empty = new EmptyDataView(_env, input.Schema);
            var chunk = ApplyTransformUtils.ApplyAllTransformsToData(_env, xf, empty, input);

            return(new TransformWrapper(_env, chunk));
        }
        public void TrainAndPredictIrisModelUsingDirectInstantiationTest()
        {
            string dataPath     = GetDataPath("iris.txt");
            string testDataPath = dataPath;

            using (var env = new TlcEnvironment(seed: 1, conc: 1))
            {
                // Pipeline
                var loader = new TextLoader(env,
                                            new TextLoader.Arguments()
                {
                    HasHeader = false,
                    Column    = new[]
                    {
                        new TextLoader.Column("Label", DataKind.R4, 0),
                        new TextLoader.Column("SepalLength", DataKind.R4, 1),
                        new TextLoader.Column("SepalWidth", DataKind.R4, 2),
                        new TextLoader.Column("PetalLength", DataKind.R4, 3),
                        new TextLoader.Column("PetalWidth", DataKind.R4, 4)
                    }
                }, new MultiFileSource(dataPath));

                IDataTransform trans = new ConcatTransform(env, loader, "Features",
                                                           "SepalLength", "SepalWidth", "PetalLength", "PetalWidth");

                // Normalizer is not automatically added though the trainer has 'NormalizeFeatures' On/Auto
                trans = NormalizeTransform.CreateMinMaxNormalizer(env, trans, "Features");

                // Train
                var trainer = new SdcaMultiClassTrainer(env, new SdcaMultiClassTrainer.Arguments()
                {
                    NumThreads = 1
                });

                // Explicity adding CacheDataView since caching is not working though trainer has 'Caching' On/Auto
                var cached     = new CacheDataView(env, trans, prefetch: null);
                var trainRoles = new RoleMappedData(cached, label: "Label", feature: "Features");
                var pred       = trainer.Train(trainRoles);

                // Get scorer and evaluate the predictions from test data
                IDataScorerTransform testDataScorer = GetScorer(env, trans, pred, testDataPath);
                var metrics = Evaluate(env, testDataScorer);
                CompareMatrics(metrics);

                // Create prediction engine and test predictions
                var model = env.CreatePredictionEngine <IrisData, IrisPrediction>(testDataScorer);
                ComparePredictions(model);

                // Get feature importance i.e. weight vector
                var summary = ((MulticlassLogisticRegressionPredictor)pred).GetSummaryInKeyValuePairs(trainRoles.Schema);
                Assert.Equal(7.757864, Convert.ToDouble(summary[0].Value), 5);
            }
        }
        public static CommonOutputs.TransformOutput PrepareFeatures(IHostEnvironment env, FeatureCombinerInput input)
        {
            const string featureCombiner = "FeatureCombiner";

            Contracts.CheckValue(env, nameof(env));
            var host = env.Register(featureCombiner);

            host.CheckValue(input, nameof(input));
            EntryPointUtils.CheckInputArgs(host, input);
            using (var ch = host.Start(featureCombiner))
            {
                var viewTrain = input.Data;
                var rms       = new RoleMappedSchema(viewTrain.Schema, input.GetRoles());
                var feats     = rms.GetColumns(RoleMappedSchema.ColumnRole.Feature);
                if (Utils.Size(feats) == 0)
                {
                    throw ch.Except("No feature columns specified");
                }
                var featNames   = new HashSet <string>();
                var concatNames = new List <KeyValuePair <string, string> >();
                List <ConvertTransform.Column> cvt;
                int errCount;
                var ktv = ConvertFeatures(feats.ToArray(), featNames, concatNames, ch, out cvt, out errCount);
                Contracts.Assert(featNames.Count > 0);
                Contracts.Assert(concatNames.Count == featNames.Count);
                if (errCount > 0)
                {
                    throw ch.Except("Encountered {0} invalid training column(s)", errCount);
                }

                viewTrain = ApplyConvert(cvt, viewTrain, host);
                viewTrain = ApplyKeyToVec(ktv, viewTrain, host);

                // REVIEW: What about column name conflicts? Eg, what if someone uses the group id column
                // (a key type) as a feature column. We convert that column to a vector so it is no longer valid
                // as a group id. That's just one example - you get the idea.
                string nameFeat = DefaultColumnNames.Features;
                viewTrain = new ConcatTransform(host,
                                                new ConcatTransform.TaggedArguments()
                {
                    Column =
                        new[] { new ConcatTransform.TaggedColumn()
                                {
                                    Name = nameFeat, Source = concatNames.ToArray()
                                } }
                },
                                                viewTrain);
                ch.Done();
                return(new CommonOutputs.TransformOutput {
                    Model = new TransformModel(env, viewTrain, input.Data), OutputData = viewTrain
                });
            }
        }
Exemple #7
0
            public Mapper(ConcatTransform parent, Schema inputSchema)
            {
                Contracts.AssertValue(parent);
                Contracts.AssertValue(inputSchema);
                _host        = parent._host.Register(nameof(Mapper));
                _parent      = parent;
                _inputSchema = inputSchema;

                _columns = new BoundColumn[_parent._columns.Length];
                for (int i = 0; i < _parent._columns.Length; i++)
                {
                    _columns[i] = MakeColumn(inputSchema, i);
                }
            }
Exemple #8
0
        public static CommonOutputs.TransformOutput ConcatColumns(IHostEnvironment env, ConcatTransform.Arguments input)
        {
            Contracts.CheckValue(env, nameof(env));
            var host = env.Register("ConcatColumns");

            host.CheckValue(input, nameof(input));
            EntryPointUtils.CheckInputArgs(host, input);

            var xf = new ConcatTransform(env, input, input.Data);

            return(new CommonOutputs.TransformOutput {
                Model = new TransformModel(env, xf, input.Data), OutputData = xf
            });
        }
        void Extensibility()
        {
            var dataPath = GetDataPath(IrisDataPath);

            using (var env = new LocalEnvironment())
            {
                var loader = TextLoader.ReadFile(env, MakeIrisTextLoaderArgs(), new MultiFileSource(dataPath));
                Action <IrisData, IrisData> action = (i, j) =>
                {
                    j.Label       = i.Label;
                    j.PetalLength = i.SepalLength > 3 ? i.PetalLength : i.SepalLength;
                    j.PetalWidth  = i.PetalWidth;
                    j.SepalLength = i.SepalLength;
                    j.SepalWidth  = i.SepalWidth;
                };
                var lambda = LambdaTransform.CreateMap(env, loader, action);
                var term   = TermTransform.Create(env, lambda, "Label");
                var concat = new ConcatTransform(env, "Features", "SepalLength", "SepalWidth", "PetalLength", "PetalWidth")
                             .Transform(term);

                var trainer = new SdcaMultiClassTrainer(env, new SdcaMultiClassTrainer.Arguments {
                    MaxIterations = 100, Shuffle = true, NumThreads = 1
                });

                IDataView trainData  = trainer.Info.WantCaching ? (IDataView) new CacheDataView(env, concat, prefetch: null) : concat;
                var       trainRoles = new RoleMappedData(trainData, label: "Label", feature: "Features");

                // Auto-normalization.
                NormalizeTransform.CreateIfNeeded(env, ref trainRoles, trainer);
                var predictor = trainer.Train(new Runtime.TrainContext(trainRoles));

                var scoreRoles = new RoleMappedData(concat, label: "Label", feature: "Features");
                IDataScorerTransform scorer = ScoreUtils.GetScorer(predictor, scoreRoles, env, trainRoles.Schema);

                var keyToValue = new KeyToValueTransform(env, "PredictedLabel").Transform(scorer);
                var model      = env.CreatePredictionEngine <IrisData, IrisPrediction>(keyToValue);

                var testLoader = TextLoader.ReadFile(env, MakeIrisTextLoaderArgs(), new MultiFileSource(dataPath));
                var testData   = testLoader.AsEnumerable <IrisData>(env, false);
                foreach (var input in testData.Take(20))
                {
                    var prediction = model.Predict(input);
                    Assert.True(prediction.PredictedLabel == input.Label);
                }
            }
        }
        public void Metacomponents()
        {
            using (var env = new LocalEnvironment())
            {
                var loader  = TextLoader.ReadFile(env, MakeIrisTextLoaderArgs(), new MultiFileSource(GetDataPath(TestDatasets.irisData.trainFilename)));
                var term    = TermTransform.Create(env, loader, "Label");
                var concat  = new ConcatTransform(env, "Features", "SepalLength", "SepalWidth", "PetalLength", "PetalWidth").Transform(term);
                var trainer = new Ova(env, new Ova.Arguments
                {
                    PredictorType = ComponentFactoryUtils.CreateFromFunction(
                        e => new AveragedPerceptronTrainer(env, new AveragedPerceptronTrainer.Arguments()))
                });

                IDataView trainData  = trainer.Info.WantCaching ? (IDataView) new CacheDataView(env, concat, prefetch: null) : concat;
                var       trainRoles = new RoleMappedData(trainData, label: "Label", feature: "Features");

                // Auto-normalization.
                NormalizeTransform.CreateIfNeeded(env, ref trainRoles, trainer);
                var predictor = trainer.Train(new TrainContext(trainRoles));
            }
        }
Exemple #11
0
        void Metacomponents()
        {
            var dataPath = GetDataPath(IrisDataPath);

            using (var env = new TlcEnvironment())
            {
                var loader  = new TextLoader(env, MakeIrisTextLoaderArgs(), new MultiFileSource(dataPath));
                var term    = new TermTransform(env, loader, "Label");
                var concat  = new ConcatTransform(env, term, "Features", "SepalLength", "SepalWidth", "PetalLength", "PetalWidth");
                var trainer = new Ova(env, new Ova.Arguments
                {
                    PredictorType = new SimpleComponentFactory <ITrainer <IPredictorProducing <float> > >
                                    (
                        (e) => new FastTreeBinaryClassificationTrainer(e, new FastTreeBinaryClassificationTrainer.Arguments())
                                    )
                });

                IDataView trainData  = trainer.Info.WantCaching ? (IDataView) new CacheDataView(env, concat, prefetch: null) : concat;
                var       trainRoles = new RoleMappedData(trainData, label: "Label", feature: "Features");

                // Auto-normalization.
                NormalizeTransform.CreateIfNeeded(env, ref trainRoles, trainer);
                var predictor = trainer.Train(new Runtime.TrainContext(trainRoles));

                var scoreRoles = new RoleMappedData(concat, label: "Label", feature: "Features");
                IDataScorerTransform scorer = ScoreUtils.GetScorer(predictor, scoreRoles, env, trainRoles.Schema);

                var keyToValue = new KeyToValueTransform(env, scorer, "PredictedLabel");
                var model      = env.CreatePredictionEngine <IrisData, IrisPrediction>(keyToValue);

                var testLoader = new TextLoader(env, MakeIrisTextLoaderArgs(), new MultiFileSource(dataPath));
                var testData   = testLoader.AsEnumerable <IrisData>(env, false);
                foreach (var input in testData.Take(20))
                {
                    var prediction = model.Predict(input);
                    Assert.True(prediction.PredictedLabel == input.Label);
                }
            }
        }
Exemple #12
0
        public void Metacomponents()
        {
            var dataPath = GetDataPath(IrisDataPath);

            using (var env = new TlcEnvironment())
            {
                var loader  = new TextLoader(env, MakeIrisTextLoaderArgs(), new MultiFileSource(dataPath));
                var term    = new TermTransform(env, loader, "Label");
                var concat  = new ConcatTransform(env, term, "Features", "SepalLength", "SepalWidth", "PetalLength", "PetalWidth");
                var trainer = new Ova(env, new Ova.Arguments
                {
                    PredictorType = ComponentFactoryUtils.CreateFromFunction(
                        e => new FastTreeBinaryClassificationTrainer(e, new FastTreeBinaryClassificationTrainer.Arguments()))
                });

                IDataView trainData  = trainer.Info.WantCaching ? (IDataView) new CacheDataView(env, concat, prefetch: null) : concat;
                var       trainRoles = new RoleMappedData(trainData, label: "Label", feature: "Features");

                // Auto-normalization.
                NormalizeTransform.CreateIfNeeded(env, ref trainRoles, trainer);
                var predictor = trainer.Train(new TrainContext(trainRoles));
            }
        }
Exemple #13
0
        void DecomposableTrainAndPredict()
        {
            var dataPath = GetDataPath(IrisDataPath);

            using (var env = new TlcEnvironment())
            {
                var loader  = new TextLoader(env, MakeIrisTextLoaderArgs(), new MultiFileSource(dataPath));
                var term    = new TermTransform(env, loader, "Label");
                var concat  = new ConcatTransform(env, term, "Features", "SepalLength", "SepalWidth", "PetalLength", "PetalWidth");
                var trainer = new SdcaMultiClassTrainer(env, new SdcaMultiClassTrainer.Arguments {
                    MaxIterations = 100, Shuffle = true, NumThreads = 1
                });

                IDataView trainData  = trainer.Info.WantCaching ? (IDataView) new CacheDataView(env, concat, prefetch: null) : concat;
                var       trainRoles = new RoleMappedData(trainData, label: "Label", feature: "Features");

                // Auto-normalization.
                NormalizeTransform.CreateIfNeeded(env, ref trainRoles, trainer);
                var predictor = trainer.Train(new Runtime.TrainContext(trainRoles));

                var scoreRoles = new RoleMappedData(concat, label: "Label", feature: "Features");
                IDataScorerTransform scorer = ScoreUtils.GetScorer(predictor, scoreRoles, env, trainRoles.Schema);

                // Cut out term transform from pipeline.
                var newScorer  = ApplyTransformUtils.ApplyAllTransformsToData(env, scorer, loader, term);
                var keyToValue = new KeyToValueTransform(env, newScorer, "PredictedLabel");
                var model      = env.CreatePredictionEngine <IrisDataNoLabel, IrisPrediction>(keyToValue);

                var testLoader = new TextLoader(env, MakeIrisTextLoaderArgs(), new MultiFileSource(dataPath));
                var testData   = testLoader.AsEnumerable <IrisDataNoLabel>(env, false);
                foreach (var input in testData.Take(20))
                {
                    var prediction = model.Predict(input);
                    Assert.True(prediction.PredictedLabel == "Iris-setosa");
                }
            }
        }
Exemple #14
0
        public static IDataView ApplyConcatOnSources(IHostEnvironment env, ManyToOneColumn[] columns, IDataView input)
        {
            Contracts.CheckValue(env, nameof(env));
            env.CheckValue(columns, nameof(columns));
            env.CheckValue(input, nameof(input));

            IDataView view       = input;
            var       concatCols = new List <ConcatTransform.Column>();

            foreach (var col in columns)
            {
                env.CheckUserArg(col != null, nameof(WordBagTransform.Arguments.Column));
                env.CheckUserArg(!string.IsNullOrWhiteSpace(col.Name), nameof(col.Name));
                env.CheckUserArg(Utils.Size(col.Source) > 0, nameof(col.Source));
                env.CheckUserArg(col.Source.All(src => !string.IsNullOrWhiteSpace(src)), nameof(col.Source));

                if (col.Source.Length > 1)
                {
                    concatCols.Add(
                        new ConcatTransform.Column
                    {
                        Source = col.Source,
                        Name   = col.Name
                    });
                }
            }
            if (concatCols.Count > 0)
            {
                var concatArgs = new ConcatTransform.Arguments {
                    Column = concatCols.ToArray()
                };
                return(ConcatTransform.Create(env, concatArgs, view));
            }

            return(view);
        }
        private static IPredictor TrainKMeansAndLRCore()
        {
            string dataPath = s_dataPath;

            using (var env = new TlcEnvironment(seed: 1))
            {
                // Pipeline
                var loader = new TextLoader(env,
                                            new TextLoader.Arguments()
                {
                    HasHeader = true,
                    Separator = ",",
                    Column    = new[] {
                        new TextLoader.Column()
                        {
                            Name   = "Label",
                            Source = new [] { new TextLoader.Range()
                                              {
                                                  Min = 14, Max = 14
                                              } },
                            Type = DataKind.R4
                        },
                        new TextLoader.Column()
                        {
                            Name   = "CatFeatures",
                            Source = new [] {
                                new TextLoader.Range()
                                {
                                    Min = 1, Max = 1
                                },
                                new TextLoader.Range()
                                {
                                    Min = 3, Max = 3
                                },
                                new TextLoader.Range()
                                {
                                    Min = 5, Max = 9
                                },
                                new TextLoader.Range()
                                {
                                    Min = 13, Max = 13
                                }
                            },
                            Type = DataKind.TX
                        },
                        new TextLoader.Column()
                        {
                            Name   = "NumFeatures",
                            Source = new [] {
                                new TextLoader.Range()
                                {
                                    Min = 0, Max = 0
                                },
                                new TextLoader.Range()
                                {
                                    Min = 2, Max = 2
                                },
                                new TextLoader.Range()
                                {
                                    Min = 4, Max = 4
                                },
                                new TextLoader.Range()
                                {
                                    Min = 10, Max = 12
                                }
                            },
                            Type = DataKind.R4
                        }
                    }
                }, new MultiFileSource(dataPath));

                IDataTransform trans = CategoricalTransform.Create(env, new CategoricalTransform.Arguments
                {
                    Column = new[]
                    {
                        new CategoricalTransform.Column {
                            Name = "CatFeatures", Source = "CatFeatures"
                        }
                    }
                }, loader);

                trans = NormalizeTransform.CreateMinMaxNormalizer(env, trans, "NumFeatures");
                trans = new ConcatTransform(env, trans, "Features", "NumFeatures", "CatFeatures");
                trans = TrainAndScoreTransform.Create(env, new TrainAndScoreTransform.Arguments
                {
                    Trainer       = new SubComponent <ITrainer, SignatureTrainer>("KMeans", "k=100"),
                    FeatureColumn = "Features"
                }, trans);
                trans = new ConcatTransform(env, trans, "Features", "Features", "Score");

                // Train
                var trainer = new LogisticRegression(env, new LogisticRegression.Arguments()
                {
                    EnforceNonNegativity = true, OptTol = 1e-3f
                });
                var trainRoles = new RoleMappedData(trans, label: "Label", feature: "Features");
                return(trainer.Train(trainRoles));
            }
        }
Exemple #16
0
        public ParameterMixingCalibratedPredictor TrainKMeansAndLR()
        {
            using (var env = new ConsoleEnvironment(seed: 1))
            {
                // Pipeline
                var loader = TextLoader.ReadFile(env,
                                                 new TextLoader.Arguments()
                {
                    HasHeader = true,
                    Separator = ",",
                    Column    = new[] {
                        new TextLoader.Column("Label", DataKind.R4, 14),
                        new TextLoader.Column("CatFeatures", DataKind.TX,
                                              new [] {
                            new TextLoader.Range()
                            {
                                Min = 1, Max = 1
                            },
                            new TextLoader.Range()
                            {
                                Min = 3, Max = 3
                            },
                            new TextLoader.Range()
                            {
                                Min = 5, Max = 9
                            },
                            new TextLoader.Range()
                            {
                                Min = 13, Max = 13
                            }
                        }),
                        new TextLoader.Column("NumFeatures", DataKind.R4,
                                              new [] {
                            new TextLoader.Range()
                            {
                                Min = 0, Max = 0
                            },
                            new TextLoader.Range()
                            {
                                Min = 2, Max = 2
                            },
                            new TextLoader.Range()
                            {
                                Min = 4, Max = 4
                            },
                            new TextLoader.Range()
                            {
                                Min = 10, Max = 12
                            }
                        })
                    }
                }, new MultiFileSource(_dataPath));

                IDataView trans = new CategoricalEstimator(env, "CatFeatures").Fit(loader).Transform(loader);

                trans = NormalizeTransform.CreateMinMaxNormalizer(env, trans, "NumFeatures");
                trans = new ConcatTransform(env, "Features", "NumFeatures", "CatFeatures").Transform(trans);
                trans = TrainAndScoreTransform.Create(env, new TrainAndScoreTransform.Arguments
                {
                    Trainer = ComponentFactoryUtils.CreateFromFunction(host =>
                                                                       new KMeansPlusPlusTrainer(host, "Features", advancedSettings: s =>
                    {
                        s.K = 100;
                    })),
                    FeatureColumn = "Features"
                }, trans);
                trans = new ConcatTransform(env, "Features", "Features", "Score").Transform(trans);

                // Train
                var trainer    = new LogisticRegression(env, "Features", "Label", advancedSettings: args => { args.EnforceNonNegativity = true; args.OptTol = 1e-3f; });
                var trainRoles = new RoleMappedData(trans, label: "Label", feature: "Features");
                return(trainer.Train(trainRoles));
            }
        }
        public void TensorFlowTransformMNISTConvTest()
        {
            var model_location = "mnist_model/frozen_saved_model.pb";

            using (var env = new TlcEnvironment(seed: 1, conc: 1))
            {
                var dataPath     = GetDataPath("Train-Tiny-28x28.txt");
                var testDataPath = GetDataPath("MNIST.Test.tiny.txt");

                // Pipeline
                var loader = TextLoader.ReadFile(env,
                                                 new TextLoader.Arguments()
                {
                    Separator = "tab",
                    HasHeader = true,
                    Column    = new[]
                    {
                        new TextLoader.Column()
                        {
                            Name   = "Label",
                            Source = new [] { new TextLoader.Range()
                                              {
                                                  Min = 0, Max = 0
                                              } },
                            Type = DataKind.Num
                        },

                        new TextLoader.Column()
                        {
                            Name   = "Placeholder",
                            Source = new [] { new TextLoader.Range()
                                              {
                                                  Min = 1, Max = 784
                                              } },
                            Type = DataKind.Num
                        }
                    }
                }, new MultiFileSource(dataPath));

                IDataView trans = TensorFlowTransform.Create(env, loader, model_location, "Softmax", "Placeholder");
                trans = new ConcatTransform(env, trans, "reshape_input", "Placeholder");
                trans = TensorFlowTransform.Create(env, trans, model_location, "dense/Relu", "reshape_input");
                trans = new ConcatTransform(env, trans, "Features", "Softmax", "dense/Relu");

                var trainer = new LightGbmMulticlassTrainer(env, new LightGbmArguments());

                var cached     = new CacheDataView(env, trans, prefetch: null);
                var trainRoles = new RoleMappedData(cached, label: "Label", feature: "Features");
                var pred       = trainer.Train(trainRoles);

                // Get scorer and evaluate the predictions from test data
                IDataScorerTransform testDataScorer = GetScorer(env, trans, pred, testDataPath);
                var metrics = Evaluate(env, testDataScorer);

                Assert.Equal(0.99, metrics.AccuracyMicro, 2);
                Assert.Equal(0.99, metrics.AccuracyMicro, 2);

                // Create prediction engine and test predictions
                var model = env.CreatePredictionEngine <MNISTData, MNISTPrediction>(testDataScorer);

                var sample1 = new MNISTData()
                {
                    Placeholder = new float[] { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
                                                0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
                                                0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
                                                0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
                                                0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 3, 18, 18, 18, 126, 136, 175, 26,
                                                166, 255, 247, 127, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 30, 36, 94, 154, 170, 253, 253, 253, 253, 253,
                                                225, 172, 253, 242, 195, 64, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 49, 238, 253, 253, 253, 253, 253, 253, 253,
                                                253, 251, 93, 82, 82, 56, 39, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 18, 219, 253, 253, 253, 253, 253, 198,
                                                182, 247, 241, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 80, 156, 107, 253, 253, 205, 11, 0,
                                                43, 154, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 14, 1, 154, 253, 90, 0, 0, 0, 0, 0, 0,
                                                0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 139, 253, 190, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
                                                0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 11, 190, 253, 70, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
                                                0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 35, 241, 225, 160, 108, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 81, 240, 253, 253, 119, 25, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 45, 186, 253, 253, 150, 27, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 16, 93, 252, 253, 187, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 249, 253, 249, 64, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 46, 130, 183, 253, 253, 207, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 39, 148, 229, 253, 253, 253, 250, 182, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 24, 114, 221, 253, 253, 253, 253, 201, 78, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 23, 66, 213, 253, 253, 253, 253, 198, 81, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 18, 171, 219, 253, 253, 253, 253, 195, 80, 9, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 55, 172, 226, 253, 253, 253, 253, 244, 133, 11, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 136, 253, 253, 253, 212, 135, 132, 16, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 }
                };

                var prediction = model.Predict(sample1);

                float max      = -1;
                int   maxIndex = -1;
                for (int i = 0; i < prediction.PredictedLabels.Length; i++)
                {
                    if (prediction.PredictedLabels[i] > max)
                    {
                        max      = prediction.PredictedLabels[i];
                        maxIndex = i;
                    }
                }

                Assert.Equal(5, maxIndex);
            }
        }
Exemple #18
0
        public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataView input)
        {
            Contracts.CheckValue(env, nameof(env));
            var h = env.Register("Categorical");

            h.CheckValue(args, nameof(args));
            h.CheckValue(input, nameof(input));
            h.CheckUserArg(Utils.Size(args.Column) > 0, nameof(args.Column));

            var replaceCols          = new List <NAReplaceTransform.ColumnInfo>();
            var naIndicatorCols      = new List <NAIndicatorTransform.Column>();
            var naConvCols           = new List <ConvertingTransform.ColumnInfo>();
            var concatCols           = new List <ConcatTransform.TaggedColumn>();
            var dropCols             = new List <string>();
            var tmpIsMissingColNames = input.Schema.GetTempColumnNames(args.Column.Length, "IsMissing");
            var tmpReplaceColNames   = input.Schema.GetTempColumnNames(args.Column.Length, "Replace");

            for (int i = 0; i < args.Column.Length; i++)
            {
                var column = args.Column[i];

                var addInd = column.ConcatIndicator ?? args.Concat;
                if (!addInd)
                {
                    replaceCols.Add(new NAReplaceTransform.ColumnInfo(column.Source, column.Name, (NAReplaceTransform.ColumnInfo.ReplacementMode)(column.Kind ?? args.ReplaceWith), column.ImputeBySlot ?? args.ImputeBySlot));
                    continue;
                }

                // Check that the indicator column has a type that can be converted to the NAReplaceTransform output type,
                // so that they can be concatenated.
                if (!input.Schema.TryGetColumnIndex(column.Source, out int inputCol))
                {
                    throw h.Except("Column '{0}' does not exist", column.Source);
                }
                var replaceType = input.Schema.GetColumnType(inputCol);
                if (!Runtime.Data.Conversion.Conversions.Instance.TryGetStandardConversion(BoolType.Instance, replaceType.ItemType, out Delegate conv, out bool identity))
                {
                    throw h.Except("Cannot concatenate indicator column of type '{0}' to input column of type '{1}'",
                                   BoolType.Instance, replaceType.ItemType);
                }

                // Find a temporary name for the NAReplaceTransform and NAIndicatorTransform output columns.
                var tmpIsMissingColName   = tmpIsMissingColNames[i];
                var tmpReplacementColName = tmpReplaceColNames[i];

                // Add an NAHandleTransform column.
                naIndicatorCols.Add(new NAIndicatorTransform.Column()
                {
                    Name = tmpIsMissingColName, Source = column.Source
                });

                // Add a ConvertTransform column if necessary.
                if (!identity)
                {
                    naConvCols.Add(new ConvertingTransform.ColumnInfo(tmpIsMissingColName, tmpIsMissingColName, replaceType.ItemType.RawKind));
                }

                // Add the NAReplaceTransform column.
                replaceCols.Add(new NAReplaceTransform.ColumnInfo(column.Source, tmpReplacementColName, (NAReplaceTransform.ColumnInfo.ReplacementMode)(column.Kind ?? args.ReplaceWith), column.ImputeBySlot ?? args.ImputeBySlot));

                // Add the ConcatTransform column.
                if (replaceType.IsVector)
                {
                    concatCols.Add(new ConcatTransform.TaggedColumn()
                    {
                        Name   = column.Name,
                        Source = new[] {
                            new KeyValuePair <string, string>(tmpReplacementColName, tmpReplacementColName),
                            new KeyValuePair <string, string>("IsMissing", tmpIsMissingColName)
                        }
                    });
                }
                else
                {
                    concatCols.Add(new ConcatTransform.TaggedColumn()
                    {
                        Name   = column.Name,
                        Source = new[]
                        {
                            new KeyValuePair <string, string>(column.Source, tmpReplacementColName),
                            new KeyValuePair <string, string>(string.Format("IsMissing.{0}", column.Source), tmpIsMissingColName),
                        }
                    });
                }

                // Add the temp column to the list of columns to drop at the end.
                dropCols.Add(tmpIsMissingColName);
                dropCols.Add(tmpReplacementColName);
            }

            IDataTransform output = null;

            // Create the indicator columns.
            if (naIndicatorCols.Count > 0)
            {
                output = NAIndicatorTransform.Create(h, new NAIndicatorTransform.Arguments()
                {
                    Column = naIndicatorCols.ToArray()
                }, input);
            }

            // Convert the indicator columns to the correct type so that they can be concatenated to the NAReplace outputs.
            if (naConvCols.Count > 0)
            {
                h.AssertValue(output);
                //REVIEW: all this need to be converted to estimatorChain as soon as we done with dropcolumns.
                output = new ConvertingTransform(h, naConvCols.ToArray()).Transform(output) as IDataTransform;
            }
            // Create the NAReplace transform.
            output = NAReplaceTransform.Create(env, output ?? input, replaceCols.ToArray());

            // Concat the NAReplaceTransform output and the NAIndicatorTransform output.
            if (naIndicatorCols.Count > 0)
            {
                output = ConcatTransform.Create(h, new ConcatTransform.TaggedArguments()
                {
                    Column = concatCols.ToArray()
                }, output);
            }

            // Finally, drop the temporary indicator columns.
            if (dropCols.Count > 0)
            {
                output = SelectColumnsTransform.CreateDrop(h, output, dropCols.ToArray());
            }

            return(output);
        }