/// <summary>
        /// Load dataset. Use row batch way to reduce peak memory cost.
        /// </summary>
        private void LoadDataset(IChannel ch, FloatLabelCursor.Factory factory, Dataset dataset, int numRow, int batchSize, CategoricalMetaData catMetaData)
        {
            Host.AssertValue(ch);
            ch.AssertValue(factory);
            ch.AssertValue(dataset);
            ch.Assert(dataset.GetNumRows() == numRow);
            ch.Assert(dataset.GetNumCols() == catMetaData.NumCol);
            var rand = Host.Rand;

            // To avoid array resize, batch size should bigger than size of one row.
            batchSize = Math.Max(batchSize, catMetaData.NumCol);
            double density       = DetectDensity(factory);
            int    numElem       = 0;
            int    totalRowCount = 0;
            int    curRowCount   = 0;

            if (density >= 0.5)
            {
                int batchRow = batchSize / catMetaData.NumCol;
                batchRow = Math.Max(1, batchRow);
                if (batchRow > numRow)
                {
                    batchRow = numRow;
                }

                // This can only happen if the size of ONE example(row) exceeds the max array size. This looks like a very unlikely case.
                if ((long)catMetaData.NumCol * batchRow > Utils.ArrayMaxSize)
                {
                    throw ch.Except("Size of array exceeded the " + nameof(Utils.ArrayMaxSize));
                }

                float[] features = new float[catMetaData.NumCol * batchRow];

                using (var cursor = factory.Create())
                {
                    while (cursor.MoveNext())
                    {
                        ch.Assert(totalRowCount < numRow);
                        CopyToArray(ch, cursor, features, catMetaData, rand, ref numElem);
                        ++totalRowCount;
                        ++curRowCount;
                        if (batchRow == curRowCount)
                        {
                            ch.Assert(numElem == curRowCount * catMetaData.NumCol);
                            // PushRows is run by multi-threading inside, so lock here.
                            lock (LightGbmShared.LockForMultiThreadingInside)
                                dataset.PushRows(features, curRowCount, catMetaData.NumCol, totalRowCount - curRowCount);
                            curRowCount = 0;
                            numElem     = 0;
                        }
                    }
                    ch.Assert(totalRowCount == numRow);
                    if (curRowCount > 0)
                    {
                        ch.Assert(numElem == curRowCount * catMetaData.NumCol);
                        // PushRows is run by multi-threading inside, so lock here.
                        lock (LightGbmShared.LockForMultiThreadingInside)
                            dataset.PushRows(features, curRowCount, catMetaData.NumCol, totalRowCount - curRowCount);
                    }
                }
            }
            else
            {
                int esimateBatchRow = (int)(batchSize / (catMetaData.NumCol * density));
                esimateBatchRow = Math.Max(1, esimateBatchRow);
                float[] features = new float[batchSize];
                int[]   indices  = new int[batchSize];
                int[]   indptr   = new int[esimateBatchRow + 1];

                using (var cursor = factory.Create())
                {
                    while (cursor.MoveNext())
                    {
                        ch.Assert(totalRowCount < numRow);
                        // Need push rows to LightGBM.
                        if (numElem + cursor.Features.Count > features.Length)
                        {
                            // Mini batch size is greater than size of one row.
                            // So, at least we have the data of one row.
                            ch.Assert(curRowCount > 0);
                            Utils.EnsureSize(ref indptr, curRowCount + 1);
                            indptr[curRowCount] = numElem;
                            // PushRows is run by multi-threading inside, so lock here.
                            lock (LightGbmShared.LockForMultiThreadingInside)
                            {
                                dataset.PushRows(indptr, indices, features,
                                                 curRowCount + 1, numElem, catMetaData.NumCol, totalRowCount - curRowCount);
                            }
                            curRowCount = 0;
                            numElem     = 0;
                        }
                        Utils.EnsureSize(ref indptr, curRowCount + 1);
                        indptr[curRowCount] = numElem;
                        CopyToCsr(ch, cursor, indices, features, catMetaData, rand, ref numElem);
                        ++totalRowCount;
                        ++curRowCount;
                    }
                    ch.Assert(totalRowCount == numRow);
                    if (curRowCount > 0)
                    {
                        Utils.EnsureSize(ref indptr, curRowCount + 1);
                        indptr[curRowCount] = numElem;
                        // PushRows is run by multi-threading inside, so lock here.
                        lock (LightGbmShared.LockForMultiThreadingInside)
                        {
                            dataset.PushRows(indptr, indices, features, curRowCount + 1,
                                             numElem, catMetaData.NumCol, totalRowCount - curRowCount);
                        }
                    }
                }
            }
        }
        private void TrainCore(IChannel ch, IProgressChannel pch, Dataset dtrain, CategoricalMetaData catMetaData, Dataset dvalid = null)
        {
            Host.AssertValue(ch);
            Host.AssertValue(pch);
            Host.AssertValue(dtrain);
            Host.AssertValueOrNull(dvalid);
            // For multi class, the number of labels is required.
            ch.Assert(PredictionKind != PredictionKind.MultiClassClassification || Options.ContainsKey("num_class"),
                      "LightGBM requires the number of classes to be specified in the parameters.");

            // Only enable one trainer to run at one time.
            lock (LightGbmShared.LockForMultiThreadingInside)
            {
                ch.Info("LightGBM objective={0}", Options["objective"]);
                using (Booster bst = WrappedLightGbmTraining.Train(ch, pch, Options, dtrain,
                                                                   dvalid: dvalid, numIteration: Args.NumBoostRound,
                                                                   verboseEval: Args.VerboseEval, earlyStoppingRound: Args.EarlyStoppingRound))
                {
                    TrainedEnsemble = bst.GetModel(catMetaData.CategoricalBoudaries);
                }
            }
        }
        /// <summary>
        /// Create a dataset from the sampling data.
        /// </summary>
        private void CreateDatasetFromSamplingData(IChannel ch, FloatLabelCursor.Factory factory,
                                                   int numRow, string param, float[] labels, float[] weights, int[] groups, CategoricalMetaData catMetaData,
                                                   out Dataset dataset)
        {
            Host.AssertValue(ch);

            int numSampleRow = GetNumSampleRow(numRow, FeatureCount);

            var    rand        = Host.Rand;
            double averageStep = (double)numRow / numSampleRow;
            int    totalIdx    = 0;
            int    sampleIdx   = 0;
            double density     = DetectDensity(factory);

            double[][] sampleValuePerColumn   = new double[catMetaData.NumCol][];
            int[][]    sampleIndicesPerColumn = new int[catMetaData.NumCol][];
            int[]      nonZeroCntPerColumn    = new int[catMetaData.NumCol];
            int        estimateNonZeroCnt     = (int)(numSampleRow * density);

            estimateNonZeroCnt = Math.Max(1, estimateNonZeroCnt);
            for (int i = 0; i < catMetaData.NumCol; i++)
            {
                nonZeroCntPerColumn[i]    = 0;
                sampleValuePerColumn[i]   = new double[estimateNonZeroCnt];
                sampleIndicesPerColumn[i] = new int[estimateNonZeroCnt];
            }
            ;
            using (var cursor = factory.Create())
            {
                int step = 1;
                if (averageStep > 1)
                {
                    step = rand.Next((int)(2 * averageStep - 1)) + 1;
                }
                while (MoveMany(cursor, step))
                {
                    if (cursor.Features.IsDense)
                    {
                        GetFeatureValueDense(ch, cursor, catMetaData, rand, out float[] featureValues);
                        for (int i = 0; i < catMetaData.NumCol; ++i)
                        {
                            float fv = featureValues[i];
                            if (fv == 0)
                            {
                                continue;
                            }
                            int curNonZeroCnt = nonZeroCntPerColumn[i];
                            Utils.EnsureSize(ref sampleValuePerColumn[i], curNonZeroCnt + 1);
                            Utils.EnsureSize(ref sampleIndicesPerColumn[i], curNonZeroCnt + 1);
                            sampleValuePerColumn[i][curNonZeroCnt]   = fv;
                            sampleIndicesPerColumn[i][curNonZeroCnt] = sampleIdx;
                            nonZeroCntPerColumn[i] = curNonZeroCnt + 1;
                        }
                    }
                    else
                    {
                        GetFeatureValueSparse(ch, cursor, catMetaData, rand, out int[] featureIndices, out float[] featureValues, out int cnt);
                        for (int i = 0; i < cnt; ++i)
                        {
                            int   colIdx = featureIndices[i];
                            float fv     = featureValues[i];
                            if (fv == 0)
                            {
                                continue;
                            }
                            int curNonZeroCnt = nonZeroCntPerColumn[colIdx];
                            Utils.EnsureSize(ref sampleValuePerColumn[colIdx], curNonZeroCnt + 1);
                            Utils.EnsureSize(ref sampleIndicesPerColumn[colIdx], curNonZeroCnt + 1);
                            sampleValuePerColumn[colIdx][curNonZeroCnt]   = fv;
                            sampleIndicesPerColumn[colIdx][curNonZeroCnt] = sampleIdx;
                            nonZeroCntPerColumn[colIdx] = curNonZeroCnt + 1;
                        }
                    }
                    totalIdx += step;
                    ++sampleIdx;
                    if (numSampleRow == sampleIdx || numRow == totalIdx)
                    {
                        break;
                    }
                    averageStep = (double)(numRow - totalIdx) / (numSampleRow - sampleIdx);
                    step        = 1;
                    if (averageStep > 1)
                    {
                        step = rand.Next((int)(2 * averageStep - 1)) + 1;
                    }
                }
            }
            dataset = new Dataset(sampleValuePerColumn, sampleIndicesPerColumn, catMetaData.NumCol, nonZeroCntPerColumn, sampleIdx, numRow, param, labels, weights, groups);
        }
