public static string CreateStratificationColumn(IHost host, ref IDataView data, string stratificationColumn = null) { host.CheckValue(data, nameof(data)); host.CheckValueOrNull(stratificationColumn); // Pick a unique name for the stratificationColumn. const string stratColName = "StratificationKey"; string stratCol = data.Schema.GetTempColumnName(stratColName); // Construct the stratification column. If user-provided stratification column exists, use HashJoin // of it to construct the strat column, otherwise generate a random number and use it. if (stratificationColumn == null) { data = new GenerateNumberTransform(host, new GenerateNumberTransform.Options { Columns = new[] { new GenerateNumberTransform.Column { Name = stratCol } } }, data); } else { var col = data.Schema.GetColumnOrNull(stratificationColumn); if (col == null) { throw host.ExceptSchemaMismatch(nameof(stratificationColumn), "Stratification", stratificationColumn); } var type = col.Value.Type; if (!RangeFilter.IsValidRangeFilterColumnType(host, type)) { // HashingEstimator currently handles all primitive types except for DateTime, DateTimeOffset and TimeSpan. var itemType = type.GetItemType(); if (itemType is DateTimeDataViewType || itemType is DateTimeOffsetDataViewType || itemType is TimeSpanDataViewType) { data = new TypeConvertingTransformer(host, stratificationColumn, DataKind.Int64, stratificationColumn).Transform(data); } var columnOptions = new HashingEstimator.ColumnOptions(stratCol, stratificationColumn, 30, combine: true); data = new HashingEstimator(host, columnOptions).Fit(data).Transform(data); } else { if (data.Schema[stratificationColumn].IsNormalized() || (type != NumberDataViewType.Single && type != NumberDataViewType.Double)) { return(stratificationColumn); } data = new NormalizingEstimator(host, new NormalizingEstimator.MinMaxColumnOptions(stratCol, stratificationColumn, ensureZeroUntouched: true)) .Fit(data).Transform(data); } } return(stratCol); }
/// <summary> /// Ensures the provided <paramref name="samplingKeyColumn"/> is valid for <see cref="RangeFilter"/>, hashing it if necessary, or creates a new column <paramref name="samplingKeyColumn"/> is null. /// </summary> internal static void EnsureGroupPreservationColumn(IHostEnvironment env, ref IDataView data, ref string samplingKeyColumn, int?seed = null) { Contracts.CheckValue(env, nameof(env)); // We need to handle two cases: if samplingKeyColumn is provided, we use hashJoin to // build a single hash of it. If it is not, we generate a random number. if (samplingKeyColumn == null) { samplingKeyColumn = data.Schema.GetTempColumnName("SamplingKeyColumn"); data = new GenerateNumberTransform(env, data, samplingKeyColumn, (uint?)(seed ?? ((ISeededEnvironment)env).Seed)); } else { if (!data.Schema.TryGetColumnIndex(samplingKeyColumn, out int stratCol)) { throw env.ExceptSchemaMismatch(nameof(samplingKeyColumn), "SamplingKeyColumn", samplingKeyColumn); } var type = data.Schema[stratCol].Type; if (!RangeFilter.IsValidRangeFilterColumnType(env, type)) { // Hash the samplingKeyColumn. // REVIEW: this could currently crash, since Hash only accepts a limited set // of column types. It used to be HashJoin, but we should probably extend Hash // instead of having two hash transformations. var origStratCol = samplingKeyColumn; samplingKeyColumn = data.Schema.GetTempColumnName(samplingKeyColumn); HashingEstimator.ColumnOptionsInternal columnOptions; if (seed.HasValue) { columnOptions = new HashingEstimator.ColumnOptionsInternal(samplingKeyColumn, origStratCol, 30, (uint)seed.Value); } else if (((ISeededEnvironment)env).Seed.HasValue) { columnOptions = new HashingEstimator.ColumnOptionsInternal(samplingKeyColumn, origStratCol, 30, (uint)((ISeededEnvironment)env).Seed.Value); } else { columnOptions = new HashingEstimator.ColumnOptionsInternal(samplingKeyColumn, origStratCol, 30); } data = new HashingEstimator(env, columnOptions).Fit(data).Transform(data); } else { if (!data.Schema[samplingKeyColumn].IsNormalized() && (type == NumberDataViewType.Single || type == NumberDataViewType.Double)) { var origStratCol = samplingKeyColumn; samplingKeyColumn = data.Schema.GetTempColumnName(samplingKeyColumn); data = new NormalizingEstimator(env, new NormalizingEstimator.MinMaxColumnOptions(samplingKeyColumn, origStratCol, ensureZeroUntouched: true)).Fit(data).Transform(data); } } } }
private IDataView WrapPerInstance(RoleMappedData perInst) { var idv = perInst.Data; // Make a list of column names that Maml outputs as part of the per-instance data view, and then wrap // the per-instance data computed by the evaluator in a SelectColumnsTransform. var cols = new List <(string name, string source)>(); var colsToKeep = new List <string>(); // If perInst is the result of cross-validation and contains a fold Id column, include it. int foldCol; if (perInst.Schema.Schema.TryGetColumnIndex(MetricKinds.ColumnNames.FoldIndex, out foldCol)) { colsToKeep.Add(MetricKinds.ColumnNames.FoldIndex); } // Maml always outputs a name column, if it doesn't exist add a GenerateNumberTransform. if (perInst.Schema.Name?.Name is string nameName) { cols.Add(("Instance", nameName)); colsToKeep.Add("Instance"); } else { var args = new GenerateNumberTransform.Arguments(); args.Columns = new[] { new GenerateNumberTransform.Column() { Name = "Instance" } }; args.UseCounter = true; idv = new GenerateNumberTransform(Host, args, idv); colsToKeep.Add("Instance"); } // Maml outputs the weight column if it exists. if (perInst.Schema.Weight?.Name is string weightName) { colsToKeep.Add(weightName); } // Get the other columns from the evaluator. foreach (var col in GetPerInstanceColumnsToSave(perInst.Schema)) { colsToKeep.Add(col); } idv = new ColumnCopyingTransformer(Host, cols.ToArray()).Transform(idv); idv = ColumnSelectingTransformer.CreateKeep(Host, idv, colsToKeep.ToArray()); return(GetPerInstanceMetricsCore(idv, perInst.Schema)); }
public IEnumerable <Batch> GetBatches(IRandom rand) { Host.Assert(Data != null, "Must call Initialize first!"); Host.AssertValue(rand); using (var ch = Host.Start("Getting batches")) { RoleMappedData dataTest; RoleMappedData dataTrain; // Split the data, if needed. if (!(ValidationDatasetProportion > 0)) { dataTest = dataTrain = Data; } else { // Split the data into train and test sets. string name = Data.Data.Schema.GetTempColumnName(); var args = new GenerateNumberTransform.Arguments(); args.Column = new[] { new GenerateNumberTransform.Column() { Name = name } }; args.Seed = (uint)rand.Next(); var view = new GenerateNumberTransform(Host, args, Data.Data); var viewTest = new RangeFilter(Host, new RangeFilter.Arguments() { Column = name, Max = ValidationDatasetProportion }, view); var viewTrain = new RangeFilter(Host, new RangeFilter.Arguments() { Column = name, Max = ValidationDatasetProportion, Complement = true }, view); dataTest = new RoleMappedData(viewTest, Data.Schema.GetColumnRoleNames()); dataTrain = new RoleMappedData(viewTrain, Data.Schema.GetColumnRoleNames()); } if (BatchSize > 0) { // REVIEW: How should we carve the data into batches? ch.Warning("Batch support is temporarily disabled"); } yield return(new Batch(dataTrain, dataTest)); ch.Done(); } }
/// <summary> /// Ensures the provided <paramref name="samplingKeyColumn"/> is valid for <see cref="RangeFilter"/>, hashing it if necessary, or creates a new column <paramref name="samplingKeyColumn"/> is null. /// </summary> internal static void EnsureGroupPreservationColumn(IHostEnvironment env, ref IDataView data, ref string samplingKeyColumn, int?seed = null) { Contracts.CheckValue(env, nameof(env)); // We need to handle two cases: if samplingKeyColumn is provided, we use hashJoin to // build a single hash of it. If it is not, we generate a random number. if (samplingKeyColumn == null) { samplingKeyColumn = data.Schema.GetTempColumnName("SamplingKeyColumn"); data = new GenerateNumberTransform(env, data, samplingKeyColumn, (uint?)(seed ?? ((ISeededEnvironment)env).Seed)); } else { if (!data.Schema.TryGetColumnIndex(samplingKeyColumn, out int stratCol)) { throw env.ExceptSchemaMismatch(nameof(samplingKeyColumn), "SamplingKeyColumn", samplingKeyColumn); } var type = data.Schema[stratCol].Type; if (!RangeFilter.IsValidRangeFilterColumnType(env, type)) { var origStratCol = samplingKeyColumn; samplingKeyColumn = data.Schema.GetTempColumnName(samplingKeyColumn); // HashingEstimator currently handles all primitive types except for DateTime, DateTimeOffset and TimeSpan. var itemType = type.GetItemType(); if (itemType is DateTimeDataViewType || itemType is DateTimeOffsetDataViewType || itemType is TimeSpanDataViewType) { data = new TypeConvertingTransformer(env, origStratCol, DataKind.Int64, origStratCol).Transform(data); } var localSeed = seed.HasValue ? seed : ((ISeededEnvironment)env).Seed.HasValue ? ((ISeededEnvironment)env).Seed : null; var columnOptions = localSeed.HasValue ? new HashingEstimator.ColumnOptions(samplingKeyColumn, origStratCol, 30, (uint)localSeed.Value, combine: true) : new HashingEstimator.ColumnOptions(samplingKeyColumn, origStratCol, 30, combine: true); data = new HashingEstimator(env, columnOptions).Fit(data).Transform(data); } else { if (!data.Schema[samplingKeyColumn].IsNormalized() && (type == NumberDataViewType.Single || type == NumberDataViewType.Double)) { var origStratCol = samplingKeyColumn; samplingKeyColumn = data.Schema.GetTempColumnName(samplingKeyColumn); data = new NormalizingEstimator(env, new NormalizingEstimator.MinMaxColumnOptions(samplingKeyColumn, origStratCol, ensureZeroUntouched: true)).Fit(data).Transform(data); } } } }
/// <summary> /// Ensures the provided <paramref name="samplingKeyColumn"/> is valid for <see cref="RangeFilter"/>, hashing it if necessary, or creates a new column <paramref name="samplingKeyColumn"/> is null. /// </summary> private void EnsureGroupPreservationColumn(ref IDataView data, ref string samplingKeyColumn, uint?seed = null) { // We need to handle two cases: if samplingKeyColumn is provided, we use hashJoin to // build a single hash of it. If it is not, we generate a random number. if (samplingKeyColumn == null) { samplingKeyColumn = data.Schema.GetTempColumnName("IdPreservationColumn"); data = new GenerateNumberTransform(Environment, data, samplingKeyColumn, seed); } else { if (!data.Schema.TryGetColumnIndex(samplingKeyColumn, out int stratCol)) { throw Environment.ExceptSchemaMismatch(nameof(samplingKeyColumn), "GroupPreservationColumn", samplingKeyColumn); } var type = data.Schema[stratCol].Type; if (!RangeFilter.IsValidRangeFilterColumnType(Environment, type)) { // Hash the samplingKeyColumn. // REVIEW: this could currently crash, since Hash only accepts a limited set // of column types. It used to be HashJoin, but we should probably extend Hash // instead of having two hash transformations. var origStratCol = samplingKeyColumn; int tmp; int inc = 0; // Generate a new column with the hashed samplingKeyColumn. while (data.Schema.TryGetColumnIndex(samplingKeyColumn, out tmp)) { samplingKeyColumn = string.Format("{0}_{1:000}", origStratCol, ++inc); } HashingEstimator.ColumnInfo columnInfo; if (seed.HasValue) { columnInfo = new HashingEstimator.ColumnInfo(samplingKeyColumn, origStratCol, 30, seed.Value); } else { columnInfo = new HashingEstimator.ColumnInfo(samplingKeyColumn, origStratCol, 30); } data = new HashingEstimator(Environment, columnInfo).Fit(data).Transform(data); } } }
public BinaryCrossValidationMetrics CrossValidate(IDataView trainData, IEstimator <ITransformer> estimator) { var models = new ITransformer[NumFolds]; var metrics = new BinaryClassificationMetrics[NumFolds]; if (StratificationColumn == null) { StratificationColumn = "StratificationColumn"; var random = new GenerateNumberTransform(_env, trainData, StratificationColumn); trainData = random; } else { throw new NotImplementedException(); } var evaluator = new MyBinaryClassifierEvaluator(_env, new BinaryClassifierEvaluator.Arguments() { }); for (int fold = 0; fold < NumFolds; fold++) { var trainFilter = new RangeFilter(_env, new RangeFilter.Arguments() { Column = StratificationColumn, Min = (Double)fold / NumFolds, Max = (Double)(fold + 1) / NumFolds, Complement = true }, trainData); var testFilter = new RangeFilter(_env, new RangeFilter.Arguments() { Column = StratificationColumn, Min = (Double)fold / NumFolds, Max = (Double)(fold + 1) / NumFolds, Complement = false }, trainData); models[fold] = estimator.Fit(trainFilter); var scoredTest = models[fold].Transform(testFilter); metrics[fold] = evaluator.Evaluate(scoredTest, labelColumn: LabelColumn, probabilityColumn: "Probability"); } return(new BinaryCrossValidationMetrics(models, metrics)); }
public static string CreateStratificationColumn(IHost host, ref IDataView data, string stratificationColumn = null) { host.CheckValue(data, nameof(data)); host.CheckValueOrNull(stratificationColumn); // Pick a unique name for the stratificationColumn. const string stratColName = "StratificationKey"; string stratCol = stratColName; int col; int j = 0; while (data.Schema.TryGetColumnIndex(stratCol, out col)) { stratCol = string.Format("{0}_{1:000}", stratColName, j++); } // Construct the stratification column. If user-provided stratification column exists, use HashJoin // of it to construct the strat column, otherwise generate a random number and use it. if (stratificationColumn == null) { data = new GenerateNumberTransform(host, new GenerateNumberTransform.Options { Columns = new[] { new GenerateNumberTransform.Column { Name = stratCol } } }, data); } else { data = new HashJoiningTransform(host, new HashJoiningTransform.Arguments { Columns = new[] { new HashJoiningTransform.Column { Name = stratCol, Source = stratificationColumn } }, Join = true, HashBits = 30 }, data); } return(stratCol); }
public override IEnumerable <Subset> GetSubsets(Batch batch, Random rand) { string name = Data.Data.Schema.GetTempColumnName(); var args = new GenerateNumberTransform.Options(); args.Columns = new[] { new GenerateNumberTransform.Column() { Name = name } }; args.Seed = (uint)rand.Next(); IDataTransform view = new GenerateNumberTransform(Host, args, Data.Data); // REVIEW: This won't be very efficient when Size is large. for (int i = 0; i < Size; i++) { var viewTrain = new RangeFilter(Host, new RangeFilter.Options() { Column = name, Min = (Double)i / Size, Max = (Double)(i + 1) / Size }, view); var dataTrain = new RoleMappedData(viewTrain, Data.Schema.GetColumnRoleNames()); yield return(FeatureSelector.SelectFeatures(dataTrain, rand)); } }
void CrossValidation() { var dataPath = GetDataPath(SentimentDataPath); var testDataPath = GetDataPath(SentimentTestPath); int numFolds = 5; using (var env = new TlcEnvironment(seed: 1, conc: 1)) { // Pipeline. var loader = new TextLoader(env, MakeSentimentTextLoaderArgs(), new MultiFileSource(dataPath)); var text = TextTransform.Create(env, MakeSentimentTextTransformArgs(false), loader); IDataView trans = new GenerateNumberTransform(env, text, "StratificationColumn"); // Train. var trainer = new LinearClassificationTrainer(env, new LinearClassificationTrainer.Arguments { NumThreads = 1, ConvergenceTolerance = 1f }); var metrics = new List <BinaryClassificationMetrics>(); for (int fold = 0; fold < numFolds; fold++) { IDataView trainPipe = new RangeFilter(env, new RangeFilter.Arguments() { Column = "StratificationColumn", Min = (Double)fold / numFolds, Max = (Double)(fold + 1) / numFolds, Complement = true }, trans); trainPipe = new OpaqueDataView(trainPipe); var trainData = new RoleMappedData(trainPipe, label: "Label", feature: "Features"); // Auto-normalization. NormalizeTransform.CreateIfNeeded(env, ref trainData, trainer); var preCachedData = trainData; // Auto-caching. if (trainer.Info.WantCaching) { var prefetch = trainData.Schema.GetColumnRoles().Select(kc => kc.Value.Index).ToArray(); var cacheView = new CacheDataView(env, trainData.Data, prefetch); // Because the prefetching worked, we know that these are valid columns. trainData = new RoleMappedData(cacheView, trainData.Schema.GetColumnRoleNames()); } var predictor = trainer.Train(new Runtime.TrainContext(trainData)); IDataView testPipe = new RangeFilter(env, new RangeFilter.Arguments() { Column = "StratificationColumn", Min = (Double)fold / numFolds, Max = (Double)(fold + 1) / numFolds, Complement = false }, trans); testPipe = new OpaqueDataView(testPipe); var pipe = ApplyTransformUtils.ApplyAllTransformsToData(env, preCachedData.Data, testPipe, trainPipe); var testRoles = new RoleMappedData(pipe, trainData.Schema.GetColumnRoleNames()); IDataScorerTransform scorer = ScoreUtils.GetScorer(predictor, testRoles, env, testRoles.Schema); BinaryClassifierMamlEvaluator eval = new BinaryClassifierMamlEvaluator(env, new BinaryClassifierMamlEvaluator.Arguments() { }); var dataEval = new RoleMappedData(scorer, testRoles.Schema.GetColumnRoleNames(), opt: true); var dict = eval.Evaluate(dataEval); var foldMetrics = BinaryClassificationMetrics.FromMetrics(env, dict["OverallMetrics"], dict["ConfusionMatrix"]); metrics.Add(foldMetrics.Single()); } } }
internal static string CreateSplitColumn(IHostEnvironment env, ref IDataView data, string samplingKeyColumn, int?seed = null, bool fallbackInEnvSeed = false) { Contracts.CheckValue(env, nameof(env)); Contracts.CheckValueOrNull(samplingKeyColumn); var splitColumnName = data.Schema.GetTempColumnName("SplitColumn"); int?seedToUse; if (seed.HasValue) { seedToUse = seed.Value; } else if (fallbackInEnvSeed) { ISeededEnvironment seededEnv = (ISeededEnvironment)env; seedToUse = seededEnv.Seed; } else { seedToUse = null; } // We need to handle two cases: if samplingKeyColumn is not provided, we generate a random number. if (samplingKeyColumn == null) { data = new GenerateNumberTransform(env, data, splitColumnName, (uint?)seedToUse); } else { // If samplingKeyColumn was provided we will make a new column based on it, but using a temporary // name, as it might be dropped elsewhere in the code if (!data.Schema.TryGetColumnIndex(samplingKeyColumn, out int samplingColIndex)) { throw env.ExceptSchemaMismatch(nameof(samplingKeyColumn), "SamplingKeyColumn", samplingKeyColumn); } var type = data.Schema[samplingColIndex].Type; if (!RangeFilter.IsValidRangeFilterColumnType(env, type)) { var hashInputColumnName = samplingKeyColumn; // HashingEstimator currently handles all primitive types except for DateTime, DateTimeOffset and TimeSpan. var itemType = type.GetItemType(); if (itemType is DateTimeDataViewType || itemType is DateTimeOffsetDataViewType || itemType is TimeSpanDataViewType) { data = new TypeConvertingTransformer(env, splitColumnName, DataKind.Int64, samplingKeyColumn).Transform(data); hashInputColumnName = splitColumnName; } var columnOptions = seedToUse.HasValue ? new HashingEstimator.ColumnOptions(splitColumnName, hashInputColumnName, 30, (uint)seedToUse.Value, combine: true) : new HashingEstimator.ColumnOptions(splitColumnName, hashInputColumnName, 30, combine: true); data = new HashingEstimator(env, columnOptions).Fit(data).Transform(data); } else { if (type != NumberDataViewType.Single && type != NumberDataViewType.Double) { data = new ColumnCopyingEstimator(env, (splitColumnName, samplingKeyColumn)).Fit(data).Transform(data); } else { data = new NormalizingEstimator(env, new NormalizingEstimator.MinMaxColumnOptions(splitColumnName, samplingKeyColumn, ensureZeroUntouched: false)).Fit(data).Transform(data); } } } return(splitColumnName); }
protected override TVectorPredictor TrainPredictor(IChannel ch, TScalarTrainer trainer, RoleMappedData data, int count) { var data0 = data; #region adding group ID // We insert a group Id. string groupColumnTemp = DataViewUtils.GetTempColumnName(data.Schema.Schema) + "GR"; var groupArgs = new GenerateNumberTransform.Options { Columns = new[] { GenerateNumberTransform.Column.Parse(groupColumnTemp) }, UseCounter = true }; var withGroup = new GenerateNumberTransform(Host, groupArgs, data.Data); data = new RoleMappedData(withGroup, data.Schema.GetColumnRoleNames()); #endregion #region preparing the training dataset string dstName, labName; var trans = MapLabelsAndInsertTransform(ch, data, out dstName, out labName, count, true, _args); var newFeatures = trans.Schema.GetTempColumnName() + "NF"; // We check the label is not boolean. int indexLab = SchemaHelper.GetColumnIndex(trans.Schema, dstName); var typeLab = trans.Schema[indexLab].Type; if (typeLab.RawKind() == DataKind.Boolean) { throw Host.Except("Column '{0}' has an unexpected type {1}.", dstName, typeLab.RawKind()); } var args3 = new DescribeTransform.Arguments { columns = new string[] { labName, dstName }, oneRowPerColumn = true }; var desc = new DescribeTransform(Host, args3, trans); IDataView viewI; if (_args.singleColumn && data.Schema.Label.Value.Type.RawKind() == DataKind.Single) { viewI = desc; } else if (_args.singleColumn) { var sch = new TypeReplacementSchema(desc.Schema, new[] { labName }, new[] { NumberDataViewType.Single }); viewI = new TypeReplacementDataView(desc, sch); #region debug #if (DEBUG) DebugChecking0(viewI, labName, false); #endif #endregion } else if (data.Schema.Label.Value.Type.IsKey()) { ulong nb = data.Schema.Label.Value.Type.AsKey().GetKeyCount(); var sch = new TypeReplacementSchema(desc.Schema, new[] { labName }, new[] { new VectorDataViewType(NumberDataViewType.Single, (int)nb) }); viewI = new TypeReplacementDataView(desc, sch); #region debug #if (DEBUG) int nb_; MinMaxLabelOverDataSet(trans, labName, out nb_); int count3; data.CheckMulticlassLabel(out count3); if ((ulong)count3 != nb) { throw ch.Except("Count mismatch (KeyCount){0} != {1}", nb, count3); } DebugChecking0(viewI, labName, true); DebugChecking0Vfloat(viewI, labName, nb); #endif #endregion } else { int nb; if (count <= 0) { MinMaxLabelOverDataSet(trans, labName, out nb); } else { nb = count; } var sch = new TypeReplacementSchema(desc.Schema, new[] { labName }, new[] { new VectorDataViewType(NumberDataViewType.Single, nb) }); viewI = new TypeReplacementDataView(desc, sch); #region debug #if (DEBUG) DebugChecking0(viewI, labName, true); #endif #endregion } ch.Info("Merging column label '{0}' with features '{1}'", labName, data.Schema.Feature.Value.Name); var args = string.Format("Concat{{col={0}:{1},{2}}}", newFeatures, data.Schema.Feature.Value.Name, labName); var after_concatenation_ = ComponentCreation.CreateTransform(Host, args, viewI); #endregion #region converting label and group into keys // We need to convert the label into a Key. var convArgs = new MulticlassConvertTransform.Arguments { column = new[] { MulticlassConvertTransform.Column.Parse(string.Format("{0}k:{0}", dstName)) }, keyCount = new KeyCount(4), resultType = DataKind.UInt32 }; IDataView after_concatenation_key_label = new MulticlassConvertTransform(Host, convArgs, after_concatenation_); // The group must be a key too! convArgs = new MulticlassConvertTransform.Arguments { column = new[] { MulticlassConvertTransform.Column.Parse(string.Format("{0}k:{0}", groupColumnTemp)) }, keyCount = new KeyCount(), resultType = _args.groupIsU4 ? DataKind.UInt32 : DataKind.UInt64 }; after_concatenation_key_label = new MulticlassConvertTransform(Host, convArgs, after_concatenation_key_label); #endregion #region preparing the RoleMapData view string groupColumn = groupColumnTemp + "k"; dstName += "k"; var roles = data.Schema.GetColumnRoleNames(); var rolesArray = roles.ToArray(); roles = roles .Where(kvp => kvp.Key.Value != RoleMappedSchema.ColumnRole.Label.Value) .Where(kvp => kvp.Key.Value != RoleMappedSchema.ColumnRole.Feature.Value) .Where(kvp => kvp.Key.Value != groupColumn) .Where(kvp => kvp.Key.Value != groupColumnTemp); rolesArray = roles.ToArray(); if (rolesArray.Any() && rolesArray[0].Value == groupColumnTemp) { throw ch.Except("Duplicated group."); } roles = roles .Prepend(RoleMappedSchema.ColumnRole.Feature.Bind(newFeatures)) .Prepend(RoleMappedSchema.ColumnRole.Label.Bind(dstName)) .Prepend(RoleMappedSchema.ColumnRole.Group.Bind(groupColumn)); var trainer_input = new RoleMappedData(after_concatenation_key_label, roles); #endregion ch.Info("New Features: {0}:{1}", trainer_input.Schema.Feature.Value.Name, trainer_input.Schema.Feature.Value.Type); ch.Info("New Label: {0}:{1}", trainer_input.Schema.Label.Value.Name, trainer_input.Schema.Label.Value.Type); // We train the unique binary classifier. var trainedPredictor = trainer.Train(trainer_input); var predictors = new TScalarPredictor[] { trainedPredictor }; // We train the reclassification classifier. if (_args.reclassicationPredictor != null) { var pred = CreateFinalPredictor(ch, data, trans, count, _args, predictors, null); TrainReclassificationPredictor(data0, pred, ScikitSubComponent <ITrainer, SignatureTrainer> .AsSubComponent(_args.reclassicationPredictor)); } return(CreateFinalPredictor(ch, data, trans, count, _args, predictors, _reclassPredictor)); }
private string GetSplitColumn(IChannel ch, IDataView input, ref IDataView output) { // The stratification column and/or group column, if they exist at all, must be present at this point. var schema = input.Schema; output = input; // If no stratification column was specified, but we have a group column of type Single, Double or // Key (contiguous) use it. string stratificationColumn = null; if (!string.IsNullOrWhiteSpace(Args.StratificationColumn)) { stratificationColumn = Args.StratificationColumn; } else { string group = TrainUtils.MatchNameOrDefaultOrNull(ch, schema, nameof(Args.GroupColumn), Args.GroupColumn, DefaultColumnNames.GroupId); int index; if (group != null && schema.TryGetColumnIndex(group, out index)) { // Check if group column key type with known cardinality. var type = schema[index].Type; if (type.GetKeyCount() > 0) { stratificationColumn = group; } } } if (string.IsNullOrEmpty(stratificationColumn)) { stratificationColumn = "StratificationColumn"; int tmp; int inc = 0; while (input.Schema.TryGetColumnIndex(stratificationColumn, out tmp)) { stratificationColumn = string.Format("StratificationColumn_{0:000}", ++inc); } var keyGenArgs = new GenerateNumberTransform.Options(); var col = new GenerateNumberTransform.Column(); col.Name = stratificationColumn; keyGenArgs.Columns = new[] { col }; output = new GenerateNumberTransform(Host, keyGenArgs, input); } else { int col; if (!input.Schema.TryGetColumnIndex(stratificationColumn, out col)) { throw ch.ExceptUserArg(nameof(Arguments.StratificationColumn), "Column '{0}' does not exist", stratificationColumn); } var type = input.Schema[col].Type; if (!RangeFilter.IsValidRangeFilterColumnType(ch, type)) { ch.Info("Hashing the stratification column"); var origStratCol = stratificationColumn; int tmp; int inc = 0; while (input.Schema.TryGetColumnIndex(stratificationColumn, out tmp)) { stratificationColumn = string.Format("{0}_{1:000}", origStratCol, ++inc); } output = new HashingEstimator(Host, origStratCol, stratificationColumn, 30).Fit(input).Transform(input); } } return(stratificationColumn); }
IDataTransform AppendToPipeline(IDataView input) { IDataView current = input; if (_shuffleInput) { var args1 = new RowShufflingTransformer.Arguments() { ForceShuffle = false, ForceShuffleSeed = _seedShuffle, PoolRows = _poolRows, PoolOnly = false, }; current = new RowShufflingTransformer(Host, args1, current); } // We generate a random number. var columnName = current.Schema.GetTempColumnName(); var args2 = new GenerateNumberTransform.Arguments() { Column = new GenerateNumberTransform.Column[] { new GenerateNumberTransform.Column() { Name = columnName } }, Seed = _seed ?? 42 }; IDataTransform currentTr = new GenerateNumberTransform(Host, args2, current); // We convert this random number into a part. var cRatios = new float[_ratios.Length]; cRatios[0] = 0; for (int i = 1; i < _ratios.Length; ++i) { cRatios[i] = cRatios[i - 1] + _ratios[i - 1]; } ValueMapper <float, int> mapper = (in float src, ref int dst) => { for (int i = cRatios.Length - 1; i > 0; --i) { if (src >= cRatios[i]) { dst = i; return; } } dst = 0; }; // Get location of columnName int index; currentTr.Schema.TryGetColumnIndex(columnName, out index); var ct = currentTr.Schema.GetColumnType(index); var view = LambdaColumnMapper.Create(Host, "Key to part mapper", currentTr, columnName, _newColumn, ct, NumberType.I4, mapper); // We cache the result to avoid the pipeline to change the random number. var args3 = new ExtendedCacheTransform.Arguments() { inDataFrame = string.IsNullOrEmpty(_cacheFile), numTheads = _numThreads, cacheFile = _cacheFile, reuse = _reuse, }; currentTr = new ExtendedCacheTransform(Host, args3, view); // Removing the temporary column. var finalTr = ColumnSelectingTransformer.CreateDrop(Host, currentTr, new string[] { columnName }); var taggedViews = new List <Tuple <string, ITaggedDataView> >(); // filenames if (_filenames != null || _tags != null) { int nbf = _filenames == null ? 0 : _filenames.Length; if (nbf > 0 && nbf != _ratios.Length) { throw Host.Except("Differen number of filenames and ratios."); } int nbt = _tags == null ? 0 : _tags.Length; if (nbt > 0 && nbt != _ratios.Length) { throw Host.Except("Differen number of filenames and ratios."); } int nb = Math.Max(nbf, nbt); using (var ch = Host.Start("Split the datasets and stores each part.")) { for (int i = 0; i < nb; ++i) { if (_filenames == null || !_filenames.Any()) { ch.Info("Create part {0}: {1} (tag: {2})", i + 1, _ratios[i], _tags[i]); } else { ch.Info("Create part {0}: {1} (file: {2})", i + 1, _ratios[i], _filenames[i]); } var ar1 = new RangeFilter.Arguments() { Column = _newColumn, Min = i, Max = i, IncludeMax = true }; int pardId = i; var filtView = LambdaFilter.Create <int>(Host, string.Format("Select part {0}", i), currentTr, _newColumn, NumberType.I4, (in int part) => { return(part.Equals(pardId)); });