private static IDataView ApplyConvert(List <ConvertingTransform.ColumnInfo> cvt, IDataView viewTrain, IHostEnvironment env)
 {
     Contracts.AssertValueOrNull(cvt);
     Contracts.AssertValue(viewTrain);
     Contracts.AssertValue(env);
     if (Utils.Size(cvt) > 0)
     {
         viewTrain = new ConvertingTransform(env, cvt.ToArray()).Transform(viewTrain);
     }
     return(viewTrain);
 }
        public static CommonOutputs.TransformOutput PrepareRegressionLabel(IHostEnvironment env, RegressionLabelInput input)
        {
            Contracts.CheckValue(env, nameof(env));
            var host = env.Register("PrepareRegressionLabel");

            host.CheckValue(input, nameof(input));
            EntryPointUtils.CheckInputArgs(host, input);

            int labelCol;

            if (!input.Data.Schema.TryGetColumnIndex(input.LabelColumn, out labelCol))
            {
                throw host.Except($"Column '{input.LabelColumn}' not found.");
            }
            var labelType = input.Data.Schema.GetColumnType(labelCol);

            if (labelType == NumberType.R4 || !labelType.IsNumber)
            {
                var nop = NopTransform.CreateIfNeeded(env, input.Data);
                return(new CommonOutputs.TransformOutput {
                    Model = new TransformModel(env, nop, input.Data), OutputData = nop
                });
            }

            var args = new ConvertingTransform.Arguments()
            {
                Column = new[]
                {
                    new ConvertingTransform.Column()
                    {
                        Name       = input.LabelColumn,
                        Source     = input.LabelColumn,
                        ResultType = DataKind.R4
                    }
                }
            };
            var xf = new ConvertingTransform(host, new ConvertingTransform.ColumnInfo(input.LabelColumn, input.LabelColumn, DataKind.R4)).Transform(input.Data);

            return(new CommonOutputs.TransformOutput {
                Model = new TransformModel(env, xf, input.Data), OutputData = xf
            });
        }
Beispiel #3
0
        public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataView input)
        {
            Contracts.CheckValue(env, nameof(env));
            var h = env.Register("Categorical");

            h.CheckValue(args, nameof(args));
            h.CheckValue(input, nameof(input));
            h.CheckUserArg(Utils.Size(args.Column) > 0, nameof(args.Column));

            var replaceCols          = new List <NAReplaceTransform.ColumnInfo>();
            var naIndicatorCols      = new List <NAIndicatorTransform.Column>();
            var naConvCols           = new List <ConvertingTransform.ColumnInfo>();
            var concatCols           = new List <ConcatTransform.TaggedColumn>();
            var dropCols             = new List <string>();
            var tmpIsMissingColNames = input.Schema.GetTempColumnNames(args.Column.Length, "IsMissing");
            var tmpReplaceColNames   = input.Schema.GetTempColumnNames(args.Column.Length, "Replace");

            for (int i = 0; i < args.Column.Length; i++)
            {
                var column = args.Column[i];

                var addInd = column.ConcatIndicator ?? args.Concat;
                if (!addInd)
                {
                    replaceCols.Add(new NAReplaceTransform.ColumnInfo(column.Source, column.Name, (NAReplaceTransform.ColumnInfo.ReplacementMode)(column.Kind ?? args.ReplaceWith), column.ImputeBySlot ?? args.ImputeBySlot));
                    continue;
                }

                // Check that the indicator column has a type that can be converted to the NAReplaceTransform output type,
                // so that they can be concatenated.
                if (!input.Schema.TryGetColumnIndex(column.Source, out int inputCol))
                {
                    throw h.Except("Column '{0}' does not exist", column.Source);
                }
                var replaceType = input.Schema.GetColumnType(inputCol);
                if (!Runtime.Data.Conversion.Conversions.Instance.TryGetStandardConversion(BoolType.Instance, replaceType.ItemType, out Delegate conv, out bool identity))
                {
                    throw h.Except("Cannot concatenate indicator column of type '{0}' to input column of type '{1}'",
                                   BoolType.Instance, replaceType.ItemType);
                }

                // Find a temporary name for the NAReplaceTransform and NAIndicatorTransform output columns.
                var tmpIsMissingColName   = tmpIsMissingColNames[i];
                var tmpReplacementColName = tmpReplaceColNames[i];

                // Add an NAHandleTransform column.
                naIndicatorCols.Add(new NAIndicatorTransform.Column()
                {
                    Name = tmpIsMissingColName, Source = column.Source
                });

                // Add a ConvertTransform column if necessary.
                if (!identity)
                {
                    naConvCols.Add(new ConvertingTransform.ColumnInfo(tmpIsMissingColName, tmpIsMissingColName, replaceType.ItemType.RawKind));
                }

                // Add the NAReplaceTransform column.
                replaceCols.Add(new NAReplaceTransform.ColumnInfo(column.Source, tmpReplacementColName, (NAReplaceTransform.ColumnInfo.ReplacementMode)(column.Kind ?? args.ReplaceWith), column.ImputeBySlot ?? args.ImputeBySlot));

                // Add the ConcatTransform column.
                if (replaceType.IsVector)
                {
                    concatCols.Add(new ConcatTransform.TaggedColumn()
                    {
                        Name   = column.Name,
                        Source = new[] {
                            new KeyValuePair <string, string>(tmpReplacementColName, tmpReplacementColName),
                            new KeyValuePair <string, string>("IsMissing", tmpIsMissingColName)
                        }
                    });
                }
                else
                {
                    concatCols.Add(new ConcatTransform.TaggedColumn()
                    {
                        Name   = column.Name,
                        Source = new[]
                        {
                            new KeyValuePair <string, string>(column.Source, tmpReplacementColName),
                            new KeyValuePair <string, string>(string.Format("IsMissing.{0}", column.Source), tmpIsMissingColName),
                        }
                    });
                }

                // Add the temp column to the list of columns to drop at the end.
                dropCols.Add(tmpIsMissingColName);
                dropCols.Add(tmpReplacementColName);
            }

            IDataTransform output = null;

            // Create the indicator columns.
            if (naIndicatorCols.Count > 0)
            {
                output = NAIndicatorTransform.Create(h, new NAIndicatorTransform.Arguments()
                {
                    Column = naIndicatorCols.ToArray()
                }, input);
            }

            // Convert the indicator columns to the correct type so that they can be concatenated to the NAReplace outputs.
            if (naConvCols.Count > 0)
            {
                h.AssertValue(output);
                //REVIEW: all this need to be converted to estimatorChain as soon as we done with dropcolumns.
                output = new ConvertingTransform(h, naConvCols.ToArray()).Transform(output) as IDataTransform;
            }
            // Create the NAReplace transform.
            output = NAReplaceTransform.Create(env, output ?? input, replaceCols.ToArray());

            // Concat the NAReplaceTransform output and the NAIndicatorTransform output.
            if (naIndicatorCols.Count > 0)
            {
                output = ConcatTransform.Create(h, new ConcatTransform.TaggedArguments()
                {
                    Column = concatCols.ToArray()
                }, output);
            }

            // Finally, drop the temporary indicator columns.
            if (dropCols.Count > 0)
            {
                output = SelectColumnsTransform.CreateDrop(h, output, dropCols.ToArray());
            }

            return(output);
        }