예제 #1
0
        /// <summary>
        /// This function returns the Learning rate using linear scale rule and LR decay.
        /// </summary>
        internal override float GetLearningRate(DnnTrainState trainstate)
        {
            float learningrate;
            float initialLearningRate = BaseLearningRate * trainstate.BatchSize / 128;

            learningrate = initialLearningRate * GetLearningRateScheduleMultiplier(trainstate.CurrentEpoch);
            return(learningrate);
        }
예제 #2
0
        /// <summary>
        /// Computes exponentially decayed learning rate
        /// </summary>
        internal override float GetLearningRate(DnnTrainState trainstate)
        {
            int numSamplesPerEpoch = trainstate.BatchSize * trainstate.BatchesPerEpoch;

            DecaySteps = (int)(numSamplesPerEpoch * NumEpochsPerDecay / trainstate.BatchSize);
            GlobalStep = (trainstate.CurrentEpoch) * (trainstate.BatchesPerEpoch) + trainstate.CurrentBatchIndex;
            float decayPower = (float)GlobalStep / DecaySteps;

            decayPower = Staircase ? (float)Math.Floor(decayPower) : decayPower;
            float decayedLearningRate = LearningRate * (float)Math.Pow(DecayRate, decayPower);

            return(decayedLearningRate);
        }
예제 #3
0
        internal override float GetLearningRate(DnnTrainState trainstate)
        {
            int numSamplesPerEpoch = trainstate.BatchSize * trainstate.BatchesPerEpoch;
            int decaySteps         = (int)(numSamplesPerEpoch * NumEpochsPerDecay / trainstate.BatchSize);
            int globalStep         = (trainstate.CurrentEpoch) * (trainstate.BatchesPerEpoch) + trainstate.CurrentBatchIndex;

            float decayedLearningRate;

            if (Cycle && globalStep > decaySteps)
            {
                float calculatedStep = (float)decaySteps * (float)Math.Ceiling((double)globalStep / (double)decaySteps);
                decayedLearningRate = (LearningRate - EndLearningRate) * ((float)Math.Pow((1 - (float)globalStep / calculatedStep), Power)) + EndLearningRate;
            }
            else
            {
                float calculatedStep = Math.Min(globalStep, decaySteps);
                decayedLearningRate = (LearningRate - EndLearningRate) * ((float)Math.Pow((1 - calculatedStep / (float)decaySteps), Power)) + EndLearningRate;
            }
            return(decayedLearningRate);
        }
예제 #4
0
 internal abstract float GetLearningRate(DnnTrainState options);