public static string CreateStratificationColumn(IHost host, ref IDataView data, string stratificationColumn = null)
        {
            host.CheckValue(data, nameof(data));
            host.CheckValueOrNull(stratificationColumn);

            // Pick a unique name for the stratificationColumn.
            const string stratColName = "StratificationKey";
            string       stratCol     = data.Schema.GetTempColumnName(stratColName);

            // Construct the stratification column. If user-provided stratification column exists, use HashJoin
            // of it to construct the strat column, otherwise generate a random number and use it.
            if (stratificationColumn == null)
            {
                data = new GenerateNumberTransform(host,
                                                   new GenerateNumberTransform.Options
                {
                    Columns = new[] { new GenerateNumberTransform.Column {
                                          Name = stratCol
                                      } }
                }, data);
            }
            else
            {
                var col = data.Schema.GetColumnOrNull(stratificationColumn);
                if (col == null)
                {
                    throw host.ExceptSchemaMismatch(nameof(stratificationColumn), "Stratification", stratificationColumn);
                }

                var type = col.Value.Type;
                if (!RangeFilter.IsValidRangeFilterColumnType(host, type))
                {
                    // HashingEstimator currently handles all primitive types except for DateTime, DateTimeOffset and TimeSpan.
                    var itemType = type.GetItemType();
                    if (itemType is DateTimeDataViewType || itemType is DateTimeOffsetDataViewType || itemType is TimeSpanDataViewType)
                    {
                        data = new TypeConvertingTransformer(host, stratificationColumn, DataKind.Int64, stratificationColumn).Transform(data);
                    }

                    var columnOptions = new HashingEstimator.ColumnOptions(stratCol, stratificationColumn, 30, combine: true);
                    data = new HashingEstimator(host, columnOptions).Fit(data).Transform(data);
                }
                else
                {
                    if (data.Schema[stratificationColumn].IsNormalized() || (type != NumberDataViewType.Single && type != NumberDataViewType.Double))
                    {
                        return(stratificationColumn);
                    }

                    data = new NormalizingEstimator(host,
                                                    new NormalizingEstimator.MinMaxColumnOptions(stratCol, stratificationColumn, ensureZeroUntouched: true))
                           .Fit(data).Transform(data);
                }
            }

            return(stratCol);
        }
Exemple #2
0
        /// <summary>
        /// Ensures the provided <paramref name="samplingKeyColumn"/> is valid for <see cref="RangeFilter"/>, hashing it if necessary, or creates a new column <paramref name="samplingKeyColumn"/> is null.
        /// </summary>
        internal static void EnsureGroupPreservationColumn(IHostEnvironment env, ref IDataView data, ref string samplingKeyColumn, int?seed = null)
        {
            Contracts.CheckValue(env, nameof(env));
            // We need to handle two cases: if samplingKeyColumn is provided, we use hashJoin to
            // build a single hash of it. If it is not, we generate a random number.
            if (samplingKeyColumn == null)
            {
                samplingKeyColumn = data.Schema.GetTempColumnName("SamplingKeyColumn");
                data = new GenerateNumberTransform(env, data, samplingKeyColumn, (uint?)(seed ?? ((ISeededEnvironment)env).Seed));
            }
            else
            {
                if (!data.Schema.TryGetColumnIndex(samplingKeyColumn, out int stratCol))
                {
                    throw env.ExceptSchemaMismatch(nameof(samplingKeyColumn), "SamplingKeyColumn", samplingKeyColumn);
                }

                var type = data.Schema[stratCol].Type;
                if (!RangeFilter.IsValidRangeFilterColumnType(env, type))
                {
                    // Hash the samplingKeyColumn.
                    // REVIEW: this could currently crash, since Hash only accepts a limited set
                    // of column types. It used to be HashJoin, but we should probably extend Hash
                    // instead of having two hash transformations.
                    var origStratCol = samplingKeyColumn;
                    samplingKeyColumn = data.Schema.GetTempColumnName(samplingKeyColumn);
                    HashingEstimator.ColumnOptionsInternal columnOptions;
                    if (seed.HasValue)
                    {
                        columnOptions = new HashingEstimator.ColumnOptionsInternal(samplingKeyColumn, origStratCol, 30, (uint)seed.Value);
                    }
                    else if (((ISeededEnvironment)env).Seed.HasValue)
                    {
                        columnOptions = new HashingEstimator.ColumnOptionsInternal(samplingKeyColumn, origStratCol, 30, (uint)((ISeededEnvironment)env).Seed.Value);
                    }
                    else
                    {
                        columnOptions = new HashingEstimator.ColumnOptionsInternal(samplingKeyColumn, origStratCol, 30);
                    }
                    data = new HashingEstimator(env, columnOptions).Fit(data).Transform(data);
                }
                else
                {
                    if (!data.Schema[samplingKeyColumn].IsNormalized() && (type == NumberDataViewType.Single || type == NumberDataViewType.Double))
                    {
                        var origStratCol = samplingKeyColumn;
                        samplingKeyColumn = data.Schema.GetTempColumnName(samplingKeyColumn);
                        data = new NormalizingEstimator(env, new NormalizingEstimator.MinMaxColumnOptions(samplingKeyColumn, origStratCol, ensureZeroUntouched: true)).Fit(data).Transform(data);
                    }
                }
            }
        }
