private ISingleFeaturePredictionTransformer <TScalarPredictor> TrainOne(IChannel ch, TScalarTrainer trainer, RoleMappedData data, int cls) { var view = MapLabels(data, cls); string trainerLabel = data.Schema.Label.Value.Name; // REVIEW: In principle we could support validation sets and the like via the train context, but // this is currently unsupported. var transformer = trainer.Fit(view); if (_args.UseProbabilities) { var calibratedModel = transformer.Model as TDistPredictor; // REVIEW: restoring the RoleMappedData, as much as we can. // not having the weight column on the data passed to the TrainCalibrator should be addressed. var trainedData = new RoleMappedData(view, label: trainerLabel, feature: transformer.FeatureColumn); if (calibratedModel == null) { calibratedModel = CalibratorUtils.GetCalibratedPredictor(Host, ch, Calibrator, transformer.Model, trainedData, Args.MaxCalibrationExamples) as TDistPredictor; } Host.Check(calibratedModel != null, "Calibrated predictor does not implement the expected interface"); return(new BinaryPredictionTransformer <TScalarPredictor>(Host, calibratedModel, trainedData.Data.Schema, transformer.FeatureColumn)); } return(new BinaryPredictionTransformer <TScalarPredictor>(Host, transformer.Model, view.Schema, transformer.FeatureColumn)); }
protected override TVectorPredictor TrainPredictor(IChannel ch, TScalarTrainer trainer, RoleMappedData data, int count) { var data0 = data; string dstName, labName; var trans = MapLabelsAndInsertTransform(ch, data, out dstName, out labName, count, true, _args); var newFeatures = trans.Schema.GetTempColumnName() + "NF"; var args3 = new DescribeTransform.Arguments { columns = new string[] { labName, dstName }, oneRowPerColumn = true }; var desc = new DescribeTransform(Host, args3, trans); IDataView viewI; if (_args.singleColumn && data.Schema.Label.Value.Type.RawKind() == DataKind.Single) { viewI = desc; } else if (_args.singleColumn) { var sch = new TypeReplacementSchema(desc.Schema, new[] { labName }, new[] { NumberDataViewType.Single }); viewI = new TypeReplacementDataView(desc, sch); #region debug #if (DEBUG) DebugChecking0(viewI, labName, false); #endif #endregion } else if (data.Schema.Label.Value.Type.IsKey()) { ulong nb = data.Schema.Label.Value.Type.AsKey().GetKeyCount(); var sch = new TypeReplacementSchema(desc.Schema, new[] { labName }, new[] { new VectorDataViewType(NumberDataViewType.Single, (int)nb) }); viewI = new TypeReplacementDataView(desc, sch); #region debug #if (DEBUG) int nb_; MinMaxLabelOverDataSet(trans, labName, out nb_); int count3; data.CheckMulticlassLabel(out count3); if ((ulong)count3 != nb) { throw ch.Except("Count mismatch (KeyCount){0} != {1}", nb, count3); } DebugChecking0(viewI, labName, true); DebugChecking0Vfloat(viewI, labName, nb); #endif #endregion } else { int nb; if (count <= 0) { MinMaxLabelOverDataSet(trans, labName, out nb); } else { nb = count; } var sch = new TypeReplacementSchema(desc.Schema, new[] { labName }, new[] { new VectorDataViewType(NumberDataViewType.Single, nb) }); viewI = new TypeReplacementDataView(desc, sch); #region debug #if (DEBUG) DebugChecking0(viewI, labName, true); #endif #endregion } ch.Info("Merging column label '{0}' with features '{1}'", labName, data.Schema.Feature.Value.Name); var args = string.Format("Concat{{col={0}:{1},{2}}}", newFeatures, data.Schema.Feature.Value.Name, labName); IDataView after_concatenation = ComponentCreation.CreateTransform(Host, args, viewI); var roles = data.Schema.GetColumnRoleNames() .Where(kvp => kvp.Key.Value != RoleMappedSchema.ColumnRole.Label.Value) .Where(kvp => kvp.Key.Value != RoleMappedSchema.ColumnRole.Feature.Value) .Prepend(RoleMappedSchema.ColumnRole.Feature.Bind(newFeatures)) .Prepend(RoleMappedSchema.ColumnRole.Label.Bind(dstName)); var trainer_input = new RoleMappedData(after_concatenation, roles); ch.Info("New Features: {0}:{1}", trainer_input.Schema.Feature.Value.Name, trainer_input.Schema.Feature.Value.Type); ch.Info("New Label: {0}:{1}", trainer_input.Schema.Label.Value.Name, trainer_input.Schema.Label.Value.Type); // We train the unique binary classifier. var trainedPredictor = trainer.Train(trainer_input); var predictors = new TScalarPredictor[] { trainedPredictor }; // We train the reclassification classifier. if (_args.reclassicationPredictor != null) { var pred = CreateFinalPredictor(ch, data, trans, count, _args, predictors, null); TrainReclassificationPredictor(data0, pred, ScikitSubComponent <ITrainer, SignatureTrainer> .AsSubComponent(_args.reclassicationPredictor)); } return(CreateFinalPredictor(ch, data, trans, count, _args, predictors, _reclassPredictor)); }
private ISingleFeaturePredictionTransformer <TDistPredictor> TrainOne(IChannel ch, TScalarTrainer trainer, RoleMappedData data, int cls1, int cls2) { // this should not be necessary when the legacy constructor doesn't exist, and the label column is not an optional parameter on the // MetaMulticlassTrainer constructor. string trainerLabel = data.Schema.Label.Value.Name; var view = MapLabels(data, cls1, cls2); var transformer = trainer.Fit(view); // the validations in the calibrator check for the feature column, in the RoleMappedData var trainedData = new RoleMappedData(view, label: trainerLabel, feature: transformer.FeatureColumn); var calibratedModel = transformer.Model as TDistPredictor; if (calibratedModel == null) { calibratedModel = CalibratorUtils.GetCalibratedPredictor(Host, ch, Calibrator, transformer.Model, trainedData, Args.MaxCalibrationExamples) as TDistPredictor; } return(new BinaryPredictionTransformer <TDistPredictor>(Host, calibratedModel, trainedData.Data.Schema, transformer.FeatureColumn)); }
/// <summary> /// Train the embedded predictor. /// </summary> protected abstract TVectorPredictor TrainPredictor(IChannel ch, TScalarTrainer trainer, RoleMappedData data, int count);