Beispiel #1
0
        /// <summary>
        /// Finds approximate minimum of the function
        /// </summary>
        /// <param name="function">Function to minimize</param>
        /// <param name="initial">Initial point</param>
        /// <param name="result">Approximate minimum</param>
        public void Minimize(DifferentiableFunction function, ref VBuffer <Float> initial, ref VBuffer <Float> result)
        {
            Contracts.Check(FloatUtils.IsFinite(initial.Values, initial.Count), "The initial vector contains NaNs or infinite values.");
            LineFunc        lineFunc = new LineFunc(function, ref initial, UseCG);
            VBuffer <Float> prev     = default(VBuffer <Float>);

            initial.CopyTo(ref prev);

            for (int n = 0; _maxSteps == 0 || n < _maxSteps; ++n)
            {
                Float step         = LineSearch.Minimize(lineFunc.Eval, lineFunc.Value, lineFunc.Deriv);
                var   newPoint     = lineFunc.NewPoint;
                bool  terminateNow = n > 0 && TerminateTester.ShouldTerminate(ref newPoint, ref prev);
                if (terminateNow || Terminate(ref newPoint))
                {
                    break;
                }
                newPoint.CopyTo(ref prev);
                lineFunc.ChangeDir();
            }

            lineFunc.NewPoint.CopyTo(ref result);
        }
Beispiel #2
0
        /// <summary>
        /// Minimize the function represented by <paramref name="f"/>.
        /// </summary>
        /// <param name="f">Stochastic gradients of function to minimize</param>
        /// <param name="initial">Initial point</param>
        /// <param name="result">Approximate minimum of <paramref name="f"/></param>
        public void Minimize(DStochasticGradient f, ref VBuffer <Float> initial, ref VBuffer <Float> result)
        {
            Contracts.Check(FloatUtils.IsFinite(initial.Values, initial.Count), "The initial vector contains NaNs or infinite values.");
            int dim = initial.Length;

            VBuffer <Float> grad = VBufferUtils.CreateEmpty <Float>(dim);
            VBuffer <Float> step = VBufferUtils.CreateEmpty <Float>(dim);
            VBuffer <Float> x    = default(VBuffer <Float>);

            initial.CopyTo(ref x);
            VBuffer <Float> prev = default(VBuffer <Float>);
            VBuffer <Float> avg  = VBufferUtils.CreateEmpty <Float>(dim);

            for (int n = 0; _maxSteps == 0 || n < _maxSteps; ++n)
            {
                if (_momentum == 0)
                {
                    step = new VBuffer <Float>(step.Length, 0, step.Values, step.Indices);
                }
                else
                {
                    VectorUtils.ScaleBy(ref step, _momentum);
                }

                Float stepSize;
                switch (_rateSchedule)
                {
                case RateScheduleType.Constant:
                    stepSize = 1 / _t0;
                    break;

                case RateScheduleType.Sqrt:
                    stepSize = 1 / (_t0 + MathUtils.Sqrt(n));
                    break;

                case RateScheduleType.Linear:
                    stepSize = 1 / (_t0 + n);
                    break;

                default:
                    throw Contracts.Except();
                }

                Float scale = (1 - _momentum) / _batchSize;
                for (int i = 0; i < _batchSize; ++i)
                {
                    f(ref x, ref grad);
                    VectorUtils.AddMult(ref grad, scale, ref step);
                }

                if (_averaging)
                {
                    Utils.Swap(ref avg, ref prev);
                    VectorUtils.ScaleBy(prev, ref avg, (Float)n / (n + 1));
                    VectorUtils.AddMult(ref step, -stepSize, ref x);
                    VectorUtils.AddMult(ref x, (Float)1 / (n + 1), ref avg);

                    if ((n > 0 && TerminateTester.ShouldTerminate(ref avg, ref prev)) || _terminate(ref avg))
                    {
                        result = avg;
                        return;
                    }
                }
                else
                {
                    Utils.Swap(ref x, ref prev);
                    VectorUtils.AddMult(ref step, -stepSize, ref prev, ref x);
                    if ((n > 0 && TerminateTester.ShouldTerminate(ref x, ref prev)) || _terminate(ref x))
                    {
                        result = x;
                        return;
                    }
                }
            }

            result = _averaging ? avg : x;
        }