예제 #1
0
        /// <summary>
        /// This function returns the Learning rate using linear scale rule and LR decay.
        /// </summary>
        internal override float GetLearningRate(TrainState trainstate)
        {
            float learningrate;
            float initialLearningRate = BaseLearningRate * trainstate.BatchSize / 128;

            learningrate = initialLearningRate * GetLearningRateScheduleMultiplier(trainstate.CurrentEpoch);
            return(learningrate);
        }
예제 #2
0
        /// <summary>
        /// Computes exponentially decayed learning rate
        /// </summary>
        internal override float GetLearningRate(TrainState trainstate)
        {
            int numSamplesPerEpoch = trainstate.BatchSize * trainstate.BatchesPerEpoch;

            DecaySteps = (int)(numSamplesPerEpoch * NumEpochsPerDecay / trainstate.BatchSize);
            GlobalStep = (trainstate.CurrentEpoch) * (trainstate.BatchesPerEpoch) + trainstate.CurrentBatchIndex;
            float decayPower = (float)GlobalStep / DecaySteps;

            decayPower = Staircase ? (float)Math.Floor(decayPower) : decayPower;
            float decayedLearningRate = LearningRate * (float)Math.Pow(DecayRate, decayPower);

            return(decayedLearningRate);
        }
예제 #3
0
 internal abstract float GetLearningRate(TrainState options);