public TrainingWorker(AbstractBatchOptimizer _enclosing, T[] dataset, AbstractDifferentiableFunction <T> fn, ConcatVector initialWeights, double l2regularization, double convergenceDerivativeNorm, bool quiet) { this._enclosing = _enclosing; this.optimizationState = this._enclosing.GetFreshOptimizationState(initialWeights); this.weights = initialWeights.DeepClone(); this.dataset = dataset; this.fn = fn; this.l2regularization = l2regularization; this.convergenceDerivativeNorm = convergenceDerivativeNorm; this.quiet = quiet; }
/// <summary>This is the hook for subclassing batch optimizers to override in order to have their optimizer work.</summary> /// <param name="weights">the current weights (update these in place)</param> /// <param name="gradient">the gradient at these weights</param> /// <param name="logLikelihood">the log likelihood at these weights</param> /// <param name="state">any saved state the optimizer wants to keep and pass around during each optimization run</param> /// <param name="quiet">whether or not to dump output about progress to the console</param> /// <returns>whether or not we've converged</returns> public abstract bool UpdateWeights(ConcatVector weights, ConcatVector gradient, double logLikelihood, AbstractBatchOptimizer.OptimizationState state, bool quiet);
// this magic number was arrived at with relation to the CoNLL benchmark, and tinkering public override bool UpdateWeights(ConcatVector weights, ConcatVector gradient, double logLikelihood, AbstractBatchOptimizer.OptimizationState optimizationState, bool quiet) { BacktrackingAdaGradOptimizer.AdaGradOptimizationState s = (BacktrackingAdaGradOptimizer.AdaGradOptimizationState)optimizationState; double logLikelihoodChange = logLikelihood - s.lastLogLikelihood; if (logLikelihoodChange == 0) { if (!quiet) { log.Info("\tlogLikelihood improvement = 0: quitting"); } return(true); } else { // Check if we should backtrack if (logLikelihoodChange < 0) { // If we should, move the weights back by half, and cut the lastDerivative by half s.lastDerivative.MapInPlace(null); weights.AddVectorInPlace(s.lastDerivative, -1.0); if (!quiet) { log.Info("\tBACKTRACK..."); } // if the lastDerivative norm falls below a threshold, it means we've converged if (s.lastDerivative.DotProduct(s.lastDerivative) < 1.0e-10) { if (!quiet) { log.Info("\tBacktracking derivative norm " + s.lastDerivative.DotProduct(s.lastDerivative) + " < 1.0e-9: quitting"); } return(true); } } else { // Apply AdaGrad ConcatVector squared = gradient.DeepClone(); squared.MapInPlace(null); s.adagradAccumulator.AddVectorInPlace(squared, 1.0); ConcatVector sqrt = s.adagradAccumulator.DeepClone(); sqrt.MapInPlace(null); gradient.ElementwiseProductInPlace(sqrt); weights.AddVectorInPlace(gradient, 1.0); // Setup for backtracking, in case necessary s.lastDerivative = gradient; s.lastLogLikelihood = logLikelihood; if (!quiet) { log.Info("\tLL: " + logLikelihood); } } } return(false); }