Ejemplo n.º 1
0
        public static CommonOutputs.TransformOutput PrepareRegressionLabel(IHostEnvironment env, RegressionLabelInput input)
        {
            Contracts.CheckValue(env, nameof(env));
            var host = env.Register("PrepareRegressionLabel");
            host.CheckValue(input, nameof(input));
            EntryPointUtils.CheckInputArgs(host, input);

            var labelCol = input.Data.Schema.GetColumnOrNull(input.LabelColumn);
            if (!labelCol.HasValue)
                throw host.Except($"Column '{input.LabelColumn}' not found.");
            var labelType = labelCol.Value.Type;
            if (labelType == NumberType.R4 || !(labelType is NumberType))
            {
                var nop = NopTransform.CreateIfNeeded(env, input.Data);
                return new CommonOutputs.TransformOutput { Model = new TransformModelImpl(env, nop, input.Data), OutputData = nop };
            }

            var args = new TypeConvertingTransformer.Arguments()
            {
                Column = new[]
                {
                    new TypeConvertingTransformer.Column()
                    {
                        Name = input.LabelColumn,
                        Source = input.LabelColumn,
                        ResultType = DataKind.R4
                    }
                }
            };
            var xf = new TypeConvertingTransformer(host, new TypeConvertingTransformer.ColumnInfo(input.LabelColumn, DataKind.R4, input.LabelColumn)).Transform(input.Data);
            return new CommonOutputs.TransformOutput { Model = new TransformModelImpl(env, xf, input.Data), OutputData = xf };
        }
Ejemplo n.º 2
0
        public static CommonOutputs.TransformOutput PrepareRegressionLabel(IHostEnvironment env, RegressionLabelInput input)
        {
            Contracts.CheckValue(env, nameof(env));
            var host = env.Register("PrepareRegressionLabel");

            host.CheckValue(input, nameof(input));
            EntryPointUtils.CheckInputArgs(host, input);

            var labelCol = input.Data.Schema.GetColumnOrNull(input.LabelColumn);

            if (!labelCol.HasValue)
            {
                throw host.Except($"Column '{input.LabelColumn}' not found.");
            }
            var labelType = labelCol.Value.Type;

            if (labelType == NumberDataViewType.Single || !(labelType is NumberDataViewType))
            {
                var nop = NopTransform.CreateIfNeeded(env, input.Data);
                return(new CommonOutputs.TransformOutput {
                    Model = new TransformModelImpl(env, nop, input.Data), OutputData = nop
                });
            }

            var xf = new TypeConvertingTransformer(host, new TypeConvertingEstimator.ColumnOptions(input.LabelColumn, DataKind.Single, input.LabelColumn)).Transform(input.Data);

            return(new CommonOutputs.TransformOutput {
                Model = new TransformModelImpl(env, xf, input.Data), OutputData = xf
            });
        }
