public static CommonOutputs.TransformOutput PrepareRegressionLabel(IHostEnvironment env, RegressionLabelInput input) { Contracts.CheckValue(env, nameof(env)); var host = env.Register("PrepareRegressionLabel"); host.CheckValue(input, nameof(input)); EntryPointUtils.CheckInputArgs(host, input); var labelCol = input.Data.Schema.GetColumnOrNull(input.LabelColumn); if (!labelCol.HasValue) throw host.Except($"Column '{input.LabelColumn}' not found."); var labelType = labelCol.Value.Type; if (labelType == NumberType.R4 || !(labelType is NumberType)) { var nop = NopTransform.CreateIfNeeded(env, input.Data); return new CommonOutputs.TransformOutput { Model = new TransformModelImpl(env, nop, input.Data), OutputData = nop }; } var args = new TypeConvertingTransformer.Arguments() { Column = new[] { new TypeConvertingTransformer.Column() { Name = input.LabelColumn, Source = input.LabelColumn, ResultType = DataKind.R4 } } }; var xf = new TypeConvertingTransformer(host, new TypeConvertingTransformer.ColumnInfo(input.LabelColumn, DataKind.R4, input.LabelColumn)).Transform(input.Data); return new CommonOutputs.TransformOutput { Model = new TransformModelImpl(env, xf, input.Data), OutputData = xf }; }
public static CommonOutputs.TransformOutput PrepareRegressionLabel(IHostEnvironment env, RegressionLabelInput input) { Contracts.CheckValue(env, nameof(env)); var host = env.Register("PrepareRegressionLabel"); host.CheckValue(input, nameof(input)); EntryPointUtils.CheckInputArgs(host, input); var labelCol = input.Data.Schema.GetColumnOrNull(input.LabelColumn); if (!labelCol.HasValue) { throw host.Except($"Column '{input.LabelColumn}' not found."); } var labelType = labelCol.Value.Type; if (labelType == NumberDataViewType.Single || !(labelType is NumberDataViewType)) { var nop = NopTransform.CreateIfNeeded(env, input.Data); return(new CommonOutputs.TransformOutput { Model = new TransformModelImpl(env, nop, input.Data), OutputData = nop }); } var xf = new TypeConvertingTransformer(host, new TypeConvertingEstimator.ColumnOptions(input.LabelColumn, DataKind.Single, input.LabelColumn)).Transform(input.Data); return(new CommonOutputs.TransformOutput { Model = new TransformModelImpl(env, xf, input.Data), OutputData = xf }); }
private static IDataView ApplyConvert(List<TypeConvertingTransformer.ColumnInfo> cvt, IDataView viewTrain, IHostEnvironment env) { Contracts.AssertValueOrNull(cvt); Contracts.AssertValue(viewTrain); Contracts.AssertValue(env); if (Utils.Size(cvt) > 0) viewTrain = new TypeConvertingTransformer(env, cvt.ToArray()).Transform(viewTrain); return viewTrain; }
public static string CreateStratificationColumn(IHost host, ref IDataView data, string stratificationColumn = null) { host.CheckValue(data, nameof(data)); host.CheckValueOrNull(stratificationColumn); // Pick a unique name for the stratificationColumn. const string stratColName = "StratificationKey"; string stratCol = data.Schema.GetTempColumnName(stratColName); // Construct the stratification column. If user-provided stratification column exists, use HashJoin // of it to construct the strat column, otherwise generate a random number and use it. if (stratificationColumn == null) { data = new GenerateNumberTransform(host, new GenerateNumberTransform.Options { Columns = new[] { new GenerateNumberTransform.Column { Name = stratCol } } }, data); } else { var col = data.Schema.GetColumnOrNull(stratificationColumn); if (col == null) { throw host.ExceptSchemaMismatch(nameof(stratificationColumn), "Stratification", stratificationColumn); } var type = col.Value.Type; if (!RangeFilter.IsValidRangeFilterColumnType(host, type)) { // HashingEstimator currently handles all primitive types except for DateTime, DateTimeOffset and TimeSpan. var itemType = type.GetItemType(); if (itemType is DateTimeDataViewType || itemType is DateTimeOffsetDataViewType || itemType is TimeSpanDataViewType) { data = new TypeConvertingTransformer(host, stratificationColumn, DataKind.Int64, stratificationColumn).Transform(data); } var columnOptions = new HashingEstimator.ColumnOptions(stratCol, stratificationColumn, 30, combine: true); data = new HashingEstimator(host, columnOptions).Fit(data).Transform(data); } else { if (data.Schema[stratificationColumn].IsNormalized() || (type != NumberDataViewType.Single && type != NumberDataViewType.Double)) { return(stratificationColumn); } data = new NormalizingEstimator(host, new NormalizingEstimator.MinMaxColumnOptions(stratCol, stratificationColumn, ensureZeroUntouched: true)) .Fit(data).Transform(data); } } return(stratCol); }
public static ScoringTransformOutput Score(IHostEnvironment env, ScoringTransformInput input) { Contracts.CheckValue(env, nameof(env)); var host = env.Register("ScoreModel"); host.CheckValue(input, nameof(input)); EntryPointUtils.CheckInputArgs(host, input); RoleMappedData data; IPredictor predictor; var inputData = input.Data; try { input.PredictorModel.PrepareData(host, inputData, out data, out predictor); } catch (Exception) { // this can happen in csr_matrix case, try to use only trainer model. host.Assert(inputData.Schema.Count == 1); var inputColumnName = inputData.Schema[0].Name; var trainingSchema = input.PredictorModel.GetTrainingSchema(host); // get feature vector item type. var trainingFeatureColumn = (DataViewSchema.Column)trainingSchema.Feature; var requiredType = trainingFeatureColumn.Type.GetItemType().RawType; var featuresColumnName = trainingFeatureColumn.Name; predictor = input.PredictorModel.Predictor; var xf = new TypeConvertingTransformer(host, new TypeConvertingEstimator.ColumnOptions(featuresColumnName, requiredType, inputColumnName)).Transform(inputData); data = new RoleMappedData(xf, null, featuresColumnName); } IDataView scoredPipe; using (var ch = host.Start("Creating scoring pipeline")) { ch.Trace("Creating pipeline"); var bindable = ScoreUtils.GetSchemaBindableMapper(host, predictor); ch.AssertValue(bindable); var mapper = bindable.Bind(host, data.Schema); var scorer = ScoreUtils.GetScorerComponent(host, mapper, input.Suffix); scoredPipe = scorer.CreateComponent(host, data.Data, mapper, input.PredictorModel.GetTrainingSchema(host)); } return (new ScoringTransformOutput { ScoredData = scoredPipe, ScoringTransform = new TransformModelImpl(host, scoredPipe, inputData) }); }
/// <summary> /// Ensures the provided <paramref name="samplingKeyColumn"/> is valid for <see cref="RangeFilter"/>, hashing it if necessary, or creates a new column <paramref name="samplingKeyColumn"/> is null. /// </summary> internal static void EnsureGroupPreservationColumn(IHostEnvironment env, ref IDataView data, ref string samplingKeyColumn, int?seed = null) { Contracts.CheckValue(env, nameof(env)); // We need to handle two cases: if samplingKeyColumn is provided, we use hashJoin to // build a single hash of it. If it is not, we generate a random number. if (samplingKeyColumn == null) { samplingKeyColumn = data.Schema.GetTempColumnName("SamplingKeyColumn"); data = new GenerateNumberTransform(env, data, samplingKeyColumn, (uint?)(seed ?? ((ISeededEnvironment)env).Seed)); } else { if (!data.Schema.TryGetColumnIndex(samplingKeyColumn, out int stratCol)) { throw env.ExceptSchemaMismatch(nameof(samplingKeyColumn), "SamplingKeyColumn", samplingKeyColumn); } var type = data.Schema[stratCol].Type; if (!RangeFilter.IsValidRangeFilterColumnType(env, type)) { var origStratCol = samplingKeyColumn; samplingKeyColumn = data.Schema.GetTempColumnName(samplingKeyColumn); // HashingEstimator currently handles all primitive types except for DateTime, DateTimeOffset and TimeSpan. var itemType = type.GetItemType(); if (itemType is DateTimeDataViewType || itemType is DateTimeOffsetDataViewType || itemType is TimeSpanDataViewType) { data = new TypeConvertingTransformer(env, origStratCol, DataKind.Int64, origStratCol).Transform(data); } var localSeed = seed.HasValue ? seed : ((ISeededEnvironment)env).Seed.HasValue ? ((ISeededEnvironment)env).Seed : null; var columnOptions = localSeed.HasValue ? new HashingEstimator.ColumnOptions(samplingKeyColumn, origStratCol, 30, (uint)localSeed.Value, combine: true) : new HashingEstimator.ColumnOptions(samplingKeyColumn, origStratCol, 30, combine: true); data = new HashingEstimator(env, columnOptions).Fit(data).Transform(data); } else { if (!data.Schema[samplingKeyColumn].IsNormalized() && (type == NumberDataViewType.Single || type == NumberDataViewType.Double)) { var origStratCol = samplingKeyColumn; samplingKeyColumn = data.Schema.GetTempColumnName(samplingKeyColumn); data = new NormalizingEstimator(env, new NormalizingEstimator.MinMaxColumnOptions(samplingKeyColumn, origStratCol, ensureZeroUntouched: true)).Fit(data).Transform(data); } } } }
internal static string CreateSplitColumn(IHostEnvironment env, ref IDataView data, string samplingKeyColumn, int?seed = null, bool fallbackInEnvSeed = false) { Contracts.CheckValue(env, nameof(env)); Contracts.CheckValueOrNull(samplingKeyColumn); var splitColumnName = data.Schema.GetTempColumnName("SplitColumn"); int?seedToUse; if (seed.HasValue) { seedToUse = seed.Value; } else if (fallbackInEnvSeed) { ISeededEnvironment seededEnv = (ISeededEnvironment)env; seedToUse = seededEnv.Seed; } else { seedToUse = null; } // We need to handle two cases: if samplingKeyColumn is not provided, we generate a random number. if (samplingKeyColumn == null) { data = new GenerateNumberTransform(env, data, splitColumnName, (uint?)seedToUse); } else { // If samplingKeyColumn was provided we will make a new column based on it, but using a temporary // name, as it might be dropped elsewhere in the code if (!data.Schema.TryGetColumnIndex(samplingKeyColumn, out int samplingColIndex)) { throw env.ExceptSchemaMismatch(nameof(samplingKeyColumn), "SamplingKeyColumn", samplingKeyColumn); } var type = data.Schema[samplingColIndex].Type; if (!RangeFilter.IsValidRangeFilterColumnType(env, type)) { var hashInputColumnName = samplingKeyColumn; // HashingEstimator currently handles all primitive types except for DateTime, DateTimeOffset and TimeSpan. var itemType = type.GetItemType(); if (itemType is DateTimeDataViewType || itemType is DateTimeOffsetDataViewType || itemType is TimeSpanDataViewType) { data = new TypeConvertingTransformer(env, splitColumnName, DataKind.Int64, samplingKeyColumn).Transform(data); hashInputColumnName = splitColumnName; } var columnOptions = seedToUse.HasValue ? new HashingEstimator.ColumnOptions(splitColumnName, hashInputColumnName, 30, (uint)seedToUse.Value, combine: true) : new HashingEstimator.ColumnOptions(splitColumnName, hashInputColumnName, 30, combine: true); data = new HashingEstimator(env, columnOptions).Fit(data).Transform(data); } else { if (type != NumberDataViewType.Single && type != NumberDataViewType.Double) { data = new ColumnCopyingEstimator(env, (splitColumnName, samplingKeyColumn)).Fit(data).Transform(data); } else { data = new NormalizingEstimator(env, new NormalizingEstimator.MinMaxColumnOptions(splitColumnName, samplingKeyColumn, ensureZeroUntouched: false)).Fit(data).Transform(data); } } } return(splitColumnName); }
private string GetSplitColumn(IChannel ch, IDataView input, ref IDataView output) { // The stratification column and/or group column, if they exist at all, must be present at this point. var schema = input.Schema; output = input; // If no stratification column was specified, but we have a group column of type Single, Double or // Key (contiguous) use it. string stratificationColumn = null; if (!string.IsNullOrWhiteSpace(ImplOptions.StratificationColumn)) { stratificationColumn = ImplOptions.StratificationColumn; } else { string group = TrainUtils.MatchNameOrDefaultOrNull(ch, schema, nameof(ImplOptions.GroupColumn), ImplOptions.GroupColumn, DefaultColumnNames.GroupId); int index; if (group != null && schema.TryGetColumnIndex(group, out index)) { // Check if group column key type with known cardinality. var type = schema[index].Type; if (type.GetKeyCount() > 0) { stratificationColumn = group; } } } if (string.IsNullOrEmpty(stratificationColumn)) { stratificationColumn = "StratificationColumn"; int tmp; int inc = 0; while (input.Schema.TryGetColumnIndex(stratificationColumn, out tmp)) { stratificationColumn = string.Format("StratificationColumn_{0:000}", ++inc); } var keyGenArgs = new GenerateNumberTransform.Options(); var col = new GenerateNumberTransform.Column(); col.Name = stratificationColumn; keyGenArgs.Columns = new[] { col }; output = new GenerateNumberTransform(Host, keyGenArgs, input); } else { int col; if (!input.Schema.TryGetColumnIndex(stratificationColumn, out col)) { throw ch.ExceptUserArg(nameof(Arguments.StratificationColumn), "Column '{0}' does not exist", stratificationColumn); } var type = input.Schema[col].Type; if (!RangeFilter.IsValidRangeFilterColumnType(ch, type)) { ch.Info("Hashing the stratification column"); var origStratCol = stratificationColumn; stratificationColumn = input.Schema.GetTempColumnName("strat"); // HashingEstimator currently handles all primitive types except for DateTime, DateTimeOffset and TimeSpan. var itemType = type.GetItemType(); if (itemType is DateTimeDataViewType || itemType is DateTimeOffsetDataViewType || itemType is TimeSpanDataViewType) { input = new TypeConvertingTransformer(Host, origStratCol, DataKind.Int64, origStratCol).Transform(input); } output = new HashingEstimator(Host, stratificationColumn, origStratCol, 30).Fit(input).Transform(input); } } return(stratificationColumn); }