Exemple #1
0
        internal unsafe string GetModelString()
        {
            int bufLen = 2 << 15;

            byte[] buffer = new byte[bufLen];
            int    size   = 0;

            fixed(byte *ptr = buffer)
            LightGbmInterfaceUtils.Check(WrappedLightGbmInterface.BoosterSaveModelToString(Handle, 0, BestIteration, bufLen, ref size, ptr));

            // If buffer size is not enough, reallocate buffer and get again.
            if (size > bufLen)
            {
                bufLen = size;
                buffer = new byte[bufLen];

                fixed(byte *ptr = buffer)
                LightGbmInterfaceUtils.Check(WrappedLightGbmInterface.BoosterSaveModelToString(Handle, 0, BestIteration, bufLen, ref size, ptr));
            }
            byte[] content = new byte[size];
            Array.Copy(buffer, content, size);

            fixed(byte *ptr = content)
            return(LightGbmInterfaceUtils.GetString((IntPtr)ptr));
        }
Exemple #2
0
        private void InitParallelTraining()
        {
            Options          = LightGbmTrainerOptions.ToDictionary(Host);
            ParallelTraining = LightGbmTrainerOptions.ParallelTrainer != null?LightGbmTrainerOptions.ParallelTrainer.CreateComponent(Host) : new SingleTrainer();

            if (ParallelTraining.ParallelType() != "serial" && ParallelTraining.NumMachines() > 1)
            {
                Options["tree_learner"] = ParallelTraining.ParallelType();
                var otherParams = ParallelTraining.AdditionalParams();
                if (otherParams != null)
                {
                    foreach (var pair in otherParams)
                    {
                        Options[pair.Key] = pair.Value;
                    }
                }

                Contracts.CheckValue(ParallelTraining.GetReduceScatterFunction(), nameof(ParallelTraining.GetReduceScatterFunction));
                Contracts.CheckValue(ParallelTraining.GetAllgatherFunction(), nameof(ParallelTraining.GetAllgatherFunction));
                LightGbmInterfaceUtils.Check(WrappedLightGbmInterface.NetworkInitWithFunctions(
                                                 ParallelTraining.NumMachines(),
                                                 ParallelTraining.Rank(),
                                                 ParallelTraining.GetReduceScatterFunction(),
                                                 ParallelTraining.GetAllgatherFunction()
                                                 ));
            }
        }
Exemple #3
0
        private protected override LightGbmRankingModelParameters CreatePredictor()
        {
            Host.Check(TrainedEnsemble != null, "The predictor cannot be created before training is complete");
            var innerArgs = LightGbmInterfaceUtils.JoinParameters(Options);

            return(new LightGbmRankingModelParameters(Host, TrainedEnsemble, FeatureCount, innerArgs));
        }
Exemple #4
0
 private void DisposeParallelTraining()
 {
     if (ParallelTraining.NumMachines() > 1)
     {
         LightGbmInterfaceUtils.Check(WrappedLightGbmInterface.NetworkFree());
     }
 }
Exemple #5
0
        public int GetNumCols()
        {
            int res = 0;

            LightGbmInterfaceUtils.Check(WrappedLightGbmInterface.DatasetGetNumFeature(_handle, ref res));
            return(res);
        }
Exemple #6
0
        private Dataset LoadTrainingData(IChannel ch, RoleMappedData trainData, out CategoricalMetaData catMetaData)
        {
            // Verifications.
            Host.AssertValue(ch);
            ch.CheckValue(trainData, nameof(trainData));

            CheckDataValid(ch, trainData);

            // Load metadata first.
            var factory = CreateCursorFactory(trainData);

            GetMetainfo(ch, factory, out int numRow, out float[] labels, out float[] weights, out int[] groups);
            catMetaData = GetCategoricalMetaData(ch, trainData, numRow);
            GetDefaultParameters(ch, numRow, catMetaData.CategoricalBoudaries != null, catMetaData.TotalCats);

            Dataset dtrain;
            string  param = LightGbmInterfaceUtils.JoinParameters(Options);

            // To reduce peak memory usage, only enable one sampling task at any given time.
            lock (LightGbmShared.SampleLock)
            {
                CreateDatasetFromSamplingData(ch, factory, numRow,
                                              param, labels, weights, groups, catMetaData, out dtrain);
            }

            // Push rows into dataset.
            LoadDataset(ch, factory, dtrain, numRow, LightGbmTrainerOptions.BatchSize, catMetaData);

            // Some checks.
            CheckAndUpdateParametersBeforeTraining(ch, trainData, labels, groups);
            return(dtrain);
        }
