private float DecodeTransformer(List <List <string> > outInputSeqs, IComputeGraph g, IWeightTensor encOutputs, IWeightTensor encMask, TransformerDecoder decoder, IWeightTensor tgtEmbedding, int batchSize, int deviceId, bool isTraining = true) { float cost = 0.0f; var originalInputLengths = ParallelCorpus.PadSentences(outInputSeqs); int tgtSeqLen = outInputSeqs[0].Count; IWeightTensor tgtDimMask = MaskUtils.BuildPadDimMask(g, tgtSeqLen, originalInputLengths, m_modelMetaData.HiddenDim, deviceId); using (IWeightTensor tgtSelfTriMask = MaskUtils.BuildPadSelfTriMask(g, tgtSeqLen, originalInputLengths, deviceId)) { List <IWeightTensor> inputs = new List <IWeightTensor>(); for (int i = 0; i < batchSize; i++) { for (int j = 0; j < tgtSeqLen; j++) { int ix_targets_k = m_modelMetaData.Vocab.GetTargetWordIndex(outInputSeqs[i][j], logUnk: true); inputs.Add(g.PeekRow(tgtEmbedding, ix_targets_k)); } } IWeightTensor tgtInputEmbeddings = inputs.Count > 1 ? g.ConcatRows(inputs) : inputs[0]; IWeightTensor decOutput = decoder.Decode(tgtInputEmbeddings, encOutputs, tgtSelfTriMask, encMask, tgtDimMask, batchSize, g); decOutput = g.Mul(decOutput, g.Transpose(tgtEmbedding)); using (IWeightTensor probs = g.Softmax(decOutput, runGradients: false, inPlace: true)) { if (isTraining) { var leftShiftInputSeqs = ParallelCorpus.LeftShiftSnts(outInputSeqs, ParallelCorpus.EOS); var originalOutputLengths = ParallelCorpus.PadSentences(leftShiftInputSeqs, tgtSeqLen); for (int i = 0; i < batchSize; i++) { for (int j = 0; j < tgtSeqLen; j++) { using (IWeightTensor probs_i_j = g.PeekRow(probs, i * tgtSeqLen + j, runGradients: false)) { if (j < originalOutputLengths[i]) { int ix_targets_i_j = m_modelMetaData.Vocab.GetTargetWordIndex(leftShiftInputSeqs[i][j], logUnk: true); float score_i_j = probs_i_j.GetWeightAt(ix_targets_i_j); if (j < originalOutputLengths[i]) { cost += (float)-Math.Log(score_i_j); } probs_i_j.SetWeightAt(score_i_j - 1, ix_targets_i_j); } else { probs_i_j.CleanWeight(); } } } } decOutput.CopyWeightsToGradients(probs); } else { // Output "i"th target word int[] targetIdx = g.Argmax(probs, 1); List <string> targetWords = m_modelMetaData.Vocab.ConvertTargetIdsToString(targetIdx.ToList()); for (int i = 0; i < batchSize; i++) { outInputSeqs[i].Add(targetWords[i * tgtSeqLen + tgtSeqLen - 1]); } } } } return(cost); }
/// <summary> /// Run forward part on given single device /// </summary> /// <param name="computeGraph">The computing graph for current device. It gets created and passed by the framework</param> /// <param name="srcSnts">A batch of input tokenized sentences in source side</param> /// <param name="tgtSnts">A batch of output tokenized sentences in target side</param> /// <param name="deviceIdIdx">The index of current device</param> /// <returns>The cost of forward part</returns> private float RunForwardOnSingleDevice(IComputeGraph computeGraph, List <List <string> > srcSnts, List <List <string> > tgtSnts, int deviceIdIdx, bool isTraining) { (IEncoder encoder, IDecoder decoder, IWeightTensor srcEmbedding, IWeightTensor tgtEmbedding) = GetNetworksOnDeviceAt(deviceIdIdx); // Reset networks encoder.Reset(computeGraph.GetWeightFactory(), srcSnts.Count); decoder.Reset(computeGraph.GetWeightFactory(), srcSnts.Count); List <int> originalSrcLengths = ParallelCorpus.PadSentences(srcSnts); int srcSeqPaddedLen = srcSnts[0].Count; int batchSize = srcSnts.Count; IWeightTensor encSelfMask = MaskUtils.BuildPadSelfMask(computeGraph, srcSeqPaddedLen, originalSrcLengths, deviceIdIdx); IWeightTensor encDimMask = MaskUtils.BuildPadDimMask(computeGraph, srcSeqPaddedLen, originalSrcLengths, m_modelMetaData.HiddenDim, deviceIdIdx); // Encoding input source sentences IWeightTensor encOutput = Encode(computeGraph, srcSnts, encoder, srcEmbedding, encSelfMask, encDimMask); // Generate output decoder sentences if (decoder is AttentionDecoder) { return(DecodeAttentionLSTM(tgtSnts, computeGraph, encOutput, decoder as AttentionDecoder, tgtEmbedding, srcSnts.Count, isTraining)); } else { if (isTraining) { List <int> originalTgtLengths = ParallelCorpus.PadSentences(tgtSnts); int tgtSeqPaddedLen = tgtSnts[0].Count; IWeightTensor encDecMask = MaskUtils.BuildSrcTgtMask(computeGraph, srcSeqPaddedLen, tgtSeqPaddedLen, originalSrcLengths, originalTgtLengths, deviceIdIdx); return(DecodeTransformer(tgtSnts, computeGraph, encOutput, encDecMask, decoder as TransformerDecoder, tgtEmbedding, batchSize, deviceIdIdx, isTraining)); } else { for (int i = 0; i < m_maxTgtSntSize; i++) { using (var g = computeGraph.CreateSubGraph($"TransformerDecoder_Step_{i}")) { List <int> originalTgtLengths = ParallelCorpus.PadSentences(tgtSnts); int tgtSeqPaddedLen = tgtSnts[0].Count; IWeightTensor encDecMask = MaskUtils.BuildSrcTgtMask(g, srcSeqPaddedLen, tgtSeqPaddedLen, originalSrcLengths, originalTgtLengths, deviceIdIdx); DecodeTransformer(tgtSnts, g, encOutput, encDecMask, decoder as TransformerDecoder, tgtEmbedding, batchSize, deviceIdIdx, isTraining); bool allSntsEnd = true; for (int j = 0; j < tgtSnts.Count; j++) { if (tgtSnts[j][tgtSnts[j].Count - 1] != ParallelCorpus.EOS) { allSntsEnd = false; break; } } if (allSntsEnd) { break; } } } return(0.0f); } } }