Ejemplo n.º 1
0
        /// <summary>
        /// Predict a target using a linear regression model trained with the <see cref="OnlineGradientDescentTrainer"/> trainer.
        /// </summary>
        /// <param name="catalog">The regression catalog trainer object.</param>
        /// <param name="label">The label, or dependent variable.</param>
        /// <param name="features">The features, or independent variables.</param>
        /// <param name="weights">The optional example weights.</param>
        /// <param name="options">Advanced arguments to the algorithm.</param>
        /// <param name="onFit">A delegate that is called every time the
        /// <see cref="Estimator{TInShape, TOutShape, TTransformer}.Fit(DataView{TInShape})"/> method is called on the
        /// <see cref="Estimator{TInShape, TOutShape, TTransformer}"/> instance created out of this. This delegate will receive
        /// the linear model that was trained, as well as the calibrator on top of that model. Note that this action cannot change the
        /// result in any way; it is only a way for the caller to be informed about what was learnt.</param>
        /// <returns>The set of output columns including in order the predicted binary classification score (which will range
        /// from negative to positive infinity), and the predicted label.</returns>
        /// <seealso cref="OnlineGradientDescentTrainer"/>.
        /// <returns>The predicted output.</returns>
        public static Scalar <float> OnlineGradientDescent(this RegressionCatalog.RegressionTrainers catalog,
                                                           Scalar <float> label,
                                                           Vector <float> features,
                                                           Scalar <float> weights,
                                                           OnlineGradientDescentTrainer.Options options,
                                                           Action <LinearRegressionModelParameters> onFit = null)
        {
            Contracts.CheckValue(label, nameof(label));
            Contracts.CheckValue(features, nameof(features));
            Contracts.CheckValueOrNull(weights);
            Contracts.CheckValue(options, nameof(options));
            Contracts.CheckValueOrNull(onFit);

            var rec = new TrainerEstimatorReconciler.Regression(
                (env, labelName, featuresName, weightsName) =>
            {
                options.LabelColumnName   = labelName;
                options.FeatureColumnName = featuresName;

                var trainer = new OnlineGradientDescentTrainer(env, options);

                if (onFit != null)
                {
                    return(trainer.WithOnFitDelegate(trans => onFit(trans.Model)));
                }

                return(trainer);
            }, label, features, weights);

            return(rec.Score);
        }
Ejemplo n.º 2
0
        OnlineGradientDescent(
            this SweepableRegressionTrainers trainers,
            string labelColumnName   = "Label",
            string featureColumnName = "Features",
            SweepableOption <OnlineGradientDescentTrainer.Options> optionSweeper = null,
            OnlineGradientDescentTrainer.Options defaultOption = null)
        {
            var context = trainers.Context;

            if (optionSweeper == null)
            {
                optionSweeper = OnlineGradientDescentTrainerSweepableOptions.Default;
            }

            optionSweeper.SetDefaultOption(defaultOption);

            return(context.AutoML().CreateSweepableEstimator(
                       (context, option) =>
            {
                option.LabelColumnName = labelColumnName;
                option.FeatureColumnName = featureColumnName;

                return context.Regression.Trainers.OnlineGradientDescent(option);
            },
                       optionSweeper,
                       new string[] { labelColumnName, featureColumnName },
                       new string[] { Score },
                       nameof(OnlineGradientDescentTrainer)));
        }
        public static void Example()
        {
            // Create a new context for ML.NET operations. It can be used for exception tracking and logging,
            // as a catalog of available operations and as the source of randomness.
            // Setting the seed to a fixed number in this example to make outputs deterministic.
            var mlContext = new MLContext(seed: 0);

            // Create a list of training data points.
            var dataPoints = GenerateRandomDataPoints(1000);

            // Convert the list of data points to an IDataView object, which is consumable by ML.NET API.
            var trainingData = mlContext.Data.LoadFromEnumerable(dataPoints);

            // Define trainer options.
            var options = new  OnlineGradientDescentTrainer.Options
            {
                LabelColumnName   = nameof(DataPoint.Label),
                FeatureColumnName = nameof(DataPoint.Features),
                // Change the loss function.
                LossFunction = new TweedieLoss(),
                // Give an extra gain to more recent updates.
                RecencyGain = 0.1f,
                // Turn off lazy updates.
                LazyUpdate = false,
                // Specify scale for initial weights.
                InitialWeightsDiameter = 0.2f
            };

            // Define the trainer.
            var pipeline = mlContext.Regression.Trainers.OnlineGradientDescent(options);

            // Train the model.
            var model = pipeline.Fit(trainingData);

            // Create testing data. Use different random seed to make it different from training data.
            var testData = mlContext.Data.LoadFromEnumerable(GenerateRandomDataPoints(5, seed: 123));

            // Run the model on test data set.
            var transformedTestData = model.Transform(testData);

            // Convert IDataView object to a list.
            var predictions = mlContext.Data.CreateEnumerable <Prediction>(transformedTestData, reuseRowObject: false).ToList();

            // Look at 5 predictions for the Label, side by side with the actual Label for comparison.
            foreach (var p in predictions)
            {
                Console.WriteLine($"Label: {p.Label:F3}, Prediction: {p.Score:F3}");
            }

            // This trainer is not numerically stable. Please see issue #2425.

            // Evaluate the overall metrics
            var metrics = mlContext.Regression.Evaluate(transformedTestData);

            PrintMetrics(metrics);

            // This trainer is not numerically stable. Please see issue #2425.
        }
Ejemplo n.º 4
0
        /// <summary>
        /// Predict a target using a linear regression model trained with the <see cref="OnlineGradientDescentTrainer"/> trainer.
        /// </summary>
        /// <param name="catalog">The regression catalog trainer object.</param>
        /// <param name="options">Advanced arguments to the algorithm.</param>
        public static OnlineGradientDescentTrainer OnlineGradientDescent(this RegressionCatalog.RegressionTrainers catalog,
                                                                         OnlineGradientDescentTrainer.Options options)
        {
            Contracts.CheckValue(catalog, nameof(catalog));
            Contracts.CheckValue(options, nameof(options));

            var env = CatalogUtils.GetEnvironment(catalog);

            return(new OnlineGradientDescentTrainer(env, options));
        }
Ejemplo n.º 5
0
        public static PredictionEngine <TIn, TOut> OnlineGradientDescent <TIn, TOut>(
            IEnumerable <TIn> trainDataset,
            OnlineGradientDescentTrainer.Options options,
            Action <ITransformer> additionModelAction = null) where TIn : class, new() where TOut : class, new()
        {
            var context          = new MLContext();
            var model            = context.RegressionTrainerTemplate(trainDataset, context.Regression.Trainers.OnlineGradientDescent(options));
            var predictionEngine = context.Model.CreatePredictionEngine <TIn, TOut>(model);

            additionModelAction?.Invoke(model);
            return(predictionEngine);
        }