예제 #1
0
        public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataView input)
        {
            Contracts.CheckValue(env, nameof(env));
            var h = env.Register("CategoricalHash");

            using (var ch = h.Start("CategoricalHash"))
            {
                h.CheckValue(args, nameof(args));
                h.CheckValue(input, nameof(input));
                h.CheckUserArg(Utils.Size(args.Column) > 0, nameof(args.Column), "Columns must be specified");
                if (args.HashBits < 1 || args.HashBits >= NumBitsLim)
                {
                    throw h.ExceptUserArg(nameof(args.HashBits), "Number of bits must be between 1 and {0}", NumBitsLim - 1);
                }

                // creating the Hash function
                var hashArgs = new HashTransform.Arguments
                {
                    HashBits   = args.HashBits,
                    Seed       = args.Seed,
                    Ordered    = args.Ordered,
                    InvertHash = args.InvertHash,
                    Column     = new HashTransform.Column[args.Column.Length]
                };
                for (int i = 0; i < args.Column.Length; i++)
                {
                    var column = args.Column[i];
                    if (!column.TrySanitize())
                    {
                        throw h.ExceptUserArg(nameof(Column.Name));
                    }
                    h.Assert(!string.IsNullOrWhiteSpace(column.Name));
                    h.Assert(!string.IsNullOrWhiteSpace(column.Source));
                    hashArgs.Column[i] = new HashTransform.Column
                    {
                        HashBits   = column.HashBits,
                        Seed       = column.Seed,
                        Ordered    = column.Ordered,
                        Name       = column.Name,
                        Source     = column.Source,
                        InvertHash = column.InvertHash,
                    };
                }

                return(CategoricalTransform.CreateTransformCore(
                           args.OutputKind, args.Column,
                           args.Column.Select(col => col.OutputKind).ToList(),
                           new HashTransform(h, hashArgs, input),
                           h,
                           env,
                           args));
            }
        }
        private string GetSplitColumn(IChannel ch, IDataView input, ref IDataView output)
        {
            // The stratification column and/or group column, if they exist at all, must be present at this point.
            var schema = input.Schema;

            output = input;
            // If no stratification column was specified, but we have a group column of type Single, Double or
            // Key (contiguous) use it.
            string stratificationColumn = null;

            if (!string.IsNullOrWhiteSpace(Args.StratificationColumn))
            {
                stratificationColumn = Args.StratificationColumn;
            }
            else
            {
                string group = TrainUtils.MatchNameOrDefaultOrNull(ch, schema, nameof(Args.GroupColumn), Args.GroupColumn, DefaultColumnNames.GroupId);
                int    index;
                if (group != null && schema.TryGetColumnIndex(group, out index))
                {
                    // Check if group column key type with known cardinality.
                    var type = schema.GetColumnType(index);
                    if (type.KeyCount > 0)
                    {
                        stratificationColumn = group;
                    }
                }
            }

            if (string.IsNullOrEmpty(stratificationColumn))
            {
                stratificationColumn = "StratificationColumn";
                int tmp;
                int inc = 0;
                while (input.Schema.TryGetColumnIndex(stratificationColumn, out tmp))
                {
                    stratificationColumn = string.Format("StratificationColumn_{0:000}", ++inc);
                }
                var keyGenArgs = new GenerateNumberTransform.Arguments();
                var col        = new GenerateNumberTransform.Column();
                col.Name          = stratificationColumn;
                keyGenArgs.Column = new[] { col };
                output            = new GenerateNumberTransform(Host, keyGenArgs, input);
            }
            else
            {
                int col;
                if (!input.Schema.TryGetColumnIndex(stratificationColumn, out col))
                {
                    throw ch.ExceptUserArg(nameof(Arguments.StratificationColumn), "Column '{0}' does not exist", stratificationColumn);
                }
                var type = input.Schema.GetColumnType(col);
                if (!RangeFilter.IsValidRangeFilterColumnType(ch, type))
                {
                    ch.Info("Hashing the stratification column");
                    var origStratCol = stratificationColumn;
                    int tmp;
                    int inc = 0;
                    while (input.Schema.TryGetColumnIndex(stratificationColumn, out tmp))
                    {
                        stratificationColumn = string.Format("{0}_{1:000}", origStratCol, ++inc);
                    }
                    var hashargs = new HashTransform.Arguments();
                    hashargs.Column = new[] { new HashTransform.Column {
                                                  Source = origStratCol, Name = stratificationColumn
                                              } };
                    hashargs.HashBits = 30;
                    output            = new HashTransform(Host, hashargs, input);
                }
            }

            return(stratificationColumn);
        }
