示例#1
0
        public Predictors <TOutput> Train(Datasets data
                                          , Func <int, double> learningRateSchedule = null  // optional: learning rate as a function of iteration (zero-based)
                                          )
        {
            // For multi class, the number of labels is required.
            if (!(PredictionKind != PredictionKind.MultiClassClassification || Objective.NumClass > 1))
            {
                throw new Exception("LightGBM requires the number of classes to be specified in the parameters for multi-class classification.");
            }

            if (PredictionKind == PredictionKind.Ranking)
            {
                if (data.Training.GetGroups() == null)
                {
                    throw new Exception("Require Groups training data for ObjectiveType.LambdaRank");
                }
                if (data.Validation != null && data.Validation.GetGroups() == null)
                {
                    throw new Exception("Require Groups validation data for ObjectiveType.LambdaRank");
                }
            }

            TrainMetrics.Clear();
            ValidMetrics.Clear();
            Booster?.Dispose();
            Booster = null;

            Datasets = data;

            var args = GetParameters(data);

            Booster = Train(args, data.Training, data.Validation, TrainMetrics, ValidMetrics, learningRateSchedule);

            (var model, var argsout) = Booster.GetModel();
            TrainedEnsemble          = model;
            FeatureCount             = data.Training.NumFeatures;

            // check parameter strings
            if (learningRateSchedule != null)
            {
                argsout.Learning.LearningRate = args.Learning.LearningRate;
            }
            var strIn  = args.ToString();
            var strOut = argsout.ToString();

            if (strIn != strOut)
            {
                throw new Exception($"Parameters differ:\n{strIn}\n{strOut}");
            }

            var managed = CreateManagedPredictor();
            var native  = CreateNativePredictor();

            return(new Predictors <TOutput>(managed, native));
        }
示例#2
0
        private Parameters GetParameters(Datasets data)
        {
            var args = new Parameters
            {
                Common    = data.Common,
                Dataset   = data.Dataset,
                Objective = Objective,
                Learning  = Learning
            };

            return(args);
        }
示例#3
0
        /// <summary>
        /// Generates files that can be used to run training with lightgbm.exe.
        ///  - train.conf: contains training parameters
        ///  - train.bin: training data
        ///  - valid.bin: validation data (if provided)
        /// Command line: lightgbm.exe config=train.conf
        /// </summary>
        /// <param name="data"></param>
        public void ToCommandLineFiles(Datasets data, string destinationDir = @"c:\temp")
        {
            var pms = GetParameters(data);

            var kvs = pms.ToDict();

            kvs.Add("output_model", Path.Combine(destinationDir, "LightGBM_model.txt"));

            var datafile = Path.Combine(destinationDir, "train.bin");

            if (File.Exists(datafile))
            {
                File.Delete(datafile);
            }
            data.Training.SaveBinary(datafile);
            kvs.Add("data", datafile);

            if (data.Validation != null)
            {
                datafile = Path.Combine(destinationDir, "valid.bin");
                if (File.Exists(datafile))
                {
                    File.Delete(datafile);
                }
                data.Validation.SaveBinary(datafile);
                kvs.Add("valid", datafile);
            }

            using (var file = new StreamWriter(Path.Combine(destinationDir, "train.conf")))
            {
                foreach (var kv in kvs)
                {
                    file.WriteLine($"{kv.Key} = {kv.Value}");
                }
            }
        }