public int GetNumCols() { int res = 0; LightGbmInterfaceUtils.Check(WrappedLightGbmInterface.DatasetGetNumFeature(_handle, ref res)); return(res); }
internal unsafe string GetModelString() { int bufLen = 2 << 15; byte[] buffer = new byte[bufLen]; int size = 0; fixed(byte *ptr = buffer) LightGbmInterfaceUtils.Check(WrappedLightGbmInterface.BoosterSaveModelToString(Handle, 0, BestIteration, bufLen, ref size, ptr)); // If buffer size is not enough, reallocate buffer and get again. if (size > bufLen) { bufLen = size; buffer = new byte[bufLen]; fixed(byte *ptr = buffer) LightGbmInterfaceUtils.Check(WrappedLightGbmInterface.BoosterSaveModelToString(Handle, 0, BestIteration, bufLen, ref size, ptr)); } byte[] content = new byte[size]; Array.Copy(buffer, content, size); fixed(byte *ptr = content) return(LightGbmInterfaceUtils.GetString((IntPtr)ptr)); }
public bool Update() { int isFinished = 0; LightGbmInterfaceUtils.Check(WrappedLightGbmInterface.BoosterUpdateOneIter(Handle, ref isFinished)); return(isFinished == 1); }
private protected override OneVersusAllModelParameters CreatePredictor() { Host.Check(TrainedEnsemble != null, "The predictor cannot be created before training is complete."); Host.Assert(_numClass > 1, "Must know the number of classes before creating a predictor."); Host.Assert(TrainedEnsemble.NumTrees % _numClass == 0, "Number of trees should be a multiple of number of classes."); var innerArgs = LightGbmInterfaceUtils.JoinParameters(GbmOptions); IPredictorProducing <float>[] predictors = new IPredictorProducing <float> [_tlcNumClass]; for (int i = 0; i < _tlcNumClass; ++i) { var pred = CreateBinaryPredictor(i, innerArgs); var cali = new PlattCalibrator(Host, -0.5, 0); predictors[i] = new FeatureWeightsCalibratedModelParameters <LightGbmBinaryModelParameters, PlattCalibrator>(Host, pred, cali); } string obj = (string)GetGbmParameters()["objective"]; if (obj == "multiclass") { return(OneVersusAllModelParameters.Create(Host, OneVersusAllModelParameters.OutputFormula.Softmax, predictors)); } else { return(OneVersusAllModelParameters.Create(Host, predictors)); } }
private protected override LightGbmRankingModelParameters CreatePredictor() { Host.Check(TrainedEnsemble != null, "The predictor cannot be created before training is complete"); var innerArgs = LightGbmInterfaceUtils.JoinParameters(Options); return(new LightGbmRankingModelParameters(Host, TrainedEnsemble, FeatureCount, innerArgs)); }
public void Dispose() { if (Handle != IntPtr.Zero) { LightGbmInterfaceUtils.Check(WrappedLightGbmInterface.BoosterFree(Handle)); } Handle = IntPtr.Zero; }
public void Dispose() { if (_handle != IntPtr.Zero) { LightGbmInterfaceUtils.Check(WrappedLightGbmInterface.DatasetFree(_handle)); } _handle = IntPtr.Zero; }
public unsafe void SetLabel(float[] labels) { Contracts.AssertValue(labels); Contracts.Assert(labels.Length == GetNumRows()); fixed(float *ptr = labels) LightGbmInterfaceUtils.Check(WrappedLightGbmInterface.DatasetSetField(_handle, "label", (IntPtr)ptr, labels.Length, WrappedLightGbmInterface.CApiDType.Float32)); }
/// <summary> /// Append examples to LightGBM dataset. /// </summary> /// <param name="data">Dense (# of rows)-by-(# of columns) matrix flattened in a row-major format. One row per example. /// The value at the i-th row and j-th column is stored in data[j + i * (# of columns)].</param> /// <param name="numRow"># of rows of the data matrix.</param> /// <param name="numCol"># of columns of the data matrix.</param> /// <param name="startRowIdx">The actual row index of the first row pushed in. If it's 36, the first row in data would be the 37th row in <see cref="Dataset"/>.</param> public void PushRows(float[] data, int numRow, int numCol, int startRowIdx) { Contracts.Assert(startRowIdx == _lastPushedRowID); Contracts.Assert(numCol == GetNumCols()); Contracts.Assert(numRow > 0); Contracts.Assert(startRowIdx <= GetNumRows() - numRow); LightGbmInterfaceUtils.Check(WrappedLightGbmInterface.DatasetPushRows(_handle, data, numRow, numCol, startRowIdx)); _lastPushedRowID = startRowIdx + numRow; }
/// <summary> /// Create a <see cref="Dataset"/> for storing training and prediciton data under LightGBM framework. The main goal of this function /// is not marshaling ML.NET data set into LightGBM format but just creates a (unmanaged) container where examples can be pushed into by calling /// <see cref="PushRows(float[], int, int, int)"/>. It also pre-allocates memory so the actual size (number of examples and number of features) /// of the data set is required. A sub-sampled version of the original data set is passed in to compute some statictics needed by the training /// procedure. Note that we use "original" to indicate a property from the unsampled data set. /// </summary> /// <param name="sampleValuePerColumn">A 2-D array which encodes the sub-sampled data matrix. sampleValuePerColumn[i] stores /// all the non-zero values of the i-th feature. sampleValuePerColumn[i][j] is the j-th non-zero value of i-th feature encountered when scanning /// the values row-by-row (i.e., example-by-example) in the matrix and column-by-column (i.e., feature-by-feature) within one row. It is similar /// to CSC format for storing sparse matrix.</param> /// <param name="sampleIndicesPerColumn">A 2-D array which encodes sub-sampled example indexes of non-zero features stored in sampleValuePerColumn. /// The sampleIndicesPerColumn[i][j]-th example has a non-zero i-th feature whose value is sampleValuePerColumn[i][j].</param> /// <param name="numCol">Total number of features in the original data.</param> /// <param name="sampleNonZeroCntPerColumn">sampleNonZeroCntPerColumn[i] is the size of sampleValuePerColumn[i].</param> /// <param name="numSampleRow">The number of sampled examples in the sub-sampled data matrix.</param> /// <param name="numTotalRow">The number of original examples added using <see cref="PushRows(float[], int, int, int)"/>.</param> /// <param name="param">LightGBM parameter used in https://github.com/Microsoft/LightGBM/blob/c920e6345bcb41fc1ec6ac338f5437034b9f0d38/src/c_api.cpp#L421. </param> /// <param name="labels">Labels of the original data. labels[i] is the label of the i-th original example.</param> /// <param name="weights">Example weights of the original data. weights[i] is the weight of the i-th original example.</param> /// <param name="groups">Group identifiers of the original data. groups[i] is the group ID of the i-th original example.</param> public unsafe Dataset(double[][] sampleValuePerColumn, int[][] sampleIndicesPerColumn, int numCol, int[] sampleNonZeroCntPerColumn, int numSampleRow, int numTotalRow, string param, float[] labels, float[] weights = null, int[] groups = null) { _handle = null; // Use GCHandle to pin the memory, avoid the memory relocation. GCHandle[] gcValues = new GCHandle[numCol]; GCHandle[] gcIndices = new GCHandle[numCol]; try { double *[] ptrArrayValues = new double *[numCol]; int *[] ptrArrayIndices = new int *[numCol]; for (int i = 0; i < numCol; i++) { gcValues[i] = GCHandle.Alloc(sampleValuePerColumn[i], GCHandleType.Pinned); ptrArrayValues[i] = (double *)gcValues[i].AddrOfPinnedObject().ToPointer(); gcIndices[i] = GCHandle.Alloc(sampleIndicesPerColumn[i], GCHandleType.Pinned); ptrArrayIndices[i] = (int *)gcIndices[i].AddrOfPinnedObject().ToPointer(); } ; fixed(double **ptrValues = ptrArrayValues) fixed(int **ptrIndices = ptrArrayIndices) { // Create container. Examples will pushed in later. LightGbmInterfaceUtils.Check(WrappedLightGbmInterface.DatasetCreateFromSampledColumn( (IntPtr)ptrValues, (IntPtr)ptrIndices, numCol, sampleNonZeroCntPerColumn, numSampleRow, numTotalRow, param, out _handle)); } } finally { for (int i = 0; i < numCol; i++) { if (gcValues[i].IsAllocated) { gcValues[i].Free(); } if (gcIndices[i].IsAllocated) { gcIndices[i].Free(); } } ; } // Before adding examples (i.e., feature vectors of the original data set), the original labels, weights, and groups are added. SetLabel(labels); SetWeights(weights); SetGroup(groups); Contracts.Assert(GetNumCols() == numCol); Contracts.Assert(GetNumRows() == numTotalRow); }
private protected override CalibratedModelParametersBase <LightGbmBinaryModelParameters, PlattCalibrator> CreatePredictor() { Host.Check(TrainedEnsemble != null, "The predictor cannot be created before training is complete"); var innerArgs = LightGbmInterfaceUtils.JoinParameters(base.GbmOptions); var pred = new LightGbmBinaryModelParameters(Host, TrainedEnsemble, FeatureCount, innerArgs); var cali = new PlattCalibrator(Host, -LightGbmTrainerOptions.Sigmoid, 0); return(new FeatureWeightsCalibratedModelParameters <LightGbmBinaryModelParameters, PlattCalibrator>(Host, pred, cali)); }
public unsafe void SetGroup(int[] groups) { if (groups != null) { fixed(int *ptr = groups) LightGbmInterfaceUtils.Check(WrappedLightGbmInterface.DatasetSetField(_handle, "group", (IntPtr)ptr, groups.Length, WrappedLightGbmInterface.CApiDType.Int32)); } }
public Dataset(Dataset reference, int numTotalRow, float[] labels, float[] weights = null, int[] groups = null) { WrappedLightGbmInterface.SafeDataSetHandle refHandle = reference?.Handle; LightGbmInterfaceUtils.Check(WrappedLightGbmInterface.DatasetCreateByReference(refHandle, numTotalRow, out _handle)); SetLabel(labels); SetWeights(weights); SetGroup(groups); }
public void PushRows(int[] indPtr, int[] indices, float[] data, int nIndptr, long numElem, int numCol, int startRowIdx) { Contracts.Assert(startRowIdx == _lastPushedRowID); Contracts.Assert(numCol == GetNumCols()); Contracts.Assert(startRowIdx < GetNumRows()); LightGbmInterfaceUtils.Check( WrappedLightGbmInterface.DatasetPushRowsByCsr( _handle, indPtr, indices, data, nIndptr, numElem, numCol, startRowIdx)); _lastPushedRowID = startRowIdx + nIndptr - 1; }
// Not used now. Can use for the continued train. public unsafe void SetInitScore(double[] initScores) { if (initScores != null) { Contracts.Assert(initScores.Length % GetNumRows() == 0); fixed(double *ptr = initScores) LightGbmInterfaceUtils.Check(WrappedLightGbmInterface.DatasetSetField(_handle, "init_score", (IntPtr)ptr, initScores.Length, WrappedLightGbmInterface.CApiDType.Float64)); } }
private unsafe double Eval(int dataIdx) { if (!_hasMetric) { return(double.NaN); } int outLen = 0; double[] res = new double[1]; fixed(double *ptr = res) LightGbmInterfaceUtils.Check(WrappedLightGbmInterface.BoosterGetEval(Handle, dataIdx, ref outLen, ptr)); return(res[0]); }
public Dataset(Dataset reference, int numTotalRow, float[] labels, float[] weights = null, int[] groups = null) { IntPtr refHandle = IntPtr.Zero; if (reference != null) { refHandle = reference.Handle; } LightGbmInterfaceUtils.Check(WrappedLightGbmInterface.DatasetCreateByReference(refHandle, numTotalRow, ref _handle)); SetLabel(labels); SetWeights(weights); SetGroup(groups); }
public Booster(Dictionary <string, object> parameters, Dataset trainset, Dataset validset = null) { var param = LightGbmInterfaceUtils.JoinParameters(parameters); LightGbmInterfaceUtils.Check(WrappedLightGbmInterface.BoosterCreate(trainset.Handle, param, out var handle)); Handle = handle; if (validset != null) { LightGbmInterfaceUtils.Check(WrappedLightGbmInterface.BoosterAddValidData(Handle, validset.Handle)); _hasValid = true; } int numEval = 0; BestIteration = -1; LightGbmInterfaceUtils.Check(WrappedLightGbmInterface.BoosterGetEvalCounts(Handle, ref numEval)); // At most one metric in ML.NET. Contracts.Assert(numEval <= 1); if (numEval == 1) { _hasMetric = true; } }
public unsafe void SetWeights(float[] weights) { if (weights != null) { Contracts.Assert(weights.Length == GetNumRows()); // Skip SetWeights if all weights are same. bool allSame = true; for (int i = 1; i < weights.Length; ++i) { if (weights[i] != weights[0]) { allSame = false; break; } } if (!allSame) { fixed(float *ptr = weights) LightGbmInterfaceUtils.Check(WrappedLightGbmInterface.DatasetSetField(_handle, "weight", (IntPtr)ptr, weights.Length, WrappedLightGbmInterface.CApiDType.Float32)); } } }
protected override bool ReleaseHandle() { LightGbmInterfaceUtils.Check(DatasetFree(handle)); return(true); }
internal void UpdateParameters(Dictionary <string, object> res) { FieldInfo[] fields = BoosterOptions.GetType().GetFields(BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Instance); foreach (var field in fields) { var attribute = field.GetCustomAttribute <ArgumentAttribute>(false); if (attribute == null) { continue; } var name = NameMapping.ContainsKey(field.Name) ? NameMapping[field.Name] : LightGbmInterfaceUtils.GetOptionName(field.Name); res[name] = field.GetValue(BoosterOptions); } }