private void TrainCore(IChannel ch, IProgressChannel pch, RoleMappedData data, TPredictor predictor) { // Verifications. _host.AssertValue(ch); ch.CheckValue(data, nameof(data)); ValidateTrainInput(ch, data); var featureColumns = data.Schema.GetColumns(RoleMappedSchema.ColumnRole.Feature); ch.Check(featureColumns.Count == 1, "Only one vector of features is allowed."); // Data dimension. int fi = data.Schema.Feature.Index; var colType = data.Schema.Schema.GetColumnType(fi); ch.Assert(colType.IsVector, "Feature must be a vector."); ch.Assert(colType.VectorSize > 0, "Feature dimension must be known."); int nbDim = colType.VectorSize; IDataView view = data.Data; long nbRows = DataViewUtils.ComputeRowCount(view); Float[] labels; uint[] groupCount; DMatrix dtrain; // REVIEW xadupre: this can be avoided by using method XGDMatrixCreateFromDataIter from the XGBoost API. // XGBoost removes NaN values from a dense matrix and stores it in sparse format anyway. bool isDense = DetectDensity(data); var dt = DateTime.Now; if (isDense) { dtrain = FillDenseMatrix(ch, nbDim, nbRows, data, out labels, out groupCount); ch.Info("Dense matrix created with nbFeatures={0} and nbRows={1} in {2}.", nbDim, nbRows, DateTime.Now - dt); } else { dtrain = FillSparseMatrix(ch, nbDim, nbRows, data, out labels, out groupCount); ch.Info("Sparse matrix created with nbFeatures={0} and nbRows={1} in {2}.", nbDim, nbRows, DateTime.Now - dt); } // Some options are filled based on the data. var options = _args.ToDict(_host); UpdateXGBoostOptions(ch, options, labels, groupCount); // For multi class, the number of labels is required. ch.Assert(PredictionKind != PredictionKind.MultiClassClassification || options.ContainsKey("num_class"), "XGBoost requires the number of classes to be specified in the parameters."); ch.Info("XGBoost objective={0}", options["objective"]); int numTrees; Booster res = WrappedXGBoostTraining.Train(ch, pch, out numTrees, options, dtrain, numBoostRound: _args.numBoostRound, obj: null, verboseEval: _args.verboseEval, xgbModel: predictor == null ? null : predictor.GetBooster(), saveBinaryDMatrix: _args.saveXGBoostDMatrixAsBinary); int nbTrees = res.GetNumTrees(); ch.Info("Training is complete. Number of added trees={0}, total={1}.", numTrees, nbTrees); _model = res.SaveRaw(); _nbFeaturesXGboost = (int)dtrain.GetNumCols(); _nbFeaturesML = nbDim; }
public int GetNumTrees() { return(_booster.GetNumTrees()); }