예제 #1
0
        public virtual double GetObjective(AbstractStochasticCachingDiffUpdateFunction function, double[] w, double wscale, int[] sample)
        {
            double wnorm = GetNorm(w) * wscale * wscale;
            double obj   = function.ValueAt(w, wscale, sample);

            // Calculate objective with L2 regularization
            return(obj + 0.5 * sample.Length * lambda * wnorm);
        }
        public virtual double[] Minimize(IDiffFunction f, double functionTolerance, double[] initial, int maxIterations)
        {
            int totalSamples = 0;

            Sayln("Using lambda=" + lambda);
            if (f is AbstractStochasticCachingDiffUpdateFunction)
            {
                AbstractStochasticCachingDiffUpdateFunction func = (AbstractStochasticCachingDiffUpdateFunction)f;
                func.sampleMethod = AbstractStochasticCachingDiffFunction.SamplingMethod.Shuffled;
                totalSamples      = func.DataDimension();
                if (bSize > totalSamples)
                {
                    log.Info("WARNING: Total number of samples=" + totalSamples + " is smaller than requested batch size=" + bSize + "!!!");
                    bSize = totalSamples;
                    Sayln("Using batch size=" + bSize);
                }
                if (bSize <= 0)
                {
                    log.Info("WARNING: Requested batch size=" + bSize + " <= 0 !!!");
                    bSize = totalSamples;
                    Sayln("Using batch size=" + bSize);
                }
            }
            x = new double[initial.Length];
            double[] testUpdateCache  = null;
            double[] currentRateCache = null;
            double[] bCache           = null;
            sumGradSquare = new double[initial.Length];
            prevGrad      = new double[initial.Length];
            prevDeltaX    = new double[initial.Length];
            if (useAdaDelta)
            {
                sumDeltaXSquare = new double[initial.Length];
                if (prior != SGDWithAdaGradAndFOBOS.Prior.None && prior != SGDWithAdaGradAndFOBOS.Prior.Gaussian)
                {
                    throw new NotSupportedException("useAdaDelta is currently only supported for Prior.NONE or Prior.GAUSSIAN");
                }
            }
            int[][] featureGrouping = null;
            if (prior != SGDWithAdaGradAndFOBOS.Prior.Lasso && prior != SGDWithAdaGradAndFOBOS.Prior.None)
            {
                testUpdateCache  = new double[initial.Length];
                currentRateCache = new double[initial.Length];
            }
            if (prior != SGDWithAdaGradAndFOBOS.Prior.Lasso && prior != SGDWithAdaGradAndFOBOS.Prior.Ridge && prior != SGDWithAdaGradAndFOBOS.Prior.Gaussian)
            {
                if (!(f is IHasFeatureGrouping))
                {
                    throw new NotSupportedException("prior is specified to be ae-lasso or g-lasso, but function does not support feature grouping");
                }
                featureGrouping = ((IHasFeatureGrouping)f).GetFeatureGrouping();
            }
            if (prior == SGDWithAdaGradAndFOBOS.Prior.sgLASSO)
            {
                bCache = new double[initial.Length];
            }
            System.Array.Copy(initial, 0, x, 0, x.Length);
            int numBatches = 1;

            if (f is AbstractStochasticCachingDiffUpdateFunction)
            {
                if (totalSamples > 0)
                {
                    numBatches = totalSamples / bSize;
                }
            }
            bool have_max = (maxIterations > 0 || numPasses > 0);

            if (!have_max)
            {
                throw new NotSupportedException("No maximum number of iterations has been specified.");
            }
            else
            {
                maxIterations = Math.Max(maxIterations, numPasses * numBatches);
            }
            Sayln("       Batch size of: " + bSize);
            Sayln("       Data dimension of: " + totalSamples);
            Sayln("       Batches per pass through data:  " + numBatches);
            Sayln("       Number of passes is = " + numPasses);
            Sayln("       Max iterations is = " + maxIterations);
            //!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
            //!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
            //            Loop
            //!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
            //!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
            Timing total   = new Timing();
            Timing current = new Timing();

            total.Start();
            current.Start();
            int            iters       = 0;
            double         gValue      = 0;
            double         wValue      = 0;
            double         currentRate = 0;
            double         testUpdate  = 0;
            double         realUpdate  = 0;
            IList <double> values      = null;
            double         oldObjVal   = 0;

            for (int pass = 0; pass < numPasses; pass++)
            {
                bool   doEval    = (pass > 0 && evaluateIters > 0 && pass % evaluateIters == 0);
                double evalScore = double.NegativeInfinity;
                if (doEval)
                {
                    evalScore = DoEvaluation(x);
                    if (useEvalImprovement && !ToContinue(x, evalScore))
                    {
                        break;
                    }
                }
                // TODO: currently objVal is only updated for GAUSSIAN prior
                // when other priors are used, objVal only reflects the un-regularized obj value
                double objVal   = double.NegativeInfinity;
                double objDelta = double.NegativeInfinity;
                Say("Iter: " + iters + " pass " + pass + " batch 1 ... ");
                int    numOfNonZero      = 0;
                int    numOfNonZeroGroup = 0;
                string gSizeStr          = string.Empty;
                for (int batch = 0; batch < numBatches; batch++)
                {
                    iters++;
                    //Get the next gradients
                    // log.info("getting gradients");
                    double[] gradients = null;
                    if (f is AbstractStochasticCachingDiffUpdateFunction)
                    {
                        AbstractStochasticCachingDiffUpdateFunction func = (AbstractStochasticCachingDiffUpdateFunction)f;
                        if (bSize == totalSamples)
                        {
                            objVal    = func.ValueAt(x);
                            gradients = func.GetDerivative();
                            objDelta  = objVal - oldObjVal;
                            oldObjVal = objVal;
                            if (values == null)
                            {
                                values = new List <double>();
                            }
                            values.Add(objVal);
                        }
                        else
                        {
                            func.CalculateStochasticGradient(x, bSize);
                            gradients = func.GetDerivative();
                        }
                    }
                    else
                    {
                        if (f is AbstractCachingDiffFunction)
                        {
                            AbstractCachingDiffFunction func = (AbstractCachingDiffFunction)f;
                            gradients = func.DerivativeAt(x);
                        }
                    }
                    // log.info("applying regularization");
                    if (prior == SGDWithAdaGradAndFOBOS.Prior.None || prior == SGDWithAdaGradAndFOBOS.Prior.Gaussian)
                    {
                        // Gaussian prior is also handled in objective
                        for (int index = 0; index < x.Length; index++)
                        {
                            gValue      = gradients[index];
                            currentRate = ComputeLearningRate(index, gValue);
                            // arrive at x(t+1/2)
                            wValue     = x[index];
                            testUpdate = wValue - (currentRate * gValue);
                            realUpdate = testUpdate;
                            UpdateX(x, index, realUpdate);
                        }
                    }
                    else
                    {
                        // x[index] = testUpdate;
                        if (prior == SGDWithAdaGradAndFOBOS.Prior.Lasso || prior == SGDWithAdaGradAndFOBOS.Prior.Ridge)
                        {
                            double            testUpdateSquaredSum = 0;
                            ICollection <int> paramRange           = null;
                            if (f is IHasRegularizerParamRange)
                            {
                                paramRange = ((IHasRegularizerParamRange)f).GetRegularizerParamRange(x);
                            }
                            else
                            {
                                paramRange = new HashSet <int>();
                                for (int i = 0; i < x.Length; i++)
                                {
                                    paramRange.Add(i);
                                }
                            }
                            foreach (int index in paramRange)
                            {
                                gValue      = gradients[index];
                                currentRate = ComputeLearningRate(index, gValue);
                                // arrive at x(t+1/2)
                                wValue     = x[index];
                                testUpdate = wValue - (currentRate * gValue);
                                double currentLambda = currentRate * lambda;
                                // apply FOBOS
                                if (prior == SGDWithAdaGradAndFOBOS.Prior.Lasso)
                                {
                                    realUpdate = Math.Signum(testUpdate) * Pospart(Math.Abs(testUpdate) - currentLambda);
                                    UpdateX(x, index, realUpdate);
                                    if (realUpdate != 0)
                                    {
                                        numOfNonZero++;
                                    }
                                }
                                else
                                {
                                    if (prior == SGDWithAdaGradAndFOBOS.Prior.Ridge)
                                    {
                                        testUpdateSquaredSum   += testUpdate * testUpdate;
                                        testUpdateCache[index]  = testUpdate;
                                        currentRateCache[index] = currentRate;
                                    }
                                }
                            }
                            // } else if (prior == Prior.GAUSSIAN) { // GAUSSIAN prior is assumed to be handled in the objective directly
                            //   realUpdate = testUpdate / (1 + currentLambda);
                            //   updateX(x, index, realUpdate);
                            //   // update objVal
                            //   objVal += currentLambda * wValue * wValue;
                            if (prior == SGDWithAdaGradAndFOBOS.Prior.Ridge)
                            {
                                double testUpdateNorm = Math.Sqrt(testUpdateSquaredSum);
                                for (int index_1 = 0; index_1 < testUpdateCache.Length; index_1++)
                                {
                                    realUpdate = testUpdateCache[index_1] * Pospart(1 - currentRateCache[index_1] * lambda / testUpdateNorm);
                                    UpdateX(x, index_1, realUpdate);
                                    if (realUpdate != 0)
                                    {
                                        numOfNonZero++;
                                    }
                                }
                            }
                        }
                        else
                        {
                            // log.info("featureGroup.length: " + featureGrouping.length);
                            foreach (int[] gFeatureIndices in featureGrouping)
                            {
                                // if (gIndex % 100 == 0) log.info(gIndex+" ");
                                double testUpdateSquaredSum = 0;
                                double testUpdateAbsSum     = 0;
                                double M  = gFeatureIndices.Length;
                                double dm = Math.Log(M);
                                foreach (int index in gFeatureIndices)
                                {
                                    gValue      = gradients[index];
                                    currentRate = ComputeLearningRate(index, gValue);
                                    // arrive at x(t+1/2)
                                    wValue                  = x[index];
                                    testUpdate              = wValue - (currentRate * gValue);
                                    testUpdateSquaredSum   += testUpdate * testUpdate;
                                    testUpdateAbsSum       += Math.Abs(testUpdate);
                                    testUpdateCache[index]  = testUpdate;
                                    currentRateCache[index] = currentRate;
                                }
                                if (prior == SGDWithAdaGradAndFOBOS.Prior.gLASSO)
                                {
                                    double testUpdateNorm  = Math.Sqrt(testUpdateSquaredSum);
                                    bool   groupHasNonZero = false;
                                    foreach (int index_1 in gFeatureIndices)
                                    {
                                        realUpdate = testUpdateCache[index_1] * Pospart(1 - currentRateCache[index_1] * lambda * dm / testUpdateNorm);
                                        UpdateX(x, index_1, realUpdate);
                                        if (realUpdate != 0)
                                        {
                                            numOfNonZero++;
                                            groupHasNonZero = true;
                                        }
                                    }
                                    if (groupHasNonZero)
                                    {
                                        numOfNonZeroGroup++;
                                    }
                                }
                                else
                                {
                                    if (prior == SGDWithAdaGradAndFOBOS.Prior.aeLASSO)
                                    {
                                        int  nonZeroCount    = 0;
                                        bool groupHasNonZero = false;
                                        foreach (int index_1 in gFeatureIndices)
                                        {
                                            double tau = currentRateCache[index_1] * lambda / (1 + currentRateCache[index_1] * lambda * M) * testUpdateAbsSum;
                                            realUpdate = Math.Signum(testUpdateCache[index_1]) * Pospart(Math.Abs(testUpdateCache[index_1]) - tau);
                                            UpdateX(x, index_1, realUpdate);
                                            if (realUpdate != 0)
                                            {
                                                numOfNonZero++;
                                                nonZeroCount++;
                                                groupHasNonZero = true;
                                            }
                                        }
                                        if (groupHasNonZero)
                                        {
                                            numOfNonZeroGroup++;
                                        }
                                    }
                                    else
                                    {
                                        // gSizeStr += nonZeroCount+",";
                                        if (prior == SGDWithAdaGradAndFOBOS.Prior.sgLASSO)
                                        {
                                            double bSquaredSum = 0;
                                            double b           = 0;
                                            foreach (int index_1 in gFeatureIndices)
                                            {
                                                b = Math.Signum(testUpdateCache[index_1]) * Pospart(Math.Abs(testUpdateCache[index_1]) - currentRateCache[index_1] * alpha * lambda);
                                                bCache[index_1] = b;
                                                bSquaredSum    += b * b;
                                            }
                                            double bNorm           = Math.Sqrt(bSquaredSum);
                                            int    nonZeroCount    = 0;
                                            bool   groupHasNonZero = false;
                                            foreach (int index_2 in gFeatureIndices)
                                            {
                                                realUpdate = bCache[index_2] * Pospart(1 - currentRateCache[index_2] * (1.0 - alpha) * lambda * dm / bNorm);
                                                UpdateX(x, index_2, realUpdate);
                                                if (realUpdate != 0)
                                                {
                                                    numOfNonZero++;
                                                    nonZeroCount++;
                                                    groupHasNonZero = true;
                                                }
                                            }
                                            if (groupHasNonZero)
                                            {
                                                numOfNonZeroGroup++;
                                            }
                                        }
                                    }
                                }
                            }
                        }
                    }
                    // gSizeStr += nonZeroCount+",";
                    // log.info();
                    // update gradient and lastX
                    for (int index_3 = 0; index_3 < x.Length; index_3++)
                    {
                        prevGrad[index_3] = gradients[index_3];
                    }
                }
                // if (hessSampleSize > 0) {
                //   approxHessian();
                // }
                try
                {
                    ArrayMath.AssertFinite(x, "x");
                }
                catch (ArrayMath.InvalidElementException e)
                {
                    log.Info(e.ToString());
                    for (int i = 0; i < x.Length; i++)
                    {
                        x[i] = double.NaN;
                    }
                    break;
                }
                Sayln(numBatches.ToString() + ", n0-fCount:" + numOfNonZero + ((prior != SGDWithAdaGradAndFOBOS.Prior.Lasso && prior != SGDWithAdaGradAndFOBOS.Prior.Ridge) ? ", n0-gCount:" + numOfNonZeroGroup : string.Empty) + ((evalScore != double.NegativeInfinity
                                                                                                                                                                                                                                     ) ? ", evalScore:" + evalScore : string.Empty) + (objVal != double.NegativeInfinity ? ", obj_val:" + nf.Format(objVal) + ", obj_delta:" + objDelta : string.Empty));
                if (values != null && useAvgImprovement && iters > 5)
                {
                    int    size               = values.Count;
                    double previousVal        = (size >= 10 ? values[size - 10] : values[0]);
                    double averageImprovement = (previousVal - objVal) / (size >= 10 ? 10 : size);
                    if (System.Math.Abs(averageImprovement / objVal) < Tol)
                    {
                        Sayln("Online Optmization completed, due to average improvement: | newest_val - previous_val | / |newestVal| < TOL ");
                        break;
                    }
                }
                if (iters >= maxIterations)
                {
                    Sayln("Online Optimization complete.  Stopped after max iterations");
                    break;
                }
                if (total.Report() >= maxTime)
                {
                    Sayln("Online Optimization complete.  Stopped after max time");
                    break;
                }
            }
            if (evaluateIters > 0)
            {
                // do final evaluation
                double evalScore = (useEvalImprovement ? DoEvaluation(xBest) : DoEvaluation(x));
                Sayln("final evalScore is: " + evalScore);
            }
            Sayln("Completed in: " + Timing.ToSecondsString(total.Report()) + " s");
            return(useEvalImprovement ? xBest : x);
        }