private Dataset LoadTrainingData(IChannel ch, RoleMappedData trainData, out CategoricalMetaData catMetaData) { // Verifications. Host.AssertValue(ch); ch.CheckValue(trainData, nameof(trainData)); CheckDataValid(ch, trainData); // Load metadata first. var factory = CreateCursorFactory(trainData); GetMetainfo(ch, factory, out int numRow, out float[] labels, out float[] weights, out int[] groups); catMetaData = GetCategoricalMetaData(ch, trainData, numRow); GetDefaultParameters(ch, numRow, catMetaData.CategoricalBoudaries != null, catMetaData.TotalCats); Dataset dtrain; string param = LightGbmInterfaceUtils.JoinParameters(Options); // To reduce peak memory usage, only enable one sampling task at any given time. lock (LightGbmShared.SampleLock) { CreateDatasetFromSamplingData(ch, factory, numRow, param, labels, weights, groups, catMetaData, out dtrain); } // Push rows into dataset. LoadDataset(ch, factory, dtrain, numRow, Args.BatchSize, catMetaData); // Some checks. CheckAndUpdateParametersBeforeTraining(ch, trainData, labels, groups); return(dtrain); }
private protected override LightGbmRankingPredictor CreatePredictor() { Host.Check(TrainedEnsemble != null, "The predictor cannot be created before training is complete"); var innerArgs = LightGbmInterfaceUtils.JoinParameters(Options); return(new LightGbmRankingPredictor(Host, TrainedEnsemble, FeatureCount, innerArgs)); }
private protected override OvaPredictor CreatePredictor() { Host.Check(TrainedEnsemble != null, "The predictor cannot be created before training is complete."); Host.Assert(_numClass > 1, "Must know the number of classes before creating a predictor."); Host.Assert(TrainedEnsemble.NumTrees % _numClass == 0, "Number of trees should be a multiple of number of classes."); var innerArgs = LightGbmInterfaceUtils.JoinParameters(Options); IPredictorProducing <float>[] predictors = new IPredictorProducing <float> [_tlcNumClass]; for (int i = 0; i < _tlcNumClass; ++i) { var pred = CreateBinaryPredictor(i, innerArgs); var cali = new PlattCalibrator(Host, -0.5, 0); predictors[i] = new FeatureWeightsCalibratedPredictor(Host, pred, cali); } string obj = (string)GetGbmParameters()["objective"]; if (obj == "multiclass") { return(OvaPredictor.Create(Host, OvaPredictor.OutputFormula.Softmax, predictors)); } else { return(OvaPredictor.Create(Host, predictors)); } }
private protected override IPredictorWithFeatureWeights <float> CreatePredictor() { Host.Check(TrainedEnsemble != null, "The predictor cannot be created before training is complete"); var innerArgs = LightGbmInterfaceUtils.JoinParameters(Options); var pred = new LightGbmBinaryModelParameters(Host, TrainedEnsemble, FeatureCount, innerArgs); var cali = new PlattCalibrator(Host, -0.5, 0); return(new FeatureWeightsCalibratedPredictor(Host, pred, cali)); }
public Booster(Dictionary <string, string> parameters, Dataset trainset, Dataset validset = null) { var param = LightGbmInterfaceUtils.JoinParameters(parameters); var handle = IntPtr.Zero; LightGbmInterfaceUtils.Check(WrappedLightGbmInterface.BoosterCreate(trainset.Handle, param, ref handle)); Handle = handle; if (validset != null) { LightGbmInterfaceUtils.Check(WrappedLightGbmInterface.BoosterAddValidData(Handle, validset.Handle)); _hasValid = true; } int numEval = 0; BestIteration = -1; LightGbmInterfaceUtils.Check(WrappedLightGbmInterface.BoosterGetEvalCounts(Handle, ref numEval)); // At most one metric in ML.NET. Contracts.Assert(numEval <= 1); if (numEval == 1) { _hasMetric = true; } }