/// <summary>
        /// Compute row count, list of labels, weights and group counts of the dataset.
        /// </summary>
        private void GetMetainfo(IChannel ch, FloatLabelCursor.Factory factory,
            out int numRow, out float[] labels, out float[] weights, out int[] groups)
        {
            ch.Check(factory.Data.Schema.Label != null, "The data should have label.");
            List<float> labelList = new List<float>();
            bool hasWeights = factory.Data.Schema.Weight != null;
            bool hasGroup = false;
            if (PredictionKind == PredictionKind.Ranking)
            {
                ch.Check(factory.Data.Schema.Group != null, "The data for ranking task should have group field.");
                hasGroup = true;
            }
            List<float> weightList = hasWeights ? new List<float>() : null;
            List<ulong> cursorGroups = hasGroup ? new List<ulong>() : null;

            using (var cursor = factory.Create())
            {
                while (cursor.MoveNext())
                {
                    if (labelList.Count == Utils.ArrayMaxSize)
                        throw ch.Except($"Dataset row count exceeded the maximum count of {Utils.ArrayMaxSize}");
                    labelList.Add(cursor.Label);
                    if (hasWeights)
                    {
                        // Default weight = 1.
                        if (float.IsNaN(cursor.Weight))
                            weightList.Add(1);
                        else
                            weightList.Add(cursor.Weight);
                    }
                    if (hasGroup)
                        cursorGroups.Add(cursor.Group);
                }
            }
            labels = labelList.ToArray();
            ConvertNaNLabels(ch, factory.Data, labels);
            numRow = labels.Length;
            ch.Check(numRow > 0, "Cannot use empty dataset.");
            weights = hasWeights ? weightList.ToArray() : null;
            groups = null;
            if (hasGroup)
            {
                List<int> groupList = new List<int>();
                int lastGroup = -1;
                for (int i = 0; i < numRow; ++i)
                {
                    if (i == 0 || cursorGroups[i] != cursorGroups[i - 1])
                    {
                        groupList.Add(1);
                        ++lastGroup;
                    }
                    else
                        ++groupList[lastGroup];
                }
                groups = groupList.ToArray();
            }
        }
Exemple #2
0
        /// <summary>
        /// This method compares two pipelines to make sure they are identical. The first pipeline is passed
        /// as a <see cref="RoleMappedData"/>, and the second as a double byte array and a string array. The double
        /// byte array and the string array are obtained by calling <see cref="SerializeRoleMappedData"/> on the
        /// second pipeline.
        /// The comparison is done by saving <see ref="dataToCompare"/> as an in-memory <see cref="ZipArchive"/>,
        /// and for each entry in it, comparing its name, and the byte sequence to the corresponding entries in
        /// <see ref="dataZipEntryNames"/> and <see ref="dataSerialized"/>.
        /// This method throws if for any of the entries the name/byte sequence are not identical.
        /// </summary>
        public static void CheckSamePipeline(IHostEnvironment env, IChannel ch,
                                             RoleMappedData dataToCompare, byte[][] dataSerialized, string[] dataZipEntryNames)
        {
            Contracts.CheckValue(env, nameof(env));
            env.CheckValue(ch, nameof(ch));
            ch.CheckValue(dataToCompare, nameof(dataToCompare));
            ch.CheckValue(dataSerialized, nameof(dataSerialized));
            ch.CheckValue(dataZipEntryNames, nameof(dataZipEntryNames));
            if (dataZipEntryNames.Length != dataSerialized.Length)
            {
                throw ch.ExceptParam(nameof(dataSerialized),
                                     $"The length of {nameof(dataSerialized)} must be equal to the length of {nameof(dataZipEntryNames)}");
            }

            using (var ms = new MemoryStream())
            {
                // REVIEW: This can be done more efficiently by adding a custom type of repository that
                // doesn't actually save the data, but upon stream closure compares the results to the given repository
                // and then discards it. Currently, however, this cannot be done because ModelSaveContext does not use
                // an abstract class/interface, but rather the RepositoryWriter class.
                TrainUtils.SaveModel(env, ch, ms, null, dataToCompare);

                string errorMsg = "Models contain different pipelines, cannot ensemble them.";
                var    zip      = new ZipArchive(ms);
                var    entries  = zip.Entries.OrderBy(e => e.FullName).ToArray();
                ch.Check(dataSerialized.Length == Utils.Size(entries));
                byte[] buffer = null;
                for (int i = 0; i < dataSerialized.Length; i++)
                {
                    ch.Check(dataZipEntryNames[i] == entries[i].FullName, errorMsg);
                    int len = dataSerialized[i].Length;
                    if (Utils.Size(buffer) < len)
                    {
                        buffer = new byte[len];
                    }
                    using (var s = entries[i].Open())
                    {
                        int bytesRead = s.Read(buffer, 0, len);
                        ch.Check(bytesRead == len, errorMsg);
                        for (int j = 0; j < len; j++)
                        {
                            ch.Check(buffer[j] == dataSerialized[i][j], errorMsg);
                        }
                        if (s.Read(buffer, 0, 1) > 0)
                        {
                            throw env.Except(errorMsg);
                        }
                    }
                }
            }
        }
        public static void SaveDataView(IChannel ch, IDataSaver saver, IDataView view, Stream stream, bool keepHidden = false)
        {
            Contracts.CheckValue(ch, nameof(ch));
            ch.CheckValue(saver, nameof(saver));
            ch.CheckValue(view, nameof(view));
            ch.CheckValue(stream, nameof(stream));

            var cols = new List <int>();

            for (int i = 0; i < view.Schema.ColumnCount; i++)
            {
                if (!keepHidden && view.Schema.IsHidden(i))
                {
                    continue;
                }
                var type = view.Schema.GetColumnType(i);
                if (saver.IsColumnSavable(type))
                {
                    cols.Add(i);
                }
                else
                {
                    ch.Info(MessageSensitivity.Schema, "The column '{0}' will not be written as it has unsavable column type.", view.Schema.GetColumnName(i));
                }
            }

            ch.Check(cols.Count > 0, "No valid columns to save");
            saver.SaveData(stream, view, cols.ToArray());
        }
Exemple #4
0
            protected TrainStateBase(IChannel ch, int numFeatures, LinearModelParameters predictor, OnlineLinearTrainer <TTransformer, TModel> parent)
            {
                Contracts.CheckValue(ch, nameof(ch));
                ch.Check(numFeatures > 0, "Cannot train with zero features!");
                ch.AssertValueOrNull(predictor);
                ch.AssertValue(parent);
                ch.Assert(Iteration == 0);
                ch.Assert(Bias == 0);

                ParentHost = parent.Host;

                ch.Trace("{0} Initializing {1} on {2} features", DateTime.UtcNow, parent.Name, numFeatures);

                // We want a dense vector, to prevent memory creation during training
                // unless we have a lot of features.
                if (predictor != null)
                {
                    ((IHaveFeatureWeights)predictor).GetFeatureWeights(ref Weights);
                    VBufferUtils.Densify(ref Weights);
                    Bias = predictor.Bias;
                }
                else if (!string.IsNullOrWhiteSpace(parent.OnlineLinearTrainerOptions.InitialWeights))
                {
                    ch.Info("Initializing weights and bias to " + parent.OnlineLinearTrainerOptions.InitialWeights);
                    string[] weightStr = parent.OnlineLinearTrainerOptions.InitialWeights.Split(',');
                    if (weightStr.Length != numFeatures + 1)
                    {
                        throw ch.Except(
                                  "Could not initialize weights from 'initialWeights': expecting {0} values to initialize {1} weights and the intercept",
                                  numFeatures + 1, numFeatures);
                    }

                    var weightValues = new float[numFeatures];
                    for (int i = 0; i < numFeatures; i++)
                    {
                        weightValues[i] = float.Parse(weightStr[i], CultureInfo.InvariantCulture);
                    }
                    Weights = new VBuffer <float>(numFeatures, weightValues);
                    Bias    = float.Parse(weightStr[numFeatures], CultureInfo.InvariantCulture);
                }
                else if (parent.OnlineLinearTrainerOptions.InitialWeightsDiameter > 0)
                {
                    var weightValues = new float[numFeatures];
                    for (int i = 0; i < numFeatures; i++)
                    {
                        weightValues[i] = parent.OnlineLinearTrainerOptions.InitialWeightsDiameter * (parent.Host.Rand.NextSingle() - (float)0.5);
                    }
                    Weights = new VBuffer <float>(numFeatures, weightValues);
                    Bias    = parent.OnlineLinearTrainerOptions.InitialWeightsDiameter * (parent.Host.Rand.NextSingle() - (float)0.5);
                }
                else if (numFeatures <= 1000)
                {
                    Weights = VBufferUtils.CreateDense <float>(numFeatures);
                }
                else
                {
                    Weights = VBufferUtils.CreateEmpty <float>(numFeatures);
                }
                WeightsScale = 1;
            }
        private MFNode[] ConstructLabeledNodesFrom(IChannel ch, DataViewRowCursor cursor, ValueGetter <float> labGetter,
                                                   ValueGetter <uint> rowGetter, ValueGetter <uint> colGetter,
                                                   int rowCount, int colCount)
        {
            long  numSkipped = 0;
            uint  row        = 0;
            uint  col        = 0;
            float label      = 0;

            List <MFNode> nodes = new List <MFNode>();
            long          i     = 0;

            using (var pch = _host.StartProgressChannel("Create matrix"))
            {
                pch.SetHeader(new ProgressHeader(new[] { "processed rows", "created nodes" }),
                              e => { e.SetProgress(0, i); e.SetProgress(1, nodes.Count); });
                while (cursor.MoveNext())
                {
                    i++;
                    labGetter(ref label);
                    if (!FloatUtils.IsFinite(label))
                    {
                        numSkipped++;
                        continue;
                    }
                    rowGetter(ref row);
                    // REVIEW: Instead of ignoring, should I throw in the row > rowCount case?
                    // The index system in the LIBMF (the underlying library trains the model) is 0-based, so we need
                    // to deduct one from 1-based indexes returned by ML.NET's key-valued getters. We also skip 0 returned
                    // by key-valued getter becuase missing value is not meaningful to the trained model.
                    if (row == 0 || row > (uint)rowCount)
                    {
                        numSkipped++;
                        continue;
                    }
                    colGetter(ref col);
                    if (col == 0 || col > (uint)colCount)
                    {
                        numSkipped++;
                        continue;
                    }

                    MFNode node;
                    node.U = (int)(row - 1);
                    node.V = (int)(col - 1);
                    node.R = label;
                    nodes.Add(node);
                }
                pch.Checkpoint(i, nodes.Count);
            }
            if (numSkipped > 0)
            {
                ch.Warning("Skipped {0} instances with missing/negative features during data loading", numSkipped);
            }
            ch.Check(nodes.Count > 0, "No valid instances encountered during data loading");

            return(nodes.ToArray());
        }
            /// <summary>
            /// Backtracking line search with Armijo-like condition, from Andrew &amp; Gao
            /// </summary>
            internal override bool LineSearch(IChannel ch, bool force)
            {
                Float dirDeriv = -VectorUtils.DotProduct(ref _dir, ref _steepestDescDir);

                if (dirDeriv == 0)
                {
                    throw ch.Process(new PrematureConvergenceException(this, "Directional derivative is zero. You may be sitting on the optimum."));
                }

                // if a non-descent direction is chosen, the line search will break anyway, so throw here
                // The most likely reason for this is a bug in your function's gradient computation
                // It may also indicate that your function is not convex.
                ch.Check(dirDeriv < 0, "L-BFGS chose a non-descent direction.");

                Float alpha = (Iter == 1 ? (1 / VectorUtils.Norm(_dir)) : 1);

                GetNextPoint(alpha);
                Float unnormCos = VectorUtils.DotProduct(ref _steepestDescDir, ref _newX) - VectorUtils.DotProduct(ref _steepestDescDir, ref _x);

                if (unnormCos < 0)
                {
                    VBufferUtils.ApplyWith(ref _steepestDescDir, ref _dir,
                                           (int ind, Float sdVal, ref Float dirVal) =>
                    {
                        if (sdVal * dirVal < 0 && ind >= _biasCount)
                        {
                            dirVal = 0;
                        }
                    });

                    GetNextPoint(alpha);
                    unnormCos = VectorUtils.DotProduct(ref _steepestDescDir, ref _newX) - VectorUtils.DotProduct(ref _steepestDescDir, ref _x);
                }

                int i = 0;

                while (true)
                {
                    Value = Eval(ref _newX, ref _newGrad);
                    GradientCalculations++;

                    if (Value <= LastValue - Gamma * unnormCos)
                    {
                        return(true);
                    }

                    ++i;
                    if (!force && i == MaxLineSearch)
                    {
                        return(false);
                    }

                    alpha *= (Float)0.25;
                    GetNextPoint(alpha);
                    unnormCos = VectorUtils.DotProduct(ref _steepestDescDir, ref _newX) - VectorUtils.DotProduct(ref _steepestDescDir, ref _x);
                }
            }
