Exemplo n.º 1
0
 protected internal override void TakeStep(AbstractStochasticCachingDiffFunction dfunction)
 {
     for (int i = 0; i < x.Length; i++)
     {
         double thisGain = fixedGain * GainSchedule(k, 5 * numBatches) / (diag[i]);
         newX[i] = x[i] - thisGain * grad[i];
     }
     //Get a new pair...
     Say(" A ");
     double[] s;
     double[] y;
     if (pairMem > 0 && sList.Count == pairMem || sList.Count == pairMem)
     {
         s = sList.Remove(0);
         y = yList.Remove(0);
     }
     else
     {
         s = new double[x.Length];
         y = new double[x.Length];
     }
     s = ArrayMath.PairwiseSubtract(newX, x);
     dfunction.recalculatePrevBatch = true;
     System.Array.Copy(dfunction.DerivativeAt(newX, bSize), 0, y, 0, grad.Length);
     ArrayMath.PairwiseSubtractInPlace(y, newGrad);
     // newY = newY-newGrad
     double[] comp = new double[x.Length];
     sList.Add(s);
     yList.Add(y);
     UpdateDiag(diag, s, y);
 }
        protected internal override void TakeStep(AbstractStochasticCachingDiffFunction dfunction)
        {
            try
            {
                ComputeDir(dir, newGrad);
            }
            catch (SQNMinimizer.SurpriseConvergence)
            {
                ClearStuff();
            }
            double thisGain = gain * GainSchedule(k, 5 * numBatches);

            for (int i = 0; i < x.Length; i++)
            {
                newX[i] = x[i] + thisGain * dir[i];
            }
            //Get a new pair...
            Say(" A ");
            if (M > 0 && sList.Count == M || sList.Count == M)
            {
                s = sList.Remove(0);
                y = yList.Remove(0);
            }
            else
            {
                s = new double[x.Length];
                y = new double[x.Length];
            }
            dfunction.recalculatePrevBatch = true;
            System.Array.Copy(dfunction.DerivativeAt(newX, bSize), 0, y, 0, grad.Length);
            // compute s_k, y_k
            ro = 0;
            for (int i_1 = 0; i_1 < x.Length; i_1++)
            {
                s[i_1] = newX[i_1] - x[i_1];
                y[i_1] = y[i_1] - newGrad[i_1] + lambda * s[i_1];
                ro    += s[i_1] * y[i_1];
            }
            ro = 1.0 / ro;
            sList.Add(s);
            yList.Add(y);
            roList.Add(ro);
        }
