public void TestOldSavingAndLoading()
        {
            var data = new[] { new TestClass()
                               {
                                   A = 1, B = 2, C = 3,
                               }, new TestClass()
                               {
                                   A = 4, B = 5, C = 6
                               } };
            var dataView = ComponentCreation.CreateDataView(Env, data);
            var pipe     = new HashingEstimator(Env, new[] {
                new HashingTransformer.ColumnInfo("A", "HashA", hashBits: 4, invertHash: -1),
                new HashingTransformer.ColumnInfo("B", "HashB", hashBits: 3, ordered: true),
                new HashingTransformer.ColumnInfo("C", "HashC", seed: 42),
                new HashingTransformer.ColumnInfo("A", "HashD"),
            });
            var result      = pipe.Fit(dataView).Transform(dataView);
            var resultRoles = new RoleMappedData(result);

            using (var ms = new MemoryStream())
            {
                TrainUtils.SaveModel(Env, Env.Start("saving"), ms, null, resultRoles);
                ms.Position = 0;
                var loadedView = ModelFileUtils.LoadTransforms(Env, dataView, ms);
            }
        }
        internal static ITransformer Create(IHostEnvironment env, Options options, IDataView input)
        {
            Contracts.CheckValue(env, nameof(env));
            var h = env.Register(LoaderSignature);

            h.CheckValue(options, nameof(options));
            h.CheckValue(input, nameof(input));
            h.CheckUserArg(Utils.Size(options.Columns) > 0, nameof(options.Columns), "Columns must be specified");

            var chain = new TransformerChain <ITransformer>();

            // To each input column to the NgramHashExtractorArguments, a HashTransform using 31
            // bits (to minimize collisions) is applied first, followed by an NgramHashTransform.

            var hashColumns      = new List <HashingEstimator.ColumnOptions>();
            var ngramHashColumns = new NgramHashingEstimator.ColumnOptions[options.Columns.Length];

            var colCount = options.Columns.Length;

            // The NGramHashExtractor has a ManyToOne column type. To avoid stepping over the source
            // column name when a 'name' destination column name was specified, we use temporary column names.
            string[][] tmpColNames = new string[colCount][];
            for (int iinfo = 0; iinfo < colCount; iinfo++)
            {
                var column = options.Columns[iinfo];
                h.CheckUserArg(!string.IsNullOrWhiteSpace(column.Name), nameof(column.Name));
                h.CheckUserArg(Utils.Size(column.Source) > 0 &&
                               column.Source.All(src => !string.IsNullOrWhiteSpace(src)), nameof(column.Source));

                int srcCount = column.Source.Length;
                tmpColNames[iinfo] = new string[srcCount];
                for (int isrc = 0; isrc < srcCount; isrc++)
                {
                    var tmpName = input.Schema.GetTempColumnName(column.Source[isrc]);
                    tmpColNames[iinfo][isrc] = tmpName;

                    hashColumns.Add(new HashingEstimator.ColumnOptions(tmpName, column.Source[isrc],
                                                                       30, column.Seed ?? options.Seed, false, column.MaximumNumberOfInverts ?? options.MaximumNumberOfInverts));
                }

                ngramHashColumns[iinfo] =
                    new NgramHashingEstimator.ColumnOptions(column.Name, tmpColNames[iinfo],
                                                            column.NgramLength ?? options.NgramLength,
                                                            column.SkipLength ?? options.SkipLength,
                                                            column.UseAllLengths ?? options.UseAllLengths,
                                                            column.NumberOfBits ?? options.NumberOfBits,
                                                            column.Seed ?? options.Seed,
                                                            column.Ordered ?? options.Ordered,
                                                            column.MaximumNumberOfInverts ?? options.MaximumNumberOfInverts);
                ngramHashColumns[iinfo].FriendlyNames = column.FriendlyNames;
            }

            var hashing = new HashingEstimator(h, hashColumns.ToArray()).Fit(input);

            return(chain.Append(hashing)
                   .Append(new NgramHashingEstimator(h, ngramHashColumns).Fit(hashing.Transform(input)))
                   .Append(new ColumnSelectingTransformer(h, null, tmpColNames.SelectMany(cols => cols).ToArray())));
        }
        public static string CreateStratificationColumn(IHost host, ref IDataView data, string stratificationColumn = null)
        {
            host.CheckValue(data, nameof(data));
            host.CheckValueOrNull(stratificationColumn);

            // Pick a unique name for the stratificationColumn.
            const string stratColName = "StratificationKey";
            string       stratCol     = data.Schema.GetTempColumnName(stratColName);

            // Construct the stratification column. If user-provided stratification column exists, use HashJoin
            // of it to construct the strat column, otherwise generate a random number and use it.
            if (stratificationColumn == null)
            {
                data = new GenerateNumberTransform(host,
                                                   new GenerateNumberTransform.Options
                {
                    Columns = new[] { new GenerateNumberTransform.Column {
                                          Name = stratCol
                                      } }
                }, data);
            }
            else
            {
                var col = data.Schema.GetColumnOrNull(stratificationColumn);
                if (col == null)
                {
                    throw host.ExceptSchemaMismatch(nameof(stratificationColumn), "Stratification", stratificationColumn);
                }

                var type = col.Value.Type;
                if (!RangeFilter.IsValidRangeFilterColumnType(host, type))
                {
                    // HashingEstimator currently handles all primitive types except for DateTime, DateTimeOffset and TimeSpan.
                    var itemType = type.GetItemType();
                    if (itemType is DateTimeDataViewType || itemType is DateTimeOffsetDataViewType || itemType is TimeSpanDataViewType)
                    {
                        data = new TypeConvertingTransformer(host, stratificationColumn, DataKind.Int64, stratificationColumn).Transform(data);
                    }

                    var columnOptions = new HashingEstimator.ColumnOptions(stratCol, stratificationColumn, 30, combine: true);
                    data = new HashingEstimator(host, columnOptions).Fit(data).Transform(data);
                }
                else
                {
                    if (data.Schema[stratificationColumn].IsNormalized() || (type != NumberDataViewType.Single && type != NumberDataViewType.Double))
                    {
                        return(stratificationColumn);
                    }

                    data = new NormalizingEstimator(host,
                                                    new NormalizingEstimator.MinMaxColumnOptions(stratCol, stratificationColumn, ensureZeroUntouched: true))
                           .Fit(data).Transform(data);
                }
            }

            return(stratCol);
        }
