public static IWeightTensor BuildTensorForSourceTokenGroupAt(IComputeGraph computeGraph, ISntPairBatch sntPairBatch, ShuffleEnums shuffleType, IEncoder encoder, IModel modelMetaData, IWeightTensor srcEmbedding, IWeightTensor posEmbedding, IWeightTensor segmentEmbedding, int groupId) { var contextTokens = InsertCLSToken(sntPairBatch.GetSrcTokens(groupId)); var originalSrcContextLength = BuildInTokens.PadSentences(contextTokens); var contextTokenIds = modelMetaData.SrcVocab.GetWordIndex(contextTokens); IWeightTensor encContextOutput = InnerRunner(computeGraph, contextTokenIds, originalSrcContextLength, shuffleType, encoder, modelMetaData, srcEmbedding, posEmbedding, segmentEmbedding); int contextPaddedLen = contextTokens[0].Count; float[] contextCLSIdxs = new float[sntPairBatch.BatchSize]; for (int j = 0; j < sntPairBatch.BatchSize; j++) { contextCLSIdxs[j] = j * contextPaddedLen; } IWeightTensor contextCLSOutput = computeGraph.IndexSelect(encContextOutput, contextCLSIdxs); return(contextCLSOutput); }
/// <summary> /// Run forward part on given single device /// </summary> /// <param name="computeGraph">The computing graph for current device. It gets created and passed by the framework</param> /// <param name="srcSnts">A batch of input tokenized sentences in source side</param> /// <param name="tgtSnts">A batch of output tokenized sentences in target side</param> /// <param name="deviceIdIdx">The index of current device</param> /// <returns>The cost of forward part</returns> public override List <NetworkResult> RunForwardOnSingleDevice(IComputeGraph computeGraph, ISntPairBatch sntPairBatch, int deviceIdIdx, bool isTraining, DecodingOptions decodingOptions) { List <NetworkResult> nrs = new List <NetworkResult>(); (IEncoder encoder, IWeightTensor srcEmbedding, List <IFeedForwardLayer> encoderFFLayer, IWeightTensor posEmbedding, IWeightTensor segmentEmbedding) = GetNetworksOnDeviceAt(deviceIdIdx); var srcSnts = sntPairBatch.GetSrcTokens(0); var originalSrcLengths = BuildInTokens.PadSentences(srcSnts); var srcTokensList = m_modelMetaData.SrcVocab.GetWordIndex(srcSnts); IWeightTensor encOutput = Encoder.Run(computeGraph, sntPairBatch, encoder, m_modelMetaData, m_shuffleType, srcEmbedding, posEmbedding, segmentEmbedding, srcTokensList, originalSrcLengths); int srcSeqPaddedLen = srcSnts[0].Count; int batchSize = srcSnts.Count; float[] clsIdxs = new float[batchSize]; for (int i = 0; i < batchSize; i++) { for (int j = 0; j < srcSnts[i].Count; j++) { if (srcSnts[i][j] == BuildInTokens.CLS) { clsIdxs[i] = i * srcSeqPaddedLen + j; break; } } } IWeightTensor clsWeightTensor = computeGraph.IndexSelect(encOutput, clsIdxs); for (int i = 0; i < m_encoderFFLayer.Length; i++) { float cost = 0.0f; NetworkResult nr = new NetworkResult { Output = new List <List <List <string> > >() }; IWeightTensor ffLayer = encoderFFLayer[i].Process(clsWeightTensor, batchSize, computeGraph); using (IWeightTensor probs = computeGraph.Softmax(ffLayer, runGradients: false, inPlace: true)) { if (isTraining) { var tgtSnts = sntPairBatch.GetTgtTokens(i); for (int k = 0; k < batchSize; k++) { int ix_targets_k_j = m_modelMetaData.ClsVocabs[i].GetWordIndex(tgtSnts[k][0]); float score_k = probs.GetWeightAt(new long[] { k, ix_targets_k_j }); cost += (float)-Math.Log(score_k); probs.SetWeightAt(score_k - 1, new long[] { k, ix_targets_k_j }); } ffLayer.CopyWeightsToGradients(probs); nr.Cost = cost / batchSize; } else { // Output "i"th target word using var targetIdxTensor = computeGraph.Argmax(probs, 1); float[] targetIdx = targetIdxTensor.ToWeightArray(); List <string> targetWords = m_modelMetaData.ClsVocabs[i].ConvertIdsToString(targetIdx.ToList()); nr.Output.Add(new List <List <string> >()); for (int k = 0; k < batchSize; k++) { nr.Output[0].Add(new List <string>()); nr.Output[0][k].Add(targetWords[k]); } } } nrs.Add(nr); } return(nrs); }
/// <summary> /// Create input embedding from token embeddings, segment embeddings /// </summary> /// <param name="seqs"></param> /// <param name="g"></param> /// <param name="embeddingsTensor"></param> /// <param name="seqOriginalLengths"></param> /// <param name="segmentEmbedding"></param> /// <param name="vocab"></param> /// <returns>The embedding tensor. shape: (batchsize * seqLen, embedding_dim) </returns> public static IWeightTensor CreateTokensEmbeddings(List <List <int> > seqs, IComputeGraph g, IWeightTensor embeddingsTensor, IWeightTensor segmentEmbedding, Vocab vocab, float scaleFactor = 1.0f, bool enableTagEmbedding = false) { int batchSize = seqs.Count; int seqLen = seqs[0].Count; float[] idxs = new float[batchSize * seqLen]; float[] segIdxs = new float[batchSize * seqLen]; List <float[]> tagIdxsList = new List <float[]>(); //float[] tagIdxs = new float[batchSize * seqLen]; for (int i = 0; i < batchSize; i++) { int segIdx = 0; List <int> currTagIdxs = new List <int>(); int currTagLevel = 0; for (int j = 0; j < seqLen; j++) { idxs[i * seqLen + j] = seqs[i][j]; segIdxs[i * seqLen + j] = segIdx; string token = vocab.GetString(seqs[i][j]); if (token == BuildInTokens.SEP) { //A new segment segIdx++; } if (enableTagEmbedding) { if (token.StartsWith("<") && token.EndsWith(">") && BuildInTokens.IsPreDefinedToken(token) == false) { if (token[1] == '/') { currTagLevel--; currTagIdxs[currTagLevel] = -1; } else { //A new opening tag while (tagIdxsList.Count <= currTagLevel) { float[] tagIdxs = new float[batchSize * seqLen]; Array.Fill(tagIdxs, -1.0f); tagIdxsList.Add(tagIdxs); } while (currTagIdxs.Count <= currTagLevel) { currTagIdxs.Add(-1); } currTagIdxs[currTagLevel] = seqs[i][j]; currTagLevel++; } } else { for (int k = 0; k < currTagLevel; k++) { tagIdxsList[k][i * seqLen + j] = currTagIdxs[k]; //Logger.WriteLine($"Add tag embeddings: '{currTagIdxs[k]}'"); } } } } } IWeightTensor tagEmbeddings = null; if (enableTagEmbedding) { for (int k = 0; k < tagIdxsList.Count; k++) { var tagEmbeddings_k = g.IndexSelect(embeddingsTensor, tagIdxsList[k], clearWeights: true); if (tagEmbeddings == null) { tagEmbeddings = tagEmbeddings_k; } else { tagEmbeddings = g.Add(tagEmbeddings, tagEmbeddings_k); } } } IWeightTensor embeddingRst = g.IndexSelect(embeddingsTensor, idxs); if (scaleFactor != 1.0f) { embeddingRst = g.Mul(embeddingRst, scaleFactor, inPlace: true); } // Apply segment embeddings to the input sequence embeddings if (segmentEmbedding != null) { embeddingRst = g.Add(embeddingRst, g.IndexSelect(segmentEmbedding, segIdxs)); } if (tagEmbeddings != null) { embeddingRst = g.Add(embeddingRst, tagEmbeddings); } return(embeddingRst); }
/// <summary> /// Run forward part on given single device /// </summary> /// <param name="computeGraph">The computing graph for current device. It gets created and passed by the framework</param> /// <param name="srcSnts">A batch of input tokenized sentences in source side</param> /// <param name="tgtSnts">A batch of output tokenized sentences in target side</param> /// <param name="deviceIdIdx">The index of current device</param> /// <returns>The cost of forward part</returns> public override List <NetworkResult> RunForwardOnSingleDevice(IComputeGraph computeGraph, ISntPairBatch sntPairBatch, int deviceIdIdx, bool isTraining, DecodingOptions decodingOptions) { (IEncoder encoder, IDecoder decoder, IFeedForwardLayer encoderFFLayer, IFeedForwardLayer decoderFFLayer, IWeightTensor srcEmbedding, IWeightTensor tgtEmbedding, IWeightTensor posEmbedding, IWeightTensor segmentEmbedding) = GetNetworksOnDeviceAt(deviceIdIdx); var srcSnts = sntPairBatch.GetSrcTokens(0); var originalSrcLengths = BuildInTokens.PadSentences(srcSnts); var srcTokensList = m_modelMetaData.SrcVocab.GetWordIndex(srcSnts); IWeightTensor encOutput = Encoder.Run(computeGraph, sntPairBatch, encoder, m_modelMetaData, m_shuffleType, srcEmbedding, posEmbedding, segmentEmbedding, srcTokensList, originalSrcLengths); List <NetworkResult> nrs = new List <NetworkResult>(); int srcSeqPaddedLen = srcSnts[0].Count; int batchSize = srcSnts.Count; float[] clsIdxs = new float[batchSize]; for (int i = 0; i < batchSize; i++) { for (int j = 0; j < srcSnts[i].Count; j++) { if (srcSnts[i][j] == BuildInTokens.CLS) { clsIdxs[i] = i * srcSeqPaddedLen + j; break; } } } IWeightTensor clsWeightTensor = computeGraph.IndexSelect(encOutput, clsIdxs); float cost = 0.0f; NetworkResult nrCLS = new NetworkResult { Output = new List <List <List <string> > >() }; IWeightTensor ffLayer = encoderFFLayer.Process(clsWeightTensor, batchSize, computeGraph); using (IWeightTensor probs = computeGraph.Softmax(ffLayer, runGradients: false, inPlace: true)) { if (isTraining) { var clsSnts = sntPairBatch.GetTgtTokens(0); for (int k = 0; k < batchSize; k++) { int ix_targets_k_j = m_modelMetaData.ClsVocab.GetWordIndex(clsSnts[k][0]); float score_k = probs.GetWeightAt(new long[] { k, ix_targets_k_j }); cost += (float)-Math.Log(score_k); probs.SetWeightAt(score_k - 1, new long[] { k, ix_targets_k_j }); } ffLayer.CopyWeightsToGradients(probs); nrCLS.Cost = cost / batchSize; } else { // Output "i"th target word using var targetIdxTensor = computeGraph.Argmax(probs, 1); float[] targetIdx = targetIdxTensor.ToWeightArray(); List <string> targetWords = m_modelMetaData.ClsVocab.ConvertIdsToString(targetIdx.ToList()); nrCLS.Output.Add(new List <List <string> >()); for (int k = 0; k < batchSize; k++) { nrCLS.Output[0].Add(new List <string>()); nrCLS.Output[0][k].Add(targetWords[k]); } } } // Reset networks decoder.Reset(computeGraph.GetWeightFactory(), srcSnts.Count); // Generate output decoder sentences var tgtSnts = sntPairBatch.GetTgtTokens(1); var tgtTokensList = m_modelMetaData.TgtVocab.GetWordIndex(tgtSnts); NetworkResult nr = new NetworkResult(); if (decoder is AttentionDecoder) { nr.Cost = Decoder.DecodeAttentionLSTM(tgtTokensList, computeGraph, encOutput, decoder as AttentionDecoder, decoderFFLayer, tgtEmbedding, m_modelMetaData.TgtVocab, srcSnts.Count, isTraining); nr.Output = new List <List <List <string> > > { m_modelMetaData.TgtVocab.ConvertIdsToString(tgtTokensList) }; } else { if (isTraining) { (var c, _) = Decoder.DecodeTransformer(tgtTokensList, computeGraph, encOutput, decoder as TransformerDecoder, decoderFFLayer, tgtEmbedding, posEmbedding, originalSrcLengths, m_modelMetaData.TgtVocab, m_shuffleType, m_options.DropoutRatio, null, isTraining); nr.Cost = c; nr.Output = null; } else { List <List <BeamSearchStatus> > beam2batchStatus = Decoder.InitBeamSearchStatusListList(batchSize, tgtTokensList); for (int i = 0; i < decodingOptions.MaxTgtSentLength; i++) { List <List <BeamSearchStatus> > batch2beam2seq = null; //(batch_size, beam_search_size) try { foreach (var batchStatus in beam2batchStatus) { var batch2tgtTokens = Decoder.ExtractBatchTokens(batchStatus); using var g = computeGraph.CreateSubGraph($"TransformerDecoder_Step_{i}"); (var cost2, var bssSeqList) = Decoder.DecodeTransformer(batch2tgtTokens, g, encOutput, decoder as TransformerDecoder, decoderFFLayer, tgtEmbedding, posEmbedding, originalSrcLengths, m_modelMetaData.TgtVocab, m_shuffleType, 0.0f, decodingOptions, isTraining, outputSentScore: decodingOptions.BeamSearchSize > 1, previousBeamSearchResults: batchStatus); bssSeqList = Decoder.SwapBeamAndBatch(bssSeqList); batch2beam2seq = Decoder.CombineBeamSearchResults(batch2beam2seq, bssSeqList); } } catch (OutOfMemoryException) { GC.Collect(); Logger.WriteLine(Logger.Level.warn, $"We have out of memory while generating '{i}th' tokens, so terminate decoding for current sequences."); break; } if (decodingOptions.BeamSearchSize > 1) { // Keep top N result and drop all others for (int k = 0; k < batchSize; k++) { batch2beam2seq[k] = BeamSearch.GetTopNBSS(batch2beam2seq[k], decodingOptions.BeamSearchSize); } } beam2batchStatus = Decoder.SwapBeamAndBatch(batch2beam2seq); if (Decoder.AreAllSentsCompleted(beam2batchStatus)) { break; } } nr.Cost = 0.0f; nr.Output = m_modelMetaData.TgtVocab.ExtractTokens(beam2batchStatus); } } nr.RemoveDuplicatedEOS(); nrs.Add(nrCLS); nrs.Add(nr); return(nrs); }