Exemple #1
0
            private protected AveragedTrainStateBase(IChannel ch, int numFeatures, LinearModelParameters predictor, AveragedLinearTrainer <TTransformer, TModel> parent)
                : base(ch, numFeatures, predictor, parent)
            {
                // Do the other initializations by setting the setters as if user had set them
                // Initialize the averaged weights if needed (i.e., do what happens when Averaged is set)
                Averaged = parent.AveragedLinearTrainerOptions.Averaged;
                if (Averaged)
                {
                    if (parent.AveragedLinearTrainerOptions.AveragedTolerance > 0)
                    {
                        VBufferUtils.Densify(ref Weights);
                    }
                    Weights.CopyTo(ref TotalWeights);
                }
                else
                {
                    // It is definitely advantageous to keep weights dense if we aren't adding them
                    // to another vector with each update.
                    VBufferUtils.Densify(ref Weights);
                }
                _resetWeightsAfterXExamples = parent.AveragedLinearTrainerOptions.ResetWeightsAfterXExamples ?? 0;
                _args = parent.AveragedLinearTrainerOptions;
                _loss = parent.LossFunction;

                Gain = 1;
            }
            protected TrainStateBase(IChannel ch, int numFeatures, LinearModelParameters predictor, OnlineLinearTrainer <TTransformer, TModel> parent)
            {
                Contracts.CheckValue(ch, nameof(ch));
                ch.Check(numFeatures > 0, "Cannot train with zero features!");
                ch.AssertValueOrNull(predictor);
                ch.AssertValue(parent);
                ch.Assert(Iteration == 0);
                ch.Assert(Bias == 0);

                ParentHost = parent.Host;

                ch.Trace("{0} Initializing {1} on {2} features", DateTime.UtcNow, parent.Name, numFeatures);

                // We want a dense vector, to prevent memory creation during training
                // unless we have a lot of features.
                if (predictor != null)
                {
                    ((IHaveFeatureWeights)predictor).GetFeatureWeights(ref Weights);
                    VBufferUtils.Densify(ref Weights);
                    Bias = predictor.Bias;
                }
                else if (!string.IsNullOrWhiteSpace(parent.OnlineLinearTrainerOptions.InitialWeights))
                {
                    ch.Info("Initializing weights and bias to " + parent.OnlineLinearTrainerOptions.InitialWeights);
                    string[] weightStr = parent.OnlineLinearTrainerOptions.InitialWeights.Split(',');
                    if (weightStr.Length != numFeatures + 1)
                    {
                        throw ch.Except(
                                  "Could not initialize weights from 'initialWeights': expecting {0} values to initialize {1} weights and the intercept",
                                  numFeatures + 1, numFeatures);
                    }

                    var weightValues = new float[numFeatures];
                    for (int i = 0; i < numFeatures; i++)
                    {
                        weightValues[i] = float.Parse(weightStr[i], CultureInfo.InvariantCulture);
                    }
                    Weights = new VBuffer <float>(numFeatures, weightValues);
                    Bias    = float.Parse(weightStr[numFeatures], CultureInfo.InvariantCulture);
                }
                else if (parent.OnlineLinearTrainerOptions.InitialWeightsDiameter > 0)
                {
                    var weightValues = new float[numFeatures];
                    for (int i = 0; i < numFeatures; i++)
                    {
                        weightValues[i] = parent.OnlineLinearTrainerOptions.InitialWeightsDiameter * (parent.Host.Rand.NextSingle() - (float)0.5);
                    }
                    Weights = new VBuffer <float>(numFeatures, weightValues);
                    Bias    = parent.OnlineLinearTrainerOptions.InitialWeightsDiameter * (parent.Host.Rand.NextSingle() - (float)0.5);
                }
                else if (numFeatures <= 1000)
                {
                    Weights = VBufferUtils.CreateDense <float>(numFeatures);
                }
                else
                {
                    Weights = VBufferUtils.CreateEmpty <float>(numFeatures);
                }
                WeightsScale = 1;
            }
