public static void SingleOutputBackward(NdArray <Real> y, IFunction <Real> upward, IFunction <Real> lateral, List <Real[][]> paramLists, List <NdArray <Real> > hPrevParams, List <Real[][]> usedParamLists, List <NdArray <Real> > hUsedPrevParams, List <Real[]> gxPrevGrads, int outputCount, ActionOptional <Real> backward) { Real[] gxPrevGrad = new Real[y.BatchCount * outputCount * 4]; Real[] gcPrev = new Real[y.BatchCount * outputCount]; //0:cPrev 1:a 2:i 3:f 4:o 5:c Real[][] param = paramLists[paramLists.Count - 1]; paramLists.RemoveAt(paramLists.Count - 1); usedParamLists.Add(param); int index = 0; for (int prevOutputIndex = 0; prevOutputIndex < gcPrev.Length; prevOutputIndex++) { Real co = Math.Tanh(param[5][prevOutputIndex]); gcPrev[prevOutputIndex] += y.Grad[prevOutputIndex] * param[4][prevOutputIndex] * GradTanh(co); gxPrevGrad[index++] = gcPrev[prevOutputIndex] * param[2][prevOutputIndex] * GradTanh(param[1][prevOutputIndex]); gxPrevGrad[index++] = gcPrev[prevOutputIndex] * param[1][prevOutputIndex] * GradSigmoid(param[2][prevOutputIndex]); gxPrevGrad[index++] = gcPrev[prevOutputIndex] * param[0][prevOutputIndex] * GradSigmoid(param[3][prevOutputIndex]); gxPrevGrad[index++] = y.Grad[prevOutputIndex] * co *GradSigmoid(param[4][prevOutputIndex]); gcPrev[prevOutputIndex] *= param[3][prevOutputIndex]; } gxPrevGrads.Add(gxPrevGrad); if (hPrevParams.Count > 0) { //linearのBackwardはgxPrev.Gradしか使わないのでgxPrev.Dataは空 NdArray <Real> gxPrev = new NdArray <Real>(new[] { outputCount * 4 }, y.BatchCount); gxPrev.Grad = gxPrevGrad; lateral.Backward(gxPrev); NdArray <Real> hPrevParam = hPrevParams[hPrevParams.Count - 1]; hPrevParams.RemoveAt(hPrevParams.Count - 1); hUsedPrevParams.Add(hPrevParam); //hのBakckward backward(hPrevParam); //使い切ったら戻す if (hPrevParams.Count == 0) { hPrevParams.AddRange(hUsedPrevParams); hUsedPrevParams.Clear(); } } //linearのBackwardはgy.Gradしか使わないのでgy.Dataは空 NdArray <Real> gy = new NdArray <Real>(new[] { outputCount * 4 }, y.BatchCount); gy.Grad = gxPrevGrads[0]; gxPrevGrads.RemoveAt(0); upward.Backward(gy); //使い切ったら戻す if (paramLists.Count == 0) { paramLists.AddRange(usedParamLists); usedParamLists.Clear(); } }