Ejemplo n.º 3
0
 private static IDataView ApplyConvert(List<TypeConvertingTransformer.ColumnInfo> cvt, IDataView viewTrain, IHostEnvironment env)
 {
     Contracts.AssertValueOrNull(cvt);
     Contracts.AssertValue(viewTrain);
     Contracts.AssertValue(env);
     if (Utils.Size(cvt) > 0)
         viewTrain = new TypeConvertingTransformer(env, cvt.ToArray()).Transform(viewTrain);
     return viewTrain;
 }
        public static string CreateStratificationColumn(IHost host, ref IDataView data, string stratificationColumn = null)
        {
            host.CheckValue(data, nameof(data));
            host.CheckValueOrNull(stratificationColumn);

            // Pick a unique name for the stratificationColumn.
            const string stratColName = "StratificationKey";
            string       stratCol     = data.Schema.GetTempColumnName(stratColName);

            // Construct the stratification column. If user-provided stratification column exists, use HashJoin
            // of it to construct the strat column, otherwise generate a random number and use it.
            if (stratificationColumn == null)
            {
                data = new GenerateNumberTransform(host,
                                                   new GenerateNumberTransform.Options
                {
                    Columns = new[] { new GenerateNumberTransform.Column {
                                          Name = stratCol
                                      } }
                }, data);
            }
            else
            {
                var col = data.Schema.GetColumnOrNull(stratificationColumn);
                if (col == null)
                {
                    throw host.ExceptSchemaMismatch(nameof(stratificationColumn), "Stratification", stratificationColumn);
                }

                var type = col.Value.Type;
                if (!RangeFilter.IsValidRangeFilterColumnType(host, type))
                {
                    // HashingEstimator currently handles all primitive types except for DateTime, DateTimeOffset and TimeSpan.
                    var itemType = type.GetItemType();
                    if (itemType is DateTimeDataViewType || itemType is DateTimeOffsetDataViewType || itemType is TimeSpanDataViewType)
                    {
                        data = new TypeConvertingTransformer(host, stratificationColumn, DataKind.Int64, stratificationColumn).Transform(data);
                    }

                    var columnOptions = new HashingEstimator.ColumnOptions(stratCol, stratificationColumn, 30, combine: true);
                    data = new HashingEstimator(host, columnOptions).Fit(data).Transform(data);
                }
                else
                {
                    if (data.Schema[stratificationColumn].IsNormalized() || (type != NumberDataViewType.Single && type != NumberDataViewType.Double))
                    {
                        return(stratificationColumn);
                    }

                    data = new NormalizingEstimator(host,
                                                    new NormalizingEstimator.MinMaxColumnOptions(stratCol, stratificationColumn, ensureZeroUntouched: true))
                           .Fit(data).Transform(data);
                }
            }

            return(stratCol);
        }
Ejemplo n.º 5
0
        public static ScoringTransformOutput Score(IHostEnvironment env, ScoringTransformInput input)
        {
            Contracts.CheckValue(env, nameof(env));
            var host = env.Register("ScoreModel");

            host.CheckValue(input, nameof(input));
            EntryPointUtils.CheckInputArgs(host, input);

            RoleMappedData data;
            IPredictor     predictor;
            var            inputData = input.Data;

            try
            {
                input.PredictorModel.PrepareData(host, inputData, out data, out predictor);
            }
            catch (Exception)
            {
                // this can happen in csr_matrix case, try to use only trainer model.
                host.Assert(inputData.Schema.Count == 1);
                var inputColumnName = inputData.Schema[0].Name;
                var trainingSchema  = input.PredictorModel.GetTrainingSchema(host);
                // get feature vector item type.
                var trainingFeatureColumn = (DataViewSchema.Column)trainingSchema.Feature;
                var requiredType          = trainingFeatureColumn.Type.GetItemType().RawType;
                var featuresColumnName    = trainingFeatureColumn.Name;
                predictor = input.PredictorModel.Predictor;
                var xf = new TypeConvertingTransformer(host,
                                                       new TypeConvertingEstimator.ColumnOptions(featuresColumnName, requiredType, inputColumnName)).Transform(inputData);
                data = new RoleMappedData(xf, null, featuresColumnName);
            }

            IDataView scoredPipe;

            using (var ch = host.Start("Creating scoring pipeline"))
            {
                ch.Trace("Creating pipeline");
                var bindable = ScoreUtils.GetSchemaBindableMapper(host, predictor);
                ch.AssertValue(bindable);

                var mapper = bindable.Bind(host, data.Schema);
                var scorer = ScoreUtils.GetScorerComponent(host, mapper, input.Suffix);
                scoredPipe = scorer.CreateComponent(host, data.Data, mapper, input.PredictorModel.GetTrainingSchema(host));
            }

            return
                (new ScoringTransformOutput
            {
                ScoredData = scoredPipe,
                ScoringTransform = new TransformModelImpl(host, scoredPipe, inputData)
            });
        }