Exemple #3
0
        private IDataView WrapPerInstance(RoleMappedData perInst)
        {
            var idv = perInst.Data;

            // Make a list of column names that Maml outputs as part of the per-instance data view, and then wrap
            // the per-instance data computed by the evaluator in a SelectColumnsTransform.
            var cols       = new List <(string name, string source)>();
            var colsToKeep = new List <string>();

            // If perInst is the result of cross-validation and contains a fold Id column, include it.
            int foldCol;

            if (perInst.Schema.Schema.TryGetColumnIndex(MetricKinds.ColumnNames.FoldIndex, out foldCol))
            {
                colsToKeep.Add(MetricKinds.ColumnNames.FoldIndex);
            }

            // Maml always outputs a name column, if it doesn't exist add a GenerateNumberTransform.
            if (perInst.Schema.Name?.Name is string nameName)
            {
                cols.Add(("Instance", nameName));
                colsToKeep.Add("Instance");
            }
            else
            {
                var args = new GenerateNumberTransform.Arguments();
                args.Columns = new[] { new GenerateNumberTransform.Column()
                                       {
                                           Name = "Instance"
                                       } };
                args.UseCounter = true;
                idv             = new GenerateNumberTransform(Host, args, idv);
                colsToKeep.Add("Instance");
            }

            // Maml outputs the weight column if it exists.
            if (perInst.Schema.Weight?.Name is string weightName)
            {
                colsToKeep.Add(weightName);
            }

            // Get the other columns from the evaluator.
            foreach (var col in GetPerInstanceColumnsToSave(perInst.Schema))
            {
                colsToKeep.Add(col);
            }

            idv = new ColumnCopyingTransformer(Host, cols.ToArray()).Transform(idv);
            idv = ColumnSelectingTransformer.CreateKeep(Host, idv, colsToKeep.ToArray());
            return(GetPerInstanceMetricsCore(idv, perInst.Schema));
        }
        public IEnumerable <Batch> GetBatches(IRandom rand)
        {
            Host.Assert(Data != null, "Must call Initialize first!");
            Host.AssertValue(rand);

            using (var ch = Host.Start("Getting batches"))
            {
                RoleMappedData dataTest;
                RoleMappedData dataTrain;

                // Split the data, if needed.
                if (!(ValidationDatasetProportion > 0))
                {
                    dataTest = dataTrain = Data;
                }
                else
                {
                    // Split the data into train and test sets.
                    string name = Data.Data.Schema.GetTempColumnName();
                    var    args = new GenerateNumberTransform.Arguments();
                    args.Column = new[] { new GenerateNumberTransform.Column()
                                          {
                                              Name = name
                                          } };
                    args.Seed = (uint)rand.Next();
                    var view     = new GenerateNumberTransform(Host, args, Data.Data);
                    var viewTest = new RangeFilter(Host, new RangeFilter.Arguments()
                    {
                        Column = name, Max = ValidationDatasetProportion
                    }, view);
                    var viewTrain = new RangeFilter(Host, new RangeFilter.Arguments()
                    {
                        Column = name, Max = ValidationDatasetProportion, Complement = true
                    }, view);
                    dataTest  = new RoleMappedData(viewTest, Data.Schema.GetColumnRoleNames());
                    dataTrain = new RoleMappedData(viewTrain, Data.Schema.GetColumnRoleNames());
                }

                if (BatchSize > 0)
                {
                    // REVIEW: How should we carve the data into batches?
                    ch.Warning("Batch support is temporarily disabled");
                }

                yield return(new Batch(dataTrain, dataTest));

                ch.Done();
            }
        }
