public void TestMetadata()
        {
            var data = new[] { new MetaClass()
                               {
                                   A = 1, B = "A"
                               },
                               new MetaClass()
                               {
                                   A = 2, B = "B"
                               } };
            var pipe = new OneHotEncodingEstimator(Env, new[] {
                new OneHotEncodingEstimator.ColumnInfo("A", "CatA", OneHotEncodingTransformer.OutputKind.Ind),
                new OneHotEncodingEstimator.ColumnInfo("B", "CatB", OneHotEncodingTransformer.OutputKind.Key)
            }).Append(new TypeConvertingEstimator(Env, new[] {
                new TypeConvertingTransformer.ColumnInfo("CatA", "ConvA", DataKind.R8),
                new TypeConvertingTransformer.ColumnInfo("CatB", "ConvB", DataKind.U2)
            }));
            var dataView = ComponentCreation.CreateDataView(Env, data);

            dataView = pipe.Fit(dataView).Transform(dataView);
            ValidateMetadata(dataView);
        }
Esempio n. 2
0
        public void TestMetadataPropagation()
        {
            var data = new[] {
                new TestMeta()
                {
                    A = new string[2] {
                        "A", "B"
                    }, B = "C", C = new int[2] {
                        3, 5
                    }, D = 6, E = new float[2] {
                        1.0f, 2.0f
                    }, F = 1.0f, G = new string[2] {
                        "A", "D"
                    }, H = "D"
                },
                new TestMeta()
                {
                    A = new string[2] {
                        "A", "B"
                    }, B = "C", C = new int[2] {
                        5, 3
                    }, D = 1, E = new float[2] {
                        3.0f, 4.0f
                    }, F = -1.0f, G = new string[2] {
                        "E", "A"
                    }, H = "E"
                },
                new TestMeta()
                {
                    A = new string[2] {
                        "A", "B"
                    }, B = "C", C = new int[2] {
                        3, 5
                    }, D = 6, E = new float[2] {
                        5.0f, 6.0f
                    }, F = 1.0f, G = new string[2] {
                        "D", "E"
                    }, H = "D"
                }
            };


            var dataView = ComponentCreation.CreateDataView(Env, data);
            var pipe     = new OneHotEncodingEstimator(Env, new[] {
                new OneHotEncodingEstimator.ColumnInfo("A", "CatA", OneHotEncodingTransformer.OutputKind.Bag),
                new OneHotEncodingEstimator.ColumnInfo("B", "CatB", OneHotEncodingTransformer.OutputKind.Bag),
                new OneHotEncodingEstimator.ColumnInfo("C", "CatC", OneHotEncodingTransformer.OutputKind.Bag),
                new OneHotEncodingEstimator.ColumnInfo("D", "CatD", OneHotEncodingTransformer.OutputKind.Bag),
                new OneHotEncodingEstimator.ColumnInfo("E", "CatE", OneHotEncodingTransformer.OutputKind.Ind),
                new OneHotEncodingEstimator.ColumnInfo("F", "CatF", OneHotEncodingTransformer.OutputKind.Ind),
                new OneHotEncodingEstimator.ColumnInfo("G", "CatG", OneHotEncodingTransformer.OutputKind.Key),
                new OneHotEncodingEstimator.ColumnInfo("H", "CatH", OneHotEncodingTransformer.OutputKind.Key),
                new OneHotEncodingEstimator.ColumnInfo("A", "CatI", OneHotEncodingTransformer.OutputKind.Bin),
                new OneHotEncodingEstimator.ColumnInfo("B", "CatJ", OneHotEncodingTransformer.OutputKind.Bin),
                new OneHotEncodingEstimator.ColumnInfo("C", "CatK", OneHotEncodingTransformer.OutputKind.Bin),
                new OneHotEncodingEstimator.ColumnInfo("D", "CatL", OneHotEncodingTransformer.OutputKind.Bin)
            });


            var result = pipe.Fit(dataView).Transform(dataView);

            ValidateMetadata(result);
            Done();
        }
        public ParameterMixingCalibratedPredictor TrainKMeansAndLR()
        {
            using (var env = new ConsoleEnvironment(seed: 1, verbose: false, sensitivity: MessageSensitivity.None, outWriter: EmptyWriter.Instance))
            {
                // Pipeline
                var loader = TextLoader.ReadFile(env,
                                                 new TextLoader.Arguments()
                {
                    HasHeader = true,
                    Separator = ",",
                    Column    = new[] {
                        new TextLoader.Column("Label", DataKind.R4, 14),
                        new TextLoader.Column("CatFeatures", DataKind.TX,
                                              new [] {
                            new TextLoader.Range()
                            {
                                Min = 1, Max = 1
                            },
                            new TextLoader.Range()
                            {
                                Min = 3, Max = 3
                            },
                            new TextLoader.Range()
                            {
                                Min = 5, Max = 9
                            },
                            new TextLoader.Range()
                            {
                                Min = 13, Max = 13
                            }
                        }),
                        new TextLoader.Column("NumFeatures", DataKind.R4,
                                              new [] {
                            new TextLoader.Range()
                            {
                                Min = 0, Max = 0
                            },
                            new TextLoader.Range()
                            {
                                Min = 2, Max = 2
                            },
                            new TextLoader.Range()
                            {
                                Min = 4, Max = 4
                            },
                            new TextLoader.Range()
                            {
                                Min = 10, Max = 12
                            }
                        })
                    }
                }, new MultiFileSource(_dataPath));

                IDataView trans = new OneHotEncodingEstimator(env, "CatFeatures").Fit(loader).Transform(loader);

                trans = NormalizeTransform.CreateMinMaxNormalizer(env, trans, "NumFeatures");
                trans = new ColumnConcatenatingTransformer(env, "Features", "NumFeatures", "CatFeatures").Transform(trans);
                trans = TrainAndScoreTransformer.Create(env, new TrainAndScoreTransformer.Arguments
                {
                    Trainer = ComponentFactoryUtils.CreateFromFunction(host =>
                                                                       new KMeansPlusPlusTrainer(host, "Features", advancedSettings: s =>
                    {
                        s.K = 100;
                    })),
                    FeatureColumn = "Features"
                }, trans);
                trans = new ColumnConcatenatingTransformer(env, "Features", "Features", "Score").Transform(trans);

                // Train
                var trainer    = new LogisticRegression(env, "Label", "Features", advancedSettings: args => { args.EnforceNonNegativity = true; args.OptTol = 1e-3f; });
                var trainRoles = new RoleMappedData(trans, label: "Label", feature: "Features");
                return(trainer.Train(trainRoles));
            }
        }
                public override IEnumerable <SuggestedTransform> Apply(IntermediateColumn[] columns)
                {
                    bool foundCat          = false;
                    bool foundCatHash      = false;
                    var  catColumnsNew     = new List <OneHotEncodingEstimator.ColumnInfo>();
                    var  catHashColumnsNew = new List <OneHotHashEncodingEstimator.ColumnInfo>();
                    var  featureCols       = new List <string>();

                    foreach (var column in columns)
                    {
                        if (!column.Type.ItemType().IsText() || column.Purpose != ColumnPurpose.CategoricalFeature)
                        {
                            continue;
                        }

                        var columnName = new StringBuilder();
                        columnName.AppendFormat("{0}", column.ColumnName);

                        if (IsDictionaryOk(column, EstimatedSampleFraction))
                        {
                            foundCat = true;
                            catColumnsNew.Add(new OneHotEncodingEstimator.ColumnInfo(columnName.ToString(), columnName.ToString()));
                        }
                        else
                        {
                            foundCatHash = true;
                            catHashColumnsNew.Add(new OneHotHashEncodingEstimator.ColumnInfo(columnName.ToString(), columnName.ToString()));
                        }
                    }

                    if (foundCat)
                    {
                        ColumnRoutingStructure.AnnotatedName[] columnsSource =
                            catColumnsNew.Select(c => new ColumnRoutingStructure.AnnotatedName {
                            IsNumeric = false, Name = c.Output
                        }).ToArray();
                        ColumnRoutingStructure.AnnotatedName[] columnsDest =
                            catColumnsNew.Select(c => new ColumnRoutingStructure.AnnotatedName {
                            IsNumeric = true, Name = c.Output
                        }).ToArray();
                        var routingStructure = new ColumnRoutingStructure(columnsSource, columnsDest);

                        var input = new OneHotEncodingEstimator(Env, catColumnsNew.ToArray());
                        featureCols.AddRange(catColumnsNew.Select(c => c.Output));

                        yield return(new SuggestedTransform(input, routingStructure));
                    }

                    if (foundCatHash)
                    {
                        ColumnRoutingStructure.AnnotatedName[] columnsSource =
                            catHashColumnsNew.Select(c => new ColumnRoutingStructure.AnnotatedName {
                            IsNumeric = false, Name = c.HashInfo.Output
                        }).ToArray();
                        ColumnRoutingStructure.AnnotatedName[] columnsDest =
                            catHashColumnsNew.Select(c => new ColumnRoutingStructure.AnnotatedName {
                            IsNumeric = true, Name = c.HashInfo.Output
                        }).ToArray();
                        var routingStructure = new ColumnRoutingStructure(columnsSource, columnsDest);

                        var input = new OneHotHashEncodingEstimator(Env, catHashColumnsNew.ToArray());

                        yield return(new SuggestedTransform(input, routingStructure));
                    }

                    if (!ExcludeFeaturesConcatTransforms && featureCols.Count > 0)
                    {
                        yield return(InferenceHelpers.GetRemainingFeatures(featureCols, columns, GetType(), IncludeFeaturesOverride));

                        IncludeFeaturesOverride = true;
                    }
                }