Exemple #4
0
        internal CountTargetEncodingEstimator(IHostEnvironment env, Options options)
        {
            Contracts.CheckValue(env, nameof(env));
            _host = env.Register(nameof(CountTargetEncodingEstimator));
            _host.CheckValue(options, nameof(options));
            _host.CheckUserArg(Utils.Size(options.Columns) > 0, nameof(options.Columns), "Columns must be specified");
            _host.CheckUserArg(!string.IsNullOrWhiteSpace(options.LabelColumn), nameof(options.LabelColumn), "Must specify the label column name");

            if (!string.IsNullOrEmpty(options.InitialCountsModel))
            {
                _countTableEstimator = LoadFromFile(env, options.InitialCountsModel, options.LabelColumn,
                                                    options.Columns.Select(col => new InputOutputColumnPair(col.Name)).ToArray());
                if (_countTableEstimator == null)
                {
                    throw env.Except($"The file {options.InitialCountsModel} does not contain a CountTableTransformer");
                }
            }
            else if (options.SharedTable)
            {
                var columns = new CountTableEstimator.SharedColumnOptions[options.Columns.Length];
                for (int i = 0; i < options.Columns.Length; i++)
                {
                    var column = options.Columns[i];
                    columns[i] = new CountTableEstimator.SharedColumnOptions(
                        column.Name,
                        column.Name,
                        column.PriorCoefficient ?? options.PriorCoefficient,
                        column.LaplaceScale ?? options.LaplaceScale,
                        column.Seed ?? options.Seed);
                }
                var builder = options.CountTable;
                _host.CheckValue(builder, nameof(options.CountTable));
                _countTableEstimator = new CountTableEstimator(_host, options.LabelColumn, builder.CreateComponent(_host), columns);
            }
            else
            {
                var columns = new CountTableEstimator.ColumnOptions[options.Columns.Length];
                for (int i = 0; i < options.Columns.Length; i++)
                {
                    var column  = options.Columns[i];
                    var builder = column.CountTable ?? options.CountTable;
                    _host.CheckValue(builder, nameof(options.CountTable));
                    columns[i] = new CountTableEstimator.ColumnOptions(
                        column.Name,
                        column.Name,
                        builder.CreateComponent(_host),
                        column.PriorCoefficient ?? options.PriorCoefficient,
                        column.LaplaceScale ?? options.LaplaceScale,
                        column.Seed ?? options.Seed);
                }
                _countTableEstimator = new CountTableEstimator(_host, options.LabelColumn, columns);
            }

            _hashingColumns   = InitializeHashingColumnOptions(options);
            _hashingEstimator = new HashingEstimator(_host, _hashingColumns);
        }