Exemple #5
0
        /// <summary>
        /// Ensures the provided <paramref name="samplingKeyColumn"/> is valid for <see cref="RangeFilter"/>, hashing it if necessary, or creates a new column <paramref name="samplingKeyColumn"/> is null.
        /// </summary>
        internal static void EnsureGroupPreservationColumn(IHostEnvironment env, ref IDataView data, ref string samplingKeyColumn, int?seed = null)
        {
            Contracts.CheckValue(env, nameof(env));
            // We need to handle two cases: if samplingKeyColumn is provided, we use hashJoin to
            // build a single hash of it. If it is not, we generate a random number.
            if (samplingKeyColumn == null)
            {
                samplingKeyColumn = data.Schema.GetTempColumnName("SamplingKeyColumn");
                data = new GenerateNumberTransform(env, data, samplingKeyColumn, (uint?)(seed ?? ((ISeededEnvironment)env).Seed));
            }
            else
            {
                if (!data.Schema.TryGetColumnIndex(samplingKeyColumn, out int stratCol))
                {
                    throw env.ExceptSchemaMismatch(nameof(samplingKeyColumn), "SamplingKeyColumn", samplingKeyColumn);
                }

                var type = data.Schema[stratCol].Type;
                if (!RangeFilter.IsValidRangeFilterColumnType(env, type))
                {
                    var origStratCol = samplingKeyColumn;
                    samplingKeyColumn = data.Schema.GetTempColumnName(samplingKeyColumn);
                    // HashingEstimator currently handles all primitive types except for DateTime, DateTimeOffset and TimeSpan.
                    var itemType = type.GetItemType();
                    if (itemType is DateTimeDataViewType || itemType is DateTimeOffsetDataViewType || itemType is TimeSpanDataViewType)
                    {
                        data = new TypeConvertingTransformer(env, origStratCol, DataKind.Int64, origStratCol).Transform(data);
                    }

                    var localSeed     = seed.HasValue ? seed : ((ISeededEnvironment)env).Seed.HasValue ? ((ISeededEnvironment)env).Seed : null;
                    var columnOptions =
                        localSeed.HasValue ?
                        new HashingEstimator.ColumnOptions(samplingKeyColumn, origStratCol, 30, (uint)localSeed.Value, combine: true) :
                        new HashingEstimator.ColumnOptions(samplingKeyColumn, origStratCol, 30, combine: true);
                    data = new HashingEstimator(env, columnOptions).Fit(data).Transform(data);
                }
                else
                {
                    if (!data.Schema[samplingKeyColumn].IsNormalized() && (type == NumberDataViewType.Single || type == NumberDataViewType.Double))
                    {
                        var origStratCol = samplingKeyColumn;
                        samplingKeyColumn = data.Schema.GetTempColumnName(samplingKeyColumn);
                        data = new NormalizingEstimator(env, new NormalizingEstimator.MinMaxColumnOptions(samplingKeyColumn, origStratCol, ensureZeroUntouched: true)).Fit(data).Transform(data);
                    }
                }
            }
        }
Exemple #6
0
        /// <summary>
        /// Ensures the provided <paramref name="samplingKeyColumn"/> is valid for <see cref="RangeFilter"/>, hashing it if necessary, or creates a new column <paramref name="samplingKeyColumn"/> is null.
        /// </summary>
        private void EnsureGroupPreservationColumn(ref IDataView data, ref string samplingKeyColumn, uint?seed = null)
        {
            // We need to handle two cases: if samplingKeyColumn is provided, we use hashJoin to
            // build a single hash of it. If it is not, we generate a random number.

            if (samplingKeyColumn == null)
            {
                samplingKeyColumn = data.Schema.GetTempColumnName("IdPreservationColumn");
                data = new GenerateNumberTransform(Environment, data, samplingKeyColumn, seed);
            }
            else
            {
                if (!data.Schema.TryGetColumnIndex(samplingKeyColumn, out int stratCol))
                {
                    throw Environment.ExceptSchemaMismatch(nameof(samplingKeyColumn), "GroupPreservationColumn", samplingKeyColumn);
                }

                var type = data.Schema[stratCol].Type;
                if (!RangeFilter.IsValidRangeFilterColumnType(Environment, type))
                {
                    // Hash the samplingKeyColumn.
                    // REVIEW: this could currently crash, since Hash only accepts a limited set
                    // of column types. It used to be HashJoin, but we should probably extend Hash
                    // instead of having two hash transformations.
                    var origStratCol = samplingKeyColumn;
                    int tmp;
                    int inc = 0;

                    // Generate a new column with the hashed samplingKeyColumn.
                    while (data.Schema.TryGetColumnIndex(samplingKeyColumn, out tmp))
                    {
                        samplingKeyColumn = string.Format("{0}_{1:000}", origStratCol, ++inc);
                    }
                    HashingEstimator.ColumnInfo columnInfo;
                    if (seed.HasValue)
                    {
                        columnInfo = new HashingEstimator.ColumnInfo(samplingKeyColumn, origStratCol, 30, seed.Value);
                    }
                    else
                    {
                        columnInfo = new HashingEstimator.ColumnInfo(samplingKeyColumn, origStratCol, 30);
                    }
                    data = new HashingEstimator(Environment, columnInfo).Fit(data).Transform(data);
                }
            }
        }
