public Vector CalculateAndGetOutput(Vector input, LstmGatesForCell gatesForCell) { Input = input; InputConcatenated = Vector.Union(Input, OutputFromPreviousLayer); ForgetGateResultF = Sigmoid.Func(gatesForCell.ForgetLayer * InputConcatenated + gatesForCell.BiasForgetLayer); InputLayerGateResultI = Sigmoid.Func(gatesForCell.InputLayer * InputConcatenated + gatesForCell.BiasInputLayer); TanhLayerGateResultG = Tanh.Func(gatesForCell.TanhLayer * InputConcatenated + gatesForCell.BiasTanhLayer); OutputLayerGateResultO = Sigmoid.Func(gatesForCell.OutputLayer * InputConcatenated + gatesForCell.BiasOutputLayer); Forget = ForgetFromPreviousLayer * ForgetGateResultF + TanhLayerGateResultG * InputLayerGateResultI; Output = Tanh.Func(Forget) * OutputLayerGateResultO; return(Output); }
public (Vector diffOutput, Vector diffForget, Vector diffInput) Learn(Vector diffInputFromNextCell, Vector diffOutputFromNextLayer, Vector diffForgetFromNextLayer, LstmGatesForCell gatesForCell) { var diffOutput = diffInputFromNextCell + diffOutputFromNextLayer; var one = new Vector(Forget.Length, () => 1); var diffForget = diffOutput * OutputLayerGateResultO * (one - Tanh.Func(Forget) ^ 2) + diffForgetFromNextLayer; var(diffTanhGate, diffInputGate, diffForgetGate, diffOutputGate) = GetDiffForGates(diffForget, diffOutput); gatesForCell.CalculateDiff(diffInputGate, diffForgetGate, diffOutputGate, diffTanhGate, InputConcatenated); var diffInputConcataneted = GetDiffInputConcataneted(diffInputGate, diffForgetGate, diffOutputGate, diffTanhGate, gatesForCell); var diffOutputOnNext = diffInputConcataneted.Skip(Input.Length); var diffInputOnNext = diffInputConcataneted.Take(Input.Length); var diffForgetOnNext = diffForget * ForgetGateResultF; return(diffOutputOnNext, diffForgetOnNext, diffInputOnNext); }
private Vector GetDiffInputConcataneted(Vector diffInputGate, Vector diffForgetGate, Vector diffOutputGate, Vector diffTanhGate, LstmGatesForCell gatesForCell) { var result = gatesForCell.InputLayer.GetTransposed() * diffInputGate; result += gatesForCell.ForgetLayer.GetTransposed() * diffForgetGate; result += gatesForCell.OutputLayer.GetTransposed() * diffOutputGate; result += gatesForCell.TanhLayer.GetTransposed() * diffTanhGate; return(result); }