Exemple #3
0
            public TrainState(IChannel ch, int numFeatures, LinearModelParameters predictor, LinearSvmTrainer parent)
                : base(ch, numFeatures, predictor, parent)
            {
                _batchSize         = parent.Opts.BatchSize;
                _noBias            = parent.Opts.NoBias;
                _performProjection = parent.Opts.PerformProjection;
                _lambda            = parent.Opts.Lambda;

                if (_noBias)
                {
                    Bias = 0;
                }

                if (predictor == null)
                {
                    VBufferUtils.Densify(ref Weights);
                }

                _weightsUpdate = VBufferUtils.CreateEmpty <float>(numFeatures);
            }
Exemple #4
0
 public TrainState(IChannel ch, int numFeatures, LinearModelParameters predictor, OnlineGradientDescentTrainer parent)
     : base(ch, numFeatures, predictor, parent)
 {
 }
Exemple #5
0
 private protected override TrainStateBase MakeState(IChannel ch, int numFeatures, LinearModelParameters predictor)
 {
     return(new TrainState(ch, numFeatures, predictor, this));
 }
 /// <summary>
 /// Continues the training of a <see cref="LbfgsPoissonRegressionTrainer"/> using an already trained <paramref name="linearModel"/> and returns
 /// a <see cref="RegressionPredictionTransformer{PoissonRegressionModelParameters}"/>.
 /// </summary>
 public RegressionPredictionTransformer <PoissonRegressionModelParameters> Fit(IDataView trainData, LinearModelParameters linearModel)
 => TrainTransformer(trainData, initPredictor: linearModel);
 public WeightsCollection(LinearModelParameters pred)
 {
     Contracts.AssertValue(pred);
     _pred = pred;
 }
Exemple #8
0
 public TrainState(IChannel ch, int numFeatures, LinearModelParameters predictor, AveragedPerceptronTrainer parent)
     : base(ch, numFeatures, predictor, parent)
 {
 }
Exemple #9
0
 /// <summary>
 /// Continues the training of a <see cref="LogisticRegression"/> using an already trained <paramref name="modelParameters"/> and returns
 /// a <see cref="BinaryPredictionTransformer{CalibratedModelParametersBase}"/>.
 /// </summary>
 public BinaryPredictionTransformer <CalibratedModelParametersBase <LinearBinaryModelParameters, PlattCalibrator> > Fit(IDataView trainData, LinearModelParameters modelParameters)
 => TrainTransformer(trainData, initPredictor: modelParameters);
Exemple #10
0
 private protected abstract TModel TrainCore(IChannel ch, RoleMappedData data, LinearModelParameters predictor, int weightSetCount);