Exemple #7
0
        internal static ModelProto ConvertTransformListToOnnxModel(OnnxContextImpl ctx, IChannel ch, IDataView inputData, IDataView outputData,
                                                                   LinkedList <ITransformCanSaveOnnx> transforms, HashSet <string> inputColumnNamesToDrop = null, HashSet <string> outputColumnNamesToDrop = null)
        {
            inputColumnNamesToDrop  = inputColumnNamesToDrop ?? new HashSet <string>();
            outputColumnNamesToDrop = outputColumnNamesToDrop ?? new HashSet <string>();
            HashSet <string> inputColumns = new HashSet <string>();

            // Create graph inputs.
            for (int i = 0; i < inputData.Schema.Count; i++)
            {
                string colName = inputData.Schema[i].Name;
                if (inputColumnNamesToDrop.Contains(colName))
                {
                    continue;
                }

                ctx.AddInputVariable(inputData.Schema[i].Type, colName);
                inputColumns.Add(colName);
            }

            // Create graph nodes, outputs and intermediate values.
            foreach (var trans in transforms)
            {
                ch.Assert(trans.CanSaveOnnx(ctx));
                trans.SaveAsOnnx(ctx);
            }

            // Add graph outputs.
            for (int i = 0; i < outputData.Schema.Count; ++i)
            {
                if (outputData.Schema[i].IsHidden)
                {
                    continue;
                }

                var idataviewColumnName = outputData.Schema[i].Name;

                // Since the last IDataView also contains columns of the initial IDataView, last IDataView's columns found in
                // _inputToDrop should be removed too.
                if (inputColumnNamesToDrop.Contains(idataviewColumnName) || outputColumnNamesToDrop.Contains(idataviewColumnName))
                {
                    continue;
                }

                var variableName = ctx.TryGetVariableName(idataviewColumnName);
                // Null variable name occurs when an unsupported transform produces an output and a downsteam step consumes that output.
                // or user accidently removes a transform whose output is used by other transforms.
                ch.Check(variableName != null, "The targeted pipeline can not be fully converted into a well-defined ONNX model. " +
                         "Please check if all steps in that pipeline are convertible to ONNX " +
                         "and all necessary variables are not dropped (via command line arguments).");
                var trueVariableName = ctx.AddIntermediateVariable(null, idataviewColumnName, true);
                ctx.CreateNode("Identity", variableName, trueVariableName, ctx.GetNodeName("Identity"), "");
                ctx.AddOutputVariable(outputData.Schema[i].Type, trueVariableName);
            }

            return(ctx.MakeModel());
        }
Exemple #8
0
        private MFNode[] ConstructLabeledNodesFrom(IChannel ch, RowCursor cursor, ValueGetter <float> labGetter,
                                                   ValueGetter <uint> rowGetter, ValueGetter <uint> colGetter,
                                                   int rowCount, int colCount)
        {
            long  numSkipped = 0;
            uint  row        = 0;
            uint  col        = 0;
            float label      = 0;

            List <MFNode> nodes = new List <MFNode>();
            long          i     = 0;

            using (var pch = _host.StartProgressChannel("Create matrix"))
            {
                pch.SetHeader(new ProgressHeader(new[] { "processed rows", "created nodes" }),
                              e => { e.SetProgress(0, i); e.SetProgress(1, nodes.Count); });
                while (cursor.MoveNext())
                {
                    i++;
                    labGetter(ref label);
                    if (!FloatUtils.IsFinite(label))
                    {
                        numSkipped++;
                        continue;
                    }
                    rowGetter(ref row);
                    // REVIEW: Instead of ignoring, should I throw in the row > rowCount case?
                    if (row == 0 || row > (uint)rowCount)
                    {
                        numSkipped++;
                        continue;
                    }
                    colGetter(ref col);
                    if (col == 0 || col > (uint)colCount)
                    {
                        numSkipped++;
                        continue;
                    }

                    MFNode node;
                    node.U = (int)(row - 1);
                    node.V = (int)(col - 1);
                    node.R = label;
                    nodes.Add(node);
                }
                pch.Checkpoint(i, nodes.Count);
            }
            if (numSkipped > 0)
            {
                ch.Warning("Skipped {0} instances with missing/negative features during data loading", numSkipped);
            }
            ch.Check(nodes.Count > 0, "No valid instances encountered during data loading");

            return(nodes.ToArray());
        }
        private static IDataView AppendFloatMapper <TInput>(IHostEnvironment env, IChannel ch, IDataView input,
                                                            string col, KeyDataViewType type, int seed)
        {
            // Any key is convertible to ulong, so rather than add special case handling for all possible
            // key-types we just upfront convert it to the most general type (ulong) and work from there.
            KeyDataViewType dstType = new KeyDataViewType(typeof(ulong), type.Count);
            bool            identity;
            var             converter = Conversions.Instance.GetStandardConversion <TInput, ulong>(type, dstType, out identity);
            var             isNa      = Conversions.Instance.GetIsNAPredicate <TInput>(type);

            ValueMapper <TInput, Single> mapper;

            if (seed == 0)
            {
                mapper =
                    (in TInput src, ref Single dst) =>
                {
                    //Attention: This method is called from multiple threads.
                    //Do not move the temp variable outside this method.
                    //If you do, the variable is shared between the threads and results in a race condition.
                    ulong temp = 0;
                    if (isNa(in src))
                    {
                        dst = Single.NaN;
                        return;
                    }
                    converter(in src, ref temp);
                    dst = (Single)temp - 1;
                };
            }
            else
            {
                ch.Check(type.Count > 0, "Label must be of known cardinality.");
                int[] permutation = Utils.GetRandomPermutation(RandomUtils.Create(seed), type.GetCountAsInt32(env));
                mapper =
                    (in TInput src, ref Single dst) =>
                {
                    //Attention: This method is called from multiple threads.
                    //Do not move the temp variable outside this method.
                    //If you do, the variable is shared between the threads and results in a race condition.
                    ulong temp = 0;
                    if (isNa(in src))
                    {
                        dst = Single.NaN;
                        return;
                    }
                    converter(in src, ref temp);
                    dst = (Single)permutation[(int)(temp - 1)];
                };
            }

            return(LambdaColumnMapper.Create(env, "Key to Float Mapper", input, col, col, type, NumberDataViewType.Single, mapper));
        }
        private static IDataView AppendFloatMapper <TInput>(IHostEnvironment env, IChannel ch, IDataView input,
                                                            string col, KeyType type, int seed)
        {
            // Any key is convertible to ulong, so rather than add special case handling for all possible
            // key-types we just upfront convert it to the most general type (ulong) and work from there.
            KeyType dstType = new KeyType(DataKind.U8, type.Min, type.Count, type.Contiguous);
            bool    identity;
            var     converter = Conversions.Instance.GetStandardConversion <TInput, ulong>(type, dstType, out identity);
            var     isNa      = Conversions.Instance.GetIsNAPredicate <TInput>(type);
            ulong   temp      = 0;

            ValueMapper <TInput, Single> mapper;

            if (seed == 0)
            {
                mapper =
                    (in TInput src, ref Single dst) =>
                {
                    if (isNa(in src))
                    {
                        dst = Single.NaN;
                        return;
                    }
                    converter(in src, ref temp);
                    dst = (Single)(temp - 1);
                };
            }
            else
            {
                ch.Check(type.Count > 0, "Label must be of known cardinality.");
                int[] permutation = Utils.GetRandomPermutation(RandomUtils.Create(seed), type.Count);
                mapper =
                    (in TInput src, ref Single dst) =>
                {
                    if (isNa(in src))
                    {
                        dst = Single.NaN;
                        return;
                    }
                    converter(in src, ref temp);
                    dst = (Single)permutation[(int)(temp - 1)];
                };
            }

            return(LambdaColumnMapper.Create(env, "Key to Float Mapper", input, col, col, type, NumberType.Float, mapper));
        }
        protected override void ConvertNaNLabels(IChannel ch, RoleMappedData data, float[] labels)
        {
            // Only initialize one time.
            if (_numClass < 0)
            {
                float minLabel = float.MaxValue;
                float maxLabel = float.MinValue;
                bool hasNaNLabel = false;
                foreach (var labelColumn in labels)
                {
                    if (float.IsNaN(labelColumn))
                        hasNaNLabel = true;
                    else
                    {
                        minLabel = Math.Min(minLabel, labelColumn);
                        maxLabel = Math.Max(maxLabel, labelColumn);
                    }
                }
                ch.CheckParam(minLabel >= 0, nameof(data), "min labelColumn cannot be negative");
                if (maxLabel >= _maxNumClass)
                    throw ch.ExceptParam(nameof(data), $"max labelColumn cannot exceed {_maxNumClass}");

                if (data.Schema.Label.Type.IsKey)
                {
                    ch.Check(data.Schema.Label.Type.AsKey.Contiguous, "labelColumn value should be contiguous");
                    if (hasNaNLabel)
                        _numClass = data.Schema.Label.Type.AsKey.Count + 1;
                    else
                        _numClass = data.Schema.Label.Type.AsKey.Count;
                    _tlcNumClass = data.Schema.Label.Type.AsKey.Count;
                }
                else
                {
                    if (hasNaNLabel)
                        _numClass = (int)maxLabel + 2;
                    else
                        _numClass = (int)maxLabel + 1;
                    _tlcNumClass = (int)maxLabel + 1;
                }
            }
            float defaultLabel = _numClass - 1;
            for (int i = 0; i < labels.Length; ++i)
                if (float.IsNaN(labels[i]))
                    labels[i] = defaultLabel;
        }
        /// <summary>
        /// This method ensures that the data meets the requirements of this trainer and its
        /// subclasses, injects necessary transforms, and throws if it couldn't meet them.
        /// </summary>
        /// <param name="ch">The channel</param>
        /// <param name="examples">The training examples</param>
        /// <param name="weightSetCount">Gets the length of weights and bias array. For binary classification and regression,
        /// this is 1. For multi-class classification, this equals the number of classes on the label.</param>
        /// <returns>A potentially modified version of <paramref name="examples"/></returns>
        private RoleMappedData PrepareDataFromTrainingExamples(IChannel ch, RoleMappedData examples, out int weightSetCount)
        {
            ch.AssertValue(examples);
            CheckLabel(examples, out weightSetCount);
            examples.CheckFeatureFloatVector();
            var       idvToShuffle = examples.Data;
            IDataView idvToFeedTrain;

            if (idvToShuffle.CanShuffle)
            {
                idvToFeedTrain = idvToShuffle;
            }
            else
            {
                var shuffleArgs = new ShuffleTransform.Arguments
                {
                    PoolOnly     = false,
                    ForceShuffle = _args.Shuffle
                };
                idvToFeedTrain = new ShuffleTransform(Host, shuffleArgs, idvToShuffle);
            }

            ch.Assert(idvToFeedTrain.CanShuffle);

            var roles = examples.Schema.GetColumnRoleNames();
            var examplesToFeedTrain = new RoleMappedData(idvToFeedTrain, roles);

            ch.AssertValue(examplesToFeedTrain.Schema.Label);
            ch.AssertValue(examplesToFeedTrain.Schema.Feature);
            if (examples.Schema.Weight != null)
            {
                ch.AssertValue(examplesToFeedTrain.Schema.Weight);
            }

            int numFeatures = examplesToFeedTrain.Schema.Feature.Type.VectorSize;

            ch.Check(numFeatures > 0, "Training set has no features, aborting training.");
            return(examplesToFeedTrain);
        }
