예제 #1
0
        protected override void InitCore(IChannel ch, int numFeatures, LinearPredictor predictor)
        {
            base.InitCore(ch, numFeatures, predictor);

            // Verify user didn't specify parameters that conflict
            Contracts.Check(!Args.DoLazyUpdates || !Args.RecencyGainMulti && Args.RecencyGain == 0,
                            "Cannot have both recency gain and lazy updates.");

            // Do the other initializations by setting the setters as if user had set them
            // Initialize the averaged weights if needed (i.e., do what happens when Averaged is set)
            if (Args.Averaged)
            {
                if (Args.AveragedTolerance > 0)
                {
                    VBufferUtils.Densify(ref Weights);
                }
                Weights.CopyTo(ref TotalWeights);
            }
            else
            {
                // It is definitely advantageous to keep weights dense if we aren't adding them
                // to another vector with each update.
                VBufferUtils.Densify(ref Weights);
            }
            Gain = 1;
        }
        private void TrainEx(RoleMappedData data, LinearPredictor predictor)
        {
            Contracts.AssertValue(data, nameof(data));
            Contracts.AssertValueOrNull(predictor);

            int numFeatures;

            data.CheckFeatureFloatVector(out numFeatures);
            CheckLabel(data);

            using (var ch = Host.Start("Training"))
            {
                InitCore(ch, numFeatures, predictor);
                // InitCore should set the number of features field.
                Contracts.Assert(NumFeatures > 0);

                TrainCore(ch, data);

                if (NumBad > 0)
                {
                    ch.Warning(
                        "Skipped {0} instances with missing features during training (over {1} iterations; {2} inst/iter)",
                        NumBad, Args.NumIterations, NumBad / Args.NumIterations);
                }

                Contracts.Assert(WeightsScale == 1);
                Float maxNorm = Math.Max(VectorUtils.MaxNorm(ref Weights), Math.Abs(Bias));
                Contracts.Check(FloatUtils.IsFinite(maxNorm),
                                "The weights/bias contain invalid values (NaN or Infinite). Potential causes: high learning rates, no normalization, high initial weights, etc.");

                ch.Done();
            }
        }
        protected virtual void InitCore(IChannel ch, int numFeatures, LinearPredictor predictor)
        {
            Contracts.Check(numFeatures > 0, "Can't train with zero features!");
            Contracts.Check(NumFeatures == 0, "Can't re-use trainer!");
            Contracts.Assert(Iteration == 0);
            Contracts.Assert(Bias == 0);

            ch.Trace("{0} Initializing {1} on {2} features", DateTime.UtcNow, Name, numFeatures);
            NumFeatures = numFeatures;

            // We want a dense vector, to prevent memory creation during training
            // unless we have a lot of features.
            // REVIEW: make a setting
            if (predictor != null)
            {
                predictor.GetFeatureWeights(ref Weights);
                VBufferUtils.Densify(ref Weights);
                Bias = predictor.Bias;
            }
            else if (!string.IsNullOrWhiteSpace(Args.InitialWeights))
            {
                ch.Info("Initializing weights and bias to " + Args.InitialWeights);
                string[] weightStr = Args.InitialWeights.Split(',');
                if (weightStr.Length != NumFeatures + 1)
                {
                    throw Contracts.Except(
                              "Could not initialize weights from 'initialWeights': expecting {0} values to initialize {1} weights and the intercept",
                              NumFeatures + 1, NumFeatures);
                }

                Weights = VBufferUtils.CreateDense <Float>(NumFeatures);
                for (int i = 0; i < NumFeatures; i++)
                {
                    Weights.Values[i] = Float.Parse(weightStr[i], CultureInfo.InvariantCulture);
                }
                Bias = Float.Parse(weightStr[NumFeatures], CultureInfo.InvariantCulture);
            }
            else if (Args.InitWtsDiameter > 0)
            {
                Weights = VBufferUtils.CreateDense <Float>(NumFeatures);
                for (int i = 0; i < NumFeatures; i++)
                {
                    Weights.Values[i] = Args.InitWtsDiameter * (Host.Rand.NextSingle() - (Float)0.5);
                }
                Bias = Args.InitWtsDiameter * (Host.Rand.NextSingle() - (Float)0.5);
            }
            else if (NumFeatures <= 1000)
            {
                Weights = VBufferUtils.CreateDense <Float>(NumFeatures);
            }
            else
            {
                Weights = VBufferUtils.CreateEmpty <Float>(NumFeatures);
            }
            WeightsScale = 1;
        }
        public void Train(RoleMappedData data, IPredictor predictor)
        {
            Host.CheckValue(data, nameof(data));
            Host.CheckValue(predictor, nameof(predictor));
            LinearPredictor pred = (predictor as CalibratedPredictorBase)?.SubPredictor as LinearPredictor;

            pred = pred ?? predictor as LinearPredictor;
            Host.CheckParam(pred != null, nameof(predictor), "Not a linear predictor.");
            TrainEx(data, pred);
        }
예제 #5
0
        protected override void InitCore(IChannel ch, int numFeatures, LinearPredictor predictor)
        {
            base.InitCore(ch, numFeatures, predictor);

            if (Args.NoBias)
            {
                Bias = 0;
            }

            if (predictor == null)
            {
                VBufferUtils.Densify(ref Weights);
            }

            _weightsUpdate = VBufferUtils.CreateEmpty <Float>(numFeatures);
        }
 public WeightsCollection(LinearPredictor pred)
 {
     Contracts.AssertValue(pred);
     _pred = pred;
 }
 protected abstract TModel TrainCore(IChannel ch, RoleMappedData data, LinearPredictor predictor, int weightSetCount);