/// <summary> /// Run forward part on given single device /// </summary> /// <param name="computeGraph">The computing graph for current device. It gets created and passed by the framework</param> /// <param name="srcSnts">A batch of input tokenized sentences in source side</param> /// <param name="tgtSnts">A batch of output tokenized sentences in target side</param> /// <param name="deviceIdIdx">The index of current device</param> /// <returns>The cost of forward part</returns> private float RunForwardOnSingleDevice(IComputeGraph computeGraph, List <List <string> > srcSnts, List <List <string> > tgtSnts, int deviceIdIdx, bool isTraining) { (IEncoder encoder, IDecoder decoder, IWeightTensor srcEmbedding, IWeightTensor tgtEmbedding) = GetNetworksOnDeviceAt(deviceIdIdx); // Reset networks encoder.Reset(computeGraph.GetWeightFactory(), srcSnts.Count); decoder.Reset(computeGraph.GetWeightFactory(), srcSnts.Count); List <int> originalSrcLengths = ParallelCorpus.PadSentences(srcSnts); int srcSeqPaddedLen = srcSnts[0].Count; int batchSize = srcSnts.Count; IWeightTensor encSelfMask = MaskUtils.BuildPadSelfMask(computeGraph, srcSeqPaddedLen, originalSrcLengths, deviceIdIdx); IWeightTensor encDimMask = MaskUtils.BuildPadDimMask(computeGraph, srcSeqPaddedLen, originalSrcLengths, m_modelMetaData.HiddenDim, deviceIdIdx); // Encoding input source sentences IWeightTensor encOutput = Encode(computeGraph, srcSnts, encoder, srcEmbedding, encSelfMask, encDimMask); // Generate output decoder sentences if (decoder is AttentionDecoder) { return(DecodeAttentionLSTM(tgtSnts, computeGraph, encOutput, decoder as AttentionDecoder, tgtEmbedding, srcSnts.Count, isTraining)); } else { if (isTraining) { List <int> originalTgtLengths = ParallelCorpus.PadSentences(tgtSnts); int tgtSeqPaddedLen = tgtSnts[0].Count; IWeightTensor encDecMask = MaskUtils.BuildSrcTgtMask(computeGraph, srcSeqPaddedLen, tgtSeqPaddedLen, originalSrcLengths, originalTgtLengths, deviceIdIdx); return(DecodeTransformer(tgtSnts, computeGraph, encOutput, encDecMask, decoder as TransformerDecoder, tgtEmbedding, batchSize, deviceIdIdx, isTraining)); } else { for (int i = 0; i < m_maxTgtSntSize; i++) { using (var g = computeGraph.CreateSubGraph($"TransformerDecoder_Step_{i}")) { List <int> originalTgtLengths = ParallelCorpus.PadSentences(tgtSnts); int tgtSeqPaddedLen = tgtSnts[0].Count; IWeightTensor encDecMask = MaskUtils.BuildSrcTgtMask(g, srcSeqPaddedLen, tgtSeqPaddedLen, originalSrcLengths, originalTgtLengths, deviceIdIdx); DecodeTransformer(tgtSnts, g, encOutput, encDecMask, decoder as TransformerDecoder, tgtEmbedding, batchSize, deviceIdIdx, isTraining); bool allSntsEnd = true; for (int j = 0; j < tgtSnts.Count; j++) { if (tgtSnts[j][tgtSnts[j].Count - 1] != ParallelCorpus.EOS) { allSntsEnd = false; break; } } if (allSntsEnd) { break; } } } return(0.0f); } } }
private float DecodeTransformer(List <List <string> > tgtSeqs, IComputeGraph g, IWeightTensor encOutputs, TransformerDecoder decoder, IWeightTensor tgtEmbedding, IWeightTensor posEmbedding, int batchSize, int deviceId, List <int> srcOriginalLenghts, bool isTraining = true) { float cost = 0.0f; var tgtOriginalLengths = ParallelCorpus.PadSentences(tgtSeqs); int tgtSeqLen = tgtSeqs[0].Count; int srcSeqLen = encOutputs.Rows / batchSize; using (IWeightTensor srcTgtMask = MaskUtils.BuildSrcTgtMask(g, srcSeqLen, tgtSeqLen, tgtOriginalLengths, srcOriginalLenghts, deviceId)) { using (IWeightTensor tgtSelfTriMask = MaskUtils.BuildPadSelfTriMask(g, tgtSeqLen, tgtOriginalLengths, deviceId)) { List <IWeightTensor> inputs = new List <IWeightTensor>(); for (int i = 0; i < batchSize; i++) { for (int j = 0; j < tgtSeqLen; j++) { int ix_targets_k = m_modelMetaData.Vocab.GetTargetWordIndex(tgtSeqs[i][j], logUnk: true); var emb = g.PeekRow(tgtEmbedding, ix_targets_k, runGradients: j < tgtOriginalLengths[i] ? true : false); inputs.Add(emb); } } IWeightTensor inputEmbs = inputs.Count > 1 ? g.ConcatRows(inputs) : inputs[0]; inputEmbs = AddPositionEmbedding(g, posEmbedding, batchSize, tgtSeqLen, inputEmbs); IWeightTensor decOutput = decoder.Decode(inputEmbs, encOutputs, tgtSelfTriMask, srcTgtMask, batchSize, g); using (IWeightTensor probs = g.Softmax(decOutput, runGradients: false, inPlace: true)) { if (isTraining) { var leftShiftInputSeqs = ParallelCorpus.LeftShiftSnts(tgtSeqs, ParallelCorpus.EOS); for (int i = 0; i < batchSize; i++) { for (int j = 0; j < tgtSeqLen; j++) { using (IWeightTensor probs_i_j = g.PeekRow(probs, i * tgtSeqLen + j, runGradients: false)) { if (j < tgtOriginalLengths[i]) { int ix_targets_i_j = m_modelMetaData.Vocab.GetTargetWordIndex(leftShiftInputSeqs[i][j], logUnk: true); float score_i_j = probs_i_j.GetWeightAt(ix_targets_i_j); cost += (float)-Math.Log(score_i_j); probs_i_j.SetWeightAt(score_i_j - 1, ix_targets_i_j); } else { probs_i_j.CleanWeight(); } } } } decOutput.CopyWeightsToGradients(probs); } //if (isTraining) //{ // var leftShiftInputSeqs = ParallelCorpus.LeftShiftSnts(tgtSeqs, ParallelCorpus.EOS); // int[] targetIds = new int[batchSize * tgtSeqLen]; // int ids = 0; // for (int i = 0; i < batchSize; i++) // { // for (int j = 0; j < tgtSeqLen; j++) // { // targetIds[ids] = j < tgtOriginalLengths[i] ? m_modelMetaData.Vocab.GetTargetWordIndex(leftShiftInputSeqs[i][j], logUnk: true) : -1; // ids++; // } // } // cost += g.UpdateCost(probs, targetIds); // decOutput.CopyWeightsToGradients(probs); //} else { // Output "i"th target word int[] targetIdx = g.Argmax(probs, 1); List <string> targetWords = m_modelMetaData.Vocab.ConvertTargetIdsToString(targetIdx.ToList()); for (int i = 0; i < batchSize; i++) { tgtSeqs[i].Add(targetWords[i * tgtSeqLen + tgtSeqLen - 1]); } } } } } return(cost); }