Exemplo n.º 3
0
        /// <summary>
        /// This function tests to make sure that the sum of the stochastic calculated gradients is equal to the
        /// full gradient.
        /// </summary>
        /// <remarks>
        /// This function tests to make sure that the sum of the stochastic calculated gradients is equal to the
        /// full gradient.  This requires using ordered sampling, so if the ObjectiveFunction itself randomizes
        /// the inputs this function will likely fail.
        /// </remarks>
        /// <param name="x">is the point to evaluate the function at</param>
        /// <param name="functionTolerance">is the tolerance to place on the infinity norm of the gradient and value</param>
        /// <returns>boolean indicating success or failure.</returns>
        public virtual bool TestSumOfBatches(double[] x, double functionTolerance)
        {
            bool ret = false;

            log.Info("Making sure that the sum of stochastic gradients equals the full gradient");
            AbstractStochasticCachingDiffFunction.SamplingMethod tmpSampleMethod = thisFunc.sampleMethod;
            StochasticCalculateMethods tmpMethod = thisFunc.method;

            //Make sure that our function is using ordered sampling.  Otherwise we have no gaurentees.
            thisFunc.sampleMethod = AbstractStochasticCachingDiffFunction.SamplingMethod.Ordered;
            if (thisFunc.method == StochasticCalculateMethods.NoneSpecified)
            {
                log.Info("No calculate method has been specified");
            }
            approxValue = 0;
            approxGrad  = new double[x.Length];
            curGrad     = new double[x.Length];
            fullGrad    = new double[x.Length];
            double percent = 0.0;

            //This loop runs through all the batches and sums of the calculations to compare against the full gradient
            for (int i = 0; i < numBatches; i++)
            {
                percent = 100 * ((double)i) / (numBatches);
                //  update the value
                approxValue += thisFunc.ValueAt(x, v, testBatchSize);
                //  update the gradient
                thisFunc.returnPreviousValues = true;
                System.Array.Copy(thisFunc.DerivativeAt(x, v, testBatchSize), 0, curGrad, 0, curGrad.Length);
                //Update Approximate
                approxGrad = ArrayMath.PairwiseAdd(approxGrad, curGrad);
                double norm = ArrayMath.Norm(approxGrad);
                System.Console.Error.Printf("%5.1f percent complete  %6.2f \n", percent, norm);
            }
            // !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
            // Get the full gradient and value, these should equal the approximates
            // !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
            log.Info("About to calculate the full derivative and value");
            System.Array.Copy(thisFunc.DerivativeAt(x), 0, fullGrad, 0, fullGrad.Length);
            thisFunc.returnPreviousValues = true;
            fullValue = thisFunc.ValueAt(x);
            diff      = new double[x.Length];
            if ((ArrayMath.Norm_inf(diff = ArrayMath.PairwiseSubtract(fullGrad, approxGrad))) < functionTolerance)
            {
                Sayln(string.Empty);
                Sayln("Success: sum of batch gradients equals full gradient");
                ret = true;
            }
            else
            {
                diffNorm = ArrayMath.Norm(diff);
                Sayln(string.Empty);
                Sayln("Failure: sum of batch gradients minus full gradient has norm " + diffNorm);
                ret = false;
            }
            if (System.Math.Abs(approxValue - fullValue) < functionTolerance)
            {
                Sayln(string.Empty);
                Sayln("Success: sum of batch values equals full value");
                ret = true;
            }
            else
            {
                Sayln(string.Empty);
                Sayln("Failure: sum of batch values minus full value has norm " + System.Math.Abs(approxValue - fullValue));
                ret = false;
            }
            thisFunc.sampleMethod = tmpSampleMethod;
            thisFunc.method       = tmpMethod;
            return(ret);
        }
        public virtual double[] Minimize(Func function, double functionTolerance, double[] initial, int maxIterations)
        {
            // check for stochastic derivatives
            if (!(function is AbstractStochasticCachingDiffFunction))
            {
                throw new NotSupportedException();
            }
            AbstractStochasticCachingDiffFunction dfunction = (AbstractStochasticCachingDiffFunction)function;

            dfunction.method = StochasticCalculateMethods.GradientOnly;

            /* ---
            *  StochasticDiffFunctionTester sdft = new StochasticDiffFunctionTester(dfunction);
            *  ArrayMath.add(initial, gen.nextDouble() ); // to make sure that priors are working.
            *  sdft.testSumOfBatches(initial, 1e-4);
            *  System.exit(1);
            *  --- */
            x               = initial;
            grad            = new double[x.Length];
            newX            = new double[x.Length];
            gradList        = new List <double[]>();
            numBatches      = dfunction.DataDimension() / bSize;
            outputFrequency = (int)System.Math.Ceil(((double)numBatches) / ((double)outputFrequency));
            Init(dfunction);
            InitFiles();
            bool have_max = (maxIterations > 0 || numPasses > 0);

            if (!have_max)
            {
                throw new NotSupportedException("No maximum number of iterations has been specified.");
            }
            else
            {
                maxIterations = System.Math.Max(maxIterations, numPasses) * numBatches;
            }
            Sayln("       Batchsize of: " + bSize);
            Sayln("       Data dimension of: " + dfunction.DataDimension());
            Sayln("       Batches per pass through data:  " + numBatches);
            Sayln("       Max iterations is = " + maxIterations);
            if (outputIterationsToFile)
            {
                infoFile.Println(function.DomainDimension() + "; DomainDimension ");
                infoFile.Println(bSize + "; batchSize ");
                infoFile.Println(maxIterations + "; maxIterations");
                infoFile.Println(numBatches + "; numBatches ");
                infoFile.Println(outputFrequency + "; outputFrequency");
            }
            //!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
            //!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
            //            Loop
            //!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
            //!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
            Timing total   = new Timing();
            Timing current = new Timing();

            total.Start();
            current.Start();
            for (k = 0; k < maxIterations; k++)
            {
                try
                {
                    bool doEval = (k > 0 && evaluateIters > 0 && k % evaluateIters == 0);
                    if (doEval)
                    {
                        DoEvaluation(x);
                    }
                    int pass  = k / numBatches;
                    int batch = k % numBatches;
                    Say("Iter: " + k + " pass " + pass + " batch " + batch);
                    // restrict number of saved gradients
                    //  (recycle memory of first gradient in list for new gradient)
                    if (k > 0 && gradList.Count >= memory)
                    {
                        newGrad = gradList.Remove(0);
                    }
                    else
                    {
                        newGrad = new double[grad.Length];
                    }
                    dfunction.hasNewVals = true;
                    System.Array.Copy(dfunction.DerivativeAt(x, v, bSize), 0, newGrad, 0, newGrad.Length);
                    ArrayMath.AssertFinite(newGrad, "newGrad");
                    gradList.Add(newGrad);
                    grad = Smooth(gradList);
                    //Get the next X
                    TakeStep(dfunction);
                    ArrayMath.AssertFinite(newX, "newX");
                    //!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
                    // THIS IS FOR DEBUG ONLY
                    //!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
                    if (outputIterationsToFile && (k % outputFrequency == 0) && k != 0)
                    {
                        double curVal = dfunction.ValueAt(x);
                        Say(" TrueValue{ " + curVal + " } ");
                        file.Println(k + " , " + curVal + " , " + total.Report());
                    }
                    //!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
                    // END OF DEBUG STUFF
                    //!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
                    if (k >= maxIterations)
                    {
                        Sayln("Stochastic Optimization complete.  Stopped after max iterations");
                        x = newX;
                        break;
                    }
                    if (total.Report() >= maxTime)
                    {
                        Sayln("Stochastic Optimization complete.  Stopped after max time");
                        x = newX;
                        break;
                    }
                    System.Array.Copy(newX, 0, x, 0, x.Length);
                    Say("[" + (total.Report()) / 1000.0 + " s ");
                    Say("{" + (current.Restart() / 1000.0) + " s}] ");
                    Say(" " + dfunction.LastValue());
                    if (quiet)
                    {
                        log.Info(".");
                    }
                    else
                    {
                        Sayln(string.Empty);
                    }
                }
                catch (ArrayMath.InvalidElementException e)
                {
                    log.Info(e.ToString());
                    for (int i = 0; i < x.Length; i++)
                    {
                        x[i] = double.NaN;
                    }
                    break;
                }
            }
            if (evaluateIters > 0)
            {
                // do final evaluation
                DoEvaluation(x);
            }
            if (outputIterationsToFile)
            {
                infoFile.Println(k + "; Iterations");
                infoFile.Println((total.Report()) / 1000.0 + "; Completion Time");
                infoFile.Println(dfunction.ValueAt(x) + "; Finalvalue");
                infoFile.Close();
                file.Close();
                log.Info("Output Files Closed");
            }
            //System.exit(1);
            Say("Completed in: " + (total.Report()) / 1000.0 + " s");
            return(x);
        }