示例#1
0
        //TODO: Change out arrays to lists
        public void GetGrads(double[] prevActivations, double bias, LSTMWeigths cellhiddenStates, double cellStateGrad, double prevCellStateGrad, double prevHiddenGrad, double outputWeigthGrad
                             , out List <double> prevActivationsGrads, out List <double> weigthsGrads, out double biasGrad, out LSTMWeigths LSTMWeigthsGrads)
        {
            LSTMWeigthsGrads = new LSTMWeigths(0, 0, outputWeigthGrad);
            LSTMWeigths initialState = cellhiddenStates;

            cellhiddenStates.hiddenState += bias;
            for (int i = 0; i < prevActivations.Length; i++)
            {
                cellhiddenStates.hiddenState += prevActivations[i] * Weigths[i];
            }
            double linearFunc = cellhiddenStates.hiddenState;

            double hiddenStateSigmoid = SigmoidActivation(cellhiddenStates.hiddenState);
            //forget gate
            double forgetMultiplication;

            cellhiddenStates.cellState *= forgetMultiplication = hiddenStateSigmoid * recurrent.forgetWeigth;
            //store gate
            cellhiddenStates.cellState += SigmoidActivation(cellhiddenStates.hiddenState) * recurrent.storeWeigth * TanhActivation(cellhiddenStates.hiddenState);
            //output gate
            cellhiddenStates.hiddenState = hiddenStateSigmoid * recurrent.outputWeigth * TanhActivation(cellhiddenStates.cellState);

            double currentGrad = cellStateGrad;
            double storeWeigthMultiplicationDerivative = hiddenStateSigmoid * Derivatives.SigmoidDerivative(linearFunc);
            double storeGateMultiplicationDerivative   = Derivatives.MultiplicationDerivative(hiddenStateSigmoid * recurrent.storeWeigth, storeWeigthMultiplicationDerivative
                                                                                              , TanhActivation(linearFunc), Derivatives.TanhDerivative(linearFunc));


            double sigmoidDerivative = Derivatives.SigmoidDerivative(linearFunc);
            double forgetWeigthMultiplicationDerivative = sigmoidDerivative * hiddenStateSigmoid;
            double forgetGateMultiplicationDerivative   = Derivatives.MultiplicationDerivative(initialState.cellState, prevCellStateGrad, forgetMultiplication, forgetWeigthMultiplicationDerivative);

            //To store addition
            currentGrad *= forgetGateMultiplicationDerivative + storeGateMultiplicationDerivative;

            double gradToStoreWeigth = currentGrad;

            gradToStoreWeigth           *= storeGateMultiplicationDerivative;
            gradToStoreWeigth           *= storeWeigthMultiplicationDerivative;
            LSTMWeigthsGrads.storeWeigth = gradToStoreWeigth;

            //To Forget multiplication
            currentGrad *= forgetGateMultiplicationDerivative;
            LSTMWeigthsGrads.forgetWeigth = currentGrad * forgetWeigthMultiplicationDerivative;

            biasGrad = prevHiddenGrad * bias;

            prevActivationsGrads = weigthsGrads = new List <double>();
            for (int i = 0; i < prevActivations.Length; i++)
            {
                prevActivationsGrads.Add(prevHiddenGrad * Weigths[i]);
                weigthsGrads.Add(prevHiddenGrad * prevActivations[i]);
            }
        }
示例#2
0
        public void GetInitialGrads(double initialCost, double[] prevActivations, double bias, LSTMWeigths cellhiddenStates, out double cellStateGrad, out double outputWeigthGrad)
        {
            cellhiddenStates.hiddenState += bias;
            double linearFunction;

            for (int i = 0; i < prevActivations.Length; i++)
            {
                cellhiddenStates.hiddenState += prevActivations[i] * Weigths[i];
            }
            linearFunction = cellhiddenStates.hiddenState;

            double hiddenStateSigmoid = SigmoidActivation(cellhiddenStates.hiddenState);

            //forget gate
            cellhiddenStates.cellState *= hiddenStateSigmoid * recurrent.forgetWeigth;
            //store gate
            cellhiddenStates.cellState += SigmoidActivation(cellhiddenStates.hiddenState) * recurrent.storeWeigth * TanhActivation(cellhiddenStates.hiddenState);
            //output gate
            cellhiddenStates.hiddenState = hiddenStateSigmoid * recurrent.outputWeigth * TanhActivation(cellhiddenStates.cellState);

            //derivative of output weigth
            double outputWeigthDerivative = Derivatives.MultiplicationDerivative(cellhiddenStates.hiddenState, Derivatives.SigmoidDerivative(linearFunction), recurrent.outputWeigth, 0);

            //* output gate derivative
            initialCost     *= Derivatives.MultiplicationDerivative(hiddenStateSigmoid * recurrent.outputWeigth, outputWeigthDerivative, TanhActivation(cellhiddenStates.cellState), Derivatives.TanhDerivative(cellhiddenStates.cellState));
            outputWeigthGrad = initialCost * outputWeigthDerivative;
            initialCost     *= Derivatives.TanhDerivative(cellhiddenStates.cellState);
            cellStateGrad    = initialCost;
        }