Exemple #13
0
        /// <summary>
        /// Train and returns a booster.
        /// </summary>
        /// <param name="ch">IChannel</param>
        /// <param name="pch">IProgressChannel</param>
        /// <param name="numberOfTrees">Number of trained trees</param>
        /// <param name="parameters">Parameters see <see cref="XGBoostArguments"/></param>
        /// <param name="dtrain">Training set</param>
        /// <param name="numBoostRound">Number of trees to train</param>
        /// <param name="obj">Custom objective</param>
        /// <param name="maximize">Whether to maximize feval.</param>
        /// <param name="verboseEval">Requires at least one item in evals.
        ///     If "verbose_eval" is True then the evaluation metric on the validation set is
        ///     printed at each boosting stage.</param>
        /// <param name="xgbModel">For continuous training.</param>
        /// <param name="saveBinaryDMatrix">Save DMatrix in binary format (for debugging purpose).</param>
        public static Booster Train(IChannel ch, IProgressChannel pch, out int numberOfTrees,
                                    Dictionary <string, string> parameters, DMatrix dtrain, int numBoostRound = 10,
                                    Booster.FObjType obj     = null, bool maximize    = false,
                                    bool verboseEval         = true, Booster xgbModel = null,
                                    string saveBinaryDMatrix = null)
        {
#if (!XGBOOST_RABIT)
            if (WrappedXGBoostInterface.RabitIsDistributed() == 1)
            {
                var pname = WrappedXGBoostInterface.RabitGetProcessorName();
                ch.Info("[WrappedXGBoostTraining.Train] start {0}:{1}", pname, WrappedXGBoostInterface.RabitGetRank());
            }
#endif

            if (!string.IsNullOrEmpty(saveBinaryDMatrix))
            {
                dtrain.SaveBinary(saveBinaryDMatrix);
            }

            Booster bst             = new Booster(parameters, dtrain, xgbModel);
            int     numParallelTree = 1;
            int     nboost          = 0;

            if (parameters != null && parameters.ContainsKey("num_parallel_tree"))
            {
                numParallelTree = Convert.ToInt32(parameters["num_parallel_tree"]);
                nboost         /= numParallelTree;
            }
            if (parameters.ContainsKey("num_class"))
            {
                int numClass = Convert.ToInt32(parameters["num_class"]);
                nboost /= numClass;
            }

            var prediction = new VBuffer <Float>();
            var grad       = new VBuffer <Float>();
            var hess       = new VBuffer <Float>();
            var start      = DateTime.Now;

#if (!XGBOOST_RABIT)
            int version = bst.LoadRabitCheckpoint();
            ch.Check(WrappedXGBoostInterface.RabitGetWorldSize() != 1 || version == 0);
#else
            int version = 0;
#endif
            int startIteration = version / 2;
            nboost += startIteration;
            int logten = 0;
            int temp   = numBoostRound * 5;
            while (temp > 0)
            {
                logten += 1;
                temp   /= 10;
            }
            temp   = Math.Max(logten - 2, 0);
            logten = 1;
            while (temp-- > 0)
            {
                logten *= 10;
            }

            var metrics = new List <string>()
            {
                "Iteration", "Training Time"
            };
            var units = new List <string>()
            {
                "iterations", "seconds"
            };
            if (verboseEval)
            {
                metrics.Add("Training Error");
                metrics.Add(parameters["objective"]);
            }
            var header = new ProgressHeader(metrics.ToArray(), units.ToArray());

            int    iter       = 0;
            double trainTime  = 0;
            double trainError = double.NaN;

            pch.SetHeader(header, e =>
            {
                e.SetProgress(0, iter, numBoostRound - startIteration);
                e.SetProgress(1, trainTime);
                if (verboseEval)
                {
                    e.SetProgress(2, trainError);
                }
            });
            for (iter = startIteration; iter < numBoostRound; ++iter)
            {
                if (version % 2 == 0)
                {
                    bst.Update(dtrain, iter, ref grad, ref hess, ref prediction, obj);
#if (!XGBOOST_RABIT)
                    bst.SaveRabitCheckpoint();
#endif
                    version += 1;
                }

#if (!XGBOOST_RABIT)
                ch.Check(WrappedXGBoostInterface.RabitGetWorldSize() == 1 ||
                         version == WrappedXGBoostInterface.RabitVersionNumber());
#endif
                nboost += 1;

                trainTime = (DateTime.Now - start).TotalMilliseconds;

                if (verboseEval)
                {
                    pch.Checkpoint(new double?[] { iter, trainTime, trainError });
                    if (iter == startIteration || iter == numBoostRound - 1 || iter % logten == 0 ||
                        (DateTime.Now - start) > TimeSpan.FromMinutes(2))
                    {
                        string strainError = bst.EvalSet(new[] { dtrain }, new[] { "Train" }, iter);
                        // Example: "[0]\tTrain-error:0.028612"
                        if (!string.IsNullOrEmpty(strainError) && strainError.Contains(":"))
                        {
                            double val;
                            if (double.TryParse(strainError.Split(':').Last(), out val))
                            {
                                trainError = val;
                            }
                        }
                    }
                }
                else
                {
                    pch.Checkpoint(new double?[] { iter, trainTime });
                }

                version += 1;
            }
            numberOfTrees = numBoostRound * numParallelTree;
            if (WrappedXGBoostInterface.RabitIsDistributed() == 1)
            {
                var pname = WrappedXGBoostInterface.RabitGetProcessorName();
                ch.Info("[WrappedXGBoostTraining.Train] end {0}:{1}", pname, WrappedXGBoostInterface.RabitGetRank());
            }
            return(bst);
        }
        protected virtual void TrainCore(IChannel ch, RoleMappedData data)
        {
            Host.AssertValue(ch);
            ch.AssertValue(data);

            // Compute the number of threads to use. The ctor should have verified that this will
            // produce a positive value.
            int numThreads = !UseThreads ? 1 : (NumThreads ?? Environment.ProcessorCount);

            if (Host.ConcurrencyFactor > 0 && numThreads > Host.ConcurrencyFactor)
            {
                numThreads = Host.ConcurrencyFactor;
                ch.Warning("The number of threads specified in trainer arguments is larger than the concurrency factor "
                           + "setting of the environment. Using {0} training threads instead.", numThreads);
            }

            ch.Assert(numThreads > 0);

            NumGoodRows = 0;
            WeightSum   = 0;

            _features = null;
            _labels   = null;
            _weights  = null;
            if (numThreads > 1)
            {
                ch.Info("LBFGS multi-threading will attempt to load dataset into memory. In case of out-of-memory " +
                        "issues, add 'numThreads=1' to the trainer arguments and 'cache=-' to the command line " +
                        "arguments to turn off multi-threading.");
                _features = new VBuffer <float> [1000];
                _labels   = new float[1000];
                if (data.Schema.Weight != null)
                {
                    _weights = new float[1000];
                }
            }

            var cursorFactory = new FloatLabelCursor.Factory(data, CursOpt.Features | CursOpt.Label | CursOpt.Weight);

            long numBad;

            // REVIEW: This pass seems overly expensive for the benefit when multi-threading is off....
            using (var cursor = cursorFactory.Create())
                using (var pch = Host.StartProgressChannel("LBFGS data prep"))
                {
                    // REVIEW: maybe it makes sense for the factory to capture the good row count after
                    // the first successful cursoring?
                    Double totalCount = data.Data.GetRowCount(true) ?? Double.NaN;

                    long exCount = 0;
                    pch.SetHeader(new ProgressHeader(null, new[] { "examples" }),
                                  e => e.SetProgress(0, exCount, totalCount));
                    while (cursor.MoveNext())
                    {
                        WeightSum += cursor.Weight;
                        if (ShowTrainingStats)
                        {
                            ProcessPriorDistribution(cursor.Label, cursor.Weight);
                        }

                        PreTrainingProcessInstance(cursor.Label, ref cursor.Features, cursor.Weight);
                        exCount++;
                        if (_features != null)
                        {
                            ch.Assert(cursor.KeptRowCount <= int.MaxValue);
                            int index = (int)cursor.KeptRowCount - 1;
                            Utils.EnsureSize(ref _features, index + 1);
                            Utils.EnsureSize(ref _labels, index + 1);
                            if (_weights != null)
                            {
                                Utils.EnsureSize(ref _weights, index + 1);
                                _weights[index] = cursor.Weight;
                            }
                            Utils.Swap(ref _features[index], ref cursor.Features);
                            _labels[index] = cursor.Label;

                            if (cursor.KeptRowCount >= int.MaxValue)
                            {
                                ch.Warning("Limiting data size for multi-threading");
                                break;
                            }
                        }
                    }
                    NumGoodRows = cursor.KeptRowCount;
                    numBad      = cursor.SkippedRowCount;
                }
            ch.Check(NumGoodRows > 0, NoTrainingInstancesMessage);
            if (numBad > 0)
            {
                ch.Warning("Skipped {0} instances with missing features/label/weight during training", numBad);
            }

            if (_features != null)
            {
                ch.Assert(numThreads > 1);

                // If there are so many threads that each only gets a small number (less than 10) of instances, trim
                // the number of threads so each gets a more reasonable number (100 or so). These numbers are pretty arbitrary,
                // but avoid the possibility of having no instances on some threads.
                if (numThreads > 1 && NumGoodRows / numThreads < 10)
                {
                    int numNew = Math.Max(1, (int)NumGoodRows / 100);
                    ch.Warning("Too few instances to use {0} threads, decreasing to {1} thread(s)", numThreads, numNew);
                    numThreads = numNew;
                }
                ch.Assert(numThreads > 0);

                // Divide up the instances among the threads.
                _numChunks = numThreads;
                _ranges    = new int[_numChunks + 1];
                int cinstTot = (int)NumGoodRows;
                for (int ichk = 0, iinstMin = 0; ichk < numThreads; ichk++)
                {
                    int cchkLeft = numThreads - ichk;                                // Number of chunks left to fill.
                    ch.Assert(0 < cchkLeft && cchkLeft <= numThreads);
                    int cinstThis = (cinstTot - iinstMin + cchkLeft - 1) / cchkLeft; // Size of this chunk.
                    ch.Assert(0 < cinstThis && cinstThis <= cinstTot - iinstMin);
                    iinstMin         += cinstThis;
                    _ranges[ichk + 1] = iinstMin;
                }

                _localLosses    = new float[numThreads];
                _localGradients = new VBuffer <float> [numThreads - 1];
                int size = BiasCount + WeightCount;
                for (int i = 0; i < _localGradients.Length; i++)
                {
                    _localGradients[i] = VBufferUtils.CreateEmpty <float>(size);
                }

                ch.Assert(_numChunks > 0 && _data == null);
            }
            else
            {
                // Streaming, single-threaded case.
                _data          = data;
                _cursorFactory = cursorFactory;
                ch.Assert(_numChunks == 0 && _data != null);
            }

            VBuffer <float>       initWeights;
            ITerminationCriterion terminationCriterion;
            Optimizer             opt = InitializeOptimizer(ch, cursorFactory, out initWeights, out terminationCriterion);

            opt.Quiet = Quiet;

            float loss;

            try
            {
                opt.Minimize(DifferentiableFunction, ref initWeights, terminationCriterion, ref CurrentWeights, out loss);
            }
            catch (Optimizer.PrematureConvergenceException e)
            {
                if (!Quiet)
                {
                    ch.Warning("Premature convergence occurred. The OptimizationTolerance may be set too small. {0}", e.Message);
                }
                CurrentWeights = e.State.X;
                loss           = e.State.Value;
            }

            ch.Assert(CurrentWeights.Length == BiasCount + WeightCount);

            int numParams = BiasCount;

            if ((L1Weight > 0 && !Quiet) || ShowTrainingStats)
            {
                VBufferUtils.ForEachDefined(ref CurrentWeights, (index, value) => { if (index >= BiasCount && value != 0)
                                                                                    {
                                                                                        numParams++;
                                                                                    }
                                            });
                if (L1Weight > 0 && !Quiet)
                {
                    ch.Info("L1 regularization selected {0} of {1} weights.", numParams, BiasCount + WeightCount);
                }
            }

            if (ShowTrainingStats)
            {
                ComputeTrainingStatistics(ch, cursorFactory, loss, numParams);
            }
        }
        private OlsLinearRegressionPredictor TrainCore(IChannel ch, FloatLabelCursor.Factory cursorFactory, int featureCount)
        {
            Host.AssertValue(ch);
            ch.AssertValue(cursorFactory);

            int m = featureCount + 1;

            // Check for memory conditions first.
            if ((long)m * (m + 1) / 2 > int.MaxValue)
            {
                throw ch.Except("Cannot hold covariance matrix in memory with {0} features", m - 1);
            }

            // Track the number of examples.
            long n = 0;
            // Since we are accumulating over many values, we use Double even for the single precision build.
            var xty = new Double[m];
            // The layout of this algorithm is a packed row-major lower triangular matrix.
            var xtx = new Double[m * (m + 1) / 2];

            // Build X'X (lower triangular) and X'y incrementally (X'X+=X'X_i; X'y+=X'y_i):
            using (var cursor = cursorFactory.Create())
            {
                while (cursor.MoveNext())
                {
                    var yi = cursor.Label;
                    // Increment first element of X'y
                    xty[0] += yi;
                    // Increment first element of lower triangular X'X
                    xtx[0] += 1;
                    var values = cursor.Features.GetValues();

                    if (cursor.Features.IsDense)
                    {
                        int ioff = 1;
                        ch.Assert(values.Length + 1 == m);
                        // Increment rest of first column of lower triangular X'X
                        for (int i = 1; i < m; i++)
                        {
                            ch.Assert(ioff == i * (i + 1) / 2);
                            var val = values[i - 1];
                            // Add the implicit first bias term to X'X
                            xtx[ioff++] += val;
                            // Add the remainder of X'X
                            for (int j = 0; j < i; j++)
                            {
                                xtx[ioff++] += val * values[j];
                            }
                            // X'y
                            xty[i] += val * yi;
                        }
                        ch.Assert(ioff == xtx.Length);
                    }
                    else
                    {
                        var fIndices = cursor.Features.GetIndices();
                        for (int ii = 0; ii < values.Length; ++ii)
                        {
                            int i    = fIndices[ii] + 1;
                            int ioff = i * (i + 1) / 2;
                            var val  = values[ii];
                            // Add the implicit first bias term to X'X
                            xtx[ioff++] += val;
                            // Add the remainder of X'X
                            for (int jj = 0; jj <= ii; jj++)
                            {
                                xtx[ioff + fIndices[jj]] += val * values[jj];
                            }
                            // X'y
                            xty[i] += val * yi;
                        }
                    }
                    n++;
                }
                ch.Check(n > 0, "No training examples in dataset.");
                if (cursor.BadFeaturesRowCount > 0)
                {
                    ch.Warning("Skipped {0} instances with missing features/label during training", cursor.SkippedRowCount);
                }

                if (_l2Weight > 0)
                {
                    // Skip the bias term for regularization, in the ridge regression case.
                    // So start at [1,1] instead of [0,0].

                    // REVIEW: There are two ways to view this, firstly, it is more
                    // user friendly ot make this scaling factor behave similarly regardless
                    // of data size, so that if you have the same parameters, you get the same
                    // model if you feed in your data than if you duplicate your data 10 times.
                    // This is what I have now. The alternate point of view is to view this
                    // L2 regularization parameter as providing some sort of prior, in which
                    // case duplication 10 times should in fact be treated differently! (That
                    // is, we should not multiply by n below.) Both interpretations seem
                    // correct, in their way.
                    Double squared = _l2Weight * _l2Weight * n;
                    int    ioff    = 0;
                    for (int i = 1; i < m; ++i)
                    {
                        xtx[ioff += i + 1] += squared;
                    }
                    ch.Assert(ioff == xtx.Length - 1);
                }
            }

            if (!(_l2Weight > 0) && n < m)
            {
                throw ch.Except("Ordinary least squares requires more examples than parameters. There are {0} parameters, but {1} examples. To enable training, use a positive L2 weight so this behaves as ridge regression.", m, n);
            }

            Double yMean = n == 0 ? 0 : xty[0] / n;

            ch.Info("Trainer solving for {0} parameters across {1} examples", m, n);
            // Cholesky Decomposition of X'X into LL'
            try
            {
                Mkl.Pptrf(Mkl.Layout.RowMajor, Mkl.UpLo.Lo, m, xtx);
            }
            catch (DllNotFoundException)
            {
                // REVIEW: Is there no better way?
                throw ch.ExceptNotSupp("The MKL library (libMklImports) or one of its dependencies is missing.");
            }
            // Solve for beta in (LL')beta = X'y:
            Mkl.Pptrs(Mkl.Layout.RowMajor, Mkl.UpLo.Lo, m, 1, xtx, xty, 1);
            // Note that the solver overwrote xty so it contains the solution. To be more clear,
            // we effectively change its name (through reassignment) so we don't get confused that
            // this is somehow xty in the remaining calculation.
            var beta = xty;

            xty = null;
            // Check that the solution is valid.
            for (int i = 0; i < beta.Length; ++i)
            {
                ch.Check(FloatUtils.IsFinite(beta[i]), "Non-finite values detected in OLS solution");
            }

            var weights = VBufferUtils.CreateDense <float>(beta.Length - 1);

            for (int i = 1; i < beta.Length; ++i)
            {
                weights.Values[i - 1] = (float)beta[i];
            }
            var bias = (float)beta[0];

            if (!(_l2Weight > 0) && m == n)
            {
                // We would expect the solution to the problem to be exact in this case.
                ch.Info("Number of examples equals number of parameters, solution is exact but no statistics can be derived");
                return(new OlsLinearRegressionPredictor(Host, in weights, bias, null, null, null, 1, float.NaN));
            }

            Double rss = 0; // residual sum of squares
            Double tss = 0; // total sum of squares

            using (var cursor = cursorFactory.Create())
            {
                var   lrPredictor = new LinearRegressionPredictor(Host, in weights, bias);
                var   lrMap       = lrPredictor.GetMapper <VBuffer <float>, float>();
                float yh          = default;
                while (cursor.MoveNext())
                {
                    var features = cursor.Features;
                    lrMap(in features, ref yh);
                    var e = cursor.Label - yh;
                    rss += e * e;
                    var ydm = cursor.Label - yMean;
                    tss += ydm * ydm;
                }
            }
            var rSquared = ProbClamp(1 - (rss / tss));
            // R^2 adjusted differs from the normal formula on account of the bias term, by Said's reckoning.
            double rSquaredAdjusted;

            if (n > m)
            {
                rSquaredAdjusted = ProbClamp(1 - (1 - rSquared) * (n - 1) / (n - m));
                ch.Info("Coefficient of determination R2 = {0:g}, or {1:g} (adjusted)",
                        rSquared, rSquaredAdjusted);
            }
            else
            {
                rSquaredAdjusted = Double.NaN;
            }

            // The per parameter significance is compute intensive and may not be required for all practitioners.
            // Also we can't estimate it, unless we can estimate the variance, which requires more examples than
            // parameters.
            if (!_perParameterSignificance || m >= n)
            {
                return(new OlsLinearRegressionPredictor(Host, in weights, bias, null, null, null, rSquared, rSquaredAdjusted));
            }

            ch.Assert(!Double.IsNaN(rSquaredAdjusted));
            var standardErrors = new Double[m];
            var tValues        = new Double[m];
            var pValues        = new Double[m];

            // Invert X'X:
            Mkl.Pptri(Mkl.Layout.RowMajor, Mkl.UpLo.Lo, m, xtx);
            var s2 = rss / (n - m); // estimate of variance of y

            for (int i = 0; i < m; i++)
            {
                // Initialize with inverse Hessian.
                standardErrors[i] = (Single)xtx[i * (i + 1) / 2 + i];
            }

            if (_l2Weight > 0)
            {
                // Iterate through all entries of inverse Hessian to make adjustment to variance.
                int   ioffset = 1;
                float reg     = _l2Weight * _l2Weight * n;
                for (int iRow = 1; iRow < m; iRow++)
                {
                    for (int iCol = 0; iCol <= iRow; iCol++)
                    {
                        var entry      = (Single)xtx[ioffset];
                        var adjustment = -reg * entry * entry;
                        standardErrors[iRow] -= adjustment;
                        if (0 < iCol && iCol < iRow)
                        {
                            standardErrors[iCol] -= adjustment;
                        }
                        ioffset++;
                    }
                }

                Contracts.Assert(ioffset == xtx.Length);
            }

            for (int i = 0; i < m; i++)
            {
                // sqrt of diagonal entries of s2 * inverse(X'X + reg * I) * X'X * inverse(X'X + reg * I).
                standardErrors[i] = Math.Sqrt(s2 * standardErrors[i]);
                ch.Check(FloatUtils.IsFinite(standardErrors[i]), "Non-finite standard error detected from OLS solution");
                tValues[i] = beta[i] / standardErrors[i];
                pValues[i] = (float)MathUtils.TStatisticToPValue(tValues[i], n - m);
                ch.Check(0 <= pValues[i] && pValues[i] <= 1, "p-Value calculated outside expected [0,1] range");
            }

            return(new OlsLinearRegressionPredictor(Host, in weights, bias, standardErrors, tValues, pValues, rSquared, rSquaredAdjusted));
        }
