コード例 #1
0
ファイル: LConvolution.cs プロジェクト: wenzongmin/testcnn
        //protected LMatrix _forward(LMatrix x, WehtInfo info, int kerDepIndex)
        //{
        //    int kernelRow = info.W.Length;
        //    int kernelColumn = info.W[0].Length;
        //    int row = (x.Row + 2 * this.Padding - kernelRow) / this.Stride + 1;
        //    int column = (x.Column + 2 * this.Padding - kernelColumn) / this.Stride + 1;
        //    if (this._col[kerDepIndex] == null)
        //        this._col[kerDepIndex] = im2col(x.Matrix, row, column, kernelRow, kernelColumn, Stride);
        //    LMatrix tW = info.W.Flatten();
        //    LMatrix _out = _col[kerDepIndex] * tW.T;
        //    LMatrix res = _out.ReShape(row, column);
        //    return res;
        //}

        protected double[,] _forward(double[,] mat, WehtInfo info, int kerDepIndex, int resRow, int resCol)
        {
            if (this._col[kerDepIndex] == null)
            {
                this._col[kerDepIndex] = im2col(mat, resRow, resCol, kernelRow, kernelColumn, Stride);
            }
            LMatrix tW   = info.W.Flatten();
            LMatrix _out = _col[kerDepIndex] * tW.T;

            return(_out.ReShape(resRow, resCol));
        }
コード例 #2
0
ファイル: LConvolution.cs プロジェクト: wenzongmin/testcnn
        //protected LMatrix _backward(LMatrix dout, WehtInfo info, int kerDepIndex)
        //{
        //    int kernelRow = info.W.Length;
        //    int kernelColumn = info.W[0].Length;
        //    LMatrix dx = dout.Flatten;
        //    LMatrix _dw = dx * this._col[kerDepIndex];
        //    if (info.DW == null)
        //        info.DW = _dw.ReShape(kernelRow, kernelColumn).Matrix;
        //    else
        //        info.DW = info.DW.Plus(_dw.ReShape(kernelRow, kernelColumn).Matrix);
        //    LMatrix tW = info.W.Flatten();
        //    LMatrix _out = dx.T * tW;
        //    return col2im(_out.Matrix, Input[kerDepIndex].Length, Input[kerDepIndex][0].Length,
        //        kernelRow, kernelColumn, Stride);
        //}

        protected LMatrix _backward(LMatrix dout, WehtInfo info, int kerDepIndex)
        {
            int     kernelRow    = info.W.GetLength(0);
            int     kernelColumn = info.W.GetLength(1);
            LMatrix dx           = dout.Flatten();
            LMatrix _dw          = dx * this._col[kerDepIndex];

            if (info.DW == null)
            {
                info.DW = _dw.ReShape(kernelRow, kernelColumn);
            }
            else
            {
                info.DW.Plus(_dw.ReShape(kernelRow, kernelColumn), kernelRow, kernelColumn);
            }
            LMatrix tW   = info.W.Flatten();
            LMatrix _out = dx.T * tW;

            return(col2im(_out.Matrix, _input_dim1, _input_dim2,
                          kernelRow, kernelColumn, Stride));
        }