Exemplo n.º 1
0
        public void TestOldSavingAndLoading()
        {
            var data = new[] { new TestClass()
                               {
                                   A = 1, B = 2, C = 3,
                               }, new TestClass()
                               {
                                   A = 4, B = 5, C = 6
                               } };
            var dataView = ComponentCreation.CreateDataView(Env, data);
            var pipe     = new HashEstimator(Env, new[] {
                new HashTransformer.ColumnInfo("A", "HashA", hashBits: 4, invertHash: -1),
                new HashTransformer.ColumnInfo("B", "HashB", hashBits: 3, ordered: true),
                new HashTransformer.ColumnInfo("C", "HashC", seed: 42),
                new HashTransformer.ColumnInfo("A", "HashD"),
            });
            var result      = pipe.Fit(dataView).Transform(dataView);
            var resultRoles = new RoleMappedData(result);

            using (var ms = new MemoryStream())
            {
                TrainUtils.SaveModel(Env, Env.Start("saving"), ms, null, resultRoles);
                ms.Position = 0;
                var loadedView = ModelFileUtils.LoadTransforms(Env, dataView, ms);
            }
        }
Exemplo n.º 2
0
        public void TestMetadata()
        {
            var data = new[] {
                new TestMeta()
                {
                    A = new float[2] {
                        3.5f, 2.5f
                    }, B = 1, C = new double[2] {
                        5.1f, 6.1f
                    }, D = 7
                },
                new TestMeta()
                {
                    A = new float[2] {
                        3.5f, 2.5f
                    }, B = 1, C = new double[2] {
                        5.1f, 6.1f
                    }, D = 7
                },
                new TestMeta()
                {
                    A = new float[2] {
                        3.5f, 2.5f
                    }, B = 1, C = new double[2] {
                        5.1f, 6.1f
                    }, D = 7
                }
            };


            var dataView = ComponentCreation.CreateDataView(Env, data);
            var pipe     = new HashEstimator(Env, new[] {
                new HashTransformer.ColumnInfo("A", "HashA", invertHash: 1, hashBits: 10),
                new HashTransformer.ColumnInfo("A", "HashAUnlim", invertHash: -1, hashBits: 10),
                new HashTransformer.ColumnInfo("A", "HashAUnlimOrdered", invertHash: -1, hashBits: 10, ordered: true)
            });
            var result = pipe.Fit(dataView).Transform(dataView);

            ValidateMetadata(result);
            Done();
        }
Exemplo n.º 3
0
        public void HashWorkout()
        {
            var data = new[] { new TestClass()
                               {
                                   A = 1, B = 2, C = 3,
                               }, new TestClass()
                               {
                                   A = 4, B = 5, C = 6
                               } };

            var dataView = ComponentCreation.CreateDataView(Env, data);
            var pipe     = new HashEstimator(Env, new[] {
                new HashTransformer.ColumnInfo("A", "HashA", hashBits: 4, invertHash: -1),
                new HashTransformer.ColumnInfo("B", "HashB", hashBits: 3, ordered: true),
                new HashTransformer.ColumnInfo("C", "HashC", seed: 42),
                new HashTransformer.ColumnInfo("A", "HashD"),
            });

            TestEstimatorCore(pipe, dataView);
            Done();
        }
