public static CommonOutputs.TransformOutput NGramTransform(IHostEnvironment env, NgramCountingTransformer.Arguments input)
        {
            var h  = EntryPointUtils.CheckArgsAndCreateHost(env, "NGramTransform", input);
            var xf = NgramCountingTransformer.Create(h, input, input.Data);

            return(new CommonOutputs.TransformOutput()
            {
                Model = new TransformModel(h, xf, input.Data),
                OutputData = xf
            });
        }
Exemple #2
0
        public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataView input,
                                            TermLoaderArguments termLoaderArgs = null)
        {
            Contracts.CheckValue(env, nameof(env));
            var h = env.Register(LoaderSignature);

            h.CheckValue(args, nameof(args));
            h.CheckValue(input, nameof(input));
            h.CheckUserArg(Utils.Size(args.Column) > 0, nameof(args.Column), "Columns must be specified");

            IDataView view      = input;
            var       termCols  = new List <Column>();
            var       isTermCol = new bool[args.Column.Length];

            for (int i = 0; i < args.Column.Length; i++)
            {
                var col = args.Column[i];

                h.CheckNonWhiteSpace(col.Name, nameof(col.Name));
                h.CheckNonWhiteSpace(col.Source, nameof(col.Source));
                int colId;
                if (input.Schema.TryGetColumnIndex(col.Source, out colId) &&
                    input.Schema.GetColumnType(colId).ItemType.IsText)
                {
                    termCols.Add(col);
                    isTermCol[i] = true;
                }
            }

            // If the column types of args.column are text, apply term transform to convert them to keys.
            // Otherwise, skip term transform and apply ngram transform directly.
            // This logic allows NgramExtractorTransform to handle both text and key input columns.
            // Note: ngram transform handles the validation of the types natively (in case the types
            // of args.column are not text nor keys).
            if (termCols.Count > 0)
            {
                ValueToKeyMappingTransformer.Arguments    termArgs   = null;
                MissingValueDroppingTransformer.Arguments naDropArgs = null;
                if (termLoaderArgs != null)
                {
                    termArgs =
                        new ValueToKeyMappingTransformer.Arguments()
                    {
                        MaxNumTerms = int.MaxValue,
                        Terms       = termLoaderArgs.Terms,
                        Term        = termLoaderArgs.Term,
                        DataFile    = termLoaderArgs.DataFile,
                        Loader      = termLoaderArgs.Loader,
                        TermsColumn = termLoaderArgs.TermsColumn,
                        Sort        = termLoaderArgs.Sort,
                        Column      = new ValueToKeyMappingTransformer.Column[termCols.Count]
                    };

                    if (termLoaderArgs.DropUnknowns)
                    {
                        naDropArgs = new MissingValueDroppingTransformer.Arguments {
                            Column = new MissingValueDroppingTransformer.Column[termCols.Count]
                        }
                    }
                    ;
                }
                else
                {
                    termArgs =
                        new ValueToKeyMappingTransformer.Arguments()
                    {
                        MaxNumTerms = Utils.Size(args.MaxNumTerms) > 0 ? args.MaxNumTerms[0] : NgramCountingTransformer.Arguments.DefaultMaxTerms,
                        Column      = new ValueToKeyMappingTransformer.Column[termCols.Count]
                    };
                }

                for (int iinfo = 0; iinfo < termCols.Count; iinfo++)
                {
                    var column = termCols[iinfo];
                    termArgs.Column[iinfo] =
                        new ValueToKeyMappingTransformer.Column()
                    {
                        Name        = column.Name,
                        Source      = column.Source,
                        MaxNumTerms = Utils.Size(column.MaxNumTerms) > 0 ? column.MaxNumTerms[0] : default(int?)
                    };

                    if (naDropArgs != null)
                    {
                        naDropArgs.Column[iinfo] = new MissingValueDroppingTransformer.Column {
                            Name = column.Name, Source = column.Name
                        }
                    }
                    ;
                }

                view = ValueToKeyMappingTransformer.Create(h, termArgs, view);
                if (naDropArgs != null)
                {
                    view = new MissingValueDroppingTransformer(h, naDropArgs, view);
                }
            }

            var ngramArgs =
                new NgramCountingTransformer.Arguments()
            {
                MaxNumTerms = args.MaxNumTerms,
                NgramLength = args.NgramLength,
                SkipLength  = args.SkipLength,
                AllLengths  = args.AllLengths,
                Weighting   = args.Weighting,
                Column      = new NgramCountingTransformer.Column[args.Column.Length]
            };

            for (int iinfo = 0; iinfo < args.Column.Length; iinfo++)
            {
                var column = args.Column[iinfo];
                ngramArgs.Column[iinfo] =
                    new NgramCountingTransformer.Column()
                {
                    Name        = column.Name,
                    Source      = isTermCol[iinfo] ? column.Name : column.Source,
                    AllLengths  = column.AllLengths,
                    MaxNumTerms = column.MaxNumTerms,
                    NgramLength = column.NgramLength,
                    SkipLength  = column.SkipLength,
                    Weighting   = column.Weighting
                };
            }

            return(new NgramCountingTransformer(h, ngramArgs, view));
        }