示例#1
0
        private Dataset LoadTrainingData(IChannel ch, RoleMappedData trainData, out CategoricalMetaData catMetaData)
        {
            // Verifications.
            Host.AssertValue(ch);
            ch.CheckValue(trainData, nameof(trainData));

            CheckDataValid(ch, trainData);

            // Load metadata first.
            var factory = CreateCursorFactory(trainData);

            GetMetainfo(ch, factory, out int numRow, out float[] labels, out float[] weights, out int[] groups);
            catMetaData = GetCategoricalMetaData(ch, trainData, numRow);
            GetDefaultParameters(ch, numRow, catMetaData.CategoricalBoudaries != null, catMetaData.TotalCats);

            Dataset dtrain;
            string  param = LightGbmInterfaceUtils.JoinParameters(Options);

            // To reduce peak memory usage, only enable one sampling task at any given time.
            lock (LightGbmShared.SampleLock)
            {
                CreateDatasetFromSamplingData(ch, factory, numRow,
                                              param, labels, weights, groups, catMetaData, out dtrain);
            }

            // Push rows into dataset.
            LoadDataset(ch, factory, dtrain, numRow, Args.BatchSize, catMetaData);

            // Some checks.
            CheckAndUpdateParametersBeforeTraining(ch, trainData, labels, groups);
            return(dtrain);
        }
        private protected override LightGbmRankingPredictor CreatePredictor()
        {
            Host.Check(TrainedEnsemble != null, "The predictor cannot be created before training is complete");
            var innerArgs = LightGbmInterfaceUtils.JoinParameters(Options);

            return(new LightGbmRankingPredictor(Host, TrainedEnsemble, FeatureCount, innerArgs));
        }
        private protected override OvaPredictor CreatePredictor()
        {
            Host.Check(TrainedEnsemble != null, "The predictor cannot be created before training is complete.");

            Host.Assert(_numClass > 1, "Must know the number of classes before creating a predictor.");
            Host.Assert(TrainedEnsemble.NumTrees % _numClass == 0, "Number of trees should be a multiple of number of classes.");

            var innerArgs = LightGbmInterfaceUtils.JoinParameters(Options);

            IPredictorProducing <float>[] predictors = new IPredictorProducing <float> [_tlcNumClass];
            for (int i = 0; i < _tlcNumClass; ++i)
            {
                var pred = CreateBinaryPredictor(i, innerArgs);
                var cali = new PlattCalibrator(Host, -0.5, 0);
                predictors[i] = new FeatureWeightsCalibratedPredictor(Host, pred, cali);
            }
            string obj = (string)GetGbmParameters()["objective"];

            if (obj == "multiclass")
            {
                return(OvaPredictor.Create(Host, OvaPredictor.OutputFormula.Softmax, predictors));
            }
            else
            {
                return(OvaPredictor.Create(Host, predictors));
            }
        }
示例#4
0
        private protected override IPredictorWithFeatureWeights <float> CreatePredictor()
        {
            Host.Check(TrainedEnsemble != null, "The predictor cannot be created before training is complete");
            var innerArgs = LightGbmInterfaceUtils.JoinParameters(Options);
            var pred      = new LightGbmBinaryModelParameters(Host, TrainedEnsemble, FeatureCount, innerArgs);
            var cali      = new PlattCalibrator(Host, -0.5, 0);

            return(new FeatureWeightsCalibratedPredictor(Host, pred, cali));
        }
        public Booster(Dictionary <string, string> parameters, Dataset trainset, Dataset validset = null)
        {
            var param  = LightGbmInterfaceUtils.JoinParameters(parameters);
            var handle = IntPtr.Zero;

            LightGbmInterfaceUtils.Check(WrappedLightGbmInterface.BoosterCreate(trainset.Handle, param, ref handle));
            Handle = handle;
            if (validset != null)
            {
                LightGbmInterfaceUtils.Check(WrappedLightGbmInterface.BoosterAddValidData(Handle, validset.Handle));
                _hasValid = true;
            }

            int numEval = 0;

            BestIteration = -1;
            LightGbmInterfaceUtils.Check(WrappedLightGbmInterface.BoosterGetEvalCounts(Handle, ref numEval));
            // At most one metric in ML.NET.
            Contracts.Assert(numEval <= 1);
            if (numEval == 1)
            {
                _hasMetric = true;
            }
        }