Exemplo n.º 4
0
        private string GetSplitColumn(IChannel ch, IDataView input, ref IDataView output)
        {
            // The stratification column and/or group column, if they exist at all, must be present at this point.
            var schema = input.Schema;

            output = input;
            // If no stratification column was specified, but we have a group column of type Single, Double or
            // Key (contiguous) use it.
            string stratificationColumn = null;

            if (!string.IsNullOrWhiteSpace(Args.StratificationColumn))
            {
                stratificationColumn = Args.StratificationColumn;
            }
            else
            {
                string group = TrainUtils.MatchNameOrDefaultOrNull(ch, schema, nameof(Args.GroupColumn), Args.GroupColumn, DefaultColumnNames.GroupId);
                int    index;
                if (group != null && schema.TryGetColumnIndex(group, out index))
                {
                    // Check if group column key type with known cardinality.
                    var type = schema.GetColumnType(index);
                    if (type.KeyCount > 0)
                    {
                        stratificationColumn = group;
                    }
                }
            }

            if (string.IsNullOrEmpty(stratificationColumn))
            {
                stratificationColumn = "StratificationColumn";
                int tmp;
                int inc = 0;
                while (input.Schema.TryGetColumnIndex(stratificationColumn, out tmp))
                {
                    stratificationColumn = string.Format("StratificationColumn_{0:000}", ++inc);
                }
                var keyGenArgs = new GenerateNumberTransform.Arguments();
                var col        = new GenerateNumberTransform.Column();
                col.Name          = stratificationColumn;
                keyGenArgs.Column = new[] { col };
                output            = new GenerateNumberTransform(Host, keyGenArgs, input);
            }
            else
            {
                int col;
                if (!input.Schema.TryGetColumnIndex(stratificationColumn, out col))
                {
                    throw ch.ExceptUserArg(nameof(Arguments.StratificationColumn), "Column '{0}' does not exist", stratificationColumn);
                }
                var type = input.Schema.GetColumnType(col);
                if (!RangeFilter.IsValidRangeFilterColumnType(ch, type))
                {
                    ch.Info("Hashing the stratification column");
                    var origStratCol = stratificationColumn;
                    int tmp;
                    int inc = 0;
                    while (input.Schema.TryGetColumnIndex(stratificationColumn, out tmp))
                    {
                        stratificationColumn = string.Format("{0}_{1:000}", origStratCol, ++inc);
                    }
                    output = new HashEstimator(Host, origStratCol, stratificationColumn, 30).Fit(input).Transform(input);
                }
            }

            return(stratificationColumn);
        }
Exemplo n.º 5
0
        public CategoricalHashEstimator(IHostEnvironment env, params ColumnInfo[] columns)
        {
            Contracts.CheckValue(env, nameof(env));
            _host = env.Register(nameof(TermEstimator));
            _hash = new HashEstimator(_host, columns.Select(x => x.HashInfo).ToArray());
            using (var ch = _host.Start(nameof(CategoricalHashEstimator)))
            {
                var binaryCols = new List <(string input, string output)>();
                var cols       = new List <(string input, string output, bool bag)>();
                for (int i = 0; i < columns.Length; i++)
                {
                    var column = columns[i];
                    CategoricalTransform.OutputKind kind = columns[i].OutputKind;
                    switch (kind)
                    {
                    default:
                        throw _host.ExceptUserArg(nameof(column.OutputKind));

                    case CategoricalTransform.OutputKind.Key:
                        continue;

                    case CategoricalTransform.OutputKind.Bin:
                        if ((column.HashInfo.InvertHash) != 0)
                        {
                            ch.Warning("Invert hashing is being used with binary encoding.");
                        }
                        binaryCols.Add((column.HashInfo.Output, column.HashInfo.Output));
                        break;

                    case CategoricalTransform.OutputKind.Ind:
                        cols.Add((column.HashInfo.Output, column.HashInfo.Output, false));
                        break;

                    case CategoricalTransform.OutputKind.Bag:
                        cols.Add((column.HashInfo.Output, column.HashInfo.Output, true));
                        break;
                    }
                }
                IEstimator <ITransformer> toBinVector = null;
                IEstimator <ITransformer> toVector    = null;
                if (binaryCols.Count > 0)
                {
                    toBinVector = new KeyToBinaryVectorEstimator(_host, binaryCols.Select(x => new KeyToBinaryVectorTransform.ColumnInfo(x.input, x.output)).ToArray());
                }
                if (cols.Count > 0)
                {
                    toVector = new KeyToVectorEstimator(_host, cols.Select(x => new KeyToVectorTransform.ColumnInfo(x.input, x.output, x.bag)).ToArray());
                }

                if (toBinVector != null && toVector != null)
                {
                    _toSomething = toVector.Append(toBinVector);
                }
                else
                {
                    if (toBinVector != null)
                    {
                        _toSomething = toBinVector;
                    }
                    else
                    {
                        _toSomething = toVector;
                    }
                }
            }
        }
Exemplo n.º 6
0
        internal CategoricalHashTransform(HashEstimator hash, IEstimator <ITransformer> keyToVector, IDataView input)
        {
            var chain = hash.Append(keyToVector);

            _transformer = chain.Fit(input);
        }