Ejemplo n.º 6
0
        /// <summary>
        /// Ensures the provided <paramref name="samplingKeyColumn"/> is valid for <see cref="RangeFilter"/>, hashing it if necessary, or creates a new column <paramref name="samplingKeyColumn"/> is null.
        /// </summary>
        internal static void EnsureGroupPreservationColumn(IHostEnvironment env, ref IDataView data, ref string samplingKeyColumn, int?seed = null)
        {
            Contracts.CheckValue(env, nameof(env));
            // We need to handle two cases: if samplingKeyColumn is provided, we use hashJoin to
            // build a single hash of it. If it is not, we generate a random number.
            if (samplingKeyColumn == null)
            {
                samplingKeyColumn = data.Schema.GetTempColumnName("SamplingKeyColumn");
                data = new GenerateNumberTransform(env, data, samplingKeyColumn, (uint?)(seed ?? ((ISeededEnvironment)env).Seed));
            }
            else
            {
                if (!data.Schema.TryGetColumnIndex(samplingKeyColumn, out int stratCol))
                {
                    throw env.ExceptSchemaMismatch(nameof(samplingKeyColumn), "SamplingKeyColumn", samplingKeyColumn);
                }

                var type = data.Schema[stratCol].Type;
                if (!RangeFilter.IsValidRangeFilterColumnType(env, type))
                {
                    var origStratCol = samplingKeyColumn;
                    samplingKeyColumn = data.Schema.GetTempColumnName(samplingKeyColumn);
                    // HashingEstimator currently handles all primitive types except for DateTime, DateTimeOffset and TimeSpan.
                    var itemType = type.GetItemType();
                    if (itemType is DateTimeDataViewType || itemType is DateTimeOffsetDataViewType || itemType is TimeSpanDataViewType)
                    {
                        data = new TypeConvertingTransformer(env, origStratCol, DataKind.Int64, origStratCol).Transform(data);
                    }

                    var localSeed     = seed.HasValue ? seed : ((ISeededEnvironment)env).Seed.HasValue ? ((ISeededEnvironment)env).Seed : null;
                    var columnOptions =
                        localSeed.HasValue ?
                        new HashingEstimator.ColumnOptions(samplingKeyColumn, origStratCol, 30, (uint)localSeed.Value, combine: true) :
                        new HashingEstimator.ColumnOptions(samplingKeyColumn, origStratCol, 30, combine: true);
                    data = new HashingEstimator(env, columnOptions).Fit(data).Transform(data);
                }
                else
                {
                    if (!data.Schema[samplingKeyColumn].IsNormalized() && (type == NumberDataViewType.Single || type == NumberDataViewType.Double))
                    {
                        var origStratCol = samplingKeyColumn;
                        samplingKeyColumn = data.Schema.GetTempColumnName(samplingKeyColumn);
                        data = new NormalizingEstimator(env, new NormalizingEstimator.MinMaxColumnOptions(samplingKeyColumn, origStratCol, ensureZeroUntouched: true)).Fit(data).Transform(data);
                    }
                }
            }
        }