Exemple #5
0
        private CountTargetEncodingEstimator(IHostEnvironment env, CountTableEstimator estimator, CountTableEstimator.ColumnOptionsBase[] columns,
                                             int numberOfBits, bool combine, uint hashingSeed)
        {
            Contracts.CheckValue(env, nameof(env));
            _host = env.Register(nameof(CountTargetEncodingEstimator));

            _hashingColumns      = InitializeHashingColumnOptions(columns, numberOfBits, combine, hashingSeed);
            _hashingEstimator    = new HashingEstimator(_host, _hashingColumns);
            _countTableEstimator = estimator;
        }
 internal OneHotHashEncodingTransformer(HashingEstimator hash, IEstimator <ITransformer> keyToVector, IDataView input)
 {
     if (keyToVector != null)
     {
         _transformer = hash.Append(keyToVector).Fit(input);
     }
     else
     {
         _transformer = new TransformerChain <ITransformer>(hash.Fit(input));
     }
 }
Exemple #7
0
        internal CountTargetEncodingEstimator(IHostEnvironment env, string labelColumnName, CountTargetEncodingTransformer initialCounts, params InputOutputColumnPair[] columns)
        {
            Contracts.CheckValue(env, nameof(env));
            _host = env.Register(nameof(CountTargetEncodingEstimator));
            _host.CheckValue(initialCounts, nameof(initialCounts));
            _host.CheckNonEmpty(columns, nameof(columns));
            _host.Check(initialCounts.VerifyColumns(columns), nameof(columns));

            _hashingEstimator    = new HashingEstimator(_host, initialCounts.HashingTransformer.Columns.ToArray());
            _countTableEstimator = new CountTableEstimator(_host, labelColumnName, initialCounts.CountTable,
                                                           columns.Select(c => new InputOutputColumnPair(c.OutputColumnName, c.OutputColumnName)).ToArray());
        }
Exemple #8
0
        /// <summary>
        /// Ensures the provided <paramref name="samplingKeyColumn"/> is valid for <see cref="RangeFilter"/>, hashing it if necessary, or creates a new column <paramref name="samplingKeyColumn"/> is null.
        /// </summary>
        internal static void EnsureGroupPreservationColumn(IHostEnvironment env, ref IDataView data, ref string samplingKeyColumn, int?seed = null)
        {
            Contracts.CheckValue(env, nameof(env));
            // We need to handle two cases: if samplingKeyColumn is provided, we use hashJoin to
            // build a single hash of it. If it is not, we generate a random number.
            if (samplingKeyColumn == null)
            {
                samplingKeyColumn = data.Schema.GetTempColumnName("SamplingKeyColumn");
                data = new GenerateNumberTransform(env, data, samplingKeyColumn, (uint?)(seed ?? ((ISeededEnvironment)env).Seed));
            }
            else
            {
                if (!data.Schema.TryGetColumnIndex(samplingKeyColumn, out int stratCol))
                {
                    throw env.ExceptSchemaMismatch(nameof(samplingKeyColumn), "SamplingKeyColumn", samplingKeyColumn);
                }

                var type = data.Schema[stratCol].Type;
                if (!RangeFilter.IsValidRangeFilterColumnType(env, type))
                {
                    // Hash the samplingKeyColumn.
                    // REVIEW: this could currently crash, since Hash only accepts a limited set
                    // of column types. It used to be HashJoin, but we should probably extend Hash
                    // instead of having two hash transformations.
                    var origStratCol = samplingKeyColumn;
                    samplingKeyColumn = data.Schema.GetTempColumnName(samplingKeyColumn);
                    HashingEstimator.ColumnOptionsInternal columnOptions;
                    if (seed.HasValue)
                    {
                        columnOptions = new HashingEstimator.ColumnOptionsInternal(samplingKeyColumn, origStratCol, 30, (uint)seed.Value);
                    }
                    else if (((ISeededEnvironment)env).Seed.HasValue)
                    {
                        columnOptions = new HashingEstimator.ColumnOptionsInternal(samplingKeyColumn, origStratCol, 30, (uint)((ISeededEnvironment)env).Seed.Value);
                    }
                    else
                    {
                        columnOptions = new HashingEstimator.ColumnOptionsInternal(samplingKeyColumn, origStratCol, 30);
                    }
                    data = new HashingEstimator(env, columnOptions).Fit(data).Transform(data);
                }
                else
                {
                    if (!data.Schema[samplingKeyColumn].IsNormalized() && (type == NumberDataViewType.Single || type == NumberDataViewType.Double))
                    {
                        var origStratCol = samplingKeyColumn;
                        samplingKeyColumn = data.Schema.GetTempColumnName(samplingKeyColumn);
                        data = new NormalizingEstimator(env, new NormalizingEstimator.MinMaxColumnOptions(samplingKeyColumn, origStratCol, ensureZeroUntouched: true)).Fit(data).Transform(data);
                    }
                }
            }
        }
