private void InitializeLSTMCell(LSTMCell c, LSTMCellWeight cw, LSTMCellWeightDeri deri) { c.cellState = 0; //partial derivatives deri.dSWPeepholeIn = 0; deri.dSWPeepholeForget = 0; deri.dSWCellIn = 0; deri.dSWCellForget = 0; deri.dSWCellState = 0; }
private void InitializeCellWeights(BinaryReader br) { CellWeights = new LSTMCellWeight[LayerSize]; CellWeightsDeri = new LSTMCellWeightDeri[LayerSize]; if (br != null) { //Load weight from input file for (var i = 0; i < LayerSize; i++) { CellWeights[i] = new LSTMCellWeight(); CellWeights[i].wPeepholeIn = br.ReadDouble(); CellWeights[i].wPeepholeForget = br.ReadDouble(); CellWeights[i].wPeepholeOut = br.ReadDouble(); CellWeights[i].wCellIn = br.ReadDouble(); CellWeights[i].wCellForget = br.ReadDouble(); CellWeights[i].wCellState = br.ReadDouble(); CellWeights[i].wCellOut = br.ReadDouble(); CellWeightsDeri[i] = new LSTMCellWeightDeri(); } } else { //Initialize weight by random number for (var i = 0; i < LayerSize; i++) { CellWeights[i] = new LSTMCellWeight(); //internal weights, also important CellWeights[i].wPeepholeIn = RNNHelper.RandInitWeight(); CellWeights[i].wPeepholeForget = RNNHelper.RandInitWeight(); CellWeights[i].wPeepholeOut = RNNHelper.RandInitWeight(); CellWeights[i].wCellIn = RNNHelper.RandInitWeight(); CellWeights[i].wCellForget = RNNHelper.RandInitWeight(); CellWeights[i].wCellState = RNNHelper.RandInitWeight(); CellWeights[i].wCellOut = RNNHelper.RandInitWeight(); CellWeightsDeri[i] = new LSTMCellWeightDeri(); } } }
public override void UpdateWeights() { wDenseInputGate.UpdateWeights(); wDenseForgetGate.UpdateWeights(); wDenseCellGate.UpdateWeights(); wDenseOutputGate.UpdateWeights(); for (var i = 0; i < LayerSize; i++) { LSTMCellWeight cellWeights_i = CellWeights[i]; //Normalize cell peephole weights delta var vecPeepholeDelta = peepholeDelta[i]; peepholeDelta[i] = Vector3.Zero; vecPeepholeDelta = vecPeepholeDelta / RNNHelper.MiniBatchSize; vecPeepholeDelta = Vector3.Clamp(vecPeepholeDelta, vecMinGrad3, vecMaxGrad3); //Normalize cell weights delta var vecCellDelta = cellDelta[i]; cellDelta[i] = Vector4.Zero; vecCellDelta = vecCellDelta / RNNHelper.MiniBatchSize; vecCellDelta = Vector4.Clamp(vecCellDelta, vecMinGrad, vecMaxGrad); //Computing actual learning rate var vecPeepholeLearningRate = ComputeLearningRate(vecPeepholeDelta, ref peepholeLearningRate[i]); vecPeepholeDelta = vecPeepholeLearningRate * vecPeepholeDelta; var vecCellLearningRate = ComputeLearningRate(vecCellDelta, ref cellLearningRate[i]); vecCellDelta = vecCellLearningRate * vecCellDelta; cellWeights_i.wPeepholeIn += vecPeepholeDelta.X; cellWeights_i.wPeepholeForget += vecPeepholeDelta.Y; cellWeights_i.wPeepholeOut += vecPeepholeDelta.Z; cellWeights_i.wCellIn += vecCellDelta.X; cellWeights_i.wCellForget += vecCellDelta.Y; cellWeights_i.wCellState += vecCellDelta.Z; cellWeights_i.wCellOut += vecCellDelta.W; //Update weights for sparse features var wlr_i = sparseFeatureLearningRate[i]; var sparseFeatureWeightsDelta_i = sparseFeatureWeightsDelta[i]; var sparseFeatureWeights_i = sparseFeatureWeights[i]; for (var j = 0; j < SparseFeatureSize; j++) { if (sparseFeatureWeightsDelta_i[j] != Vector4.Zero) { Vector4 vecDelta = sparseFeatureWeightsDelta_i[j]; sparseFeatureWeightsDelta_i[j] = Vector4.Zero; vecDelta = vecDelta / RNNHelper.MiniBatchSize; vecDelta = Vector4.Clamp(vecDelta, vecMinGrad, vecMaxGrad); var vecLearningRate = ComputeLearningRate(vecDelta, ref wlr_i[j]); sparseFeatureWeights_i[j] += vecDelta * vecLearningRate; } } } }