Exemple #7
0
            public BinaryCrossValidationMetrics CrossValidate(IDataView trainData, IEstimator <ITransformer> estimator)
            {
                var models  = new ITransformer[NumFolds];
                var metrics = new BinaryClassificationMetrics[NumFolds];

                if (StratificationColumn == null)
                {
                    StratificationColumn = "StratificationColumn";
                    var random = new GenerateNumberTransform(_env, trainData, StratificationColumn);
                    trainData = random;
                }
                else
                {
                    throw new NotImplementedException();
                }

                var evaluator = new MyBinaryClassifierEvaluator(_env, new BinaryClassifierEvaluator.Arguments()
                {
                });

                for (int fold = 0; fold < NumFolds; fold++)
                {
                    var trainFilter = new RangeFilter(_env, new RangeFilter.Arguments()
                    {
                        Column     = StratificationColumn,
                        Min        = (Double)fold / NumFolds,
                        Max        = (Double)(fold + 1) / NumFolds,
                        Complement = true
                    }, trainData);
                    var testFilter = new RangeFilter(_env, new RangeFilter.Arguments()
                    {
                        Column     = StratificationColumn,
                        Min        = (Double)fold / NumFolds,
                        Max        = (Double)(fold + 1) / NumFolds,
                        Complement = false
                    }, trainData);

                    models[fold] = estimator.Fit(trainFilter);
                    var scoredTest = models[fold].Transform(testFilter);
                    metrics[fold] = evaluator.Evaluate(scoredTest, labelColumn: LabelColumn, probabilityColumn: "Probability");
                }

                return(new BinaryCrossValidationMetrics(models, metrics));
            }
Exemple #8
0
        public static string CreateStratificationColumn(IHost host, ref IDataView data, string stratificationColumn = null)
        {
            host.CheckValue(data, nameof(data));
            host.CheckValueOrNull(stratificationColumn);

            // Pick a unique name for the stratificationColumn.
            const string stratColName = "StratificationKey";
            string       stratCol     = stratColName;
            int          col;
            int          j = 0;

            while (data.Schema.TryGetColumnIndex(stratCol, out col))
            {
                stratCol = string.Format("{0}_{1:000}", stratColName, j++);
            }
            // Construct the stratification column. If user-provided stratification column exists, use HashJoin
            // of it to construct the strat column, otherwise generate a random number and use it.
            if (stratificationColumn == null)
            {
                data = new GenerateNumberTransform(host,
                                                   new GenerateNumberTransform.Options
                {
                    Columns = new[] { new GenerateNumberTransform.Column {
                                          Name = stratCol
                                      } }
                }, data);
            }
            else
            {
                data = new HashJoiningTransform(host,
                                                new HashJoiningTransform.Arguments
                {
                    Columns = new[] { new HashJoiningTransform.Column {
                                          Name = stratCol, Source = stratificationColumn
                                      } },
                    Join     = true,
                    HashBits = 30
                }, data);
            }

            return(stratCol);
        }
        public override IEnumerable <Subset> GetSubsets(Batch batch, Random rand)
        {
            string name = Data.Data.Schema.GetTempColumnName();
            var    args = new GenerateNumberTransform.Options();

            args.Columns = new[] { new GenerateNumberTransform.Column()
                                   {
                                       Name = name
                                   } };
            args.Seed = (uint)rand.Next();
            IDataTransform view = new GenerateNumberTransform(Host, args, Data.Data);

            // REVIEW: This won't be very efficient when Size is large.
            for (int i = 0; i < Size; i++)
            {
                var viewTrain = new RangeFilter(Host, new RangeFilter.Options()
                {
                    Column = name, Min = (Double)i / Size, Max = (Double)(i + 1) / Size
                }, view);
                var dataTrain = new RoleMappedData(viewTrain, Data.Schema.GetColumnRoleNames());
                yield return(FeatureSelector.SelectFeatures(dataTrain, rand));
            }
        }