Exemple #9
0
        /// <summary>
        /// Ensures the provided <paramref name="samplingKeyColumn"/> is valid for <see cref="RangeFilter"/>, hashing it if necessary, or creates a new column <paramref name="samplingKeyColumn"/> is null.
        /// </summary>
        internal static void EnsureGroupPreservationColumn(IHostEnvironment env, ref IDataView data, ref string samplingKeyColumn, int?seed = null)
        {
            Contracts.CheckValue(env, nameof(env));
            // We need to handle two cases: if samplingKeyColumn is provided, we use hashJoin to
            // build a single hash of it. If it is not, we generate a random number.
            if (samplingKeyColumn == null)
            {
                samplingKeyColumn = data.Schema.GetTempColumnName("SamplingKeyColumn");
                data = new GenerateNumberTransform(env, data, samplingKeyColumn, (uint?)(seed ?? ((ISeededEnvironment)env).Seed));
            }
            else
            {
                if (!data.Schema.TryGetColumnIndex(samplingKeyColumn, out int stratCol))
                {
                    throw env.ExceptSchemaMismatch(nameof(samplingKeyColumn), "SamplingKeyColumn", samplingKeyColumn);
                }

                var type = data.Schema[stratCol].Type;
                if (!RangeFilter.IsValidRangeFilterColumnType(env, type))
                {
                    var origStratCol = samplingKeyColumn;
                    samplingKeyColumn = data.Schema.GetTempColumnName(samplingKeyColumn);
                    // HashingEstimator currently handles all primitive types except for DateTime, DateTimeOffset and TimeSpan.
                    var itemType = type.GetItemType();
                    if (itemType is DateTimeDataViewType || itemType is DateTimeOffsetDataViewType || itemType is TimeSpanDataViewType)
                    {
                        data = new TypeConvertingTransformer(env, origStratCol, DataKind.Int64, origStratCol).Transform(data);
                    }

                    var localSeed     = seed.HasValue ? seed : ((ISeededEnvironment)env).Seed.HasValue ? ((ISeededEnvironment)env).Seed : null;
                    var columnOptions =
                        localSeed.HasValue ?
                        new HashingEstimator.ColumnOptions(samplingKeyColumn, origStratCol, 30, (uint)localSeed.Value, combine: true) :
                        new HashingEstimator.ColumnOptions(samplingKeyColumn, origStratCol, 30, combine: true);
                    data = new HashingEstimator(env, columnOptions).Fit(data).Transform(data);
                }
                else
                {
                    if (!data.Schema[samplingKeyColumn].IsNormalized() && (type == NumberDataViewType.Single || type == NumberDataViewType.Double))
                    {
                        var origStratCol = samplingKeyColumn;
                        samplingKeyColumn = data.Schema.GetTempColumnName(samplingKeyColumn);
                        data = new NormalizingEstimator(env, new NormalizingEstimator.MinMaxColumnOptions(samplingKeyColumn, origStratCol, ensureZeroUntouched: true)).Fit(data).Transform(data);
                    }
                }
            }
        }