Exemple #11
0
        private TPredictor TrainCore(IChannel ch, RoleMappedData data, LinearModelParameters predictor, int weightSetCount)
        {
            int numFeatures   = data.Schema.Feature.Value.Type.GetVectorSize();
            var cursorFactory = new FloatLabelCursor.Factory(data, CursOpt.Label | CursOpt.Features);
            int numThreads    = 1;

            ch.CheckUserArg(numThreads > 0, nameof(_options.NumberOfThreads),
                            "The number of threads must be either null or a positive integer.");

            var             positiveInstanceWeight = _options.PositiveInstanceWeight;
            VBuffer <float> weights = default;
            float           bias    = 0.0f;

            if (predictor != null)
            {
                ((IHaveFeatureWeights)predictor).GetFeatureWeights(ref weights);
                VBufferUtils.Densify(ref weights);
                bias = predictor.Bias;
            }
            else
            {
                weights = VBufferUtils.CreateDense <float>(numFeatures);
            }

            var weightsEditor = VBufferEditor.CreateFromBuffer(ref weights);

            // Reference: Parasail. SymSGD.
            bool tuneLR = _options.LearningRate == null;
            var  lr     = _options.LearningRate ?? 1.0f;

            bool tuneNumLocIter = (_options.UpdateFrequency == null);
            var  numLocIter     = _options.UpdateFrequency ?? 1;

            var l2Const = _options.L2Regularization;
            var piw     = _options.PositiveInstanceWeight;

            // This is state of the learner that is shared with the native code.
            State    state         = new State();
            GCHandle stateGCHandle = default;

            try
            {
                stateGCHandle = GCHandle.Alloc(state, GCHandleType.Pinned);

                state.TotalInstancesProcessed = 0;
                using (InputDataManager inputDataManager = new InputDataManager(this, cursorFactory, ch))
                {
                    bool shouldInitialize = true;
                    using (var pch = Host.StartProgressChannel("Preprocessing"))
                        inputDataManager.LoadAsMuchAsPossible();

                    int iter = 0;
                    if (inputDataManager.IsFullyLoaded)
                    {
                        ch.Info("Data fully loaded into memory.");
                    }
                    using (var pch = Host.StartProgressChannel("Training"))
                    {
                        if (inputDataManager.IsFullyLoaded)
                        {
                            pch.SetHeader(new ProgressHeader(new[] { "iterations" }),
                                          entry => entry.SetProgress(0, state.PassIteration, _options.NumberOfIterations));
                            // If fully loaded, call the SymSGDNative and do not come back until learned for all iterations.
                            Native.LearnAll(inputDataManager, tuneLR, ref lr, l2Const, piw, weightsEditor.Values, ref bias, numFeatures,
                                            _options.NumberOfIterations, numThreads, tuneNumLocIter, ref numLocIter, _options.Tolerance, _options.Shuffle, shouldInitialize,
                                            stateGCHandle, ch.Info);
                            shouldInitialize = false;
                        }
                        else
                        {
                            pch.SetHeader(new ProgressHeader(new[] { "iterations" }),
                                          entry => entry.SetProgress(0, iter, _options.NumberOfIterations));

                            // Since we loaded data in batch sizes, multiple passes over the loaded data is feasible.
                            int numPassesForABatch = inputDataManager.Count / 10000;
                            while (iter < _options.NumberOfIterations)
                            {
                                // We want to train on the final passes thoroughly (without learning on the same batch multiple times)
                                // This is for fine tuning the AUC. Experimentally, we found that 1 or 2 passes is enough
                                int numFinalPassesToTrainThoroughly = 2;
                                // We also do not want to learn for more passes than what the user asked
                                int numPassesForThisBatch = Math.Min(numPassesForABatch, _options.NumberOfIterations - iter - numFinalPassesToTrainThoroughly);
                                // If all of this leaves us with 0 passes, then set numPassesForThisBatch to 1
                                numPassesForThisBatch = Math.Max(1, numPassesForThisBatch);
                                state.PassIteration   = iter;
                                Native.LearnAll(inputDataManager, tuneLR, ref lr, l2Const, piw, weightsEditor.Values, ref bias, numFeatures,
                                                numPassesForThisBatch, numThreads, tuneNumLocIter, ref numLocIter, _options.Tolerance, _options.Shuffle, shouldInitialize,
                                                stateGCHandle, ch.Info);
                                shouldInitialize = false;

                                // Check if we are done with going through the data
                                if (inputDataManager.FinishedTheLoad)
                                {
                                    iter += numPassesForThisBatch;
                                    // Check if more passes are left
                                    if (iter < _options.NumberOfIterations)
                                    {
                                        inputDataManager.RestartLoading(_options.Shuffle, Host);
                                    }
                                }

                                // If more passes are left, load as much as possible
                                if (iter < _options.NumberOfIterations)
                                {
                                    inputDataManager.LoadAsMuchAsPossible();
                                }
                            }
                        }

                        // Maps back the dense features that are mislocated
                        if (numThreads > 1)
                        {
                            Native.MapBackWeightVector(weightsEditor.Values, stateGCHandle);
                        }
                        Native.DeallocateSequentially(stateGCHandle);
                    }
                }
            }
            finally
            {
                if (stateGCHandle.IsAllocated)
                {
                    stateGCHandle.Free();
                }
            }
            return(CreatePredictor(weights, bias));
        }
Exemple #12
0
 /// <summary>
 /// Continues the training of <see cref="SymbolicStochasticGradientDescentClassificationTrainer"/> using an already trained <paramref name="modelParameters"/>
 /// a <see cref="BinaryPredictionTransformer"/>.
 /// </summary>
 public BinaryPredictionTransformer <TPredictor> Fit(IDataView trainData, LinearModelParameters modelParameters)
 => TrainTransformer(trainData, initPredictor: modelParameters);