Exemple #7
0
        public bool Update()
        {
            int isFinished = 0;

            LightGbmInterfaceUtils.Check(WrappedLightGbmInterface.BoosterUpdateOneIter(Handle, ref isFinished));
            return(isFinished == 1);
        }
Exemple #8
0
        private protected override OvaModelParameters CreatePredictor()
        {
            Host.Check(TrainedEnsemble != null, "The predictor cannot be created before training is complete.");

            Host.Assert(_numClass > 1, "Must know the number of classes before creating a predictor.");
            Host.Assert(TrainedEnsemble.NumTrees % _numClass == 0, "Number of trees should be a multiple of number of classes.");

            var innerArgs = LightGbmInterfaceUtils.JoinParameters(Options);

            IPredictorProducing <float>[] predictors = new IPredictorProducing <float> [_tlcNumClass];
            for (int i = 0; i < _tlcNumClass; ++i)
            {
                var pred = CreateBinaryPredictor(i, innerArgs);
                var cali = new PlattCalibrator(Host, -0.5, 0);
                predictors[i] = new FeatureWeightsCalibratedPredictor(Host, pred, cali);
            }
            string obj = (string)GetGbmParameters()["objective"];

            if (obj == "multiclass")
            {
                return(OvaModelParameters.Create(Host, OvaModelParameters.OutputFormula.Softmax, predictors));
            }
            else
            {
                return(OvaModelParameters.Create(Host, predictors));
            }
        }
Exemple #9
0
 public void Dispose()
 {
     if (_handle != IntPtr.Zero)
     {
         LightGbmInterfaceUtils.Check(WrappedLightGbmInterface.DatasetFree(_handle));
     }
     _handle = IntPtr.Zero;
 }
Exemple #10
0
 public void Dispose()
 {
     if (Handle != IntPtr.Zero)
     {
         LightGbmInterfaceUtils.Check(WrappedLightGbmInterface.BoosterFree(Handle));
     }
     Handle = IntPtr.Zero;
 }
        private protected override CalibratedModelParametersBase <LightGbmBinaryModelParameters, PlattCalibrator> CreatePredictor()
        {
            Host.Check(TrainedEnsemble != null, "The predictor cannot be created before training is complete");
            var innerArgs = LightGbmInterfaceUtils.JoinParameters(Options);
            var pred      = new LightGbmBinaryModelParameters(Host, TrainedEnsemble, FeatureCount, innerArgs);
            var cali      = new PlattCalibrator(Host, -0.5, 0);

            return(new FeatureWeightsCalibratedModelParameters <LightGbmBinaryModelParameters, PlattCalibrator>(Host, pred, cali));
        }
Exemple #12
0
        public unsafe void SetLabel(float[] labels)
        {
            Contracts.AssertValue(labels);
            Contracts.Assert(labels.Length == GetNumRows());

            fixed(float *ptr = labels)
            LightGbmInterfaceUtils.Check(WrappedLightGbmInterface.DatasetSetField(_handle, "label", (IntPtr)ptr, labels.Length,
                                                                                  WrappedLightGbmInterface.CApiDType.Float32));
        }
Exemple #13
0
 public unsafe void SetGroup(int[] groups)
 {
     if (groups != null)
     {
         fixed(int *ptr = groups)
         LightGbmInterfaceUtils.Check(WrappedLightGbmInterface.DatasetSetField(_handle, "group", (IntPtr)ptr, groups.Length,
                                                                               WrappedLightGbmInterface.CApiDType.Int32));
     }
 }