Exemple #16
0
        private void WriteDataCore(IChannel ch, TextWriter writer, IDataView data,
                                   out string argsLoader, out long count, out int min, out int max, params int[] cols)
        {
            _host.AssertValue(ch);
            ch.AssertValue(writer);
            ch.AssertValue(data);
            ch.AssertNonEmpty(cols);

            // Determine the active columns and whether there is header information.
            bool[] active = new bool[data.Schema.ColumnCount];
            for (int i = 0; i < cols.Length; i++)
            {
                ch.Check(0 <= cols[i] && cols[i] < active.Length);
                ch.Check(data.Schema.GetColumnType(cols[i]).ItemType.RawKind != 0);
                active[cols[i]] = true;
            }

            bool hasHeader = false;

            if (_outputHeader)
            {
                for (int i = 0; i < cols.Length; i++)
                {
                    if (hasHeader)
                    {
                        continue;
                    }
                    var type = data.Schema.GetColumnType(cols[i]);
                    if (!type.IsVector)
                    {
                        hasHeader = true;
                        continue;
                    }
                    if (!type.IsKnownSizeVector)
                    {
                        continue;
                    }
                    var typeNames = data.Schema.GetMetadataTypeOrNull(MetadataUtils.Kinds.SlotNames, cols[i]);
                    if (typeNames != null && typeNames.VectorSize == type.VectorSize && typeNames.ItemType.IsText)
                    {
                        hasHeader = true;
                    }
                }
            }

            using (var cursor = data.GetRowCursor(i => active[i]))
            {
                var pipes = new ValueWriter[cols.Length];
                for (int i = 0; i < cols.Length; i++)
                {
                    pipes[i] = ValueWriter.Create(cursor, cols[i], _sepChar);
                }

                // REVIEW: This should be outside the cursor creation.
                string header = CreateLoaderArguments(data.Schema, pipes, hasHeader, ch);
                argsLoader = header;
                if (_outputSchema)
                {
                    WriteSchemaAsComment(writer, header);
                }

                double rowCount = data.GetRowCount(true) ?? double.NaN;
                using (var pch = !_silent ? _host.StartProgressChannel("TextSaver: saving data") : null)
                {
                    long stateCount = 0;
                    var  state      = new State(this, writer, pipes, hasHeader);
                    if (pch != null)
                    {
                        pch.SetHeader(new ProgressHeader(new[] { "rows" }), e => e.SetProgress(0, stateCount, rowCount));
                    }
                    state.Run(cursor, ref stateCount, out min, out max);
                    count = stateCount;
                    if (pch != null)
                    {
                        pch.Checkpoint(stateCount);
                    }
                }
            }
        }
