/// <summary> /// Run forward part on given single device /// </summary> /// <param name="computeGraph">The computing graph for current device. It gets created and passed by the framework</param> /// <param name="srcSnts">A batch of input tokenized sentences in source side</param> /// <param name="tgtSnts">A batch of output tokenized sentences in target side</param> /// <param name="deviceIdIdx">The index of current device</param> /// <returns>The cost of forward part</returns> public override List <NetworkResult> RunForwardOnSingleDevice(IComputeGraph computeGraph, ISntPairBatch sntPairBatch, int deviceIdIdx, bool isTraining, DecodingOptions decodingOptions) { List <NetworkResult> nrs = new List <NetworkResult>(); (IEncoder encoder, IWeightTensor srcEmbedding, List <IFeedForwardLayer> encoderFFLayer, IWeightTensor posEmbedding, IWeightTensor segmentEmbedding) = GetNetworksOnDeviceAt(deviceIdIdx); var srcSnts = sntPairBatch.GetSrcTokens(0); var originalSrcLengths = BuildInTokens.PadSentences(srcSnts); var srcTokensList = m_modelMetaData.SrcVocab.GetWordIndex(srcSnts); IWeightTensor encOutput = Encoder.Run(computeGraph, sntPairBatch, encoder, m_modelMetaData, m_shuffleType, srcEmbedding, posEmbedding, segmentEmbedding, srcTokensList, originalSrcLengths); int srcSeqPaddedLen = srcSnts[0].Count; int batchSize = srcSnts.Count; float[] clsIdxs = new float[batchSize]; for (int i = 0; i < batchSize; i++) { for (int j = 0; j < srcSnts[i].Count; j++) { if (srcSnts[i][j] == BuildInTokens.CLS) { clsIdxs[i] = i * srcSeqPaddedLen + j; break; } } } IWeightTensor clsWeightTensor = computeGraph.IndexSelect(encOutput, clsIdxs); for (int i = 0; i < m_encoderFFLayer.Length; i++) { float cost = 0.0f; NetworkResult nr = new NetworkResult { Output = new List <List <List <string> > >() }; IWeightTensor ffLayer = encoderFFLayer[i].Process(clsWeightTensor, batchSize, computeGraph); using (IWeightTensor probs = computeGraph.Softmax(ffLayer, runGradients: false, inPlace: true)) { if (isTraining) { var tgtSnts = sntPairBatch.GetTgtTokens(i); for (int k = 0; k < batchSize; k++) { int ix_targets_k_j = m_modelMetaData.ClsVocabs[i].GetWordIndex(tgtSnts[k][0]); float score_k = probs.GetWeightAt(new long[] { k, ix_targets_k_j }); cost += (float)-Math.Log(score_k); probs.SetWeightAt(score_k - 1, new long[] { k, ix_targets_k_j }); } ffLayer.CopyWeightsToGradients(probs); nr.Cost = cost / batchSize; } else { // Output "i"th target word using var targetIdxTensor = computeGraph.Argmax(probs, 1); float[] targetIdx = targetIdxTensor.ToWeightArray(); List <string> targetWords = m_modelMetaData.ClsVocabs[i].ConvertIdsToString(targetIdx.ToList()); nr.Output.Add(new List <List <string> >()); for (int k = 0; k < batchSize; k++) { nr.Output[0].Add(new List <string>()); nr.Output[0][k].Add(targetWords[k]); } } } nrs.Add(nr); } return(nrs); }
/// <summary> /// Run forward part on given single device /// </summary> /// <param name="computeGraph">The computing graph for current device. It gets created and passed by the framework</param> /// <param name="srcSnts">A batch of input tokenized sentences in source side</param> /// <param name="tgtSnts">A batch of output tokenized sentences in target side</param> /// <param name="deviceIdIdx">The index of current device</param> /// <returns>The cost of forward part</returns> public override List <NetworkResult> RunForwardOnSingleDevice(IComputeGraph computeGraph, ISntPairBatch sntPairBatch, int deviceIdIdx, bool isTraining, DecodingOptions decodingOptions) { int batchSize = sntPairBatch.BatchSize; float cost = 0.0f; var nrs = new List <NetworkResult>(); var nr = new NetworkResult { Output = new List <List <List <string> > >() }; (IEncoder encoder, IWeightTensor srcEmbedding, IFeedForwardLayer encoderFFLayer, IWeightTensor posEmbedding, IWeightTensor segmentEmbedding) = GetNetworksOnDeviceAt(deviceIdIdx); IWeightTensor encOutput1; IWeightTensor encOutput2; if (!isTraining && (m_options.ProcessorType == ProcessorTypeEnums.CPU)) { //We only check cache at inference time string cacheKey1 = GenerateCacheKey(sntPairBatch.GetSrcTokens(0)); if (!m_memoryCache.TryGetValue(cacheKey1, out encOutput1)) { encOutput1 = Encoder.BuildTensorForSourceTokenGroupAt(computeGraph, sntPairBatch, m_shuffleType, encoder, m_modelMetaData, srcEmbedding, posEmbedding, segmentEmbedding, 0); // output shape: [batch_size, dim] var cacheEntryOptions = new MemoryCacheEntryOptions().SetSize(1); m_memoryCache.Set(cacheKey1, encOutput1.CopyWeightsRef($"cache_{encOutput1.Name}", false), cacheEntryOptions); } string cacheKey2 = GenerateCacheKey(sntPairBatch.GetSrcTokens(1)); if (!m_memoryCache.TryGetValue(cacheKey2, out encOutput2)) { encOutput2 = Encoder.BuildTensorForSourceTokenGroupAt(computeGraph, sntPairBatch, m_shuffleType, encoder, m_modelMetaData, srcEmbedding, posEmbedding, segmentEmbedding, 1); // output_shape: [batch_size, dim] var cacheEntryOptions = new MemoryCacheEntryOptions().SetSize(1); m_memoryCache.Set(cacheKey2, encOutput2.CopyWeightsRef($"cache_{encOutput2.Name}", false), cacheEntryOptions); } } else { //We always run encoder network during training time or using GPUs encOutput1 = Encoder.BuildTensorForSourceTokenGroupAt(computeGraph, sntPairBatch, m_shuffleType, encoder, m_modelMetaData, srcEmbedding, posEmbedding, segmentEmbedding, 0); // output shape: [batch_size, dim] encOutput2 = Encoder.BuildTensorForSourceTokenGroupAt(computeGraph, sntPairBatch, m_shuffleType, encoder, m_modelMetaData, srcEmbedding, posEmbedding, segmentEmbedding, 1); // output_shape: [batch_size, dim] } if (m_modelMetaData.SimilarityType.Equals("Continuous", StringComparison.InvariantCultureIgnoreCase)) { // Cosine similairy var w12 = computeGraph.EltMul(encOutput1, encOutput2); w12 = computeGraph.Sum(w12, 1); var w1 = computeGraph.EltMul(encOutput1, encOutput1); w1 = computeGraph.Sum(w1, 1); var w2 = computeGraph.EltMul(encOutput2, encOutput2); w2 = computeGraph.Sum(w2, 1); var n12 = computeGraph.EltMul(w1, w2); n12 = computeGraph.Rsqrt(n12); var probs = computeGraph.EltMul(w12, n12); if (isTraining) { var tgtSnts = sntPairBatch.GetTgtTokens(0); for (int k = 0; k < batchSize; k++) { float golden_score_k = float.Parse(tgtSnts[k][0]); // Get golden similiary score from target side float score_k = probs.GetWeightAt(new long[] { k, 0 }); probs.SetWeightAt(score_k - golden_score_k, new long[] { k, 0 }); cost += (float)Math.Abs(score_k - golden_score_k); } probs.CopyWeightsToGradients(probs); nr.Cost = cost / batchSize; } else { nr.Output.Add(new List <List <string> >()); for (int k = 0; k < batchSize; k++) { float score_k = probs.GetWeightAt(new long[] { k, 0 }); nr.Output[0].Add(new List <string>()); nr.Output[0][k].Add(score_k.ToString()); } } } else { IWeightTensor encOutput = computeGraph.EltMul(encOutput1, encOutput2); IWeightTensor ffLayer = encoderFFLayer.Process(encOutput, batchSize, computeGraph); using (IWeightTensor probs = computeGraph.Softmax(ffLayer, runGradients: false, inPlace: true)) { if (isTraining) { var tgtSnts = sntPairBatch.GetTgtTokens(0); for (int k = 0; k < batchSize; k++) { int ix_targets_k_j = m_modelMetaData.ClsVocab.GetWordIndex(tgtSnts[k][0]); float score_k = probs.GetWeightAt(new long[] { k, ix_targets_k_j }); cost += (float)-Math.Log(score_k); probs.SetWeightAt(score_k - 1, new long[] { k, ix_targets_k_j }); } ffLayer.CopyWeightsToGradients(probs); nr.Cost = cost / batchSize; } else { // Output "i"th target word using var targetIdxTensor = computeGraph.Argmax(probs, 1); float[] targetIdx = targetIdxTensor.ToWeightArray(); List <string> targetWords = m_modelMetaData.ClsVocab.ConvertIdsToString(targetIdx.ToList()); nr.Output.Add(new List <List <string> >()); for (int k = 0; k < batchSize; k++) { nr.Output[0].Add(new List <string>()); nr.Output[0][k].Add(targetWords[k]); } } } } nrs.Add(nr); return(nrs); }
/// <summary> /// Run forward part on given single device /// </summary> /// <param name="computeGraph">The computing graph for current device. It gets created and passed by the framework</param> /// <param name="srcSnts">A batch of input tokenized sentences in source side</param> /// <param name="tgtSnts">A batch of output tokenized sentences in target side</param> /// <param name="deviceIdIdx">The index of current device</param> /// <returns>The cost of forward part</returns> public override List <NetworkResult> RunForwardOnSingleDevice(IComputeGraph computeGraph, ISntPairBatch sntPairBatch, int deviceIdIdx, bool isTraining, DecodingOptions decodingOptions) { (IEncoder encoder, IDecoder decoder, IFeedForwardLayer encoderFFLayer, IFeedForwardLayer decoderFFLayer, IWeightTensor srcEmbedding, IWeightTensor tgtEmbedding, IWeightTensor posEmbedding, IWeightTensor segmentEmbedding) = GetNetworksOnDeviceAt(deviceIdIdx); var srcSnts = sntPairBatch.GetSrcTokens(0); var originalSrcLengths = BuildInTokens.PadSentences(srcSnts); var srcTokensList = m_modelMetaData.SrcVocab.GetWordIndex(srcSnts); IWeightTensor encOutput = Encoder.Run(computeGraph, sntPairBatch, encoder, m_modelMetaData, m_shuffleType, srcEmbedding, posEmbedding, segmentEmbedding, srcTokensList, originalSrcLengths); List <NetworkResult> nrs = new List <NetworkResult>(); int srcSeqPaddedLen = srcSnts[0].Count; int batchSize = srcSnts.Count; float[] clsIdxs = new float[batchSize]; for (int i = 0; i < batchSize; i++) { for (int j = 0; j < srcSnts[i].Count; j++) { if (srcSnts[i][j] == BuildInTokens.CLS) { clsIdxs[i] = i * srcSeqPaddedLen + j; break; } } } IWeightTensor clsWeightTensor = computeGraph.IndexSelect(encOutput, clsIdxs); float cost = 0.0f; NetworkResult nrCLS = new NetworkResult { Output = new List <List <List <string> > >() }; IWeightTensor ffLayer = encoderFFLayer.Process(clsWeightTensor, batchSize, computeGraph); using (IWeightTensor probs = computeGraph.Softmax(ffLayer, runGradients: false, inPlace: true)) { if (isTraining) { var clsSnts = sntPairBatch.GetTgtTokens(0); for (int k = 0; k < batchSize; k++) { int ix_targets_k_j = m_modelMetaData.ClsVocab.GetWordIndex(clsSnts[k][0]); float score_k = probs.GetWeightAt(new long[] { k, ix_targets_k_j }); cost += (float)-Math.Log(score_k); probs.SetWeightAt(score_k - 1, new long[] { k, ix_targets_k_j }); } ffLayer.CopyWeightsToGradients(probs); nrCLS.Cost = cost / batchSize; } else { // Output "i"th target word using var targetIdxTensor = computeGraph.Argmax(probs, 1); float[] targetIdx = targetIdxTensor.ToWeightArray(); List <string> targetWords = m_modelMetaData.ClsVocab.ConvertIdsToString(targetIdx.ToList()); nrCLS.Output.Add(new List <List <string> >()); for (int k = 0; k < batchSize; k++) { nrCLS.Output[0].Add(new List <string>()); nrCLS.Output[0][k].Add(targetWords[k]); } } } // Reset networks decoder.Reset(computeGraph.GetWeightFactory(), srcSnts.Count); // Generate output decoder sentences var tgtSnts = sntPairBatch.GetTgtTokens(1); var tgtTokensList = m_modelMetaData.TgtVocab.GetWordIndex(tgtSnts); NetworkResult nr = new NetworkResult(); if (decoder is AttentionDecoder) { nr.Cost = Decoder.DecodeAttentionLSTM(tgtTokensList, computeGraph, encOutput, decoder as AttentionDecoder, decoderFFLayer, tgtEmbedding, m_modelMetaData.TgtVocab, srcSnts.Count, isTraining); nr.Output = new List <List <List <string> > > { m_modelMetaData.TgtVocab.ConvertIdsToString(tgtTokensList) }; } else { if (isTraining) { (var c, _) = Decoder.DecodeTransformer(tgtTokensList, computeGraph, encOutput, decoder as TransformerDecoder, decoderFFLayer, tgtEmbedding, posEmbedding, originalSrcLengths, m_modelMetaData.TgtVocab, m_shuffleType, m_options.DropoutRatio, null, isTraining); nr.Cost = c; nr.Output = null; } else { List <List <BeamSearchStatus> > beam2batchStatus = Decoder.InitBeamSearchStatusListList(batchSize, tgtTokensList); for (int i = 0; i < decodingOptions.MaxTgtSentLength; i++) { List <List <BeamSearchStatus> > batch2beam2seq = null; //(batch_size, beam_search_size) try { foreach (var batchStatus in beam2batchStatus) { var batch2tgtTokens = Decoder.ExtractBatchTokens(batchStatus); using var g = computeGraph.CreateSubGraph($"TransformerDecoder_Step_{i}"); (var cost2, var bssSeqList) = Decoder.DecodeTransformer(batch2tgtTokens, g, encOutput, decoder as TransformerDecoder, decoderFFLayer, tgtEmbedding, posEmbedding, originalSrcLengths, m_modelMetaData.TgtVocab, m_shuffleType, 0.0f, decodingOptions, isTraining, outputSentScore: decodingOptions.BeamSearchSize > 1, previousBeamSearchResults: batchStatus); bssSeqList = Decoder.SwapBeamAndBatch(bssSeqList); batch2beam2seq = Decoder.CombineBeamSearchResults(batch2beam2seq, bssSeqList); } } catch (OutOfMemoryException) { GC.Collect(); Logger.WriteLine(Logger.Level.warn, $"We have out of memory while generating '{i}th' tokens, so terminate decoding for current sequences."); break; } if (decodingOptions.BeamSearchSize > 1) { // Keep top N result and drop all others for (int k = 0; k < batchSize; k++) { batch2beam2seq[k] = BeamSearch.GetTopNBSS(batch2beam2seq[k], decodingOptions.BeamSearchSize); } } beam2batchStatus = Decoder.SwapBeamAndBatch(batch2beam2seq); if (Decoder.AreAllSentsCompleted(beam2batchStatus)) { break; } } nr.Cost = 0.0f; nr.Output = m_modelMetaData.TgtVocab.ExtractTokens(beam2batchStatus); } } nr.RemoveDuplicatedEOS(); nrs.Add(nrCLS); nrs.Add(nr); return(nrs); }