Ejemplo n.º 7
0
        internal static string CreateSplitColumn(IHostEnvironment env, ref IDataView data, string samplingKeyColumn, int?seed = null, bool fallbackInEnvSeed = false)
        {
            Contracts.CheckValue(env, nameof(env));
            Contracts.CheckValueOrNull(samplingKeyColumn);

            var splitColumnName = data.Schema.GetTempColumnName("SplitColumn");
            int?seedToUse;

            if (seed.HasValue)
            {
                seedToUse = seed.Value;
            }
            else if (fallbackInEnvSeed)
            {
                ISeededEnvironment seededEnv = (ISeededEnvironment)env;
                seedToUse = seededEnv.Seed;
            }
            else
            {
                seedToUse = null;
            }

            // We need to handle two cases: if samplingKeyColumn is not provided, we generate a random number.
            if (samplingKeyColumn == null)
            {
                data = new GenerateNumberTransform(env, data, splitColumnName, (uint?)seedToUse);
            }
            else
            {
                // If samplingKeyColumn was provided we will make a new column based on it, but using a temporary
                // name, as it might be dropped elsewhere in the code

                if (!data.Schema.TryGetColumnIndex(samplingKeyColumn, out int samplingColIndex))
                {
                    throw env.ExceptSchemaMismatch(nameof(samplingKeyColumn), "SamplingKeyColumn", samplingKeyColumn);
                }

                var type = data.Schema[samplingColIndex].Type;
                if (!RangeFilter.IsValidRangeFilterColumnType(env, type))
                {
                    var hashInputColumnName = samplingKeyColumn;
                    // HashingEstimator currently handles all primitive types except for DateTime, DateTimeOffset and TimeSpan.
                    var itemType = type.GetItemType();
                    if (itemType is DateTimeDataViewType || itemType is DateTimeOffsetDataViewType || itemType is TimeSpanDataViewType)
                    {
                        data = new TypeConvertingTransformer(env, splitColumnName, DataKind.Int64, samplingKeyColumn).Transform(data);
                        hashInputColumnName = splitColumnName;
                    }

                    var columnOptions =
                        seedToUse.HasValue ?
                        new HashingEstimator.ColumnOptions(splitColumnName, hashInputColumnName, 30, (uint)seedToUse.Value, combine: true) :
                        new HashingEstimator.ColumnOptions(splitColumnName, hashInputColumnName, 30, combine: true);
                    data = new HashingEstimator(env, columnOptions).Fit(data).Transform(data);
                }
                else
                {
                    if (type != NumberDataViewType.Single && type != NumberDataViewType.Double)
                    {
                        data = new ColumnCopyingEstimator(env, (splitColumnName, samplingKeyColumn)).Fit(data).Transform(data);
                    }
                    else
                    {
                        data = new NormalizingEstimator(env, new NormalizingEstimator.MinMaxColumnOptions(splitColumnName, samplingKeyColumn, ensureZeroUntouched: false)).Fit(data).Transform(data);
                    }
                }
            }

            return(splitColumnName);
        }
        private string GetSplitColumn(IChannel ch, IDataView input, ref IDataView output)
        {
            // The stratification column and/or group column, if they exist at all, must be present at this point.
            var schema = input.Schema;

            output = input;
            // If no stratification column was specified, but we have a group column of type Single, Double or
            // Key (contiguous) use it.
            string stratificationColumn = null;

            if (!string.IsNullOrWhiteSpace(ImplOptions.StratificationColumn))
            {
                stratificationColumn = ImplOptions.StratificationColumn;
            }
            else
            {
                string group = TrainUtils.MatchNameOrDefaultOrNull(ch, schema, nameof(ImplOptions.GroupColumn), ImplOptions.GroupColumn, DefaultColumnNames.GroupId);
                int    index;
                if (group != null && schema.TryGetColumnIndex(group, out index))
                {
                    // Check if group column key type with known cardinality.
                    var type = schema[index].Type;
                    if (type.GetKeyCount() > 0)
                    {
                        stratificationColumn = group;
                    }
                }
            }

            if (string.IsNullOrEmpty(stratificationColumn))
            {
                stratificationColumn = "StratificationColumn";
                int tmp;
                int inc = 0;
                while (input.Schema.TryGetColumnIndex(stratificationColumn, out tmp))
                {
                    stratificationColumn = string.Format("StratificationColumn_{0:000}", ++inc);
                }
                var keyGenArgs = new GenerateNumberTransform.Options();
                var col        = new GenerateNumberTransform.Column();
                col.Name           = stratificationColumn;
                keyGenArgs.Columns = new[] { col };
                output             = new GenerateNumberTransform(Host, keyGenArgs, input);
            }
            else
            {
                int col;
                if (!input.Schema.TryGetColumnIndex(stratificationColumn, out col))
                {
                    throw ch.ExceptUserArg(nameof(Arguments.StratificationColumn), "Column '{0}' does not exist", stratificationColumn);
                }
                var type = input.Schema[col].Type;
                if (!RangeFilter.IsValidRangeFilterColumnType(ch, type))
                {
                    ch.Info("Hashing the stratification column");
                    var origStratCol = stratificationColumn;
                    stratificationColumn = input.Schema.GetTempColumnName("strat");

                    // HashingEstimator currently handles all primitive types except for DateTime, DateTimeOffset and TimeSpan.
                    var itemType = type.GetItemType();
                    if (itemType is DateTimeDataViewType || itemType is DateTimeOffsetDataViewType || itemType is TimeSpanDataViewType)
                    {
                        input = new TypeConvertingTransformer(Host, origStratCol, DataKind.Int64, origStratCol).Transform(input);
                    }

                    output = new HashingEstimator(Host, stratificationColumn, origStratCol, 30).Fit(input).Transform(input);
                }
            }

            return(stratificationColumn);
        }