Exemple #17
0
        private FieldAwareFactorizationMachineModelParameters TrainCore(IChannel ch, IProgressChannel pch, RoleMappedData data,
                                                                        RoleMappedData validData = null, FieldAwareFactorizationMachineModelParameters predictor = null)
        {
            _host.AssertValue(ch);
            _host.AssertValue(pch);

            data.CheckBinaryLabel();
            var featureColumns    = data.Schema.GetColumns(RoleMappedSchema.ColumnRole.Feature);
            int fieldCount        = featureColumns.Count;
            int totalFeatureCount = 0;

            int[] fieldColumnIndexes = new int[fieldCount];
            for (int f = 0; f < fieldCount; f++)
            {
                var col = featureColumns[f];
                _host.Assert(!col.IsHidden);
                if (!(col.Type is VectorDataViewType vectorType) ||
                    !vectorType.IsKnownSize ||
                    vectorType.ItemType != NumberDataViewType.Single)
                {
                    throw ch.ExceptParam(nameof(data), "Training feature column '{0}' must be a known-size vector of Single, but has type: {1}.", col.Name, col.Type);
                }
                _host.Assert(vectorType.Size > 0);
                fieldColumnIndexes[f] = col.Index;
                totalFeatureCount    += vectorType.Size;
            }
            ch.Check(checked (totalFeatureCount * fieldCount * _latentDimAligned) <= Utils.ArrayMaxSize, "Latent dimension or the number of fields too large");
            if (predictor != null)
            {
                ch.Check(predictor.FeatureCount == totalFeatureCount, "Input model's feature count mismatches training feature count");
                ch.Check(predictor.LatentDimension == _latentDim, "Input model's latent dimension mismatches trainer's");
            }
            if (validData != null)
            {
                validData.CheckBinaryLabel();
                var validFeatureColumns = data.Schema.GetColumns(RoleMappedSchema.ColumnRole.Feature);
                _host.Assert(fieldCount == validFeatureColumns.Count);
                for (int f = 0; f < fieldCount; f++)
                {
                    var featCol      = featureColumns[f];
                    var validFeatCol = validFeatureColumns[f];
                    _host.Assert(featCol.Name == validFeatCol.Name);
                    _host.Assert(featCol.Type == validFeatCol.Type);
                }
            }
            bool shuffle = _shuffle;

            if (shuffle && !data.Data.CanShuffle)
            {
                ch.Warning("Training data does not support shuffling, so ignoring request to shuffle");
                shuffle = false;
            }
            var rng                = shuffle ? _host.Rand : null;
            var featureGetters     = new ValueGetter <VBuffer <float> > [fieldCount];
            var featureBuffer      = new VBuffer <float>();
            var featureValueBuffer = new float[totalFeatureCount];
            var featureIndexBuffer = new int[totalFeatureCount];
            var featureFieldBuffer = new int[totalFeatureCount];
            var latentSum          = new AlignedArray(fieldCount * fieldCount * _latentDimAligned, 16);
            var metricNames        = new List <string>()
            {
                "Training-loss"
            };

            if (validData != null)
            {
                metricNames.Add("Validation-loss");
            }
            int    iter                 = 0;
            long   exampleCount         = 0;
            long   badExampleCount      = 0;
            long   validBadExampleCount = 0;
            double loss                 = 0;
            double validLoss            = 0;

            pch.SetHeader(new ProgressHeader(metricNames.ToArray(), new string[] { "iterations", "examples" }), entry =>
            {
                entry.SetProgress(0, iter, _numIterations);
                entry.SetProgress(1, exampleCount);
            });

            var columns = data.Schema.Schema.Where(x => fieldColumnIndexes.Contains(x.Index)).ToList();

            columns.Add(data.Schema.Label.Value);
            if (data.Schema.Weight != null)
            {
                columns.Add(data.Schema.Weight.Value);
            }

            InitializeTrainingState(fieldCount, totalFeatureCount, predictor, out float[] linearWeights,
                                    out AlignedArray latentWeightsAligned, out float[] linearAccSqGrads, out AlignedArray latentAccSqGradsAligned);

            // refer to Algorithm 3 in https://github.com/wschin/fast-ffm/blob/master/fast-ffm.pdf
            while (iter++ < _numIterations)
            {
                using (var cursor = data.Data.GetRowCursor(columns, rng))
                {
                    var labelGetter  = RowCursorUtils.GetLabelGetter(cursor, data.Schema.Label.Value.Index);
                    var weightGetter = data.Schema.Weight?.Index is int weightIdx?RowCursorUtils.GetGetterAs <float>(NumberDataViewType.Single, cursor, weightIdx) : null;

                    for (int i = 0; i < fieldCount; i++)
                    {
                        featureGetters[i] = cursor.GetGetter <VBuffer <float> >(cursor.Schema[fieldColumnIndexes[i]]);
                    }
                    loss            = 0;
                    exampleCount    = 0;
                    badExampleCount = 0;
                    while (cursor.MoveNext())
                    {
                        float label         = 0;
                        float weight        = 1;
                        int   count         = 0;
                        float modelResponse = 0;
                        labelGetter(ref label);
                        weightGetter?.Invoke(ref weight);
                        float annihilation = label - label + weight - weight;
                        if (!FloatUtils.IsFinite(annihilation))
                        {
                            badExampleCount++;
                            continue;
                        }
                        if (!FieldAwareFactorizationMachineUtils.LoadOneExampleIntoBuffer(featureGetters, featureBuffer, _norm, ref count,
                                                                                          featureFieldBuffer, featureIndexBuffer, featureValueBuffer))
                        {
                            badExampleCount++;
                            continue;
                        }

                        // refer to Algorithm 1 in [3] https://github.com/wschin/fast-ffm/blob/master/fast-ffm.pdf
                        FieldAwareFactorizationMachineInterface.CalculateIntermediateVariables(fieldCount, _latentDimAligned, count,
                                                                                               featureFieldBuffer, featureIndexBuffer, featureValueBuffer, linearWeights, latentWeightsAligned, latentSum, ref modelResponse);
                        var slope = CalculateLossSlope(label, modelResponse);

                        // refer to Algorithm 2 in [3] https://github.com/wschin/fast-ffm/blob/master/fast-ffm.pdf
                        FieldAwareFactorizationMachineInterface.CalculateGradientAndUpdate(_lambdaLinear, _lambdaLatent, _learningRate, fieldCount, _latentDimAligned, weight, count,
                                                                                           featureFieldBuffer, featureIndexBuffer, featureValueBuffer, latentSum, slope, linearWeights, latentWeightsAligned, linearAccSqGrads, latentAccSqGradsAligned);
                        loss += weight * CalculateLoss(label, modelResponse);
                        exampleCount++;
                    }
                    loss /= exampleCount;
                }

                if (_verbose)
                {
                    if (validData == null)
                    {
                        pch.Checkpoint(loss, iter, exampleCount);
                    }
                    else
                    {
                        validLoss = CalculateAvgLoss(ch, validData, _norm, linearWeights, latentWeightsAligned, _latentDimAligned, latentSum,
                                                     featureFieldBuffer, featureIndexBuffer, featureValueBuffer, featureBuffer, ref validBadExampleCount);
                        pch.Checkpoint(loss, validLoss, iter, exampleCount);
                    }
                }
            }
            if (badExampleCount != 0)
            {
                ch.Warning($"Skipped {badExampleCount} examples with bad label/weight/features in training set");
            }
            if (validBadExampleCount != 0)
            {
                ch.Warning($"Skipped {validBadExampleCount} examples with bad label/weight/features in validation set");
            }

            return(new FieldAwareFactorizationMachineModelParameters(_host, _norm, fieldCount, totalFeatureCount, _latentDim, linearWeights, latentWeightsAligned));
        }