Exemple #10
0
        void CrossValidation()
        {
            var dataPath     = GetDataPath(SentimentDataPath);
            var testDataPath = GetDataPath(SentimentTestPath);

            int numFolds = 5;

            using (var env = new TlcEnvironment(seed: 1, conc: 1))
            {
                // Pipeline.
                var loader = new TextLoader(env, MakeSentimentTextLoaderArgs(), new MultiFileSource(dataPath));

                var       text  = TextTransform.Create(env, MakeSentimentTextTransformArgs(false), loader);
                IDataView trans = new GenerateNumberTransform(env, text, "StratificationColumn");
                // Train.
                var trainer = new LinearClassificationTrainer(env, new LinearClassificationTrainer.Arguments
                {
                    NumThreads           = 1,
                    ConvergenceTolerance = 1f
                });


                var metrics = new List <BinaryClassificationMetrics>();
                for (int fold = 0; fold < numFolds; fold++)
                {
                    IDataView trainPipe = new RangeFilter(env, new RangeFilter.Arguments()
                    {
                        Column     = "StratificationColumn",
                        Min        = (Double)fold / numFolds,
                        Max        = (Double)(fold + 1) / numFolds,
                        Complement = true
                    }, trans);
                    trainPipe = new OpaqueDataView(trainPipe);
                    var trainData = new RoleMappedData(trainPipe, label: "Label", feature: "Features");
                    // Auto-normalization.
                    NormalizeTransform.CreateIfNeeded(env, ref trainData, trainer);
                    var preCachedData = trainData;
                    // Auto-caching.
                    if (trainer.Info.WantCaching)
                    {
                        var prefetch  = trainData.Schema.GetColumnRoles().Select(kc => kc.Value.Index).ToArray();
                        var cacheView = new CacheDataView(env, trainData.Data, prefetch);
                        // Because the prefetching worked, we know that these are valid columns.
                        trainData = new RoleMappedData(cacheView, trainData.Schema.GetColumnRoleNames());
                    }

                    var       predictor = trainer.Train(new Runtime.TrainContext(trainData));
                    IDataView testPipe  = new RangeFilter(env, new RangeFilter.Arguments()
                    {
                        Column     = "StratificationColumn",
                        Min        = (Double)fold / numFolds,
                        Max        = (Double)(fold + 1) / numFolds,
                        Complement = false
                    }, trans);
                    testPipe = new OpaqueDataView(testPipe);
                    var pipe = ApplyTransformUtils.ApplyAllTransformsToData(env, preCachedData.Data, testPipe, trainPipe);

                    var testRoles = new RoleMappedData(pipe, trainData.Schema.GetColumnRoleNames());

                    IDataScorerTransform scorer = ScoreUtils.GetScorer(predictor, testRoles, env, testRoles.Schema);

                    BinaryClassifierMamlEvaluator eval = new BinaryClassifierMamlEvaluator(env, new BinaryClassifierMamlEvaluator.Arguments()
                    {
                    });
                    var dataEval    = new RoleMappedData(scorer, testRoles.Schema.GetColumnRoleNames(), opt: true);
                    var dict        = eval.Evaluate(dataEval);
                    var foldMetrics = BinaryClassificationMetrics.FromMetrics(env, dict["OverallMetrics"], dict["ConfusionMatrix"]);
                    metrics.Add(foldMetrics.Single());
                }
            }
        }
