示例#1
0
        private void InitializeLSTMCell(LSTMCell c, LSTMCellWeight cw, LSTMCellWeightDeri deri)
        {
            c.cellState = 0;

            //partial derivatives
            deri.dSWPeepholeIn     = 0;
            deri.dSWPeepholeForget = 0;

            deri.dSWCellIn     = 0;
            deri.dSWCellForget = 0;
            deri.dSWCellState  = 0;
        }
示例#2
0
        private void InitializeCellWeights(BinaryReader br)
        {
            CellWeights     = new LSTMCellWeight[LayerSize];
            CellWeightsDeri = new LSTMCellWeightDeri[LayerSize];

            if (br != null)
            {
                //Load weight from input file
                for (var i = 0; i < LayerSize; i++)
                {
                    CellWeights[i]                 = new LSTMCellWeight();
                    CellWeights[i].wPeepholeIn     = br.ReadDouble();
                    CellWeights[i].wPeepholeForget = br.ReadDouble();
                    CellWeights[i].wPeepholeOut    = br.ReadDouble();

                    CellWeights[i].wCellIn     = br.ReadDouble();
                    CellWeights[i].wCellForget = br.ReadDouble();
                    CellWeights[i].wCellState  = br.ReadDouble();
                    CellWeights[i].wCellOut    = br.ReadDouble();

                    CellWeightsDeri[i] = new LSTMCellWeightDeri();
                }
            }
            else
            {
                //Initialize weight by random number
                for (var i = 0; i < LayerSize; i++)
                {
                    CellWeights[i] = new LSTMCellWeight();
                    //internal weights, also important
                    CellWeights[i].wPeepholeIn     = RNNHelper.RandInitWeight();
                    CellWeights[i].wPeepholeForget = RNNHelper.RandInitWeight();
                    CellWeights[i].wPeepholeOut    = RNNHelper.RandInitWeight();

                    CellWeights[i].wCellIn     = RNNHelper.RandInitWeight();
                    CellWeights[i].wCellForget = RNNHelper.RandInitWeight();
                    CellWeights[i].wCellState  = RNNHelper.RandInitWeight();
                    CellWeights[i].wCellOut    = RNNHelper.RandInitWeight();

                    CellWeightsDeri[i] = new LSTMCellWeightDeri();
                }
            }
        }
示例#3
0
        public override void UpdateWeights()
        {
            wDenseInputGate.UpdateWeights();
            wDenseForgetGate.UpdateWeights();
            wDenseCellGate.UpdateWeights();
            wDenseOutputGate.UpdateWeights();

            for (var i = 0; i < LayerSize; i++)
            {
                LSTMCellWeight cellWeights_i = CellWeights[i];

                //Normalize cell peephole weights delta
                var vecPeepholeDelta = peepholeDelta[i];
                peepholeDelta[i] = Vector3.Zero;

                vecPeepholeDelta = vecPeepholeDelta / RNNHelper.MiniBatchSize;
                vecPeepholeDelta = Vector3.Clamp(vecPeepholeDelta, vecMinGrad3, vecMaxGrad3);

                //Normalize cell weights delta
                var vecCellDelta = cellDelta[i];
                cellDelta[i] = Vector4.Zero;

                vecCellDelta = vecCellDelta / RNNHelper.MiniBatchSize;
                vecCellDelta = Vector4.Clamp(vecCellDelta, vecMinGrad, vecMaxGrad);

                //Computing actual learning rate
                var vecPeepholeLearningRate = ComputeLearningRate(vecPeepholeDelta, ref peepholeLearningRate[i]);
                vecPeepholeDelta = vecPeepholeLearningRate * vecPeepholeDelta;

                var vecCellLearningRate = ComputeLearningRate(vecCellDelta, ref cellLearningRate[i]);
                vecCellDelta = vecCellLearningRate * vecCellDelta;

                cellWeights_i.wPeepholeIn     += vecPeepholeDelta.X;
                cellWeights_i.wPeepholeForget += vecPeepholeDelta.Y;
                cellWeights_i.wPeepholeOut    += vecPeepholeDelta.Z;
                cellWeights_i.wCellIn         += vecCellDelta.X;
                cellWeights_i.wCellForget     += vecCellDelta.Y;
                cellWeights_i.wCellState      += vecCellDelta.Z;
                cellWeights_i.wCellOut        += vecCellDelta.W;

                //Update weights for sparse features
                var wlr_i = sparseFeatureLearningRate[i];
                var sparseFeatureWeightsDelta_i = sparseFeatureWeightsDelta[i];
                var sparseFeatureWeights_i      = sparseFeatureWeights[i];
                for (var j = 0; j < SparseFeatureSize; j++)
                {
                    if (sparseFeatureWeightsDelta_i[j] != Vector4.Zero)
                    {
                        Vector4 vecDelta = sparseFeatureWeightsDelta_i[j];
                        sparseFeatureWeightsDelta_i[j] = Vector4.Zero;

                        vecDelta = vecDelta / RNNHelper.MiniBatchSize;

                        vecDelta = Vector4.Clamp(vecDelta, vecMinGrad, vecMaxGrad);

                        var vecLearningRate = ComputeLearningRate(vecDelta, ref wlr_i[j]);
                        sparseFeatureWeights_i[j] += vecDelta * vecLearningRate;
                    }
                }
            }
        }