예제 #3
0
        public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataView input,
                                            TermLoaderArguments termLoaderArgs = null)
        {
            Contracts.CheckValue(env, nameof(env));
            var h = env.Register(LoaderSignature);

            h.CheckValue(args, nameof(args));
            h.CheckValue(input, nameof(input));
            h.CheckUserArg(Utils.Size(args.Column) > 0, nameof(args.Column), "Columns must be specified");

            // To each input column to the NgramHashExtractorArguments, a HashTransform using 31
            // bits (to minimize collisions) is applied first, followed by an NgramHashTransform.
            IDataView view = input;

            List <TermTransform.Column> termCols = null;

            if (termLoaderArgs != null)
            {
                termCols = new List <TermTransform.Column>();
            }
            var hashColumns      = new List <HashTransform.Column>();
            var ngramHashColumns = new NgramHashTransform.Column[args.Column.Length];

            var colCount = args.Column.Length;

            // The NGramHashExtractor has a ManyToOne column type. To avoid stepping over the source
            // column name when a 'name' destination column name was specified, we use temporary column names.
            string[][] tmpColNames = new string[colCount][];
            for (int iinfo = 0; iinfo < colCount; iinfo++)
            {
                var column = args.Column[iinfo];
                h.CheckUserArg(!string.IsNullOrWhiteSpace(column.Name), nameof(column.Name));
                h.CheckUserArg(Utils.Size(column.Source) > 0 &&
                               column.Source.All(src => !string.IsNullOrWhiteSpace(src)), nameof(column.Source));

                int srcCount = column.Source.Length;
                tmpColNames[iinfo] = new string[srcCount];
                for (int isrc = 0; isrc < srcCount; isrc++)
                {
                    var tmpName = input.Schema.GetTempColumnName(column.Source[isrc]);
                    tmpColNames[iinfo][isrc] = tmpName;
                    if (termLoaderArgs != null)
                    {
                        termCols.Add(
                            new TermTransform.Column
                        {
                            Name   = tmpName,
                            Source = column.Source[isrc]
                        });
                    }

                    hashColumns.Add(
                        new HashTransform.Column
                    {
                        Name       = tmpName,
                        Source     = termLoaderArgs == null ? column.Source[isrc] : tmpName,
                        HashBits   = 30,
                        Seed       = column.Seed,
                        Ordered    = false,
                        InvertHash = column.InvertHash
                    });
                }

                ngramHashColumns[iinfo] =
                    new NgramHashTransform.Column
                {
                    Name           = column.Name,
                    Source         = tmpColNames[iinfo],
                    AllLengths     = column.AllLengths,
                    HashBits       = column.HashBits,
                    NgramLength    = column.NgramLength,
                    RehashUnigrams = false,
                    Seed           = column.Seed,
                    SkipLength     = column.SkipLength,
                    Ordered        = column.Ordered,
                    InvertHash     = column.InvertHash,
                    // REVIEW: This is an ugly internal hack to get around
                    // the problem that we want the *original* source names surfacing
                    // in the descriptions where appropriate, rather than _tmp000 and
                    // what have you. The alternative is we do something elaborate
                    // with metadata or something but I'm not sure that's better.
                    FriendlyNames = column.FriendlyNames
                };
            }

            if (termLoaderArgs != null)
            {
                h.Assert(Utils.Size(termCols) == hashColumns.Count);
                var termArgs =
                    new TermTransform.Arguments()
                {
                    MaxNumTerms = int.MaxValue,
                    Terms       = termLoaderArgs.Terms,
                    Term        = termLoaderArgs.Term,
                    DataFile    = termLoaderArgs.DataFile,
                    Loader      = termLoaderArgs.Loader,
                    TermsColumn = termLoaderArgs.TermsColumn,
                    Sort        = termLoaderArgs.Sort,
                    Column      = termCols.ToArray()
                };
                view = new TermTransform(h, termArgs, view);

                if (termLoaderArgs.DropUnknowns)
                {
                    var naDropArgs = new NADropTransform.Arguments {
                        Column = new NADropTransform.Column[termCols.Count]
                    };
                    for (int iinfo = 0; iinfo < termCols.Count; iinfo++)
                    {
                        naDropArgs.Column[iinfo] =
                            new NADropTransform.Column {
                            Name = termCols[iinfo].Name, Source = termCols[iinfo].Name
                        };
                    }
                    view = new NADropTransform(h, naDropArgs, view);
                }
            }

            // Args for the Hash function with multiple columns
            var hashArgs =
                new HashTransform.Arguments
            {
                HashBits   = 31,
                Seed       = args.Seed,
                Ordered    = false,
                Column     = hashColumns.ToArray(),
                InvertHash = args.InvertHash
            };

            view = new HashTransform(h, hashArgs, view);

            // creating the NgramHash function
            var ngramHashArgs =
                new NgramHashTransform.Arguments
            {
                AllLengths     = args.AllLengths,
                HashBits       = args.HashBits,
                NgramLength    = args.NgramLength,
                SkipLength     = args.SkipLength,
                RehashUnigrams = false,
                Ordered        = args.Ordered,
                Seed           = args.Seed,
                Column         = ngramHashColumns,
                InvertHash     = args.InvertHash
            };

            view = new NgramHashTransform(h, ngramHashArgs, view);
            return(new DropColumnsTransform(h, new DropColumnsTransform.Arguments()
            {
                Column = tmpColNames.SelectMany(cols => cols).ToArray()
            }, view));
        }