Exemple #11
0
        internal static string CreateSplitColumn(IHostEnvironment env, ref IDataView data, string samplingKeyColumn, int?seed = null, bool fallbackInEnvSeed = false)
        {
            Contracts.CheckValue(env, nameof(env));
            Contracts.CheckValueOrNull(samplingKeyColumn);

            var splitColumnName = data.Schema.GetTempColumnName("SplitColumn");
            int?seedToUse;

            if (seed.HasValue)
            {
                seedToUse = seed.Value;
            }
            else if (fallbackInEnvSeed)
            {
                ISeededEnvironment seededEnv = (ISeededEnvironment)env;
                seedToUse = seededEnv.Seed;
            }
            else
            {
                seedToUse = null;
            }

            // We need to handle two cases: if samplingKeyColumn is not provided, we generate a random number.
            if (samplingKeyColumn == null)
            {
                data = new GenerateNumberTransform(env, data, splitColumnName, (uint?)seedToUse);
            }
            else
            {
                // If samplingKeyColumn was provided we will make a new column based on it, but using a temporary
                // name, as it might be dropped elsewhere in the code

                if (!data.Schema.TryGetColumnIndex(samplingKeyColumn, out int samplingColIndex))
                {
                    throw env.ExceptSchemaMismatch(nameof(samplingKeyColumn), "SamplingKeyColumn", samplingKeyColumn);
                }

                var type = data.Schema[samplingColIndex].Type;
                if (!RangeFilter.IsValidRangeFilterColumnType(env, type))
                {
                    var hashInputColumnName = samplingKeyColumn;
                    // HashingEstimator currently handles all primitive types except for DateTime, DateTimeOffset and TimeSpan.
                    var itemType = type.GetItemType();
                    if (itemType is DateTimeDataViewType || itemType is DateTimeOffsetDataViewType || itemType is TimeSpanDataViewType)
                    {
                        data = new TypeConvertingTransformer(env, splitColumnName, DataKind.Int64, samplingKeyColumn).Transform(data);
                        hashInputColumnName = splitColumnName;
                    }

                    var columnOptions =
                        seedToUse.HasValue ?
                        new HashingEstimator.ColumnOptions(splitColumnName, hashInputColumnName, 30, (uint)seedToUse.Value, combine: true) :
                        new HashingEstimator.ColumnOptions(splitColumnName, hashInputColumnName, 30, combine: true);
                    data = new HashingEstimator(env, columnOptions).Fit(data).Transform(data);
                }
                else
                {
                    if (type != NumberDataViewType.Single && type != NumberDataViewType.Double)
                    {
                        data = new ColumnCopyingEstimator(env, (splitColumnName, samplingKeyColumn)).Fit(data).Transform(data);
                    }
                    else
                    {
                        data = new NormalizingEstimator(env, new NormalizingEstimator.MinMaxColumnOptions(splitColumnName, samplingKeyColumn, ensureZeroUntouched: false)).Fit(data).Transform(data);
                    }
                }
            }

            return(splitColumnName);
        }
        protected override TVectorPredictor TrainPredictor(IChannel ch, TScalarTrainer trainer, RoleMappedData data, int count)
        {
            var data0 = data;

            #region adding group ID

            // We insert a group Id.
            string groupColumnTemp = DataViewUtils.GetTempColumnName(data.Schema.Schema) + "GR";
            var    groupArgs       = new GenerateNumberTransform.Options
            {
                Columns    = new[] { GenerateNumberTransform.Column.Parse(groupColumnTemp) },
                UseCounter = true
            };

            var withGroup = new GenerateNumberTransform(Host, groupArgs, data.Data);
            data = new RoleMappedData(withGroup, data.Schema.GetColumnRoleNames());

            #endregion

            #region preparing the training dataset

            string dstName, labName;
            var    trans       = MapLabelsAndInsertTransform(ch, data, out dstName, out labName, count, true, _args);
            var    newFeatures = trans.Schema.GetTempColumnName() + "NF";

            // We check the label is not boolean.
            int indexLab = SchemaHelper.GetColumnIndex(trans.Schema, dstName);
            var typeLab  = trans.Schema[indexLab].Type;
            if (typeLab.RawKind() == DataKind.Boolean)
            {
                throw Host.Except("Column '{0}' has an unexpected type {1}.", dstName, typeLab.RawKind());
            }

            var args3 = new DescribeTransform.Arguments {
                columns = new string[] { labName, dstName }, oneRowPerColumn = true
            };
            var desc = new DescribeTransform(Host, args3, trans);

            IDataView viewI;
            if (_args.singleColumn && data.Schema.Label.Value.Type.RawKind() == DataKind.Single)
            {
                viewI = desc;
            }
            else if (_args.singleColumn)
            {
                var sch = new TypeReplacementSchema(desc.Schema, new[] { labName }, new[] { NumberDataViewType.Single });
                viewI = new TypeReplacementDataView(desc, sch);
                #region debug
#if (DEBUG)
                DebugChecking0(viewI, labName, false);
#endif
                #endregion
            }
            else if (data.Schema.Label.Value.Type.IsKey())
            {
                ulong nb  = data.Schema.Label.Value.Type.AsKey().GetKeyCount();
                var   sch = new TypeReplacementSchema(desc.Schema, new[] { labName }, new[] { new VectorDataViewType(NumberDataViewType.Single, (int)nb) });
                viewI = new TypeReplacementDataView(desc, sch);
                #region debug
#if (DEBUG)
                int nb_;
                MinMaxLabelOverDataSet(trans, labName, out nb_);
                int count3;
                data.CheckMulticlassLabel(out count3);
                if ((ulong)count3 != nb)
                {
                    throw ch.Except("Count mismatch (KeyCount){0} != {1}", nb, count3);
                }
                DebugChecking0(viewI, labName, true);
                DebugChecking0Vfloat(viewI, labName, nb);
#endif
                #endregion
            }
            else
            {
                int nb;
                if (count <= 0)
                {
                    MinMaxLabelOverDataSet(trans, labName, out nb);
                }
                else
                {
                    nb = count;
                }
                var sch = new TypeReplacementSchema(desc.Schema, new[] { labName }, new[] { new VectorDataViewType(NumberDataViewType.Single, nb) });
                viewI = new TypeReplacementDataView(desc, sch);
                #region debug
#if (DEBUG)
                DebugChecking0(viewI, labName, true);
#endif
                #endregion
            }

            ch.Info("Merging column label '{0}' with features '{1}'", labName, data.Schema.Feature.Value.Name);
            var args = string.Format("Concat{{col={0}:{1},{2}}}", newFeatures, data.Schema.Feature.Value.Name, labName);
            var after_concatenation_ = ComponentCreation.CreateTransform(Host, args, viewI);

            #endregion

            #region converting label and group into keys

            // We need to convert the label into a Key.
            var convArgs = new MulticlassConvertTransform.Arguments
            {
                column     = new[] { MulticlassConvertTransform.Column.Parse(string.Format("{0}k:{0}", dstName)) },
                keyCount   = new KeyCount(4),
                resultType = DataKind.UInt32
            };
            IDataView after_concatenation_key_label = new MulticlassConvertTransform(Host, convArgs, after_concatenation_);

            // The group must be a key too!
            convArgs = new MulticlassConvertTransform.Arguments
            {
                column     = new[] { MulticlassConvertTransform.Column.Parse(string.Format("{0}k:{0}", groupColumnTemp)) },
                keyCount   = new KeyCount(),
                resultType = _args.groupIsU4 ? DataKind.UInt32 : DataKind.UInt64
            };
            after_concatenation_key_label = new MulticlassConvertTransform(Host, convArgs, after_concatenation_key_label);

            #endregion

            #region preparing the RoleMapData view

            string groupColumn = groupColumnTemp + "k";
            dstName += "k";

            var roles      = data.Schema.GetColumnRoleNames();
            var rolesArray = roles.ToArray();
            roles = roles
                    .Where(kvp => kvp.Key.Value != RoleMappedSchema.ColumnRole.Label.Value)
                    .Where(kvp => kvp.Key.Value != RoleMappedSchema.ColumnRole.Feature.Value)
                    .Where(kvp => kvp.Key.Value != groupColumn)
                    .Where(kvp => kvp.Key.Value != groupColumnTemp);
            rolesArray = roles.ToArray();
            if (rolesArray.Any() && rolesArray[0].Value == groupColumnTemp)
            {
                throw ch.Except("Duplicated group.");
            }
            roles = roles
                    .Prepend(RoleMappedSchema.ColumnRole.Feature.Bind(newFeatures))
                    .Prepend(RoleMappedSchema.ColumnRole.Label.Bind(dstName))
                    .Prepend(RoleMappedSchema.ColumnRole.Group.Bind(groupColumn));
            var trainer_input = new RoleMappedData(after_concatenation_key_label, roles);

            #endregion

            ch.Info("New Features: {0}:{1}", trainer_input.Schema.Feature.Value.Name, trainer_input.Schema.Feature.Value.Type);
            ch.Info("New Label: {0}:{1}", trainer_input.Schema.Label.Value.Name, trainer_input.Schema.Label.Value.Type);

            // We train the unique binary classifier.
            var trainedPredictor = trainer.Train(trainer_input);
            var predictors       = new TScalarPredictor[] { trainedPredictor };

            // We train the reclassification classifier.
            if (_args.reclassicationPredictor != null)
            {
                var pred = CreateFinalPredictor(ch, data, trans, count, _args, predictors, null);
                TrainReclassificationPredictor(data0, pred, ScikitSubComponent <ITrainer, SignatureTrainer> .AsSubComponent(_args.reclassicationPredictor));
            }

            return(CreateFinalPredictor(ch, data, trans, count, _args, predictors, _reclassPredictor));
        }
        private string GetSplitColumn(IChannel ch, IDataView input, ref IDataView output)
        {
            // The stratification column and/or group column, if they exist at all, must be present at this point.
            var schema = input.Schema;

            output = input;
            // If no stratification column was specified, but we have a group column of type Single, Double or
            // Key (contiguous) use it.
            string stratificationColumn = null;

            if (!string.IsNullOrWhiteSpace(Args.StratificationColumn))
            {
                stratificationColumn = Args.StratificationColumn;
            }
            else
            {
                string group = TrainUtils.MatchNameOrDefaultOrNull(ch, schema, nameof(Args.GroupColumn), Args.GroupColumn, DefaultColumnNames.GroupId);
                int    index;
                if (group != null && schema.TryGetColumnIndex(group, out index))
                {
                    // Check if group column key type with known cardinality.
                    var type = schema[index].Type;
                    if (type.GetKeyCount() > 0)
                    {
                        stratificationColumn = group;
                    }
                }
            }

            if (string.IsNullOrEmpty(stratificationColumn))
            {
                stratificationColumn = "StratificationColumn";
                int tmp;
                int inc = 0;
                while (input.Schema.TryGetColumnIndex(stratificationColumn, out tmp))
                {
                    stratificationColumn = string.Format("StratificationColumn_{0:000}", ++inc);
                }
                var keyGenArgs = new GenerateNumberTransform.Options();
                var col        = new GenerateNumberTransform.Column();
                col.Name           = stratificationColumn;
                keyGenArgs.Columns = new[] { col };
                output             = new GenerateNumberTransform(Host, keyGenArgs, input);
            }
            else
            {
                int col;
                if (!input.Schema.TryGetColumnIndex(stratificationColumn, out col))
                {
                    throw ch.ExceptUserArg(nameof(Arguments.StratificationColumn), "Column '{0}' does not exist", stratificationColumn);
                }
                var type = input.Schema[col].Type;
                if (!RangeFilter.IsValidRangeFilterColumnType(ch, type))
                {
                    ch.Info("Hashing the stratification column");
                    var origStratCol = stratificationColumn;
                    int tmp;
                    int inc = 0;
                    while (input.Schema.TryGetColumnIndex(stratificationColumn, out tmp))
                    {
                        stratificationColumn = string.Format("{0}_{1:000}", origStratCol, ++inc);
                    }
                    output = new HashingEstimator(Host, origStratCol, stratificationColumn, 30).Fit(input).Transform(input);
                }
            }

            return(stratificationColumn);
        }
