Esempio n. 1
0
        /// <summary>
        /// Initialize the Booster.
        /// </summary>
        /// <param name="parameters">Parameters for boosters. See <see cref="XGBoostArguments"/>.</param>
        /// <param name="data">training data<see cref="DMatrix"/></param>
        /// <param name="continuousTraining">Start from a trained model</param>
        public Booster(Dictionary <string, string> parameters, DMatrix data, Booster continuousTraining)
        {
            _featureNames = null;
            _featureTypes = null;
            _numFeatures  = (int)data.GetNumCols();
            Contracts.Assert(_numFeatures > 0);

            _handle = IntPtr.Zero;
            WrappedXGBoostInterface.Check(WrappedXGBoostInterface.XGBoosterCreate(new[] { data.Handle }, 1, ref _handle));
            if (continuousTraining != null)
            {
                // There should be another way than serialized then loading the model.
                var saved = continuousTraining.SaveRaw();
                unsafe
                {
                    fixed(byte *buf = saved)
                    WrappedXGBoostInterface.Check(WrappedXGBoostInterface.XGBoosterLoadModelFromBuffer(_handle, buf, (uint)saved.Length));
                }
            }

            if (parameters != null)
            {
                SetParam(parameters);
            }
            WrappedXGBoostInterface.Check(WrappedXGBoostInterface.XGBoosterLazyInit(_handle));
        }
        protected override void SaveCore(ModelSaveContext ctx)
        {
            Host.CheckValue(ctx, nameof(ctx));
            ctx.CheckAtModel();

            // *** Binary format ***
            // <base>
            // int (_booster.NumFeatures)
            // int (_numFeaturesML) (version >= 0x00010003)

            base.SaveCore(ctx);
            ctx.SaveBinaryStream("xgboost.model", bw => bw.Write(_booster.SaveRaw()));
            ctx.Writer.Write(_booster.NumFeatures);
            ctx.Writer.Write(_numFeaturesML);
        }
        private void TrainCore(IChannel ch, IProgressChannel pch, RoleMappedData data, TPredictor predictor)
        {
            // Verifications.
            _host.AssertValue(ch);
            ch.CheckValue(data, nameof(data));

            ValidateTrainInput(ch, data);

            var featureColumns = data.Schema.GetColumns(RoleMappedSchema.ColumnRole.Feature);

            ch.Check(featureColumns.Count == 1, "Only one vector of features is allowed.");

            // Data dimension.
            int fi      = data.Schema.Feature.Index;
            var colType = data.Schema.Schema.GetColumnType(fi);

            ch.Assert(colType.IsVector, "Feature must be a vector.");
            ch.Assert(colType.VectorSize > 0, "Feature dimension must be known.");
            int       nbDim  = colType.VectorSize;
            IDataView view   = data.Data;
            long      nbRows = DataViewUtils.ComputeRowCount(view);

            Float[] labels;
            uint[]  groupCount;
            DMatrix dtrain;
            // REVIEW xadupre: this can be avoided by using method XGDMatrixCreateFromDataIter from the XGBoost API.
            // XGBoost removes NaN values from a dense matrix and stores it in sparse format anyway.
            bool isDense = DetectDensity(data);
            var  dt      = DateTime.Now;

            if (isDense)
            {
                dtrain = FillDenseMatrix(ch, nbDim, nbRows, data, out labels, out groupCount);
                ch.Info("Dense matrix created with nbFeatures={0} and nbRows={1} in {2}.", nbDim, nbRows, DateTime.Now - dt);
            }
            else
            {
                dtrain = FillSparseMatrix(ch, nbDim, nbRows, data, out labels, out groupCount);
                ch.Info("Sparse matrix created with nbFeatures={0} and nbRows={1} in {2}.", nbDim, nbRows, DateTime.Now - dt);
            }

            // Some options are filled based on the data.
            var options = _args.ToDict(_host);

            UpdateXGBoostOptions(ch, options, labels, groupCount);

            // For multi class, the number of labels is required.
            ch.Assert(PredictionKind != PredictionKind.MultiClassClassification || options.ContainsKey("num_class"),
                      "XGBoost requires the number of classes to be specified in the parameters.");

            ch.Info("XGBoost objective={0}", options["objective"]);

            int     numTrees;
            Booster res = WrappedXGBoostTraining.Train(ch, pch, out numTrees, options, dtrain,
                                                       numBoostRound: _args.numBoostRound,
                                                       obj: null, verboseEval: _args.verboseEval,
                                                       xgbModel: predictor == null ? null : predictor.GetBooster(),
                                                       saveBinaryDMatrix: _args.saveXGBoostDMatrixAsBinary);

            int nbTrees = res.GetNumTrees();

            ch.Info("Training is complete. Number of added trees={0}, total={1}.", numTrees, nbTrees);

            _model             = res.SaveRaw();
            _nbFeaturesXGboost = (int)dtrain.GetNumCols();
            _nbFeaturesML      = nbDim;
        }