private protected AveragedTrainStateBase(IChannel ch, int numFeatures, LinearPredictor predictor, AveragedLinearTrainer <TTransformer, TModel> parent) : base(ch, numFeatures, predictor, parent) { // Do the other initializations by setting the setters as if user had set them // Initialize the averaged weights if needed (i.e., do what happens when Averaged is set) Averaged = parent.Args.Averaged; if (Averaged) { if (parent.Args.AveragedTolerance > 0) { VBufferUtils.Densify(ref Weights); } Weights.CopyTo(ref TotalWeights); } else { // It is definitely advantageous to keep weights dense if we aren't adding them // to another vector with each update. VBufferUtils.Densify(ref Weights); } _resetWeightsAfterXExamples = parent.Args.ResetWeightsAfterXExamples ?? 0; _args = parent.Args; _loss = parent.LossFunction; Gain = 1; }
protected AveragedLinearTrainer(AveragedLinearArguments args, IHostEnvironment env, string name, SchemaShape.Column label) : base(args, env, name, label) { Contracts.CheckUserArg(args.LearningRate > 0, nameof(args.LearningRate), UserErrorPositive); Contracts.CheckUserArg(!args.ResetWeightsAfterXExamples.HasValue || args.ResetWeightsAfterXExamples > 0, nameof(args.ResetWeightsAfterXExamples), UserErrorPositive); // Weights are scaled down by 2 * L2 regularization on each update step, so 0.5 would scale all weights to 0, which is not sensible. Contracts.CheckUserArg(0 <= args.L2RegularizerWeight && args.L2RegularizerWeight < 0.5, nameof(args.L2RegularizerWeight), "must be in range [0, 0.5)"); Contracts.CheckUserArg(args.RecencyGain >= 0, nameof(args.RecencyGain), UserErrorNonNegative); Contracts.CheckUserArg(args.AveragedTolerance >= 0, nameof(args.AveragedTolerance), UserErrorNonNegative); Args = args; }