public double[,,,] forward(double[,,,] input) { //获取样本的总数 int sampleCount = input.GetLength(0); //获取单个样本的深度 int sampleSingleDepth = input.GetLength(1); //获取特征图的行数(高) int inputRow = input.GetLength(2); //获取特征图的列数(宽) int inputColumn = input.GetLength(3); //创建最大值的 SingleMaxIndex = new int[sampleCount][][]; int row = (inputRow - PadRow) / Stride + 1; int column = (inputColumn - PadColumn) / Stride + 1; var result = new double[sampleCount, sampleSingleDepth, row, column]; for (int sampleIndex = 0; sampleIndex < sampleCount; sampleIndex++) { //初始化特征图最大索引对象 SingleMaxIndex[sampleIndex] = new int[sampleSingleDepth][]; //记录特征图的行和列信息 InputShape = new MatrixShape(inputRow, inputColumn); for (int depth = 0; depth < sampleSingleDepth; depth++) { LMatrix pad = im2col(input.GetNextDimVal(sampleIndex, depth, inputRow, inputColumn), row, column, PadRow, PadColumn, Stride); SingleMaxIndex[sampleIndex][depth] = pad.MaxIndex(); //LMatrix data = pad.Matrix.Select(m => m.Max()).ToArray(); LMatrix data = pad.Max(1); result.SetDimVal(data.ReShape(row, column), sampleIndex, depth, row, column); } } return(result); }
public override LMatrix forward(LMatrix x, bool train_flg) { Row = x.Row; Column = x.Column; LMatrix res = _forward(x, train_flg); return(res.ReShape(x.Row, x.Column)); }
//protected LMatrix _forward(LMatrix x, WehtInfo info, int kerDepIndex) //{ // int kernelRow = info.W.Length; // int kernelColumn = info.W[0].Length; // int row = (x.Row + 2 * this.Padding - kernelRow) / this.Stride + 1; // int column = (x.Column + 2 * this.Padding - kernelColumn) / this.Stride + 1; // if (this._col[kerDepIndex] == null) // this._col[kerDepIndex] = im2col(x.Matrix, row, column, kernelRow, kernelColumn, Stride); // LMatrix tW = info.W.Flatten(); // LMatrix _out = _col[kerDepIndex] * tW.T; // LMatrix res = _out.ReShape(row, column); // return res; //} protected double[,] _forward(double[,] mat, WehtInfo info, int kerDepIndex, int resRow, int resCol) { if (this._col[kerDepIndex] == null) { this._col[kerDepIndex] = im2col(mat, resRow, resCol, kernelRow, kernelColumn, Stride); } LMatrix tW = info.W.Flatten(); LMatrix _out = _col[kerDepIndex] * tW.T; return(_out.ReShape(resRow, resCol)); }
//protected LMatrix _backward(LMatrix dout, WehtInfo info, int kerDepIndex) //{ // int kernelRow = info.W.Length; // int kernelColumn = info.W[0].Length; // LMatrix dx = dout.Flatten; // LMatrix _dw = dx * this._col[kerDepIndex]; // if (info.DW == null) // info.DW = _dw.ReShape(kernelRow, kernelColumn).Matrix; // else // info.DW = info.DW.Plus(_dw.ReShape(kernelRow, kernelColumn).Matrix); // LMatrix tW = info.W.Flatten(); // LMatrix _out = dx.T * tW; // return col2im(_out.Matrix, Input[kerDepIndex].Length, Input[kerDepIndex][0].Length, // kernelRow, kernelColumn, Stride); //} protected LMatrix _backward(LMatrix dout, WehtInfo info, int kerDepIndex) { int kernelRow = info.W.GetLength(0); int kernelColumn = info.W.GetLength(1); LMatrix dx = dout.Flatten(); LMatrix _dw = dx * this._col[kerDepIndex]; if (info.DW == null) { info.DW = _dw.ReShape(kernelRow, kernelColumn); } else { info.DW.Plus(_dw.ReShape(kernelRow, kernelColumn), kernelRow, kernelColumn); } LMatrix tW = info.W.Flatten(); LMatrix _out = dx.T * tW; return(col2im(_out.Matrix, _input_dim1, _input_dim2, kernelRow, kernelColumn, Stride)); }
private void button3_Click(object sender, EventArgs e) { // testCustomNet.ResetSGD(this.textBox2.Text); LMatrix a = new double[3, 4] { { 1, 2, 3, 4 }, { 5, 6, 7, 8 }, { 9, 10, 11, 12 } }; LMatrix z = new double[4, 1] { { 1 }, { 5 }, { 9 }, { 15 }, }; LMatrix b = a.ReShape(4, 3); double[,] c = b.ReShape(3, 4); }