Beispiel #1
0
        private ISingleFeaturePredictionTransformer <TScalarPredictor> TrainOne(IChannel ch, TScalarTrainer trainer, RoleMappedData data, int cls)
        {
            var view = MapLabels(data, cls);

            string trainerLabel = data.Schema.Label.Name;

            // REVIEW: In principle we could support validation sets and the like via the train context, but
            // this is currently unsupported.
            var transformer = trainer.Fit(view);

            if (_args.UseProbabilities)
            {
                var calibratedModel = transformer.Model as TDistPredictor;

                // REVIEW: restoring the RoleMappedData, as much as we can.
                // not having the weight column on the data passed to the TrainCalibrator should be addressed.
                var trainedData = new RoleMappedData(view, label: trainerLabel, feature: transformer.FeatureColumn);

                if (calibratedModel == null)
                {
                    calibratedModel = CalibratorUtils.TrainCalibrator(Host, ch, Calibrator, Args.MaxCalibrationExamples, transformer.Model, trainedData) as TDistPredictor;
                }

                Host.Check(calibratedModel != null, "Calibrated predictor does not implement the expected interface");
                return(new BinaryPredictionTransformer <TScalarPredictor>(Host, calibratedModel, trainedData.Data.Schema, transformer.FeatureColumn));
            }

            return(new BinaryPredictionTransformer <TScalarPredictor>(Host, transformer.Model, view.Schema, transformer.FeatureColumn));
        }
        private ISingleFeaturePredictionTransformer<TDistPredictor> TrainOne(IChannel ch, TScalarTrainer trainer, RoleMappedData data, int cls1, int cls2)
        {
            // this should not be necessary when the legacy constructor doesn't exist, and the label column is not an optional parameter on the
            // MetaMulticlassTrainer constructor.
            string trainerLabel = data.Schema.Label.Value.Name;

            var view = MapLabels(data, cls1, cls2);
            var transformer = trainer.Fit(view);

            // the validations in the calibrator check for the feature column, in the RoleMappedData
            var trainedData = new RoleMappedData(view, label: trainerLabel, feature: transformer.FeatureColumnName);

            var calibratedModel = transformer.Model as TDistPredictor;
            if (calibratedModel == null)
                calibratedModel = CalibratorUtils.GetCalibratedPredictor(Host, ch, Calibrator, transformer.Model, trainedData, Args.MaxCalibrationExamples) as TDistPredictor;

            return new BinaryPredictionTransformer<TDistPredictor>(Host, calibratedModel, trainedData.Data.Schema, transformer.FeatureColumnName);
        }