Пример #4
0
        /// <summary>
        /// Train and return a booster.
        /// </summary>
        public static Booster Train(IChannel ch, IProgressChannel pch,
                                    Dictionary <string, object> parameters, Dataset dtrain, Dataset dvalid = null, int numIteration = 100,
                                    bool verboseEval = true, int earlyStoppingRound = 0)
        {
            // create Booster.
            Booster bst = new Booster(parameters, dtrain, dvalid);

            // Disable early stopping if we don't have validation data.
            if (dvalid == null && earlyStoppingRound > 0)
            {
                earlyStoppingRound = 0;
                ch.Warning("Validation dataset not present, early stopping will be disabled.");
            }

            int    bestIter              = 0;
            double bestScore             = double.MaxValue;
            double factorToSmallerBetter = 1.0;

            var metric = (string)parameters["metric"];

            if (earlyStoppingRound > 0 && (metric == "auc" || metric == "ndcg" || metric == "map"))
            {
                factorToSmallerBetter = -1.0;
            }

            const int evalFreq = 50;

            var metrics = new List <string>()
            {
                "Iteration"
            };
            var units = new List <string>()
            {
                "iterations"
            };

            if (verboseEval)
            {
                ch.Assert(parameters.ContainsKey("metric"));
                metrics.Add("Training-" + parameters["metric"]);
                if (dvalid != null)
                {
                    metrics.Add("Validation-" + parameters["metric"]);
                }
            }

            var header = new ProgressHeader(metrics.ToArray(), units.ToArray());

            int    iter       = 0;
            double trainError = double.NaN;
            double validError = double.NaN;

            pch.SetHeader(header, e =>
            {
                e.SetProgress(0, iter, numIteration);
                if (verboseEval)
                {
                    e.SetProgress(1, trainError);
                    if (dvalid != null)
                    {
                        e.SetProgress(2, validError);
                    }
                }
            });
            for (iter = 0; iter < numIteration; ++iter)
            {
                if (bst.Update())
                {
                    break;
                }

                if (earlyStoppingRound > 0)
                {
                    validError = bst.EvalValid();
                    if (validError * factorToSmallerBetter < bestScore)
                    {
                        bestScore = validError * factorToSmallerBetter;
                        bestIter  = iter;
                    }
                    if (iter - bestIter >= earlyStoppingRound)
                    {
                        ch.Info($"Met early stopping, best iteration: {bestIter + 1}, best score: {bestScore / factorToSmallerBetter}");
                        break;
                    }
                }
                if ((iter + 1) % evalFreq == 0)
                {
                    if (verboseEval)
                    {
                        trainError = bst.EvalTrain();
                        if (dvalid == null)
                        {
                            pch.Checkpoint(new double?[] { iter + 1, trainError });
                        }
                        else
                        {
                            if (earlyStoppingRound == 0)
                            {
                                validError = bst.EvalValid();
                            }
                            pch.Checkpoint(new double?[] { iter + 1,
                                                           trainError, validError });
                        }
                    }
                    else
                    {
                        pch.Checkpoint(new double?[] { iter + 1 });
                    }
                }
            }
            // Set the BestIteration.
            if (iter != numIteration && earlyStoppingRound > 0)
            {
                bst.BestIteration = bestIter + 1;
            }
            return(bst);
        }