Exemple #14
0
        /// <summary>
        /// Create a <see cref="Dataset"/> for storing training and prediciton data under LightGBM framework. The main goal of this function
        /// is not marshaling ML.NET data set into LightGBM format but just creates a (unmanaged) container where examples can be pushed into by calling
        /// <see cref="PushRows(float[], int, int, int)"/>. It also pre-allocates memory so the actual size (number of examples and number of features)
        /// of the data set is required. A sub-sampled version of the original data set is passed in to compute some statictics needed by the training
        /// procedure. Note that we use "original" to indicate a property from the unsampled data set.
        /// </summary>
        /// <param name="sampleValuePerColumn">A 2-D array which encodes the sub-sampled data matrix. sampleValuePerColumn[i] stores
        /// all the non-zero values of the i-th feature. sampleValuePerColumn[i][j] is the j-th non-zero value of i-th feature encountered when scanning
        /// the values row-by-row (i.e., example-by-example) in the matrix and column-by-column (i.e., feature-by-feature) within one row. It is similar
        /// to CSC format for storing sparse matrix.</param>
        /// <param name="sampleIndicesPerColumn">A 2-D array which encodes sub-sampled example indexes of non-zero features stored in sampleValuePerColumn.
        /// The sampleIndicesPerColumn[i][j]-th example has a non-zero i-th feature whose value is sampleValuePerColumn[i][j].</param>
        /// <param name="numCol">Total number of features in the original data.</param>
        /// <param name="sampleNonZeroCntPerColumn">sampleNonZeroCntPerColumn[i] is the size of sampleValuePerColumn[i].</param>
        /// <param name="numSampleRow">The number of sampled examples in the sub-sampled data matrix.</param>
        /// <param name="numTotalRow">The number of original examples added using <see cref="PushRows(float[], int, int, int)"/>.</param>
        /// <param name="param">LightGBM parameter used in https://github.com/Microsoft/LightGBM/blob/c920e6345bcb41fc1ec6ac338f5437034b9f0d38/src/c_api.cpp#L421. </param>
        /// <param name="labels">Labels of the original data. labels[i] is the label of the i-th original example.</param>
        /// <param name="weights">Example weights of the original data. weights[i] is the weight of the i-th original example.</param>
        /// <param name="groups">Group identifiers of the original data. groups[i] is the group ID of the i-th original example.</param>
        public unsafe Dataset(double[][] sampleValuePerColumn,
                              int[][] sampleIndicesPerColumn,
                              int numCol,
                              int[] sampleNonZeroCntPerColumn,
                              int numSampleRow,
                              int numTotalRow,
                              string param, float[] labels, float[] weights = null, int[] groups = null)
        {
            _handle = IntPtr.Zero;

            // Use GCHandle to pin the memory, avoid the memory relocation.
            GCHandle[] gcValues  = new GCHandle[numCol];
            GCHandle[] gcIndices = new GCHandle[numCol];
            try
            {
                double *[] ptrArrayValues  = new double *[numCol];
                int *[]    ptrArrayIndices = new int *[numCol];
                for (int i = 0; i < numCol; i++)
                {
                    gcValues[i]        = GCHandle.Alloc(sampleValuePerColumn[i], GCHandleType.Pinned);
                    ptrArrayValues[i]  = (double *)gcValues[i].AddrOfPinnedObject().ToPointer();
                    gcIndices[i]       = GCHandle.Alloc(sampleIndicesPerColumn[i], GCHandleType.Pinned);
                    ptrArrayIndices[i] = (int *)gcIndices[i].AddrOfPinnedObject().ToPointer();
                }
                ;
                fixed(double **ptrValues = ptrArrayValues)
                fixed(int **ptrIndices = ptrArrayIndices)
                {
                    // Create container. Examples will pushed in later.
                    LightGbmInterfaceUtils.Check(WrappedLightGbmInterface.DatasetCreateFromSampledColumn(
                                                     (IntPtr)ptrValues, (IntPtr)ptrIndices, numCol, sampleNonZeroCntPerColumn, numSampleRow, numTotalRow,
                                                     param, ref _handle));
                }
            }
            finally
            {
                for (int i = 0; i < numCol; i++)
                {
                    if (gcValues[i].IsAllocated)
                    {
                        gcValues[i].Free();
                    }
                    if (gcIndices[i].IsAllocated)
                    {
                        gcIndices[i].Free();
                    }
                }
                ;
            }
            // Before adding examples (i.e., feature vectors of the original data set), the original labels, weights, and groups are added.
            SetLabel(labels);
            SetWeights(weights);
            SetGroup(groups);

            Contracts.Assert(GetNumCols() == numCol);
            Contracts.Assert(GetNumRows() == numTotalRow);
        }
