示例#1
0
        /*
         * @Theory
         * public void testOptimizeLogLikelihoodWithConstraints(AbstractBatchOptimizer optimizer,
         * @ForAll(sampleSize = 5) @From(LogLikelihoodFunctionTest.GraphicalModelDatasetGenerator.class) GraphicalModel[] dataset,
         * @ForAll(sampleSize = 2) @From(LogLikelihoodFunctionTest.WeightsGenerator.class) ConcatVector initialWeights,
         * @ForAll(sampleSize = 2) @InRange(minDouble = 0.0, maxDouble = 5.0) double l2regularization) throws Exception {
         * Random r = new Random(42);
         *
         * int constraintComponent = r.nextInt(initialWeights.getNumberOfComponents());
         * double constraintValue = r.nextDouble();
         *
         * if (r.nextBoolean()) {
         * optimizer.addSparseConstraint(constraintComponent, 0, constraintValue);
         * } else {
         * optimizer.addDenseConstraint(constraintComponent, new double[]{constraintValue});
         * }
         *
         * // Put in some constraints
         *
         * AbstractDifferentiableFunction<GraphicalModel> ll = new LogLikelihoodDifferentiableFunction();
         * ConcatVector finalWeights = optimizer.optimize(dataset, ll, initialWeights, l2regularization, 1.0e-9, false);
         * System.err.println("Finished optimizing");
         *
         * assertEquals(constraintValue, finalWeights.getValueAt(constraintComponent, 0), 1.0e-9);
         *
         * double logLikelihood = getValueSum(dataset, finalWeights, ll, l2regularization);
         *
         * // Check in a whole bunch of random directions really nearby that there is no nearby point with a higher log
         * // likelihood
         * for (int i = 0; i < 1000; i++) {
         * int size = finalWeights.getNumberOfComponents();
         * ConcatVector randomDirection = new ConcatVector(size);
         * for (int j = 0; j < size; j++) {
         * if (j == constraintComponent) continue;
         * double[] dense = new double[finalWeights.isComponentSparse(j) ? finalWeights.getSparseIndex(j) + 1 : finalWeights.getDenseComponent(j).length];
         * for (int k = 0; k < dense.length; k++) {
         * dense[k] = (r.nextDouble() - 0.5) * 1.0e-3;
         * }
         * randomDirection.setDenseComponent(j, dense);
         * }
         *
         * ConcatVector randomPerturbation = finalWeights.deepClone();
         * randomPerturbation.addVectorInPlace(randomDirection, 1.0);
         *
         * double randomPerturbedLogLikelihood = getValueSum(dataset, randomPerturbation, ll, l2regularization);
         *
         * // Check that we're within a very small margin of error (around 3 decimal places) of the randomly
         * // discovered value
         *
         * if (logLikelihood < randomPerturbedLogLikelihood - (1.0e-3 * Math.max(1.0, Math.abs(logLikelihood)))) {
         * System.err.println("Thought optimal point was: " + logLikelihood);
         * System.err.println("Discovered better point: " + randomPerturbedLogLikelihood);
         * }
         *
         * assertTrue(logLikelihood >= randomPerturbedLogLikelihood - (1.0e-3 * Math.max(1.0, Math.abs(logLikelihood))));
         * }
         * }
         */
        private double GetValueSum <T>(T[] dataset, ConcatVector weights, AbstractDifferentiableFunction <T> fn, double l2regularization)
        {
            double value = 0.0;

            foreach (T t in dataset)
            {
                value += fn.GetSummaryForInstance(t, weights, new ConcatVector(0));
            }
            return((value / dataset.Length) - (weights.DotProduct(weights) * l2regularization));
        }
 public TrainingWorker(AbstractBatchOptimizer _enclosing, T[] dataset, AbstractDifferentiableFunction <T> fn, ConcatVector initialWeights, double l2regularization, double convergenceDerivativeNorm, bool quiet)
 {
     this._enclosing        = _enclosing;
     this.optimizationState = this._enclosing.GetFreshOptimizationState(initialWeights);
     this.weights           = initialWeights.DeepClone();
     this.dataset           = dataset;
     this.fn = fn;
     this.l2regularization          = l2regularization;
     this.convergenceDerivativeNorm = convergenceDerivativeNorm;
     this.quiet = quiet;
 }
 public GradientWorker(AbstractBatchOptimizer.TrainingWorker <T> mainWorker, int threadIdx, int numThreads, IList <T> queue, AbstractDifferentiableFunction <T> fn, ConcatVector weights)
 {
     // This is to help the dynamic re-balancing of work queues
     this.mainWorker = mainWorker;
     this.threadIdx  = threadIdx;
     this.numThreads = numThreads;
     this.queue      = queue;
     this.fn         = fn;
     this.weights    = weights;
     localDerivative = weights.NewEmptyClone();
 }
        public virtual ConcatVector Optimize <T>(T[] dataset, AbstractDifferentiableFunction <T> fn, ConcatVector initialWeights, double l2regularization, double convergenceDerivativeNorm, bool quiet)
        {
            if (!quiet)
            {
                log.Info("\n**************\nBeginning training\n");
            }
            else
            {
                log.Info("[Beginning quiet training]");
            }
            AbstractBatchOptimizer.TrainingWorker <T> mainWorker = new AbstractBatchOptimizer.TrainingWorker <T>(this, dataset, fn, initialWeights, l2regularization, convergenceDerivativeNorm, quiet);
            new Thread(mainWorker).Start();
            BufferedReader br = new BufferedReader(new InputStreamReader(Runtime.@in));

            if (!quiet)
            {
                log.Info("NOTE: you can press any key (and maybe ENTER afterwards to jog stdin) to terminate learning early.");
                log.Info("The convergence criteria are quite aggressive if left uninterrupted, and will run for a while");
                log.Info("if left to their own devices.\n");
                while (true)
                {
                    if (mainWorker.isFinished)
                    {
                        log.Info("training completed without interruption");
                        return(mainWorker.weights);
                    }
                    try
                    {
                        if (br.Ready())
                        {
                            log.Info("received quit command: quitting");
                            log.Info("training completed by interruption");
                            mainWorker.isFinished = true;
                            return(mainWorker.weights);
                        }
                    }
                    catch (IOException e)
                    {
                        Sharpen.Runtime.PrintStackTrace(e);
                    }
                }
            }
            else
            {
                while (!mainWorker.isFinished)
                {
                    lock (mainWorker.naturalTerminationBarrier)
                    {
                        try
                        {
                            Sharpen.Runtime.Wait(mainWorker.naturalTerminationBarrier);
                        }
                        catch (Exception e)
                        {
                            throw new RuntimeInterruptedException(e);
                        }
                    }
                }
                log.Info("[Quiet training complete]");
                return(mainWorker.weights);
            }
        }
 public virtual ConcatVector Optimize <T>(T[] dataset, AbstractDifferentiableFunction <T> fn)
 {
     return(Optimize(dataset, fn, new ConcatVector(0), 0.0, 1.0e-5, false));
 }