Exemple #14
0
        IDataTransform AppendToPipeline(IDataView input)
        {
            IDataView current = input;

            if (_shuffleInput)
            {
                var args1 = new RowShufflingTransformer.Arguments()
                {
                    ForceShuffle     = false,
                    ForceShuffleSeed = _seedShuffle,
                    PoolRows         = _poolRows,
                    PoolOnly         = false,
                };
                current = new RowShufflingTransformer(Host, args1, current);
            }

            // We generate a random number.
            var columnName = current.Schema.GetTempColumnName();
            var args2      = new GenerateNumberTransform.Arguments()
            {
                Column = new GenerateNumberTransform.Column[] { new GenerateNumberTransform.Column()
                                                                {
                                                                    Name = columnName
                                                                } },
                Seed = _seed ?? 42
            };
            IDataTransform currentTr = new GenerateNumberTransform(Host, args2, current);

            // We convert this random number into a part.
            var cRatios = new float[_ratios.Length];

            cRatios[0] = 0;
            for (int i = 1; i < _ratios.Length; ++i)
            {
                cRatios[i] = cRatios[i - 1] + _ratios[i - 1];
            }

            ValueMapper <float, int> mapper = (in float src, ref int dst) =>
            {
                for (int i = cRatios.Length - 1; i > 0; --i)
                {
                    if (src >= cRatios[i])
                    {
                        dst = i;
                        return;
                    }
                }
                dst = 0;
            };

            // Get location of columnName

            int index;

            currentTr.Schema.TryGetColumnIndex(columnName, out index);
            var ct   = currentTr.Schema.GetColumnType(index);
            var view = LambdaColumnMapper.Create(Host, "Key to part mapper", currentTr,
                                                 columnName, _newColumn, ct, NumberType.I4, mapper);

            // We cache the result to avoid the pipeline to change the random number.
            var args3 = new ExtendedCacheTransform.Arguments()
            {
                inDataFrame = string.IsNullOrEmpty(_cacheFile),
                numTheads   = _numThreads,
                cacheFile   = _cacheFile,
                reuse       = _reuse,
            };

            currentTr = new ExtendedCacheTransform(Host, args3, view);

            // Removing the temporary column.
            var finalTr     = ColumnSelectingTransformer.CreateDrop(Host, currentTr, new string[] { columnName });
            var taggedViews = new List <Tuple <string, ITaggedDataView> >();

            // filenames
            if (_filenames != null || _tags != null)
            {
                int nbf = _filenames == null ? 0 : _filenames.Length;
                if (nbf > 0 && nbf != _ratios.Length)
                {
                    throw Host.Except("Differen number of filenames and ratios.");
                }
                int nbt = _tags == null ? 0 : _tags.Length;
                if (nbt > 0 && nbt != _ratios.Length)
                {
                    throw Host.Except("Differen number of filenames and ratios.");
                }
                int nb = Math.Max(nbf, nbt);

                using (var ch = Host.Start("Split the datasets and stores each part."))
                {
                    for (int i = 0; i < nb; ++i)
                    {
                        if (_filenames == null || !_filenames.Any())
                        {
                            ch.Info("Create part {0}: {1} (tag: {2})", i + 1, _ratios[i], _tags[i]);
                        }
                        else
                        {
                            ch.Info("Create part {0}: {1} (file: {2})", i + 1, _ratios[i], _filenames[i]);
                        }
                        var ar1 = new RangeFilter.Arguments()
                        {
                            Column = _newColumn, Min = i, Max = i, IncludeMax = true
                        };
                        int pardId   = i;
                        var filtView = LambdaFilter.Create <int>(Host, string.Format("Select part {0}", i), currentTr,
                                                                 _newColumn, NumberType.I4,
                                                                 (in int part) => { return(part.Equals(pardId)); });