Exemple #18
0
        private TPredictor TrainCore(IChannel ch, RoleMappedData data)
        {
            Host.AssertValue(ch);
            ch.AssertValue(data);

            // 1. Subset Selection
            var stackingTrainer = Combiner as IStackingTrainer <TOutput>;

            //REVIEW: Implement stacking for Batch mode.
            ch.CheckUserArg(stackingTrainer == null || Args.BatchSize <= 0, nameof(Args.BatchSize), "Stacking works only with Non-batch mode");

            var validationDataSetProportion = SubModelSelector.ValidationDatasetProportion;

            if (stackingTrainer != null)
            {
                validationDataSetProportion = Math.Max(validationDataSetProportion, stackingTrainer.ValidationDatasetProportion);
            }

            var needMetrics = Args.ShowMetrics || Combiner is IWeightedAverager;
            var models      = new List <FeatureSubsetModel <IPredictorProducing <TOutput> > >();

            _subsetSelector.Initialize(data, NumModels, Args.BatchSize, validationDataSetProportion);
            int batchNumber = 1;

            foreach (var batch in _subsetSelector.GetBatches(Host.Rand))
            {
                // 2. Core train
                ch.Info("Training {0} learners for the batch {1}", Trainers.Length, batchNumber++);
                var batchModels = new FeatureSubsetModel <IPredictorProducing <TOutput> > [Trainers.Length];

                Parallel.ForEach(_subsetSelector.GetSubsets(batch, Host.Rand),
                                 new ParallelOptions()
                {
                    MaxDegreeOfParallelism = Args.TrainParallel ? -1 : 1
                },
                                 (subset, state, index) =>
                {
                    ch.Info("Beginning training model {0} of {1}", index + 1, Trainers.Length);
                    Stopwatch sw = Stopwatch.StartNew();
                    try
                    {
                        if (EnsureMinimumFeaturesSelected(subset))
                        {
                            var model = new FeatureSubsetModel <IPredictorProducing <TOutput> >(
                                Trainers[(int)index].Train(subset.Data),
                                subset.SelectedFeatures,
                                null);
                            SubModelSelector.CalculateMetrics(model, _subsetSelector, subset, batch, needMetrics);
                            batchModels[(int)index] = model;
                        }
                    }
                    catch (Exception ex)
                    {
                        ch.Assert(batchModels[(int)index] == null);
                        ch.Warning(ex.Sensitivity(), "Trainer {0} of {1} was not learned properly due to the exception '{2}' and will not be added to models.",
                                   index + 1, Trainers.Length, ex.Message);
                    }
                    ch.Info("Trainer {0} of {1} finished in {2}", index + 1, Trainers.Length, sw.Elapsed);
                });

                var modelsList = batchModels.Where(m => m != null).ToList();
                if (Args.ShowMetrics)
                {
                    PrintMetrics(ch, modelsList);
                }

                modelsList = SubModelSelector.Prune(modelsList).ToList();

                if (stackingTrainer != null)
                {
                    stackingTrainer.Train(modelsList, _subsetSelector.GetTestData(null, batch), Host);
                }

                models.AddRange(modelsList);
                int modelSize = Utils.Size(models);
                if (modelSize < Utils.Size(Trainers))
                {
                    ch.Warning("{0} of {1} trainings failed.", Utils.Size(Trainers) - modelSize, Utils.Size(Trainers));
                }
                ch.Check(modelSize > 0, "Ensemble training resulted in no valid models.");
            }
            return(CreatePredictor(models));
        }
Exemple #19
0
        private void RunCore(IChannel ch)
        {
            Host.AssertValue(ch);

            ch.Trace("Creating loader");

            LoadModelObjects(ch, true, out var predictor, true, out var trainSchema, out var loader);
            ch.AssertValue(predictor);
            ch.AssertValueOrNull(trainSchema);
            ch.AssertValue(loader);

            ch.Trace("Creating pipeline");
            var scorer = Args.Scorer;

            ch.Assert(scorer == null || scorer is ICommandLineComponentFactory, "ScoreCommand should only be used from the command line.");
            var bindable = ScoreUtils.GetSchemaBindableMapper(Host, predictor, scorerFactorySettings: scorer as ICommandLineComponentFactory);

            ch.AssertValue(bindable);

            // REVIEW: We probably ought to prefer role mappings from the training schema.
            string feat = TrainUtils.MatchNameOrDefaultOrNull(ch, loader.Schema,
                                                              nameof(Args.FeatureColumn), Args.FeatureColumn, DefaultColumnNames.Features);
            string group = TrainUtils.MatchNameOrDefaultOrNull(ch, loader.Schema,
                                                               nameof(Args.GroupColumn), Args.GroupColumn, DefaultColumnNames.GroupId);
            var customCols = TrainUtils.CheckAndGenerateCustomColumns(ch, Args.CustomColumn);
            var schema     = new RoleMappedSchema(loader.Schema, label: null, feature: feat, group: group, custom: customCols, opt: true);
            var mapper     = bindable.Bind(Host, schema);

            if (scorer == null)
            {
                scorer = ScoreUtils.GetScorerComponent(Host, mapper);
            }

            loader = CompositeDataLoader.ApplyTransform(Host, loader, "Scorer", scorer.ToString(),
                                                        (env, view) => scorer.CreateComponent(env, view, mapper, trainSchema));

            loader = CompositeDataLoader.Create(Host, loader, Args.PostTransform);

            if (!string.IsNullOrWhiteSpace(Args.OutputModelFile))
            {
                ch.Trace("Saving the data pipe");
                SaveLoader(loader, Args.OutputModelFile);
            }

            ch.Trace("Creating saver");
            IDataSaver writer;

            if (Args.Saver == null)
            {
                var ext    = Path.GetExtension(Args.OutputDataFile);
                var isText = ext == ".txt" || ext == ".tlc";
                if (isText)
                {
                    writer = new TextSaver(Host, new TextSaver.Arguments());
                }
                else
                {
                    writer = new BinarySaver(Host, new BinarySaver.Arguments());
                }
            }
            else
            {
                writer = Args.Saver.CreateComponent(Host);
            }
            ch.Assert(writer != null);
            var outputIsBinary = writer is BinaryWriter;

            bool outputAllColumns =
                Args.OutputAllColumns == true ||
                (Args.OutputAllColumns == null && Utils.Size(Args.OutputColumn) == 0 && outputIsBinary);

            bool outputNamesAndLabels =
                Args.OutputAllColumns == true || Utils.Size(Args.OutputColumn) == 0;

            if (Args.OutputAllColumns == true && Utils.Size(Args.OutputColumn) != 0)
            {
                ch.Warning(nameof(Args.OutputAllColumns) + "=+ always writes all columns irrespective of " + nameof(Args.OutputColumn) + " specified.");
            }

            if (!outputAllColumns && Utils.Size(Args.OutputColumn) != 0)
            {
                foreach (var outCol in Args.OutputColumn)
                {
                    if (!loader.Schema.TryGetColumnIndex(outCol, out int dummyColIndex))
                    {
                        throw ch.ExceptUserArg(nameof(Arguments.OutputColumn), "Column '{0}' not found.", outCol);
                    }
                }
            }

            uint maxScoreId = 0;

            if (!outputAllColumns)
            {
                maxScoreId = loader.Schema.GetMaxMetadataKind(out int colMax, MetadataUtils.Kinds.ScoreColumnSetId);
            }
            ch.Assert(outputAllColumns || maxScoreId > 0); // score set IDs are one-based
            var cols = new List <int>();

            for (int i = 0; i < loader.Schema.Count; i++)
            {
                if (!Args.KeepHidden && loader.Schema.IsHidden(i))
                {
                    continue;
                }
                if (!(outputAllColumns || ShouldAddColumn(loader.Schema, i, maxScoreId, outputNamesAndLabels)))
                {
                    continue;
                }
                var type = loader.Schema.GetColumnType(i);
                if (writer.IsColumnSavable(type))
                {
                    cols.Add(i);
                }
                else
                {
                    ch.Warning("The column '{0}' will not be written as it has unsavable column type.",
                               loader.Schema.GetColumnName(i));
                }
            }

            ch.Check(cols.Count > 0, "No valid columns to save");

            ch.Trace("Scoring and saving data");
            using (var file = Host.CreateOutputFile(Args.OutputDataFile))
                using (var stream = file.CreateWriteStream())
                    writer.SaveData(stream, loader, cols.ToArray());
        }
