示例#1
0
        public double[,,,] forward(double[,,,] input)
        {
            //获取样本的总数
            int sampleCount = input.GetLength(0);
            //获取单个样本的深度
            int sampleSingleDepth = input.GetLength(1);
            //获取特征图的行数(高)
            int inputRow = input.GetLength(2);
            //获取特征图的列数(宽)
            int inputColumn = input.GetLength(3);

            //创建最大值的
            SingleMaxIndex = new int[sampleCount][][];
            int row    = (inputRow - PadRow) / Stride + 1;
            int column = (inputColumn - PadColumn) / Stride + 1;
            var result = new double[sampleCount, sampleSingleDepth, row, column];

            for (int sampleIndex = 0; sampleIndex < sampleCount; sampleIndex++)
            {
                //初始化特征图最大索引对象
                SingleMaxIndex[sampleIndex] = new int[sampleSingleDepth][];
                //记录特征图的行和列信息
                InputShape = new MatrixShape(inputRow, inputColumn);
                for (int depth = 0; depth < sampleSingleDepth; depth++)
                {
                    LMatrix pad = im2col(input.GetNextDimVal(sampleIndex, depth, inputRow, inputColumn), row, column, PadRow, PadColumn, Stride);
                    SingleMaxIndex[sampleIndex][depth] = pad.MaxIndex();
                    //LMatrix data = pad.Matrix.Select(m => m.Max()).ToArray();
                    LMatrix data = pad.Max(1);
                    result.SetDimVal(data.ReShape(row, column), sampleIndex, depth, row, column);
                }
            }
            return(result);
        }
示例#2
0
文件: LBN.cs 项目: wenzongmin/testcnn
        public override LMatrix forward(LMatrix x, bool train_flg)
        {
            Row    = x.Row;
            Column = x.Column;
            LMatrix res = _forward(x, train_flg);

            return(res.ReShape(x.Row, x.Column));
        }
示例#3
0
        //protected LMatrix _forward(LMatrix x, WehtInfo info, int kerDepIndex)
        //{
        //    int kernelRow = info.W.Length;
        //    int kernelColumn = info.W[0].Length;
        //    int row = (x.Row + 2 * this.Padding - kernelRow) / this.Stride + 1;
        //    int column = (x.Column + 2 * this.Padding - kernelColumn) / this.Stride + 1;
        //    if (this._col[kerDepIndex] == null)
        //        this._col[kerDepIndex] = im2col(x.Matrix, row, column, kernelRow, kernelColumn, Stride);
        //    LMatrix tW = info.W.Flatten();
        //    LMatrix _out = _col[kerDepIndex] * tW.T;
        //    LMatrix res = _out.ReShape(row, column);
        //    return res;
        //}

        protected double[,] _forward(double[,] mat, WehtInfo info, int kerDepIndex, int resRow, int resCol)
        {
            if (this._col[kerDepIndex] == null)
            {
                this._col[kerDepIndex] = im2col(mat, resRow, resCol, kernelRow, kernelColumn, Stride);
            }
            LMatrix tW   = info.W.Flatten();
            LMatrix _out = _col[kerDepIndex] * tW.T;

            return(_out.ReShape(resRow, resCol));
        }
示例#4
0
        //protected LMatrix _backward(LMatrix dout, WehtInfo info, int kerDepIndex)
        //{
        //    int kernelRow = info.W.Length;
        //    int kernelColumn = info.W[0].Length;
        //    LMatrix dx = dout.Flatten;
        //    LMatrix _dw = dx * this._col[kerDepIndex];
        //    if (info.DW == null)
        //        info.DW = _dw.ReShape(kernelRow, kernelColumn).Matrix;
        //    else
        //        info.DW = info.DW.Plus(_dw.ReShape(kernelRow, kernelColumn).Matrix);
        //    LMatrix tW = info.W.Flatten();
        //    LMatrix _out = dx.T * tW;
        //    return col2im(_out.Matrix, Input[kerDepIndex].Length, Input[kerDepIndex][0].Length,
        //        kernelRow, kernelColumn, Stride);
        //}

        protected LMatrix _backward(LMatrix dout, WehtInfo info, int kerDepIndex)
        {
            int     kernelRow    = info.W.GetLength(0);
            int     kernelColumn = info.W.GetLength(1);
            LMatrix dx           = dout.Flatten();
            LMatrix _dw          = dx * this._col[kerDepIndex];

            if (info.DW == null)
            {
                info.DW = _dw.ReShape(kernelRow, kernelColumn);
            }
            else
            {
                info.DW.Plus(_dw.ReShape(kernelRow, kernelColumn), kernelRow, kernelColumn);
            }
            LMatrix tW   = info.W.Flatten();
            LMatrix _out = dx.T * tW;

            return(col2im(_out.Matrix, _input_dim1, _input_dim2,
                          kernelRow, kernelColumn, Stride));
        }
示例#5
0
        private void button3_Click(object sender, EventArgs e)
        {
            // testCustomNet.ResetSGD(this.textBox2.Text);
            LMatrix a = new double[3, 4] {
                { 1, 2, 3, 4 },
                { 5, 6, 7, 8 },
                { 9, 10, 11, 12 }
            };
            LMatrix z = new double[4, 1] {
                { 1 },
                { 5 },
                { 9 },
                { 15 },
            };
            LMatrix b = a.ReShape(4, 3);

            double[,] c = b.ReShape(3, 4);
        }