/// <summary> /// This function returns the Learning rate using linear scale rule and LR decay. /// </summary> internal override float GetLearningRate(TrainState trainstate) { float learningrate; float initialLearningRate = BaseLearningRate * trainstate.BatchSize / 128; learningrate = initialLearningRate * GetLearningRateScheduleMultiplier(trainstate.CurrentEpoch); return(learningrate); }
/// <summary> /// Computes exponentially decayed learning rate /// </summary> internal override float GetLearningRate(TrainState trainstate) { int numSamplesPerEpoch = trainstate.BatchSize * trainstate.BatchesPerEpoch; DecaySteps = (int)(numSamplesPerEpoch * NumEpochsPerDecay / trainstate.BatchSize); GlobalStep = (trainstate.CurrentEpoch) * (trainstate.BatchesPerEpoch) + trainstate.CurrentBatchIndex; float decayPower = (float)GlobalStep / DecaySteps; decayPower = Staircase ? (float)Math.Floor(decayPower) : decayPower; float decayedLearningRate = LearningRate * (float)Math.Pow(DecayRate, decayPower); return(decayedLearningRate); }
internal abstract float GetLearningRate(TrainState options);