Exemple #20
0
            /// <summary>
            /// An implementation of the line search for the Wolfe conditions, from Nocedal &amp; Wright
            /// </summary>
            internal virtual bool LineSearch(IChannel ch, bool force)
            {
                Contracts.AssertValue(ch);
                Float dirDeriv = VectorUtils.DotProduct(ref _dir, ref _grad);

                if (dirDeriv == 0)
                {
                    throw ch.Process(new PrematureConvergenceException(this, "Directional derivative is zero. You may be sitting on the optimum."));
                }

                // if a non-descent direction is chosen, the line search will break anyway, so throw here
                // The most likely reasons for this is a bug in your function's gradient computation,
                ch.Check(dirDeriv < 0, "L-BFGS chose a non-descent direction.");

                Float c1 = (Float)1e-4 * dirDeriv;
                Float c2 = (Float)0.9 * dirDeriv;

                Float alpha = (Iter == 1 ? (1 / VectorUtils.Norm(_dir)) : 1);

                PointValueDeriv last = new PointValueDeriv(0, LastValue, dirDeriv);
                PointValueDeriv aLo  = new PointValueDeriv();
                PointValueDeriv aHi  = new PointValueDeriv();

                // initial bracketing phase
                while (true)
                {
                    VectorUtils.AddMultInto(ref _x, alpha, ref _dir, ref _newX);
                    if (EnforceNonNegativity)
                    {
                        VBufferUtils.Apply(ref _newX, delegate(int ind, ref Float newXval)
                        {
                            if (newXval < 0.0)
                            {
                                newXval = 0;
                            }
                        });
                    }

                    Value = Eval(ref _newX, ref _newGrad);
                    GradientCalculations++;
                    if (Float.IsPositiveInfinity(Value))
                    {
                        alpha /= 2;
                        continue;
                    }

                    if (!FloatUtils.IsFinite(Value))
                    {
                        throw ch.Except("Optimizer unable to proceed with loss function yielding {0}", Value);
                    }

                    dirDeriv = VectorUtils.DotProduct(ref _dir, ref _newGrad);
                    PointValueDeriv curr = new PointValueDeriv(alpha, Value, dirDeriv);

                    if ((curr.V > LastValue + c1 * alpha) || (last.A > 0 && curr.V >= last.V))
                    {
                        aLo = last;
                        aHi = curr;
                        break;
                    }
                    else if (Math.Abs(curr.D) <= -c2)
                    {
                        return(true);
                    }
                    else if (curr.D >= 0)
                    {
                        aLo = curr;
                        aHi = last;
                        break;
                    }

                    last = curr;
                    if (alpha == 0)
                    {
                        alpha = Float.Epsilon; // Robust to divisional underflow.
                    }
                    else
                    {
                        alpha *= 2;
                    }
                }

                Float minChange = (Float)0.01;
                int   maxSteps  = 10;

                // this loop is the "zoom" procedure described in Nocedal & Wright
                for (int step = 0; ; ++step)
                {
                    if (step == maxSteps && !force)
                    {
                        return(false);
                    }

                    PointValueDeriv left  = aLo.A < aHi.A ? aLo : aHi;
                    PointValueDeriv right = aLo.A < aHi.A ? aHi : aLo;
                    if (left.D > 0 && right.D < 0)
                    {
                        // interpolating cubic would have max in range, not min (can this happen?)
                        // set a to the one with smaller value
                        alpha = aLo.V < aHi.V ? aLo.A : aHi.A;
                    }
                    else
                    {
                        alpha = CubicInterp(aLo, aHi);
                        if (Float.IsNaN(alpha) || Float.IsInfinity(alpha))
                        {
                            alpha = (aLo.A + aHi.A) / 2;
                        }
                    }

                    // this is to ensure that the new point is within bounds
                    // and that the change is reasonably sized
                    Float ub = (minChange * left.A + (1 - minChange) * right.A);
                    if (alpha > ub)
                    {
                        alpha = ub;
                    }
                    Float lb = (minChange * right.A + (1 - minChange) * left.A);
                    if (alpha < lb)
                    {
                        alpha = lb;
                    }

                    VectorUtils.AddMultInto(ref _x, alpha, ref _dir, ref _newX);
                    if (EnforceNonNegativity)
                    {
                        VBufferUtils.Apply(ref _newX, delegate(int ind, ref Float newXval)
                        {
                            if (newXval < 0.0)
                            {
                                newXval = 0;
                            }
                        });
                    }

                    Value = Eval(ref _newX, ref _newGrad);
                    GradientCalculations++;
                    if (!FloatUtils.IsFinite(Value))
                    {
                        throw ch.Except("Optimizer unable to proceed with loss function yielding {0}", Value);
                    }
                    dirDeriv = VectorUtils.DotProduct(ref _dir, ref _newGrad);

                    PointValueDeriv curr = new PointValueDeriv(alpha, Value, dirDeriv);

                    if ((curr.V > LastValue + c1 * alpha) || (curr.V >= aLo.V))
                    {
                        if (aHi.A == curr.A)
                        {
                            if (force)
                            {
                                throw ch.Process(new PrematureConvergenceException(this, "Step size interval numerically zero."));
                            }
                            else
                            {
                                return(false);
                            }
                        }
                        aHi = curr;
                    }
                    else if (Math.Abs(curr.D) <= -c2)
                    {
                        return(true);
                    }
                    else
                    {
                        if (curr.D * (aHi.A - aLo.A) >= 0)
                        {
                            aHi = aLo;
                        }
                        if (aLo.A == curr.A)
                        {
                            if (force)
                            {
                                throw ch.Process(new PrematureConvergenceException(this, "Step size interval numerically zero."));
                            }
                            else
                            {
                                return(false);
                            }
                        }
                        aLo = curr;
                    }
                }
            }
        // The multi-output regression evaluator prints only the per-label metrics for each fold.
        protected override void PrintFoldResultsCore(IChannel ch, Dictionary <string, IDataView> metrics)
        {
            IDataView fold;

            if (!metrics.TryGetValue(MetricKinds.OverallMetrics, out fold))
            {
                throw ch.Except("No overall metrics found");
            }

            int  isWeightedCol;
            bool needWeighted = fold.Schema.TryGetColumnIndex(MetricKinds.ColumnNames.IsWeighted, out isWeightedCol);

            int  stratCol;
            bool hasStrats = fold.Schema.TryGetColumnIndex(MetricKinds.ColumnNames.StratCol, out stratCol);
            int  stratVal;
            bool hasStratVals = fold.Schema.TryGetColumnIndex(MetricKinds.ColumnNames.StratVal, out stratVal);

            ch.Assert(hasStrats == hasStratVals);

            var colCount       = fold.Schema.ColumnCount;
            var vBufferGetters = new ValueGetter <VBuffer <double> > [colCount];

            using (var cursor = fold.GetRowCursor(col => true))
            {
                DvBool isWeighted = DvBool.False;
                ValueGetter <DvBool> isWeightedGetter;
                if (needWeighted)
                {
                    isWeightedGetter = cursor.GetGetter <DvBool>(isWeightedCol);
                }
                else
                {
                    isWeightedGetter = (ref DvBool dst) => dst = DvBool.False;
                }

                ValueGetter <uint> stratGetter;
                if (hasStrats)
                {
                    var type = cursor.Schema.GetColumnType(stratCol);
                    stratGetter = RowCursorUtils.GetGetterAs <uint>(type, cursor, stratCol);
                }
                else
                {
                    stratGetter = (ref uint dst) => dst = 0;
                }

                int labelCount = 0;
                for (int i = 0; i < fold.Schema.ColumnCount; i++)
                {
                    if (fold.Schema.IsHidden(i) || (needWeighted && i == isWeightedCol) ||
                        (hasStrats && (i == stratCol || i == stratVal)))
                    {
                        continue;
                    }

                    var type = fold.Schema.GetColumnType(i);
                    if (type.IsKnownSizeVector && type.ItemType == NumberType.R8)
                    {
                        vBufferGetters[i] = cursor.GetGetter <VBuffer <double> >(i);
                        if (labelCount == 0)
                        {
                            labelCount = type.VectorSize;
                        }
                        else
                        {
                            ch.Check(labelCount == type.VectorSize, "All vector metrics should contain the same number of slots");
                        }
                    }
                }
                var labelNames = new DvText[labelCount];
                for (int j = 0; j < labelCount; j++)
                {
                    labelNames[j] = new DvText(string.Format("Label_{0}", j));
                }

                var sb = new StringBuilder();
                sb.AppendLine("Per-label metrics:");
                sb.AppendFormat("{0,12} ", " ");
                for (int i = 0; i < labelCount; i++)
                {
                    sb.AppendFormat(" {0,20}", labelNames[i]);
                }
                sb.AppendLine();

                VBuffer <Double> metricVals      = default(VBuffer <Double>);
                bool             foundWeighted   = !needWeighted;
                bool             foundUnweighted = false;
                uint             strat           = 0;
                while (cursor.MoveNext())
                {
                    isWeightedGetter(ref isWeighted);
                    if (foundWeighted && isWeighted.IsTrue || foundUnweighted && isWeighted.IsFalse)
                    {
                        throw ch.Except("Multiple {0} rows found in overall metrics data view",
                                        isWeighted.IsTrue ? "weighted" : "unweighted");
                    }
                    if (isWeighted.IsTrue)
                    {
                        foundWeighted = true;
                    }
                    else
                    {
                        foundUnweighted = true;
                    }

                    stratGetter(ref strat);
                    if (strat > 0)
                    {
                        continue;
                    }

                    for (int i = 0; i < colCount; i++)
                    {
                        if (vBufferGetters[i] != null)
                        {
                            vBufferGetters[i](ref metricVals);
                            ch.Assert(metricVals.Length == labelCount);

                            sb.AppendFormat("{0}{1,12}:", isWeighted.IsTrue ? "Weighted " : "", fold.Schema.GetColumnName(i));
                            foreach (var metric in metricVals.Items(all: true))
                            {
                                sb.AppendFormat(" {0,20:G20}", metric.Value);
                            }
                            sb.AppendLine();
                        }
                    }
                    if (foundUnweighted && foundWeighted)
                    {
                        break;
                    }
                }
                ch.Assert(foundUnweighted && foundWeighted);
                ch.Info(sb.ToString());
            }
        }
 protected ServerChannel.IServer InitServer(IChannel ch)
 {
     Host.CheckValue(ch, nameof(ch));
     ch.Check(Host != null, nameof(InitServer) + " called prematurely");
     return(_serverFactory?.CreateComponent(Host, ch));
 }
        private DMatrix FillDenseMatrix(IChannel ch, int nbDim, long nbRows,
                                        RoleMappedData data, out Float[] labels, out uint[] groupCount)
        {
            // Allocation.
            string errorMessageGroup = string.Format("Group is above {0}.", uint.MaxValue);

            if (nbDim * nbRows >= Utils.ArrayMaxSize)
            {
                throw _host.Except("The training dataset is too big to hold in memory. " +
                                   "Number of features ({0}) multiplied by the number of rows ({1}) must be less than {2}.", nbDim, nbRows, Utils.ArrayMaxSize);
            }
            var features = new Float[nbDim * nbRows];

            labels = new Float[nbRows];
            var hasWeights = data.Schema.Weight != null;
            var hasGroup   = data.Schema.Group != null;
            var weights    = hasWeights ? new Float[nbRows] : null;
            var groupsML   = hasGroup ? new uint[nbRows] : null;

            groupCount = hasGroup ? new uint[nbRows] : null;
            var groupId = hasGroup ? new HashSet <uint>() : null;

            int count     = 0;
            int lastGroup = -1;
            int fcount    = 0;
            var flags     = CursOpt.Features | CursOpt.Label | CursOpt.AllowBadEverything | CursOpt.Weight | CursOpt.Group;

            var featureVector = default(VBuffer <float>);
            var labelProxy    = float.NaN;
            var groupProxy    = ulong.MaxValue;

            using (var cursor = data.CreateRowCursor(flags, null))
            {
                var featureGetter = cursor.GetFeatureFloatVectorGetter(data);
                var labelGetter   = cursor.GetLabelFloatGetter(data);
                var weighGetter   = cursor.GetOptWeightFloatGetter(data);
                var groupGetter   = cursor.GetOptGroupGetter(data);

                while (cursor.MoveNext())
                {
                    featureGetter(ref featureVector);
                    labelGetter(ref labelProxy);

                    labels[count] = labelProxy;
                    if (Single.IsNaN(labels[count]))
                    {
                        continue;
                    }

                    featureVector.CopyTo(features, fcount, Single.NaN);
                    fcount += featureVector.Count;

                    if (hasWeights)
                    {
                        weighGetter(ref weights[count]);
                    }
                    if (hasGroup)
                    {
                        groupGetter(ref groupProxy);
                        _host.Check(groupProxy < uint.MaxValue, errorMessageGroup);
                        groupsML[count] = (uint)groupProxy;
                        if (count == 0 || groupsML[count - 1] != groupsML[count])
                        {
                            groupCount[++lastGroup] = 1;
                            ch.Check(!groupId.Contains(groupsML[count]), "Group Id are not contiguous.");
                            groupId.Add(groupsML[count]);
                        }
                        else
                        {
                            ++groupCount[lastGroup];
                        }
                    }
                    ++count;
                }
            }

            PostProcessLabelsBeforeCreatingXGBoostContainer(ch, data, labels);

            // We create a DMatrix.
            DMatrix dtrain = new DMatrix(features, (uint)count, (uint)nbDim, labels: labels, weights: weights, groups: groupCount);

            return(dtrain);
        }
        /// <summary>
        /// Fill a sparse DMatrix using CSR compression.
        /// See http://docs.scipy.org/doc/scipy/reference/generated/scipy.sparse.csr_matrix.html.
        /// </summary>
        private DMatrix FillSparseMatrix(IChannel ch, int nbDim, long nbRows, RoleMappedData data,
                                         out Float[] labels, out uint[] groupCount)
        {
            // Allocation.
            if ((2 * nbRows) >= Utils.ArrayMaxSize)
            {
                throw _host.Except("The training dataset is too big to hold in memory. " +
                                   "2 features multiplied by the number of rows must be less than {0}.", Utils.ArrayMaxSize);
            }

            var  features = new Float[nbRows * 2];
            var  indices  = new uint[features.Length];
            var  indptr   = new ulong[nbRows + 1];
            long nelem    = 0;

            labels = new Float[nbRows];
            var hasWeights = data.Schema.Weight != null;
            var hasGroup   = data.Schema.Group != null;
            var weights    = hasWeights ? new Float[nbRows] : null;
            var groupsML   = hasGroup ? new uint[nbRows] : null;

            groupCount = hasGroup ? new uint[nbRows] : null;
            var groupId = hasGroup ? new HashSet <uint>() : null;

            int count     = 0;
            int lastGroup = -1;
            var flags     = CursOpt.Features | CursOpt.Label | CursOpt.AllowBadEverything | CursOpt.Weight | CursOpt.Group;

            var featureVector = default(VBuffer <float>);
            var labelProxy    = float.NaN;
            var groupProxy    = ulong.MaxValue;

            using (var cursor = data.CreateRowCursor(flags, null))
            {
                var featureGetter = cursor.GetFeatureFloatVectorGetter(data);
                var labelGetter   = cursor.GetLabelFloatGetter(data);
                var weighGetter   = cursor.GetOptWeightFloatGetter(data);
                var groupGetter   = cursor.GetOptGroupGetter(data);
                while (cursor.MoveNext())
                {
                    featureGetter(ref featureVector);
                    labelGetter(ref labelProxy);
                    labels[count] = labelProxy;
                    if (Single.IsNaN(labels[count]))
                    {
                        continue;
                    }

                    indptr[count] = (ulong)nelem;
                    int nbValues = featureVector.Count;
                    if (nbValues > 0)
                    {
                        if (nelem + nbValues > features.Length)
                        {
                            long newSize = Math.Max(nelem + nbValues, features.Length * 2);
                            if (newSize >= Utils.ArrayMaxSize)
                            {
                                throw _host.Except("The training dataset is too big to hold in memory. " +
                                                   "It should be half of {0}.", Utils.ArrayMaxSize);
                            }
                            Array.Resize(ref features, (int)newSize);
                            Array.Resize(ref indices, (int)newSize);
                        }

                        Array.Copy(featureVector.Values, 0, features, nelem, nbValues);
                        if (featureVector.IsDense)
                        {
                            for (int i = 0; i < nbValues; ++i)
                            {
                                indices[nelem++] = (uint)i;
                            }
                        }
                        else
                        {
                            for (int i = 0; i < nbValues; ++i)
                            {
                                indices[nelem++] = (uint)featureVector.Indices[i];
                            }
                        }
                    }

                    if (hasWeights)
                    {
                        weighGetter(ref weights[count]);
                    }
                    if (hasGroup)
                    {
                        groupGetter(ref groupProxy);
                        if (groupProxy >= uint.MaxValue)
                        {
                            throw _host.Except($"Group is above {uint.MaxValue}");
                        }
                        groupsML[count] = (uint)groupProxy;
                        if (count == 0 || groupsML[count - 1] != groupsML[count])
                        {
                            groupCount[++lastGroup] = 1;
                            ch.Check(!groupId.Contains(groupsML[count]), "Group Id are not contiguous.");
                            groupId.Add(groupsML[count]);
                        }
                        else
                        {
                            ++groupCount[lastGroup];
                        }
                    }
                    ++count;
                }
            }
            indptr[count] = (uint)nelem;

            if (nelem < features.Length * 3 / 4)
            {
                Array.Resize(ref features, (int)nelem);
                Array.Resize(ref indices, (int)nelem);
            }

            PostProcessLabelsBeforeCreatingXGBoostContainer(ch, data, labels);

            // We create a DMatrix.
            DMatrix dtrain = new DMatrix((uint)nbDim, indptr, indices, features, (uint)count, (uint)nelem, labels: labels, weights: weights, groups: groupCount);

            return(dtrain);
        }
        protected override void PostProcessLabelsBeforeCreatingXGBoostContainer(IChannel ch, RoleMappedData data, Float[] labels)
        {
            Contracts.Assert(PredictionKind == PredictionKind.MultiClassClassification);

            int[] classMapping;

            // This builds the mapping from XGBoost classes to Microsoft.ML classes.
            // XGBoost classes must start at 0. The mapping removes empty classes as XGBoost
            // multiplies the number of tree by the number of classes, this reduces the complexity.
            classMapping = labels.Select(c => (int)c).Distinct().OrderBy(c => c).ToArray();
            ch.Check(classMapping[0] >= 0, "Negative labels are not allowed.");
            var map = classMapping.Select((c, i) => new { c = c, i = i })
                      .ToDictionary(item => item.c, item => item.i);

            for (int i = 0; i < labels.Length; ++i)
            {
                ch.Assert(!Single.IsNaN(labels[i]));
                labels[i] = (float)map[(int)labels[i]];
            }

            _nbClass = classMapping.Length;

            // The classMapping is used by the prediction when the label is R4
            // or when the label is Key with a different range than XGBoost one.
            if (data.Schema.Label.Type.IsKey)
            {
                _isFloatLabel = false;
                var labelType = data.Schema.Label.Type.AsKey;
                if (labelType.Count == classMapping.Length)
                {
                    classMapping = null;
                }
                else
                {
                    // There are fewer classes in the training database than the label type provides.
                    // We keep the mapping to compute the final prediction
                    // returned as a sparse vector.
                    ulong max = (ulong)labelType.Count;
                    if (max >= int.MaxValue)
                    {
                        throw ch.Except("Labels must be < {0}.", int.MaxValue);
                    }
                    _nbClass = labelType.Count;
                    // Mapping starts at zero.
                    var mini = classMapping.Min();
                    for (int i = 0; i < classMapping.Length; ++i)
                    {
                        classMapping[i] -= mini;
                    }
                }
            }
            else if (data.Schema.Label.Type == NumberType.R4)
            {
                _isFloatLabel = true;
            }
            else
            {
                throw ch.ExceptParam(nameof(data), "Label type must be a key or a float R4.");
            }

            _classMapping = classMapping;
        }
        private void TrainCore(IChannel ch, IProgressChannel pch, RoleMappedData data, TPredictor predictor)
        {
            // Verifications.
            _host.AssertValue(ch);
            ch.CheckValue(data, nameof(data));

            ValidateTrainInput(ch, data);

            var featureColumns = data.Schema.GetColumns(RoleMappedSchema.ColumnRole.Feature);

            ch.Check(featureColumns.Count == 1, "Only one vector of features is allowed.");

            // Data dimension.
            int fi      = data.Schema.Feature.Index;
            var colType = data.Schema.Schema.GetColumnType(fi);

            ch.Assert(colType.IsVector, "Feature must be a vector.");
            ch.Assert(colType.VectorSize > 0, "Feature dimension must be known.");
            int       nbDim  = colType.VectorSize;
            IDataView view   = data.Data;
            long      nbRows = DataViewUtils.ComputeRowCount(view);

            Float[] labels;
            uint[]  groupCount;
            DMatrix dtrain;
            // REVIEW xadupre: this can be avoided by using method XGDMatrixCreateFromDataIter from the XGBoost API.
            // XGBoost removes NaN values from a dense matrix and stores it in sparse format anyway.
            bool isDense = DetectDensity(data);
            var  dt      = DateTime.Now;

            if (isDense)
            {
                dtrain = FillDenseMatrix(ch, nbDim, nbRows, data, out labels, out groupCount);
                ch.Info("Dense matrix created with nbFeatures={0} and nbRows={1} in {2}.", nbDim, nbRows, DateTime.Now - dt);
            }
            else
            {
                dtrain = FillSparseMatrix(ch, nbDim, nbRows, data, out labels, out groupCount);
                ch.Info("Sparse matrix created with nbFeatures={0} and nbRows={1} in {2}.", nbDim, nbRows, DateTime.Now - dt);
            }

            // Some options are filled based on the data.
            var options = _args.ToDict(_host);

            UpdateXGBoostOptions(ch, options, labels, groupCount);

            // For multi class, the number of labels is required.
            ch.Assert(PredictionKind != PredictionKind.MultiClassClassification || options.ContainsKey("num_class"),
                      "XGBoost requires the number of classes to be specified in the parameters.");

            ch.Info("XGBoost objective={0}", options["objective"]);

            int     numTrees;
            Booster res = WrappedXGBoostTraining.Train(ch, pch, out numTrees, options, dtrain,
                                                       numBoostRound: _args.numBoostRound,
                                                       obj: null, verboseEval: _args.verboseEval,
                                                       xgbModel: predictor == null ? null : predictor.GetBooster(),
                                                       saveBinaryDMatrix: _args.saveXGBoostDMatrixAsBinary);

            int nbTrees = res.GetNumTrees();

            ch.Info("Training is complete. Number of added trees={0}, total={1}.", numTrees, nbTrees);

            _model             = res.SaveRaw();
            _nbFeaturesXGboost = (int)dtrain.GetNumCols();
            _nbFeaturesML      = nbDim;
        }