Example #1
0
        /// <summary>
        /// Given the top Blob error gradients, compute the bottom Blob error gradients.
        /// </summary>
        /// <remarks>
        /// The Backward function calls the overriden backward function implemented by each specific Layer derivative,
        /// to compute the bottom (input) Blob diffs given the top (output) Blob diffs.
        /// </remarks>
        /// <param name="colTop">Specifies a collection of top (output) Blobs, whos diff fields store the gradient of the
        /// error with respect to themselves.</param>
        /// <param name="rgbPropagateDown">Specifies a List with equal length to the bottom, with each element
        /// indicating whether or not to propagate the error gradients down to the bottom Blob at the corresponding
        /// index.</param>
        /// <param name="colBottom">Specifies a collection of bottom (input) Blobs, whos diff fields are filled with
        /// the gradient of the error with respect to themselves after the Backward function is run.</param>
        public void Backward(BlobCollection <T> colTop, List <bool> rgbPropagateDown, BlobCollection <T> colBottom)
        {
            try
            {
                m_swTiming.Restart();
                backward(colTop, rgbPropagateDown, colBottom);
                m_swTiming.Stop();
                m_dfBackwardTiming        = m_swTiming.Elapsed.TotalMilliseconds;
                m_dfBackwardAverageTiming = getAveTiming(m_dfAverageInterval, m_dfBackwardTiming, m_dfBackwardAverageTiming);

                if (OnDebug != null)
                {
                    GetWorkBlobArgs <T> args = new GetWorkBlobArgs <T>();
                    OnDebug(this, args);

                    foreach (Blob <T> b in colBottom)
                    {
                        Tuple <double, double, double, double> mm_data = b.minmax_data(args.Blob, true);
                        Tuple <double, double, double, double> mm_diff = b.minmax_diff(args.Blob, true);

                        if (mm_data.Item3 > 0 || mm_data.Item4 > 0)
                        {
                            throw new Exception("NAN or INF detected in the BOTTOM '" + b.Name + "' Data for layer '" + m_param.name + "' on the backward pass.");
                        }

                        if (mm_diff.Item3 > 0 || mm_diff.Item4 > 0)
                        {
                            throw new Exception("NAN or INF detected in the BOTTOM '" + b.Name + "' Diff for layer '" + m_param.name + "' on the backward pass.");
                        }
                    }
                }
            }
            catch (Exception excpt)
            {
                if (m_param != null)
                {
                    throw new Exception("Layer: '" + m_param.name + "' (" + m_param.type.ToString() + ") Error: " + excpt.Message, excpt);
                }
                else
                {
                    throw excpt;
                }
            }
        }
Example #2
0
        /// <summary>
        /// Given the bottom (input) Blobs, this function computes the top (output) Blobs and the loss.
        /// </summary>
        /// <remarks>
        /// The Forward function calls the overriden forward function implemented by each specific Layer derivative
        /// to compute the top (output) Blob's values given the bottom (input) Blobs.  If the layer has any non-zero
        /// <code>loss_weights</code> this function then computes and returns the loss.
        /// </remarks>
        /// <param name="colBottom">Specifies the collection of bottom (input) Blobs, whos data fields
        /// store the input data for this layers' outputs.</param>
        /// <param name="colTop">Specifies the collection of preshaped top (output) Blobs, whos data fields
        /// will store this layers' outputs.</param>
        /// <returns>Returns the total loss from the Layer.</returns>
        public double Forward(BlobCollection <T> colBottom, BlobCollection <T> colTop)
        {
            try
            {
                m_swTiming.Restart();
                double dfLoss = 0;

                Reshape(colBottom, colTop);
                forward(colBottom, colTop);

                for (int i = 0; i < colTop.Count; i++)
                {
                    if (loss(i) == 0)
                    {
                        continue;
                    }

                    int    nCount     = colTop[i].count();
                    long   hData      = colTop[i].gpu_data;
                    long   hDiff      = colTop[i].gpu_diff;
                    double dfBlobLoss = m_cuda.dot_double(nCount, hData, hDiff);

                    dfLoss += dfBlobLoss;
                }

                m_swTiming.Stop();
                m_dfForwardTiming        = m_swTiming.Elapsed.TotalMilliseconds;
                m_dfForwardAverageTiming = getAveTiming(m_dfAverageInterval, m_dfForwardTiming, m_dfForwardAverageTiming);

                if (OnDebug != null)
                {
                    GetWorkBlobArgs <T> args = new GetWorkBlobArgs <T>();
                    OnDebug(this, args);

                    foreach (Blob <T> b in colTop)
                    {
                        Tuple <double, double, double, double> mm_data = b.minmax_data(args.Blob, true);
                        Tuple <double, double, double, double> mm_diff = b.minmax_diff(args.Blob, true);

                        if (mm_data.Item3 > 0 || mm_data.Item4 > 0)
                        {
                            throw new Exception("NAN or INF detected in the TOP '" + b.Name + "' Data for layer '" + m_param.name + "' on the forward pass.");
                        }

                        if (mm_diff.Item3 > 0 || mm_diff.Item4 > 0)
                        {
                            throw new Exception("NAN or INF detected in TOP '" + b.Name + "' Diff for layer '" + m_param.name + "' on the forward pass.");
                        }
                    }
                }

                return(dfLoss);
            }
            catch (Exception excpt)
            {
                if (m_param != null)
                {
                    throw new Exception("Layer: '" + m_param.name + "' (" + m_param.type.ToString() + ") Error: " + excpt.Message, excpt);
                }
                else
                {
                    throw excpt;
                }
            }
        }