public void GetInitialGrads(double initialCost, double[] prevActivations, double bias, LSTMWeigths cellhiddenStates, out double cellStateGrad, out double outputWeigthGrad) { cellhiddenStates.hiddenState += bias; double linearFunction; for (int i = 0; i < prevActivations.Length; i++) { cellhiddenStates.hiddenState += prevActivations[i] * Weigths[i]; } linearFunction = cellhiddenStates.hiddenState; double hiddenStateSigmoid = SigmoidActivation(cellhiddenStates.hiddenState); //forget gate cellhiddenStates.cellState *= hiddenStateSigmoid * recurrent.forgetWeigth; //store gate cellhiddenStates.cellState += SigmoidActivation(cellhiddenStates.hiddenState) * recurrent.storeWeigth * TanhActivation(cellhiddenStates.hiddenState); //output gate cellhiddenStates.hiddenState = hiddenStateSigmoid * recurrent.outputWeigth * TanhActivation(cellhiddenStates.cellState); //derivative of output weigth double outputWeigthDerivative = Derivatives.MultiplicationDerivative(cellhiddenStates.hiddenState, Derivatives.SigmoidDerivative(linearFunction), recurrent.outputWeigth, 0); //* output gate derivative initialCost *= Derivatives.MultiplicationDerivative(hiddenStateSigmoid * recurrent.outputWeigth, outputWeigthDerivative, TanhActivation(cellhiddenStates.cellState), Derivatives.TanhDerivative(cellhiddenStates.cellState)); outputWeigthGrad = initialCost * outputWeigthDerivative; initialCost *= Derivatives.TanhDerivative(cellhiddenStates.cellState); cellStateGrad = initialCost; }
//TODO: Change out arrays to lists public void GetGrads(double[] prevActivations, double bias, LSTMWeigths cellhiddenStates, double cellStateGrad, double prevCellStateGrad, double prevHiddenGrad, double outputWeigthGrad , out List <double> prevActivationsGrads, out List <double> weigthsGrads, out double biasGrad, out LSTMWeigths LSTMWeigthsGrads) { LSTMWeigthsGrads = new LSTMWeigths(0, 0, outputWeigthGrad); LSTMWeigths initialState = cellhiddenStates; cellhiddenStates.hiddenState += bias; for (int i = 0; i < prevActivations.Length; i++) { cellhiddenStates.hiddenState += prevActivations[i] * Weigths[i]; } double linearFunc = cellhiddenStates.hiddenState; double hiddenStateSigmoid = SigmoidActivation(cellhiddenStates.hiddenState); //forget gate double forgetMultiplication; cellhiddenStates.cellState *= forgetMultiplication = hiddenStateSigmoid * recurrent.forgetWeigth; //store gate cellhiddenStates.cellState += SigmoidActivation(cellhiddenStates.hiddenState) * recurrent.storeWeigth * TanhActivation(cellhiddenStates.hiddenState); //output gate cellhiddenStates.hiddenState = hiddenStateSigmoid * recurrent.outputWeigth * TanhActivation(cellhiddenStates.cellState); double currentGrad = cellStateGrad; double storeWeigthMultiplicationDerivative = hiddenStateSigmoid * Derivatives.SigmoidDerivative(linearFunc); double storeGateMultiplicationDerivative = Derivatives.MultiplicationDerivative(hiddenStateSigmoid * recurrent.storeWeigth, storeWeigthMultiplicationDerivative , TanhActivation(linearFunc), Derivatives.TanhDerivative(linearFunc)); double sigmoidDerivative = Derivatives.SigmoidDerivative(linearFunc); double forgetWeigthMultiplicationDerivative = sigmoidDerivative * hiddenStateSigmoid; double forgetGateMultiplicationDerivative = Derivatives.MultiplicationDerivative(initialState.cellState, prevCellStateGrad, forgetMultiplication, forgetWeigthMultiplicationDerivative); //To store addition currentGrad *= forgetGateMultiplicationDerivative + storeGateMultiplicationDerivative; double gradToStoreWeigth = currentGrad; gradToStoreWeigth *= storeGateMultiplicationDerivative; gradToStoreWeigth *= storeWeigthMultiplicationDerivative; LSTMWeigthsGrads.storeWeigth = gradToStoreWeigth; //To Forget multiplication currentGrad *= forgetGateMultiplicationDerivative; LSTMWeigthsGrads.forgetWeigth = currentGrad * forgetWeigthMultiplicationDerivative; biasGrad = prevHiddenGrad * bias; prevActivationsGrads = weigthsGrads = new List <double>(); for (int i = 0; i < prevActivations.Length; i++) { prevActivationsGrads.Add(prevHiddenGrad * Weigths[i]); weigthsGrads.Add(prevHiddenGrad * prevActivations[i]); } }