/* * @Theory * public void testOptimizeLogLikelihoodWithConstraints(AbstractBatchOptimizer optimizer, * @ForAll(sampleSize = 5) @From(LogLikelihoodFunctionTest.GraphicalModelDatasetGenerator.class) GraphicalModel[] dataset, * @ForAll(sampleSize = 2) @From(LogLikelihoodFunctionTest.WeightsGenerator.class) ConcatVector initialWeights, * @ForAll(sampleSize = 2) @InRange(minDouble = 0.0, maxDouble = 5.0) double l2regularization) throws Exception { * Random r = new Random(42); * * int constraintComponent = r.nextInt(initialWeights.getNumberOfComponents()); * double constraintValue = r.nextDouble(); * * if (r.nextBoolean()) { * optimizer.addSparseConstraint(constraintComponent, 0, constraintValue); * } else { * optimizer.addDenseConstraint(constraintComponent, new double[]{constraintValue}); * } * * // Put in some constraints * * AbstractDifferentiableFunction<GraphicalModel> ll = new LogLikelihoodDifferentiableFunction(); * ConcatVector finalWeights = optimizer.optimize(dataset, ll, initialWeights, l2regularization, 1.0e-9, false); * System.err.println("Finished optimizing"); * * assertEquals(constraintValue, finalWeights.getValueAt(constraintComponent, 0), 1.0e-9); * * double logLikelihood = getValueSum(dataset, finalWeights, ll, l2regularization); * * // Check in a whole bunch of random directions really nearby that there is no nearby point with a higher log * // likelihood * for (int i = 0; i < 1000; i++) { * int size = finalWeights.getNumberOfComponents(); * ConcatVector randomDirection = new ConcatVector(size); * for (int j = 0; j < size; j++) { * if (j == constraintComponent) continue; * double[] dense = new double[finalWeights.isComponentSparse(j) ? finalWeights.getSparseIndex(j) + 1 : finalWeights.getDenseComponent(j).length]; * for (int k = 0; k < dense.length; k++) { * dense[k] = (r.nextDouble() - 0.5) * 1.0e-3; * } * randomDirection.setDenseComponent(j, dense); * } * * ConcatVector randomPerturbation = finalWeights.deepClone(); * randomPerturbation.addVectorInPlace(randomDirection, 1.0); * * double randomPerturbedLogLikelihood = getValueSum(dataset, randomPerturbation, ll, l2regularization); * * // Check that we're within a very small margin of error (around 3 decimal places) of the randomly * // discovered value * * if (logLikelihood < randomPerturbedLogLikelihood - (1.0e-3 * Math.max(1.0, Math.abs(logLikelihood)))) { * System.err.println("Thought optimal point was: " + logLikelihood); * System.err.println("Discovered better point: " + randomPerturbedLogLikelihood); * } * * assertTrue(logLikelihood >= randomPerturbedLogLikelihood - (1.0e-3 * Math.max(1.0, Math.abs(logLikelihood)))); * } * } */ private double GetValueSum <T>(T[] dataset, ConcatVector weights, AbstractDifferentiableFunction <T> fn, double l2regularization) { double value = 0.0; foreach (T t in dataset) { value += fn.GetSummaryForInstance(t, weights, new ConcatVector(0)); } return((value / dataset.Length) - (weights.DotProduct(weights) * l2regularization)); }
public TrainingWorker(AbstractBatchOptimizer _enclosing, T[] dataset, AbstractDifferentiableFunction <T> fn, ConcatVector initialWeights, double l2regularization, double convergenceDerivativeNorm, bool quiet) { this._enclosing = _enclosing; this.optimizationState = this._enclosing.GetFreshOptimizationState(initialWeights); this.weights = initialWeights.DeepClone(); this.dataset = dataset; this.fn = fn; this.l2regularization = l2regularization; this.convergenceDerivativeNorm = convergenceDerivativeNorm; this.quiet = quiet; }
public GradientWorker(AbstractBatchOptimizer.TrainingWorker <T> mainWorker, int threadIdx, int numThreads, IList <T> queue, AbstractDifferentiableFunction <T> fn, ConcatVector weights) { // This is to help the dynamic re-balancing of work queues this.mainWorker = mainWorker; this.threadIdx = threadIdx; this.numThreads = numThreads; this.queue = queue; this.fn = fn; this.weights = weights; localDerivative = weights.NewEmptyClone(); }
public virtual ConcatVector Optimize <T>(T[] dataset, AbstractDifferentiableFunction <T> fn, ConcatVector initialWeights, double l2regularization, double convergenceDerivativeNorm, bool quiet) { if (!quiet) { log.Info("\n**************\nBeginning training\n"); } else { log.Info("[Beginning quiet training]"); } AbstractBatchOptimizer.TrainingWorker <T> mainWorker = new AbstractBatchOptimizer.TrainingWorker <T>(this, dataset, fn, initialWeights, l2regularization, convergenceDerivativeNorm, quiet); new Thread(mainWorker).Start(); BufferedReader br = new BufferedReader(new InputStreamReader(Runtime.@in)); if (!quiet) { log.Info("NOTE: you can press any key (and maybe ENTER afterwards to jog stdin) to terminate learning early."); log.Info("The convergence criteria are quite aggressive if left uninterrupted, and will run for a while"); log.Info("if left to their own devices.\n"); while (true) { if (mainWorker.isFinished) { log.Info("training completed without interruption"); return(mainWorker.weights); } try { if (br.Ready()) { log.Info("received quit command: quitting"); log.Info("training completed by interruption"); mainWorker.isFinished = true; return(mainWorker.weights); } } catch (IOException e) { Sharpen.Runtime.PrintStackTrace(e); } } } else { while (!mainWorker.isFinished) { lock (mainWorker.naturalTerminationBarrier) { try { Sharpen.Runtime.Wait(mainWorker.naturalTerminationBarrier); } catch (Exception e) { throw new RuntimeInterruptedException(e); } } } log.Info("[Quiet training complete]"); return(mainWorker.weights); } }
public virtual ConcatVector Optimize <T>(T[] dataset, AbstractDifferentiableFunction <T> fn) { return(Optimize(dataset, fn, new ConcatVector(0), 0.0, 1.0e-5, false)); }