Пример #1
0
        private IDataView WrapPerInstance(RoleMappedData perInst)
        {
            var idv = perInst.Data;

            // Make a list of column names that Maml outputs as part of the per-instance data view, and then wrap
            // the per-instance data computed by the evaluator in a SelectColumnsTransform.
            var cols       = new List <(string name, string source)>();
            var colsToKeep = new List <string>();

            // If perInst is the result of cross-validation and contains a fold Id column, include it.
            int foldCol;

            if (perInst.Schema.Schema.TryGetColumnIndex(MetricKinds.ColumnNames.FoldIndex, out foldCol))
            {
                colsToKeep.Add(MetricKinds.ColumnNames.FoldIndex);
            }

            // Maml always outputs a name column, if it doesn't exist add a GenerateNumberTransform.
            if (perInst.Schema.Name?.Name is string nameName)
            {
                cols.Add(("Instance", nameName));
                colsToKeep.Add("Instance");
            }
            else
            {
                var args = new GenerateNumberTransform.Options();
                args.Columns = new[] { new GenerateNumberTransform.Column()
                                       {
                                           Name = "Instance"
                                       } };
                args.UseCounter = true;
                idv             = new GenerateNumberTransform(Host, args, idv);
                colsToKeep.Add("Instance");
            }

            // Maml outputs the weight column if it exists.
            if (perInst.Schema.Weight?.Name is string weightName)
            {
                colsToKeep.Add(weightName);
            }

            // Get the other columns from the evaluator.
            foreach (var col in GetPerInstanceColumnsToSave(perInst.Schema))
            {
                colsToKeep.Add(col);
            }

            idv = new ColumnCopyingTransformer(Host, cols.ToArray()).Transform(idv);
            idv = ColumnSelectingTransformer.CreateKeep(Host, idv, colsToKeep.ToArray());
            return(GetPerInstanceMetricsCore(idv, perInst.Schema));
        }
        public IEnumerable <Batch> GetBatches(Random rand)
        {
            Host.Assert(Data != null, "Must call Initialize first!");
            Host.AssertValue(rand);

            using (var ch = Host.Start("Getting batches"))
            {
                RoleMappedData dataTest;
                RoleMappedData dataTrain;

                // Split the data, if needed.
                if (!(ValidationDatasetProportion > 0))
                {
                    dataTest = dataTrain = Data;
                }
                else
                {
                    // Split the data into train and test sets.
                    string name = Data.Data.Schema.GetTempColumnName();
                    var    args = new GenerateNumberTransform.Options();
                    args.Columns = new[] { new GenerateNumberTransform.Column()
                                           {
                                               Name = name
                                           } };
                    args.Seed = (uint)rand.Next();
                    var view     = new GenerateNumberTransform(Host, args, Data.Data);
                    var viewTest = new RangeFilter(Host, new RangeFilter.Options()
                    {
                        Column = name, Max = ValidationDatasetProportion
                    }, view);
                    var viewTrain = new RangeFilter(Host, new RangeFilter.Options()
                    {
                        Column = name, Max = ValidationDatasetProportion, Complement = true
                    }, view);
                    dataTest  = new RoleMappedData(viewTest, Data.Schema.GetColumnRoleNames());
                    dataTrain = new RoleMappedData(viewTrain, Data.Schema.GetColumnRoleNames());
                }

                if (BatchSize > 0)
                {
                    // REVIEW: How should we carve the data into batches?
                    ch.Warning("Batch support is temporarily disabled");
                }

                yield return(new Batch(dataTrain, dataTest));
            }
        }
        public override IEnumerable <Subset> GetSubsets(Batch batch, Random rand)
        {
            string name = Data.Data.Schema.GetTempColumnName();
            var    args = new GenerateNumberTransform.Options();

            args.Columns = new[] { new GenerateNumberTransform.Column()
                                   {
                                       Name = name
                                   } };
            args.Seed = (uint)rand.Next();
            IDataTransform view = new GenerateNumberTransform(Host, args, Data.Data);

            // REVIEW: This won't be very efficient when Size is large.
            for (int i = 0; i < Size; i++)
            {
                var viewTrain = new RangeFilter(Host, new RangeFilter.Options()
                {
                    Column = name, Min = (Double)i / Size, Max = (Double)(i + 1) / Size
                }, view);
                var dataTrain = new RoleMappedData(viewTrain, Data.Schema.GetColumnRoleNames());
                yield return(FeatureSelector.SelectFeatures(dataTrain, rand));
            }
        }
        protected override TVectorPredictor TrainPredictor(IChannel ch, TScalarTrainer trainer, RoleMappedData data, int count)
        {
            var data0 = data;

            #region adding group ID

            // We insert a group Id.
            string groupColumnTemp = DataViewUtils.GetTempColumnName(data.Schema.Schema) + "GR";
            var    groupArgs       = new GenerateNumberTransform.Options
            {
                Columns    = new[] { GenerateNumberTransform.Column.Parse(groupColumnTemp) },
                UseCounter = true
            };

            var withGroup = new GenerateNumberTransform(Host, groupArgs, data.Data);
            data = new RoleMappedData(withGroup, data.Schema.GetColumnRoleNames());

            #endregion

            #region preparing the training dataset

            string dstName, labName;
            var    trans       = MapLabelsAndInsertTransform(ch, data, out dstName, out labName, count, true, _args);
            var    newFeatures = trans.Schema.GetTempColumnName() + "NF";

            // We check the label is not boolean.
            int indexLab = SchemaHelper.GetColumnIndex(trans.Schema, dstName);
            var typeLab  = trans.Schema[indexLab].Type;
            if (typeLab.RawKind() == DataKind.Boolean)
            {
                throw Host.Except("Column '{0}' has an unexpected type {1}.", dstName, typeLab.RawKind());
            }

            var args3 = new DescribeTransform.Arguments {
                columns = new string[] { labName, dstName }, oneRowPerColumn = true
            };
            var desc = new DescribeTransform(Host, args3, trans);

            IDataView viewI;
            if (_args.singleColumn && data.Schema.Label.Value.Type.RawKind() == DataKind.Single)
            {
                viewI = desc;
            }
            else if (_args.singleColumn)
            {
                var sch = new TypeReplacementSchema(desc.Schema, new[] { labName }, new[] { NumberDataViewType.Single });
                viewI = new TypeReplacementDataView(desc, sch);
                #region debug
#if (DEBUG)
                DebugChecking0(viewI, labName, false);
#endif
                #endregion
            }
            else if (data.Schema.Label.Value.Type.IsKey())
            {
                ulong nb  = data.Schema.Label.Value.Type.AsKey().GetKeyCount();
                var   sch = new TypeReplacementSchema(desc.Schema, new[] { labName }, new[] { new VectorDataViewType(NumberDataViewType.Single, (int)nb) });
                viewI = new TypeReplacementDataView(desc, sch);
                #region debug
#if (DEBUG)
                int nb_;
                MinMaxLabelOverDataSet(trans, labName, out nb_);
                int count3;
                data.CheckMulticlassLabel(out count3);
                if ((ulong)count3 != nb)
                {
                    throw ch.Except("Count mismatch (KeyCount){0} != {1}", nb, count3);
                }
                DebugChecking0(viewI, labName, true);
                DebugChecking0Vfloat(viewI, labName, nb);
#endif
                #endregion
            }
            else
            {
                int nb;
                if (count <= 0)
                {
                    MinMaxLabelOverDataSet(trans, labName, out nb);
                }
                else
                {
                    nb = count;
                }
                var sch = new TypeReplacementSchema(desc.Schema, new[] { labName }, new[] { new VectorDataViewType(NumberDataViewType.Single, nb) });
                viewI = new TypeReplacementDataView(desc, sch);
                #region debug
#if (DEBUG)
                DebugChecking0(viewI, labName, true);
#endif
                #endregion
            }

            ch.Info("Merging column label '{0}' with features '{1}'", labName, data.Schema.Feature.Value.Name);
            var args = string.Format("Concat{{col={0}:{1},{2}}}", newFeatures, data.Schema.Feature.Value.Name, labName);
            var after_concatenation_ = ComponentCreation.CreateTransform(Host, args, viewI);

            #endregion

            #region converting label and group into keys

            // We need to convert the label into a Key.
            var convArgs = new MulticlassConvertTransform.Arguments
            {
                column     = new[] { MulticlassConvertTransform.Column.Parse(string.Format("{0}k:{0}", dstName)) },
                keyCount   = new KeyCount(4),
                resultType = DataKind.UInt32
            };
            IDataView after_concatenation_key_label = new MulticlassConvertTransform(Host, convArgs, after_concatenation_);

            // The group must be a key too!
            convArgs = new MulticlassConvertTransform.Arguments
            {
                column     = new[] { MulticlassConvertTransform.Column.Parse(string.Format("{0}k:{0}", groupColumnTemp)) },
                keyCount   = new KeyCount(),
                resultType = _args.groupIsU4 ? DataKind.UInt32 : DataKind.UInt64
            };
            after_concatenation_key_label = new MulticlassConvertTransform(Host, convArgs, after_concatenation_key_label);

            #endregion

            #region preparing the RoleMapData view

            string groupColumn = groupColumnTemp + "k";
            dstName += "k";

            var roles      = data.Schema.GetColumnRoleNames();
            var rolesArray = roles.ToArray();
            roles = roles
                    .Where(kvp => kvp.Key.Value != RoleMappedSchema.ColumnRole.Label.Value)
                    .Where(kvp => kvp.Key.Value != RoleMappedSchema.ColumnRole.Feature.Value)
                    .Where(kvp => kvp.Key.Value != groupColumn)
                    .Where(kvp => kvp.Key.Value != groupColumnTemp);
            rolesArray = roles.ToArray();
            if (rolesArray.Any() && rolesArray[0].Value == groupColumnTemp)
            {
                throw ch.Except("Duplicated group.");
            }
            roles = roles
                    .Prepend(RoleMappedSchema.ColumnRole.Feature.Bind(newFeatures))
                    .Prepend(RoleMappedSchema.ColumnRole.Label.Bind(dstName))
                    .Prepend(RoleMappedSchema.ColumnRole.Group.Bind(groupColumn));
            var trainer_input = new RoleMappedData(after_concatenation_key_label, roles);

            #endregion

            ch.Info("New Features: {0}:{1}", trainer_input.Schema.Feature.Value.Name, trainer_input.Schema.Feature.Value.Type);
            ch.Info("New Label: {0}:{1}", trainer_input.Schema.Label.Value.Name, trainer_input.Schema.Label.Value.Type);

            // We train the unique binary classifier.
            var trainedPredictor = trainer.Train(trainer_input);
            var predictors       = new TScalarPredictor[] { trainedPredictor };

            // We train the reclassification classifier.
            if (_args.reclassicationPredictor != null)
            {
                var pred = CreateFinalPredictor(ch, data, trans, count, _args, predictors, null);
                TrainReclassificationPredictor(data0, pred, ScikitSubComponent <ITrainer, SignatureTrainer> .AsSubComponent(_args.reclassicationPredictor));
            }

            return(CreateFinalPredictor(ch, data, trans, count, _args, predictors, _reclassPredictor));
        }
        private string GetSplitColumn(IChannel ch, IDataView input, ref IDataView output)
        {
            // The stratification column and/or group column, if they exist at all, must be present at this point.
            var schema = input.Schema;

            output = input;
            // If no stratification column was specified, but we have a group column of type Single, Double or
            // Key (contiguous) use it.
            string stratificationColumn = null;

            if (!string.IsNullOrWhiteSpace(Args.StratificationColumn))
            {
                stratificationColumn = Args.StratificationColumn;
            }
            else
            {
                string group = TrainUtils.MatchNameOrDefaultOrNull(ch, schema, nameof(Args.GroupColumn), Args.GroupColumn, DefaultColumnNames.GroupId);
                int    index;
                if (group != null && schema.TryGetColumnIndex(group, out index))
                {
                    // Check if group column key type with known cardinality.
                    var type = schema[index].Type;
                    if (type.GetKeyCount() > 0)
                    {
                        stratificationColumn = group;
                    }
                }
            }

            if (string.IsNullOrEmpty(stratificationColumn))
            {
                stratificationColumn = "StratificationColumn";
                int tmp;
                int inc = 0;
                while (input.Schema.TryGetColumnIndex(stratificationColumn, out tmp))
                {
                    stratificationColumn = string.Format("StratificationColumn_{0:000}", ++inc);
                }
                var keyGenArgs = new GenerateNumberTransform.Options();
                var col        = new GenerateNumberTransform.Column();
                col.Name           = stratificationColumn;
                keyGenArgs.Columns = new[] { col };
                output             = new GenerateNumberTransform(Host, keyGenArgs, input);
            }
            else
            {
                int col;
                if (!input.Schema.TryGetColumnIndex(stratificationColumn, out col))
                {
                    throw ch.ExceptUserArg(nameof(Arguments.StratificationColumn), "Column '{0}' does not exist", stratificationColumn);
                }
                var type = input.Schema[col].Type;
                if (!RangeFilter.IsValidRangeFilterColumnType(ch, type))
                {
                    ch.Info("Hashing the stratification column");
                    var origStratCol = stratificationColumn;
                    int tmp;
                    int inc = 0;
                    while (input.Schema.TryGetColumnIndex(stratificationColumn, out tmp))
                    {
                        stratificationColumn = string.Format("{0}_{1:000}", origStratCol, ++inc);
                    }
                    output = new HashingEstimator(Host, origStratCol, stratificationColumn, 30).Fit(input).Transform(input);
                }
            }

            return(stratificationColumn);
        }
        private void RunCore(IChannel ch, string cmd)
        {
            Host.AssertValue(ch);

            IPredictor inputPredictor = null;

            if (Args.ContinueTrain && !TrainUtils.TryLoadPredictor(ch, Host, Args.InputModelFile, out inputPredictor))
            {
                ch.Warning("No input model file specified or model file did not contain a predictor. The model state cannot be initialized.");
            }

            ch.Trace("Constructing data pipeline");
            IDataLoader loader = CreateRawLoader();

            // If the per-instance results are requested and there is no name column, add a GenerateNumberTransform.
            var preXf = Args.PreTransforms;

            if (!string.IsNullOrEmpty(Args.OutputDataFile))
            {
                string name = TrainUtils.MatchNameOrDefaultOrNull(ch, loader.Schema, nameof(Args.NameColumn), Args.NameColumn, DefaultColumnNames.Name);
                if (name == null)
                {
                    preXf = preXf.Concat(
                        new[]
                    {
                        new KeyValuePair <string, IComponentFactory <IDataView, IDataTransform> >(
                            "", ComponentFactoryUtils.CreateFromFunction <IDataView, IDataTransform>(
                                (env, input) =>
                        {
                            var args     = new GenerateNumberTransform.Options();
                            args.Columns = new[] { new GenerateNumberTransform.Column()
                                                   {
                                                       Name = DefaultColumnNames.Name
                                                   }, };
                            args.UseCounter = true;
                            return(new GenerateNumberTransform(env, args, input));
                        }))
                    }).ToArray();
                }
            }
            loader = CompositeDataLoader.Create(Host, loader, preXf);

            ch.Trace("Binding label and features columns");

            IDataView pipe = loader;
            var       stratificationColumn = GetSplitColumn(ch, loader, ref pipe);
            var       scorer    = Args.Scorer;
            var       evaluator = Args.Evaluator;

            Func <IDataView> validDataCreator = null;

            if (Args.ValidationFile != null)
            {
                validDataCreator =
                    () =>
                {
                    // Fork the command.
                    var impl = new CrossValidationCommand(this);
                    return(impl.CreateRawLoader(dataFile: Args.ValidationFile));
                };
            }

            FoldHelper fold = new FoldHelper(Host, RegistrationName, pipe, stratificationColumn,
                                             Args, CreateRoleMappedData, ApplyAllTransformsToData, scorer, evaluator,
                                             validDataCreator, ApplyAllTransformsToData, inputPredictor, cmd, loader, !string.IsNullOrEmpty(Args.OutputDataFile));
            var tasks = fold.GetCrossValidationTasks();

            var eval = evaluator?.CreateComponent(Host) ??
                       EvaluateUtils.GetEvaluator(Host, tasks[0].Result.ScoreSchema);

            // Print confusion matrix and fold results for each fold.
            for (int i = 0; i < tasks.Length; i++)
            {
                var dict = tasks[i].Result.Metrics;
                MetricWriter.PrintWarnings(ch, dict);
                eval.PrintFoldResults(ch, dict);
            }

            // Print the overall results.
            if (!TryGetOverallMetrics(tasks.Select(t => t.Result.Metrics).ToArray(), out var overallList))
            {
                throw ch.Except("No overall metrics found");
            }

            var overall = eval.GetOverallResults(overallList.ToArray());

            MetricWriter.PrintOverallMetrics(Host, ch, Args.SummaryFilename, overall, Args.NumFolds);
            eval.PrintAdditionalMetrics(ch, tasks.Select(t => t.Result.Metrics).ToArray());
            Dictionary <string, IDataView>[] metricValues = tasks.Select(t => t.Result.Metrics).ToArray();
            SendTelemetryMetric(metricValues);

            // Save the per-instance results.
            if (!string.IsNullOrWhiteSpace(Args.OutputDataFile))
            {
                var perInstance = EvaluateUtils.ConcatenatePerInstanceDataViews(Host, eval, Args.CollateMetrics,
                                                                                Args.OutputExampleFoldIndex, tasks.Select(t => t.Result.PerInstanceResults).ToArray(), out var variableSizeVectorColumnNames);
                if (variableSizeVectorColumnNames.Length > 0)
                {
                    ch.Warning("Detected columns of variable length: {0}. Consider setting collateMetrics- for meaningful per-Folds results.",
                               string.Join(", ", variableSizeVectorColumnNames));
                }
                if (Args.CollateMetrics)
                {
                    ch.Assert(perInstance.Length == 1);
                    MetricWriter.SavePerInstance(Host, ch, Args.OutputDataFile, perInstance[0]);
                }
                else
                {
                    int i = 0;
                    foreach (var idv in perInstance)
                    {
                        MetricWriter.SavePerInstance(Host, ch, ConstructPerFoldName(Args.OutputDataFile, i), idv);
                        i++;
                    }
                }
            }
        }
        IDataTransform AppendToPipeline(IDataView input)
        {
            IDataView current = input;

            if (_shuffleInput)
            {
                var args1 = new RowShufflingTransformer.Options()
                {
                    ForceShuffle     = false,
                    ForceShuffleSeed = _seedShuffle,
                    PoolRows         = _poolRows,
                    PoolOnly         = false,
                };
                current = new RowShufflingTransformer(Host, args1, current);
            }

            // We generate a random number.
            var columnName = current.Schema.GetTempColumnName();
            var args2      = new GenerateNumberTransform.Options()
            {
                Columns = new GenerateNumberTransform.Column[] { new GenerateNumberTransform.Column()
                                                                 {
                                                                     Name = columnName
                                                                 } },
                Seed = _seed ?? 42
            };
            IDataTransform currentTr = new GenerateNumberTransform(Host, args2, current);

            // We convert this random number into a part.
            var cRatios = new float[_ratios.Length];

            cRatios[0] = 0;
            for (int i = 1; i < _ratios.Length; ++i)
            {
                cRatios[i] = cRatios[i - 1] + _ratios[i - 1];
            }

            ValueMapper <float, int> mapper = (in float src, ref int dst) =>
            {
                for (int i = cRatios.Length - 1; i > 0; --i)
                {
                    if (src >= cRatios[i])
                    {
                        dst = i;
                        return;
                    }
                }
                dst = 0;
            };

            // Get location of columnName

            int index = SchemaHelper.GetColumnIndex(currentTr.Schema, columnName);
            var ct    = currentTr.Schema[index].Type;
            var view  = LambdaColumnMapper.Create(Host, "Key to part mapper", currentTr,
                                                  columnName, _newColumn, ct, NumberDataViewType.Int32, mapper);

            // We cache the result to avoid the pipeline to change the random number.
            var args3 = new ExtendedCacheTransform.Arguments()
            {
                inDataFrame = string.IsNullOrEmpty(_cacheFile),
                numTheads   = _numThreads,
                cacheFile   = _cacheFile,
                reuse       = _reuse,
            };

            currentTr = new ExtendedCacheTransform(Host, args3, view);

            // Removing the temporary column.
            var objtr   = ColumnSelectingTransformer.CreateDrop(Host, currentTr, new string[] { columnName });
            var finalTr = objtr as IDataTransform;

            if (finalTr == null)
            {
                throw Contracts.ExceptNotSupp("Desgin change.");
            }
            var taggedViews = new List <Tuple <string, ITaggedDataView> >();

            // filenames
            if (_filenames != null || _tags != null)
            {
                int nbf = _filenames == null ? 0 : _filenames.Length;
                if (nbf > 0 && nbf != _ratios.Length)
                {
                    throw Host.Except("Differen number of filenames and ratios.");
                }
                int nbt = _tags == null ? 0 : _tags.Length;
                if (nbt > 0 && nbt != _ratios.Length)
                {
                    throw Host.Except("Differen number of filenames and ratios.");
                }
                int nb = Math.Max(nbf, nbt);

                using (var ch = Host.Start("Split the datasets and stores each part."))
                {
                    for (int i = 0; i < nb; ++i)
                    {
                        if (_filenames == null || !_filenames.Any())
                        {
                            ch.Info("Create part {0}: {1} (tag: {2})", i + 1, _ratios[i], _tags[i]);
                        }
                        else
                        {
                            ch.Info("Create part {0}: {1} (file: {2})", i + 1, _ratios[i], _filenames[i]);
                        }
                        var ar1 = new RangeFilter.Options()
                        {
                            Column = _newColumn, Min = i, Max = i, IncludeMax = true
                        };
                        int pardId   = i;
                        var filtView = LambdaFilter.Create <int>(Host, string.Format("Select part {0}", i), currentTr,
                                                                 _newColumn, NumberDataViewType.Int32,
                                                                 (in int part) => { return(part.Equals(pardId)); });