private float DecodeTransformer(List <List <string> > outInputSeqs, IComputeGraph g, IWeightTensor encOutputs, IWeightTensor encMask, TransformerDecoder decoder, IWeightTensor tgtEmbedding, int batchSize, int deviceId, bool isTraining = true) { float cost = 0.0f; var originalInputLengths = ParallelCorpus.PadSentences(outInputSeqs); int tgtSeqLen = outInputSeqs[0].Count; IWeightTensor tgtDimMask = MaskUtils.BuildPadDimMask(g, tgtSeqLen, originalInputLengths, m_modelMetaData.HiddenDim, deviceId); using (IWeightTensor tgtSelfTriMask = MaskUtils.BuildPadSelfTriMask(g, tgtSeqLen, originalInputLengths, deviceId)) { List <IWeightTensor> inputs = new List <IWeightTensor>(); for (int i = 0; i < batchSize; i++) { for (int j = 0; j < tgtSeqLen; j++) { int ix_targets_k = m_modelMetaData.Vocab.GetTargetWordIndex(outInputSeqs[i][j], logUnk: true); inputs.Add(g.PeekRow(tgtEmbedding, ix_targets_k)); } } IWeightTensor tgtInputEmbeddings = inputs.Count > 1 ? g.ConcatRows(inputs) : inputs[0]; IWeightTensor decOutput = decoder.Decode(tgtInputEmbeddings, encOutputs, tgtSelfTriMask, encMask, tgtDimMask, batchSize, g); decOutput = g.Mul(decOutput, g.Transpose(tgtEmbedding)); using (IWeightTensor probs = g.Softmax(decOutput, runGradients: false, inPlace: true)) { if (isTraining) { var leftShiftInputSeqs = ParallelCorpus.LeftShiftSnts(outInputSeqs, ParallelCorpus.EOS); var originalOutputLengths = ParallelCorpus.PadSentences(leftShiftInputSeqs, tgtSeqLen); for (int i = 0; i < batchSize; i++) { for (int j = 0; j < tgtSeqLen; j++) { using (IWeightTensor probs_i_j = g.PeekRow(probs, i * tgtSeqLen + j, runGradients: false)) { if (j < originalOutputLengths[i]) { int ix_targets_i_j = m_modelMetaData.Vocab.GetTargetWordIndex(leftShiftInputSeqs[i][j], logUnk: true); float score_i_j = probs_i_j.GetWeightAt(ix_targets_i_j); if (j < originalOutputLengths[i]) { cost += (float)-Math.Log(score_i_j); } probs_i_j.SetWeightAt(score_i_j - 1, ix_targets_i_j); } else { probs_i_j.CleanWeight(); } } } } decOutput.CopyWeightsToGradients(probs); } else { // Output "i"th target word int[] targetIdx = g.Argmax(probs, 1); List <string> targetWords = m_modelMetaData.Vocab.ConvertTargetIdsToString(targetIdx.ToList()); for (int i = 0; i < batchSize; i++) { outInputSeqs[i].Add(targetWords[i * tgtSeqLen + tgtSeqLen - 1]); } } } } return(cost); }
private float DecodeTransformer(List <List <string> > tgtSeqs, IComputeGraph g, IWeightTensor encOutputs, TransformerDecoder decoder, IWeightTensor tgtEmbedding, IWeightTensor posEmbedding, int batchSize, int deviceId, List <int> srcOriginalLenghts, bool isTraining = true) { float cost = 0.0f; var tgtOriginalLengths = ParallelCorpus.PadSentences(tgtSeqs); int tgtSeqLen = tgtSeqs[0].Count; int srcSeqLen = encOutputs.Rows / batchSize; using (IWeightTensor srcTgtMask = MaskUtils.BuildSrcTgtMask(g, srcSeqLen, tgtSeqLen, tgtOriginalLengths, srcOriginalLenghts, deviceId)) { using (IWeightTensor tgtSelfTriMask = MaskUtils.BuildPadSelfTriMask(g, tgtSeqLen, tgtOriginalLengths, deviceId)) { List <IWeightTensor> inputs = new List <IWeightTensor>(); for (int i = 0; i < batchSize; i++) { for (int j = 0; j < tgtSeqLen; j++) { int ix_targets_k = m_modelMetaData.Vocab.GetTargetWordIndex(tgtSeqs[i][j], logUnk: true); var emb = g.PeekRow(tgtEmbedding, ix_targets_k, runGradients: j < tgtOriginalLengths[i] ? true : false); inputs.Add(emb); } } IWeightTensor inputEmbs = inputs.Count > 1 ? g.ConcatRows(inputs) : inputs[0]; inputEmbs = AddPositionEmbedding(g, posEmbedding, batchSize, tgtSeqLen, inputEmbs); IWeightTensor decOutput = decoder.Decode(inputEmbs, encOutputs, tgtSelfTriMask, srcTgtMask, batchSize, g); using (IWeightTensor probs = g.Softmax(decOutput, runGradients: false, inPlace: true)) { if (isTraining) { var leftShiftInputSeqs = ParallelCorpus.LeftShiftSnts(tgtSeqs, ParallelCorpus.EOS); for (int i = 0; i < batchSize; i++) { for (int j = 0; j < tgtSeqLen; j++) { using (IWeightTensor probs_i_j = g.PeekRow(probs, i * tgtSeqLen + j, runGradients: false)) { if (j < tgtOriginalLengths[i]) { int ix_targets_i_j = m_modelMetaData.Vocab.GetTargetWordIndex(leftShiftInputSeqs[i][j], logUnk: true); float score_i_j = probs_i_j.GetWeightAt(ix_targets_i_j); cost += (float)-Math.Log(score_i_j); probs_i_j.SetWeightAt(score_i_j - 1, ix_targets_i_j); } else { probs_i_j.CleanWeight(); } } } } decOutput.CopyWeightsToGradients(probs); } //if (isTraining) //{ // var leftShiftInputSeqs = ParallelCorpus.LeftShiftSnts(tgtSeqs, ParallelCorpus.EOS); // int[] targetIds = new int[batchSize * tgtSeqLen]; // int ids = 0; // for (int i = 0; i < batchSize; i++) // { // for (int j = 0; j < tgtSeqLen; j++) // { // targetIds[ids] = j < tgtOriginalLengths[i] ? m_modelMetaData.Vocab.GetTargetWordIndex(leftShiftInputSeqs[i][j], logUnk: true) : -1; // ids++; // } // } // cost += g.UpdateCost(probs, targetIds); // decOutput.CopyWeightsToGradients(probs); //} else { // Output "i"th target word int[] targetIdx = g.Argmax(probs, 1); List <string> targetWords = m_modelMetaData.Vocab.ConvertTargetIdsToString(targetIdx.ToList()); for (int i = 0; i < batchSize; i++) { tgtSeqs[i].Add(targetWords[i * tgtSeqLen + tgtSeqLen - 1]); } } } } } return(cost); }