Exemple #10
0
        /// <summary>
        /// Ensures the provided <paramref name="samplingKeyColumn"/> is valid for <see cref="RangeFilter"/>, hashing it if necessary, or creates a new column <paramref name="samplingKeyColumn"/> is null.
        /// </summary>
        private void EnsureGroupPreservationColumn(ref IDataView data, ref string samplingKeyColumn, uint?seed = null)
        {
            // We need to handle two cases: if samplingKeyColumn is provided, we use hashJoin to
            // build a single hash of it. If it is not, we generate a random number.

            if (samplingKeyColumn == null)
            {
                samplingKeyColumn = data.Schema.GetTempColumnName("IdPreservationColumn");
                data = new GenerateNumberTransform(Environment, data, samplingKeyColumn, seed);
            }
            else
            {
                if (!data.Schema.TryGetColumnIndex(samplingKeyColumn, out int stratCol))
                {
                    throw Environment.ExceptSchemaMismatch(nameof(samplingKeyColumn), "GroupPreservationColumn", samplingKeyColumn);
                }

                var type = data.Schema[stratCol].Type;
                if (!RangeFilter.IsValidRangeFilterColumnType(Environment, type))
                {
                    // Hash the samplingKeyColumn.
                    // REVIEW: this could currently crash, since Hash only accepts a limited set
                    // of column types. It used to be HashJoin, but we should probably extend Hash
                    // instead of having two hash transformations.
                    var origStratCol = samplingKeyColumn;
                    int tmp;
                    int inc = 0;

                    // Generate a new column with the hashed samplingKeyColumn.
                    while (data.Schema.TryGetColumnIndex(samplingKeyColumn, out tmp))
                    {
                        samplingKeyColumn = string.Format("{0}_{1:000}", origStratCol, ++inc);
                    }
                    HashingEstimator.ColumnInfo columnInfo;
                    if (seed.HasValue)
                    {
                        columnInfo = new HashingEstimator.ColumnInfo(samplingKeyColumn, origStratCol, 30, seed.Value);
                    }
                    else
                    {
                        columnInfo = new HashingEstimator.ColumnInfo(samplingKeyColumn, origStratCol, 30);
                    }
                    data = new HashingEstimator(Environment, columnInfo).Fit(data).Transform(data);
                }
            }
        }
        public void TestMetadata()
        {
            var data = new[] {
                new TestMeta()
                {
                    A = new float[2] {
                        3.5f, 2.5f
                    }, B = 1, C = new double[2] {
                        5.1f, 6.1f
                    }, D = 7
                },
                new TestMeta()
                {
                    A = new float[2] {
                        3.5f, 2.5f
                    }, B = 1, C = new double[2] {
                        5.1f, 6.1f
                    }, D = 7
                },
                new TestMeta()
                {
                    A = new float[2] {
                        3.5f, 2.5f
                    }, B = 1, C = new double[2] {
                        5.1f, 6.1f
                    }, D = 7
                }
            };


            var dataView = ComponentCreation.CreateDataView(Env, data);
            var pipe     = new HashingEstimator(Env, new[] {
                new HashingTransformer.ColumnInfo("A", "HashA", invertHash: 1, hashBits: 10),
                new HashingTransformer.ColumnInfo("A", "HashAUnlim", invertHash: -1, hashBits: 10),
                new HashingTransformer.ColumnInfo("A", "HashAUnlimOrdered", invertHash: -1, hashBits: 10, ordered: true)
            });
            var result = pipe.Fit(dataView).Transform(dataView);

            ValidateMetadata(result);
            Done();
        }
        public void HashWorkout()
        {
            var data = new[] { new TestClass()
                               {
                                   A = 1, B = 2, C = 3,
                               }, new TestClass()
                               {
                                   A = 4, B = 5, C = 6
                               } };

            var dataView = ComponentCreation.CreateDataView(Env, data);
            var pipe     = new HashingEstimator(Env, new[] {
                new HashingTransformer.ColumnInfo("A", "HashA", hashBits: 4, invertHash: -1),
                new HashingTransformer.ColumnInfo("B", "HashB", hashBits: 3, ordered: true),
                new HashingTransformer.ColumnInfo("C", "HashC", seed: 42),
                new HashingTransformer.ColumnInfo("A", "HashD"),
            });

            TestEstimatorCore(pipe, dataView);
            Done();
        }
