public void TestEstimatorSymSgdInitPredictor()
        {
            (var pipe, var dataView) = GetBinaryClassificationPipeline();
            var transformedData = pipe.Fit(dataView).Transform(dataView);

            var args          = new LinearClassificationTrainer.Arguments();
            var initPredictor = new LinearClassificationTrainer(Env, args, "Features", "Label").Fit(transformedData);
            var data          = initPredictor.Transform(transformedData);

            var withInitPredictor = new SymSgdClassificationTrainer(Env, "Features", "Label").Train(transformedData, initialPredictor: initPredictor.Model);
            var outInitData       = withInitPredictor.Transform(transformedData);

            var notInitPredictor = new SymSgdClassificationTrainer(Env, "Features", "Label").Train(transformedData);
            var outNoInitData    = notInitPredictor.Transform(transformedData);

            int numExamples = 10;
            var col1        = data.GetColumn <float>(Env, "Score").Take(numExamples).ToArray();
            var col2        = outInitData.GetColumn <float>(Env, "Score").Take(numExamples).ToArray();
            var col3        = outNoInitData.GetColumn <float>(Env, "Score").Take(numExamples).ToArray();

            bool col12Diff = default;
            bool col23Diff = default;
            bool col13Diff = default;

            for (int i = 0; i < numExamples; i++)
            {
                col12Diff = col12Diff || (col1[i] != col2[i]);
                col23Diff = col23Diff || (col2[i] != col3[i]);
                col13Diff = col13Diff || (col1[i] != col3[i]);
            }
            Contracts.Assert(col12Diff && col23Diff && col13Diff);
            Done();
        }
        /// <summary>
        /// Predict a target using a linear binary classification model trained with the SDCA trainer, and a custom loss.
        /// Note that because we cannot be sure that all loss functions will produce naturally calibrated outputs, setting
        /// a custom loss function will not produce a calibrated probability column.
        /// </summary>
        /// <param name="ctx">The binary classification context trainer object.</param>
        /// <param name="label">The label, or dependent variable.</param>
        /// <param name="features">The features, or independent variables.</param>
        /// <param name="loss">The custom loss.</param>
        /// <param name="weights">The optional example weights.</param>
        /// <param name="l2Const">The L2 regularization hyperparameter.</param>
        /// <param name="l1Threshold">The L1 regularization hyperparameter. Higher values will tend to lead to more sparse model.</param>
        /// <param name="maxIterations">The maximum number of passes to perform over the data.</param>
        /// <param name="onFit">A delegate that is called every time the
        /// <see cref="Estimator{TTupleInShape, TTupleOutShape, TTransformer}.Fit(DataView{TTupleInShape})"/> method is called on the
        /// <see cref="Estimator{TTupleInShape, TTupleOutShape, TTransformer}"/> instance created out of this. This delegate will receive
        /// the linear model that was trained, as well as the calibrator on top of that model. Note that this action cannot change the
        /// result in any way; it is only a way for the caller to be informed about what was learnt.</param>
        /// <returns>The set of output columns including in order the predicted binary classification score (which will range
        /// from negative to positive infinity), and the predicted label.</returns>
        /// <seealso cref="Sdca(BinaryClassificationContext.BinaryClassificationTrainers, Scalar{bool}, Vector{float}, Scalar{float}, float?, float?, int?, Action{LinearBinaryPredictor, ParameterMixingCalibratedPredictor})"/>
        public static (Scalar <float> score, Scalar <bool> predictedLabel) Sdca(
            this BinaryClassificationContext.BinaryClassificationTrainers ctx,
            Scalar <bool> label, Vector <float> features,
            ISupportSdcaClassificationLoss loss,
            Scalar <float> weights = null,
            float?l2Const          = null,
            float?l1Threshold      = null,
            int?maxIterations      = null,
            Action <LinearBinaryPredictor> onFit = null
            )
        {
            Contracts.CheckValue(label, nameof(label));
            Contracts.CheckValue(features, nameof(features));
            Contracts.CheckValue(loss, nameof(loss));
            Contracts.CheckValueOrNull(weights);
            Contracts.CheckParam(!(l2Const < 0), nameof(l2Const), "Must not be negative, if specified.");
            Contracts.CheckParam(!(l1Threshold < 0), nameof(l1Threshold), "Must not be negative, if specified.");
            Contracts.CheckParam(!(maxIterations < 1), nameof(maxIterations), "Must be positive if specified");
            Contracts.CheckValueOrNull(onFit);

            bool hasProbs = loss is LogLoss;

            var args = new LinearClassificationTrainer.Arguments()
            {
                L2Const       = l2Const,
                L1Threshold   = l1Threshold,
                MaxIterations = maxIterations,
                LossFunction  = new TrivialSdcaClassificationLossFactory(loss)
            };

            var rec = new TrainerEstimatorReconciler.BinaryClassifierNoCalibration(
                (env, labelName, featuresName, weightsName) =>
            {
                var trainer = new LinearClassificationTrainer(env, args, featuresName, labelName, weightsName);
                if (onFit != null)
                {
                    return(trainer.WithOnFitDelegate(trans =>
                    {
                        var model = trans.Model;
                        if (model is ParameterMixingCalibratedPredictor cali)
                        {
                            onFit((LinearBinaryPredictor)cali.SubPredictor);
                        }
                        else
                        {
                            onFit((LinearBinaryPredictor)model);
                        }
                    }));
                }
                return(trainer);
            }, label, features, weights, hasProbs);

            return(rec.Output);
        }
        /// <summary>
        /// Predict a target using a linear binary classification model trained with the SDCA trainer, and log-loss.
        /// </summary>
        /// <param name="ctx">The binary classification context trainer object.</param>
        /// <param name="label">The label, or dependent variable.</param>
        /// <param name="features">The features, or independent variables.</param>
        /// <param name="weights">The optional example weights.</param>
        /// <param name="l2Const">The L2 regularization hyperparameter.</param>
        /// <param name="l1Threshold">The L1 regularization hyperparameter. Higher values will tend to lead to more sparse model.</param>
        /// <param name="maxIterations">The maximum number of passes to perform over the data.</param>
        /// <param name="onFit">A delegate that is called every time the
        /// <see cref="Estimator{TTupleInShape, TTupleOutShape, TTransformer}.Fit(DataView{TTupleInShape})"/> method is called on the
        /// <see cref="Estimator{TTupleInShape, TTupleOutShape, TTransformer}"/> instance created out of this. This delegate will receive
        /// the linear model that was trained, as well as the calibrator on top of that model. Note that this action cannot change the
        /// result in any way; it is only a way for the caller to be informed about what was learnt.</param>
        /// <returns>The set of output columns including in order the predicted binary classification score (which will range
        /// from negative to positive infinity), the calibrated prediction (from 0 to 1), and the predicted label.</returns>
        public static (Scalar <float> score, Scalar <float> probability, Scalar <bool> predictedLabel) Sdca(
            this BinaryClassificationContext.BinaryClassificationTrainers ctx,
            Scalar <bool> label, Vector <float> features, Scalar <float> weights = null,
            float?l2Const     = null,
            float?l1Threshold = null,
            int?maxIterations = null,
            Action <LinearBinaryPredictor, ParameterMixingCalibratedPredictor> onFit = null)
        {
            Contracts.CheckValue(label, nameof(label));
            Contracts.CheckValue(features, nameof(features));
            Contracts.CheckValueOrNull(weights);
            Contracts.CheckParam(!(l2Const < 0), nameof(l2Const), "Must not be negative, if specified.");
            Contracts.CheckParam(!(l1Threshold < 0), nameof(l1Threshold), "Must not be negative, if specified.");
            Contracts.CheckParam(!(maxIterations < 1), nameof(maxIterations), "Must be positive if specified");
            Contracts.CheckValueOrNull(onFit);

            var args = new LinearClassificationTrainer.Arguments()
            {
                L2Const       = l2Const,
                L1Threshold   = l1Threshold,
                MaxIterations = maxIterations,
            };

            var rec = new TrainerEstimatorReconciler.BinaryClassifier(
                (env, labelName, featuresName, weightsName) =>
            {
                var trainer = new LinearClassificationTrainer(env, args, featuresName, labelName, weightsName);
                if (onFit != null)
                {
                    return(trainer.WithOnFitDelegate(trans =>
                    {
                        // Under the default log-loss we assume a calibrated predictor.
                        var model = trans.Model;
                        var cali = (ParameterMixingCalibratedPredictor)model;
                        var pred = (LinearBinaryPredictor)cali.SubPredictor;
                        onFit(pred, cali);
                    }));
                }
                return(trainer);
            }, label, features, weights);

            return(rec.Output);
        }
Exemplo n.º 4
0
 public MySdca(IHostEnvironment env, LinearClassificationTrainer.Arguments args, string featureCol, string labelCol)
     : base(env, new TrainerInfo(), featureCol, labelCol)
 {
     _args = args;
 }