public static void SingleOutputBackward(NdArray <Real> y, NdArray <Real> x, bool train, NdArray <Real> gamma, NdArray <Real> beta, NdArray <Real> avgMean, NdArray <Real> avgVar, Real[] std, Real[] xhat, int channelSize) { beta.InitGrad(); gamma.InitGrad(); int dataSize = x.Length / channelSize; for (int i = 0; i < channelSize; i++) { for (int b = 0; b < x.BatchCount; b++) { for (int location = 0; location < dataSize; location++) { int index = b * y.Length + i * dataSize + location; beta.Grad[i] += y.Grad[index]; gamma.Grad[i] += y.Grad[index] * xhat[index]; } } } if (train) { // 学習あり for (int i = 0; i < channelSize; i++) { Real gs = gamma.Data[i] / std[i]; for (int b = 0; b < y.BatchCount; b++) { for (int location = 0; location < dataSize; location++) { int index = b * y.Length + i * dataSize + location; Real val = (xhat[index] * gamma.Grad[i] + beta.Grad[i]) / (y.BatchCount * dataSize); x.Grad[index] += gs * (y.Grad[index] - val); } } } } else { // 学習なし for (int i = 0; i < channelSize; i++) { Real gs = gamma.Data[i] / std[i]; avgMean.Grad[i] = -gs * beta.Grad[i]; avgVar.Grad[i] = -0.5f * gamma.Data[i] / avgVar.Data[i] * gamma.Grad[i]; for (int b = 0; b < y.BatchCount; b++) { for (int location = 0; location < dataSize; location++) { x.Grad[b * y.Length + i * dataSize + location] += gs * y.Grad[b * y.Length + i * dataSize + location]; } } } } }
public static NdArray <Real> SingleInputForward(NdArray <Real> x, IFunction <Real> upward, IFunction <Real> lateral, List <Real[][]> paramList, List <NdArray <Real> > hPrevParams, ref NdArray <Real> hParam, ref Real[] lcPrev, int outputCount, IFunction <Real> lstm) { int outputDataSize = x.BatchCount * outputCount; NdArray <Real> lstmIn = upward.Forward(x)[0]; if (hParam == null) { lcPrev = new Real[outputDataSize]; } else { NdArray <Real> hPrevParam = hParam.Clone(); if (hPrevParam.Grad != null) { hPrevParam.InitGrad(); } lstmIn += lateral.Forward(hPrevParam)[0]; hPrevParams.Add(hPrevParam); } //0:cPrev 1:a 2:i 3:f 4:o 5:c Real[][] param = { lcPrev, new Real[outputDataSize], new Real[outputDataSize], new Real[outputDataSize], new Real[outputDataSize], new Real[outputDataSize] }; Real[] lhParam = new Real[outputDataSize]; int index = 0; for (int outIndex = 0; outIndex < lhParam.Length; outIndex++) { param[1][outIndex] = Math.Tanh(lstmIn.Data[index++]); param[2][outIndex] = Sigmoid(lstmIn.Data[index++]); param[3][outIndex] = Sigmoid(lstmIn.Data[index++]); param[4][outIndex] = Sigmoid(lstmIn.Data[index++]); param[5][outIndex] = param[1][outIndex] * param[2][outIndex] + param[3][outIndex] * param[0][outIndex]; lhParam[outIndex] = param[4][outIndex] * Math.Tanh(param[5][outIndex]); } paramList.Add(param); //Backwardで消えないように別で保管 lcPrev = param[5]; hParam = new NdArray <Real>(lhParam, new[] { outputCount }, x.BatchCount, lstm); return(hParam); }
public override NdArray SingleInputForward(NdArray x) { NdArray lstmIn = this.upward.Forward(x)[0]; //a int outputDataSize = x.BatchCount * this.OutputCount; if (this.hParam == null) { //値がなければ初期化 this.aParam = new List <Real[]>(); this.iParam = new List <Real[]>(); this.fParam = new List <Real[]>(); this.oParam = new List <Real[]>(); this.cNextParam = new List <Real[]>(); this.cPrevParam = new List <Real[]>(); this.hPrevParams = new List <NdArray>(); this.aUsedParam = new List <Real[]>(); this.iUsedParam = new List <Real[]>(); this.fUsedParam = new List <Real[]>(); this.oUsedParam = new List <Real[]>(); this.cUsedNextParam = new List <Real[]>(); this.cUsedPrevParam = new List <Real[]>(); this.hUsedPrevParams = new List <NdArray>(); gxPrevGrads = new List <Real[]>(); cPrev = new Real[outputDataSize]; } else { NdArray hPrevParam = this.hParam.Clone(); if (hPrevParam.Grad != null) { hPrevParam.InitGrad(); } lstmIn += this.lateral.Forward(hPrevParam)[0]; hPrevParams.Add(hPrevParam); } Real[] la = new Real[outputDataSize]; Real[] li = new Real[outputDataSize]; Real[] lf = new Real[outputDataSize]; Real[] lo = new Real[outputDataSize]; Real[] cNext = new Real[outputDataSize]; Real[] lhParam = new Real[outputDataSize]; for (int b = 0; b < x.BatchCount; b++) { int index = b * lstmIn.Length; for (int i = 0; i < this.OutputCount; i++) { int outIndex = b * this.OutputCount + i; la[outIndex] = Math.Tanh(lstmIn.Data[index++]); li[outIndex] = Sigmoid(lstmIn.Data[index++]); lf[outIndex] = Sigmoid(lstmIn.Data[index++]); lo[outIndex] = Sigmoid(lstmIn.Data[index++]); cNext[outIndex] = la[outIndex] * li[outIndex] + lf[outIndex] * cPrev[outIndex]; lhParam[outIndex] = lo[outIndex] * Math.Tanh(cNext[outIndex]); } } this.cPrevParam.Add(cPrev); this.cNextParam.Add(cNext); this.aParam.Add(la); this.iParam.Add(li); this.fParam.Add(lf); this.oParam.Add(lo); //Backwardで消えないように別で保管 cPrev = cNext; this.hParam = new NdArray(lhParam, new[] { OutputCount }, x.BatchCount, this); return(this.hParam); }
public static NdArray Concatenate(NdArray a, NdArray b, int axis) { int[] shapeList = a.Shape.ToArray(); shapeList[axis] += b.Shape[axis]; #if DEBUG for (int i = 0; i < a.Shape.Length; i++) { if (i != axis && a.Shape[i] != b.Shape[i]) { throw new Exception("配列の大きさがマッチしていません"); } } if (a.BatchCount != b.BatchCount) { throw new Exception("バッチの大きさがマッチしていません"); } if ((a.Grad != null) != (b.Grad != null)) { throw new Exception("Grad値の有無が揃っていません"); } #endif NdArray result = new NdArray(shapeList, a.BatchCount); if (a.Grad != null || b.Grad != null) { result.InitGrad(); } for (int batchCount = 0; batchCount < a.BatchCount; batchCount++) { int aInputBatchoffset = batchCount * a.Length; int bInputBatchoffset = batchCount * b.Length; for (int i = 0; i < a.Length; i++) { int resultindex = result.GetLocalIndex(batchCount, a.GetDimensionsIndex(i)); result.Data[resultindex] = a.Data[i + aInputBatchoffset]; if (a.Grad != null) { result.Grad[resultindex] = a.Grad[i + aInputBatchoffset]; } } for (int i = 0; i < b.Length; i++) { int[] tmpIndex = b.GetDimensionsIndex(i); tmpIndex[axis] += a.Shape[axis]; int resultIndex = result.GetLocalIndex(batchCount, tmpIndex); result.Data[resultIndex] = b.Data[i + bInputBatchoffset]; if (b.Grad != null) { result.Grad[resultIndex] = b.Grad[i + bInputBatchoffset]; } } } return(result); }