Exemple #13
0
        public void HashWorkout()
        {
            var data = new[] { new TestClass()
                               {
                                   A = 1, B = 2, C = 3,
                               }, new TestClass()
                               {
                                   A = 4, B = 5, C = 6
                               } };

            var dataView = ML.Data.ReadFromEnumerable(data);
            var pipe     = new HashingEstimator(Env, new[] {
                new HashingTransformer.ColumnInfo("HashA", "A", hashBits: 4, invertHash: -1),
                new HashingTransformer.ColumnInfo("HashB", "B", hashBits: 3, ordered: true),
                new HashingTransformer.ColumnInfo("HashC", "C", seed: 42),
                new HashingTransformer.ColumnInfo("HashD", "A"),
            });

            TestEstimatorCore(pipe, dataView);
            Done();
        }
        public OneHotHashEncodingEstimator(IHostEnvironment env, params ColumnInfo[] columns)
        {
            Contracts.CheckValue(env, nameof(env));
            _host = env.Register(nameof(ValueToKeyMappingEstimator));
            _hash = new HashingEstimator(_host, columns.Select(x => x.HashInfo).ToArray());
            using (var ch = _host.Start(nameof(OneHotHashEncodingEstimator)))
            {
                var binaryCols = new List <(string input, string output)>();
                var cols       = new List <(string input, string output, bool bag)>();
                for (int i = 0; i < columns.Length; i++)
                {
                    var column = columns[i];
                    CategoricalTransform.OutputKind kind = columns[i].OutputKind;
                    switch (kind)
                    {
                    default:
                        throw _host.ExceptUserArg(nameof(column.OutputKind));

                    case CategoricalTransform.OutputKind.Key:
                        continue;

                    case CategoricalTransform.OutputKind.Bin:
                        if ((column.HashInfo.InvertHash) != 0)
                        {
                            ch.Warning("Invert hashing is being used with binary encoding.");
                        }
                        binaryCols.Add((column.HashInfo.Output, column.HashInfo.Output));
                        break;

                    case CategoricalTransform.OutputKind.Ind:
                        cols.Add((column.HashInfo.Output, column.HashInfo.Output, false));
                        break;

                    case CategoricalTransform.OutputKind.Bag:
                        cols.Add((column.HashInfo.Output, column.HashInfo.Output, true));
                        break;
                    }
                }
                IEstimator <ITransformer> toBinVector = null;
                IEstimator <ITransformer> toVector    = null;
                if (binaryCols.Count > 0)
                {
                    toBinVector = new KeyToBinaryVectorMappingEstimator(_host, binaryCols.Select(x => new KeyToBinaryVectorTransform.ColumnInfo(x.input, x.output)).ToArray());
                }
                if (cols.Count > 0)
                {
                    toVector = new KeyToVectorMappingEstimator(_host, cols.Select(x => new KeyToVectorTransform.ColumnInfo(x.input, x.output, x.bag)).ToArray());
                }

                if (toBinVector != null && toVector != null)
                {
                    _toSomething = toVector.Append(toBinVector);
                }
                else
                {
                    if (toBinVector != null)
                    {
                        _toSomething = toBinVector;
                    }
                    else
                    {
                        _toSomething = toVector;
                    }
                }
            }
        }
        internal CategoricalHashTransform(HashingEstimator hash, IEstimator <ITransformer> keyToVector, IDataView input)
        {
            var chain = hash.Append(keyToVector);

            _transformer = chain.Fit(input);
        }
Exemple #16
0
        internal OneHotHashEncoding(HashingEstimator hash, IEstimator <ITransformer> keyToVector, IDataView input)
        {
            var chain = hash.Append(keyToVector);

            _transformer = chain.Fit(input);
        }