Exemple #15
0
 /// <summary>
 /// Append examples to LightGBM dataset.
 /// </summary>
 /// <param name="data">Dense (# of rows)-by-(# of columns) matrix flattened in a row-major format. One row per example.
 /// The value at the i-th row and j-th column is stored in data[j + i * (# of columns)].</param>
 /// <param name="numRow"># of rows of the data matrix.</param>
 /// <param name="numCol"># of columns of the data matrix.</param>
 /// <param name="startRowIdx">The actual row index of the first row pushed in. If it's 36, the first row in data would be the 37th row in <see cref="Dataset"/>.</param>
 public void PushRows(float[] data, int numRow, int numCol, int startRowIdx)
 {
     Contracts.Assert(startRowIdx == _lastPushedRowID);
     Contracts.Assert(numCol == GetNumCols());
     Contracts.Assert(numRow > 0);
     Contracts.Assert(startRowIdx <= GetNumRows() - numRow);
     LightGbmInterfaceUtils.Check(WrappedLightGbmInterface.DatasetPushRows(_handle, data, numRow, numCol, startRowIdx));
     _lastPushedRowID = startRowIdx + numRow;
 }
Exemple #16
0
 public void PushRows(int[] indPtr, int[] indices, float[] data, int nIndptr,
                      long numElem, int numCol, int startRowIdx)
 {
     Contracts.Assert(startRowIdx == _lastPushedRowID);
     Contracts.Assert(numCol == GetNumCols());
     Contracts.Assert(startRowIdx < GetNumRows());
     LightGbmInterfaceUtils.Check(
         WrappedLightGbmInterface.DatasetPushRowsByCsr(
             _handle, indPtr, indices, data, nIndptr, numElem, numCol, startRowIdx));
     _lastPushedRowID = startRowIdx + nIndptr - 1;
 }
Exemple #17
0
        // Not used now. Can use for the continued train.
        public unsafe void SetInitScore(double[] initScores)
        {
            if (initScores != null)
            {
                Contracts.Assert(initScores.Length % GetNumRows() == 0);

                fixed(double *ptr = initScores)
                LightGbmInterfaceUtils.Check(WrappedLightGbmInterface.DatasetSetField(_handle, "init_score", (IntPtr)ptr, initScores.Length,
                                                                                      WrappedLightGbmInterface.CApiDType.Float64));
            }
        }
Exemple #18
0
        private unsafe double Eval(int dataIdx)
        {
            if (!_hasMetric)
            {
                return(double.NaN);
            }
            int outLen = 0;

            double[] res = new double[1];

            fixed(double *ptr = res)
            LightGbmInterfaceUtils.Check(WrappedLightGbmInterface.BoosterGetEval(Handle, dataIdx, ref outLen, ptr));

            return(res[0]);
        }
Exemple #19
0
        public Dataset(Dataset reference, int numTotalRow, float[] labels, float[] weights = null, int[] groups = null)
        {
            IntPtr refHandle = IntPtr.Zero;

            if (reference != null)
            {
                refHandle = reference.Handle;
            }

            LightGbmInterfaceUtils.Check(WrappedLightGbmInterface.DatasetCreateByReference(refHandle, numTotalRow, ref _handle));

            SetLabel(labels);
            SetWeights(weights);
            SetGroup(groups);
        }
Exemple #20
0
 public unsafe void SetWeights(float[] weights)
 {
     if (weights != null)
     {
         Contracts.Assert(weights.Length == GetNumRows());
         // Skip SetWeights if all weights are same.
         bool allSame = true;
         for (int i = 1; i < weights.Length; ++i)
         {
             if (weights[i] != weights[0])
             {
                 allSame = false;
                 break;
             }
         }
         if (!allSame)
         {
             fixed(float *ptr = weights)
             LightGbmInterfaceUtils.Check(WrappedLightGbmInterface.DatasetSetField(_handle, "weight", (IntPtr)ptr, weights.Length,
                                                                                   WrappedLightGbmInterface.CApiDType.Float32));
         }
     }
 }
Exemple #21
0
        public Booster(Dictionary <string, object> parameters, Dataset trainset, Dataset validset = null)
        {
            var param  = LightGbmInterfaceUtils.JoinParameters(parameters);
            var handle = IntPtr.Zero;

            LightGbmInterfaceUtils.Check(WrappedLightGbmInterface.BoosterCreate(trainset.Handle, param, ref handle));
            Handle = handle;
            if (validset != null)
            {
                LightGbmInterfaceUtils.Check(WrappedLightGbmInterface.BoosterAddValidData(Handle, validset.Handle));
                _hasValid = true;
            }

            int numEval = 0;

            BestIteration = -1;
            LightGbmInterfaceUtils.Check(WrappedLightGbmInterface.BoosterGetEvalCounts(Handle, ref numEval));
            // At most one metric in ML.NET.
            Contracts.Assert(numEval <= 1);
            if (numEval == 1)
            {
                _hasMetric = true;
            }
        }