public Predictors <TOutput> Train(Datasets data , Func <int, double> learningRateSchedule = null // optional: learning rate as a function of iteration (zero-based) ) { // For multi class, the number of labels is required. if (!(PredictionKind != PredictionKind.MultiClassClassification || Objective.NumClass > 1)) { throw new Exception("LightGBM requires the number of classes to be specified in the parameters for multi-class classification."); } if (PredictionKind == PredictionKind.Ranking) { if (data.Training.GetGroups() == null) { throw new Exception("Require Groups training data for ObjectiveType.LambdaRank"); } if (data.Validation != null && data.Validation.GetGroups() == null) { throw new Exception("Require Groups validation data for ObjectiveType.LambdaRank"); } } TrainMetrics.Clear(); ValidMetrics.Clear(); Booster?.Dispose(); Booster = null; Datasets = data; var args = GetParameters(data); Booster = Train(args, data.Training, data.Validation, TrainMetrics, ValidMetrics, learningRateSchedule); (var model, var argsout) = Booster.GetModel(); TrainedEnsemble = model; FeatureCount = data.Training.NumFeatures; // check parameter strings if (learningRateSchedule != null) { argsout.Learning.LearningRate = args.Learning.LearningRate; } var strIn = args.ToString(); var strOut = argsout.ToString(); if (strIn != strOut) { throw new Exception($"Parameters differ:\n{strIn}\n{strOut}"); } var managed = CreateManagedPredictor(); var native = CreateNativePredictor(); return(new Predictors <TOutput>(managed, native)); }
private Parameters GetParameters(Datasets data) { var args = new Parameters { Common = data.Common, Dataset = data.Dataset, Objective = Objective, Learning = Learning }; return(args); }
/// <summary> /// Generates files that can be used to run training with lightgbm.exe. /// - train.conf: contains training parameters /// - train.bin: training data /// - valid.bin: validation data (if provided) /// Command line: lightgbm.exe config=train.conf /// </summary> /// <param name="data"></param> public void ToCommandLineFiles(Datasets data, string destinationDir = @"c:\temp") { var pms = GetParameters(data); var kvs = pms.ToDict(); kvs.Add("output_model", Path.Combine(destinationDir, "LightGBM_model.txt")); var datafile = Path.Combine(destinationDir, "train.bin"); if (File.Exists(datafile)) { File.Delete(datafile); } data.Training.SaveBinary(datafile); kvs.Add("data", datafile); if (data.Validation != null) { datafile = Path.Combine(destinationDir, "valid.bin"); if (File.Exists(datafile)) { File.Delete(datafile); } data.Validation.SaveBinary(datafile); kvs.Add("valid", datafile); } using (var file = new StreamWriter(Path.Combine(destinationDir, "train.conf"))) { foreach (var kv in kvs) { file.WriteLine($"{kv.Key} = {kv.Value}"); } } }