Exemple #17
0
        internal static string CreateSplitColumn(IHostEnvironment env, ref IDataView data, string samplingKeyColumn, int?seed = null, bool fallbackInEnvSeed = false)
        {
            Contracts.CheckValue(env, nameof(env));
            Contracts.CheckValueOrNull(samplingKeyColumn);

            var splitColumnName = data.Schema.GetTempColumnName("SplitColumn");
            int?seedToUse;

            if (seed.HasValue)
            {
                seedToUse = seed.Value;
            }
            else if (fallbackInEnvSeed)
            {
                ISeededEnvironment seededEnv = (ISeededEnvironment)env;
                seedToUse = seededEnv.Seed;
            }
            else
            {
                seedToUse = null;
            }

            // We need to handle two cases: if samplingKeyColumn is not provided, we generate a random number.
            if (samplingKeyColumn == null)
            {
                data = new GenerateNumberTransform(env, data, splitColumnName, (uint?)seedToUse);
            }
            else
            {
                // If samplingKeyColumn was provided we will make a new column based on it, but using a temporary
                // name, as it might be dropped elsewhere in the code

                if (!data.Schema.TryGetColumnIndex(samplingKeyColumn, out int samplingColIndex))
                {
                    throw env.ExceptSchemaMismatch(nameof(samplingKeyColumn), "SamplingKeyColumn", samplingKeyColumn);
                }

                var type = data.Schema[samplingColIndex].Type;
                if (!RangeFilter.IsValidRangeFilterColumnType(env, type))
                {
                    var hashInputColumnName = samplingKeyColumn;
                    // HashingEstimator currently handles all primitive types except for DateTime, DateTimeOffset and TimeSpan.
                    var itemType = type.GetItemType();
                    if (itemType is DateTimeDataViewType || itemType is DateTimeOffsetDataViewType || itemType is TimeSpanDataViewType)
                    {
                        data = new TypeConvertingTransformer(env, splitColumnName, DataKind.Int64, samplingKeyColumn).Transform(data);
                        hashInputColumnName = splitColumnName;
                    }

                    var columnOptions =
                        seedToUse.HasValue ?
                        new HashingEstimator.ColumnOptions(splitColumnName, hashInputColumnName, 30, (uint)seedToUse.Value, combine: true) :
                        new HashingEstimator.ColumnOptions(splitColumnName, hashInputColumnName, 30, combine: true);
                    data = new HashingEstimator(env, columnOptions).Fit(data).Transform(data);
                }
                else
                {
                    if (type != NumberDataViewType.Single && type != NumberDataViewType.Double)
                    {
                        data = new ColumnCopyingEstimator(env, (splitColumnName, samplingKeyColumn)).Fit(data).Transform(data);
                    }
                    else
                    {
                        data = new NormalizingEstimator(env, new NormalizingEstimator.MinMaxColumnOptions(splitColumnName, samplingKeyColumn, ensureZeroUntouched: false)).Fit(data).Transform(data);
                    }
                }
            }

            return(splitColumnName);
        }
        private string GetSplitColumn(IChannel ch, IDataView input, ref IDataView output)
        {
            // The stratification column and/or group column, if they exist at all, must be present at this point.
            var schema = input.Schema;

            output = input;
            // If no stratification column was specified, but we have a group column of type Single, Double or
            // Key (contiguous) use it.
            string stratificationColumn = null;

            if (!string.IsNullOrWhiteSpace(Args.StratificationColumn))
            {
                stratificationColumn = Args.StratificationColumn;
            }
            else
            {
                string group = TrainUtils.MatchNameOrDefaultOrNull(ch, schema, nameof(Args.GroupColumn), Args.GroupColumn, DefaultColumnNames.GroupId);
                int    index;
                if (group != null && schema.TryGetColumnIndex(group, out index))
                {
                    // Check if group column key type with known cardinality.
                    var type = schema[index].Type;
                    if (type.GetKeyCount() > 0)
                    {
                        stratificationColumn = group;
                    }
                }
            }

            if (string.IsNullOrEmpty(stratificationColumn))
            {
                stratificationColumn = "StratificationColumn";
                int tmp;
                int inc = 0;
                while (input.Schema.TryGetColumnIndex(stratificationColumn, out tmp))
                {
                    stratificationColumn = string.Format("StratificationColumn_{0:000}", ++inc);
                }
                var keyGenArgs = new GenerateNumberTransform.Options();
                var col        = new GenerateNumberTransform.Column();
                col.Name           = stratificationColumn;
                keyGenArgs.Columns = new[] { col };
                output             = new GenerateNumberTransform(Host, keyGenArgs, input);
            }
            else
            {
                int col;
                if (!input.Schema.TryGetColumnIndex(stratificationColumn, out col))
                {
                    throw ch.ExceptUserArg(nameof(Arguments.StratificationColumn), "Column '{0}' does not exist", stratificationColumn);
                }
                var type = input.Schema[col].Type;
                if (!RangeFilter.IsValidRangeFilterColumnType(ch, type))
                {
                    ch.Info("Hashing the stratification column");
                    var origStratCol = stratificationColumn;
                    int tmp;
                    int inc = 0;
                    while (input.Schema.TryGetColumnIndex(stratificationColumn, out tmp))
                    {
                        stratificationColumn = string.Format("{0}_{1:000}", origStratCol, ++inc);
                    }
                    output = new HashingEstimator(Host, origStratCol, stratificationColumn, 30).Fit(input).Transform(input);
                }
            }

            return(stratificationColumn);
        }