/// <summary> /// Run forward part on given single device /// </summary> /// <param name="g">The computing graph for current device. It gets created and passed by the framework</param> /// <param name="srcSnts">A batch of input tokenized sentences in source side</param> /// <param name="tgtSnts">A batch of output tokenized sentences in target side. In training mode, it inputs target tokens, otherwise, it outputs target tokens generated by decoder</param> /// <param name="deviceIdIdx">The index of current device</param> /// <returns>The cost of forward part</returns> public override List <NetworkResult> RunForwardOnSingleDevice(IComputeGraph g, ISntPairBatch sntPairBatch, int deviceIdIdx, bool isTraining, DecodingOptions decodingOptions) { List <NetworkResult> nrs = new List <NetworkResult>(); var srcSnts = sntPairBatch.GetSrcTokens(0); var tgtSnts = sntPairBatch.GetTgtTokens(0); (IEncoder encoder, IWeightTensor srcEmbedding, IWeightTensor posEmbedding, FeedForwardLayer decoderFFLayer) = GetNetworksOnDeviceAt(deviceIdIdx); // Reset networks encoder.Reset(g.GetWeightFactory(), srcSnts.Count); var originalSrcLengths = BuildInTokens.PadSentences(srcSnts); var srcTokensList = m_modelMetaData.SrcVocab.GetWordIndex(srcSnts); BuildInTokens.PadSentences(tgtSnts); var tgtTokensLists = m_modelMetaData.ClsVocab.GetWordIndex(tgtSnts); int seqLen = srcSnts[0].Count; int batchSize = srcSnts.Count; // Encoding input source sentences IWeightTensor encOutput = Encoder.Run(g, sntPairBatch, encoder, m_modelMetaData, m_shuffleType, srcEmbedding, posEmbedding, null, srcTokensList, originalSrcLengths); IWeightTensor ffLayer = decoderFFLayer.Process(encOutput, batchSize, g); float cost = 0.0f; IWeightTensor probs = g.Softmax(ffLayer, inPlace: true); if (isTraining) { var tgtTokensTensor = g.CreateTokensTensor(tgtTokensLists); cost = g.CrossEntropyLoss(probs, tgtTokensTensor); } else { // Output "i"th target word using var targetIdxTensor = g.Argmax(probs, 1); float[] targetIdx = targetIdxTensor.ToWeightArray(); List <string> targetWords = m_modelMetaData.ClsVocab.ConvertIdsToString(targetIdx.ToList()); for (int k = 0; k < batchSize; k++) { tgtSnts[k] = targetWords.GetRange(k * seqLen, seqLen); } } NetworkResult nr = new NetworkResult { Cost = cost, Output = new List <List <List <string> > >() }; nr.Output.Add(tgtSnts); nrs.Add(nr); return(nrs); }
public List <string> GetAllTokens(bool keepBuildInTokens = true) { if (keepBuildInTokens) { return(Items); } else { List <string> results = new List <string>(); foreach (var item in Items) { if (BuildInTokens.IsPreDefinedToken(item) == false) { results.Add(item); } } return(results); } }
public static IWeightTensor BuildTensorForSourceTokenGroupAt(IComputeGraph computeGraph, ISntPairBatch sntPairBatch, ShuffleEnums shuffleType, IEncoder encoder, IModel modelMetaData, IWeightTensor srcEmbedding, IWeightTensor posEmbedding, IWeightTensor segmentEmbedding, int groupId) { var contextTokens = InsertCLSToken(sntPairBatch.GetSrcTokens(groupId)); var originalSrcContextLength = BuildInTokens.PadSentences(contextTokens); var contextTokenIds = modelMetaData.SrcVocab.GetWordIndex(contextTokens); IWeightTensor encContextOutput = InnerRunner(computeGraph, contextTokenIds, originalSrcContextLength, shuffleType, encoder, modelMetaData, srcEmbedding, posEmbedding, segmentEmbedding); int contextPaddedLen = contextTokens[0].Count; float[] contextCLSIdxs = new float[sntPairBatch.BatchSize]; for (int j = 0; j < sntPairBatch.BatchSize; j++) { contextCLSIdxs[j] = j * contextPaddedLen; } IWeightTensor contextCLSOutput = computeGraph.IndexSelect(encContextOutput, contextCLSIdxs); return(contextCLSOutput); }
/// <summary> /// Load vocabulary from given files /// </summary> public Vocab(string vocabFilePath) { Logger.WriteLine("Loading vocabulary files..."); string[] vocab = File.ReadAllLines(vocabFilePath); CreateIndex(); //Build word index for both source and target sides int q = 3; foreach (string line in vocab) { string[] items = line.Split('\t'); string word = items[0]; if (BuildInTokens.IsPreDefinedToken(word) == false) { Items.Add(word); WordToIndex[word] = q; IndexToWord[q] = word; q++; } } }
/// <summary> /// Create input embedding from token embeddings, segment embeddings /// </summary> /// <param name="seqs"></param> /// <param name="g"></param> /// <param name="embeddingsTensor"></param> /// <param name="seqOriginalLengths"></param> /// <param name="segmentEmbedding"></param> /// <param name="vocab"></param> /// <returns>The embedding tensor. shape: (batchsize * seqLen, embedding_dim) </returns> public static IWeightTensor CreateTokensEmbeddings(List <List <int> > seqs, IComputeGraph g, IWeightTensor embeddingsTensor, IWeightTensor segmentEmbedding, Vocab vocab, float scaleFactor = 1.0f, bool enableTagEmbedding = false) { int batchSize = seqs.Count; int seqLen = seqs[0].Count; float[] idxs = new float[batchSize * seqLen]; float[] segIdxs = new float[batchSize * seqLen]; List <float[]> tagIdxsList = new List <float[]>(); //float[] tagIdxs = new float[batchSize * seqLen]; for (int i = 0; i < batchSize; i++) { int segIdx = 0; List <int> currTagIdxs = new List <int>(); int currTagLevel = 0; for (int j = 0; j < seqLen; j++) { idxs[i * seqLen + j] = seqs[i][j]; segIdxs[i * seqLen + j] = segIdx; string token = vocab.GetString(seqs[i][j]); if (token == BuildInTokens.SEP) { //A new segment segIdx++; } if (enableTagEmbedding) { if (token.StartsWith("<") && token.EndsWith(">") && BuildInTokens.IsPreDefinedToken(token) == false) { if (token[1] == '/') { currTagLevel--; currTagIdxs[currTagLevel] = -1; } else { //A new opening tag while (tagIdxsList.Count <= currTagLevel) { float[] tagIdxs = new float[batchSize * seqLen]; Array.Fill(tagIdxs, -1.0f); tagIdxsList.Add(tagIdxs); } while (currTagIdxs.Count <= currTagLevel) { currTagIdxs.Add(-1); } currTagIdxs[currTagLevel] = seqs[i][j]; currTagLevel++; } } else { for (int k = 0; k < currTagLevel; k++) { tagIdxsList[k][i * seqLen + j] = currTagIdxs[k]; //Logger.WriteLine($"Add tag embeddings: '{currTagIdxs[k]}'"); } } } } } IWeightTensor tagEmbeddings = null; if (enableTagEmbedding) { for (int k = 0; k < tagIdxsList.Count; k++) { var tagEmbeddings_k = g.IndexSelect(embeddingsTensor, tagIdxsList[k], clearWeights: true); if (tagEmbeddings == null) { tagEmbeddings = tagEmbeddings_k; } else { tagEmbeddings = g.Add(tagEmbeddings, tagEmbeddings_k); } } } IWeightTensor embeddingRst = g.IndexSelect(embeddingsTensor, idxs); if (scaleFactor != 1.0f) { embeddingRst = g.Mul(embeddingRst, scaleFactor, inPlace: true); } // Apply segment embeddings to the input sequence embeddings if (segmentEmbedding != null) { embeddingRst = g.Add(embeddingRst, g.IndexSelect(segmentEmbedding, segIdxs)); } if (tagEmbeddings != null) { embeddingRst = g.Add(embeddingRst, tagEmbeddings); } return(embeddingRst); }
private static void Main(string[] args) { ShowOptions(args); Logger.LogFile = $"{nameof(SeqLabelConsole)}_{Utils.GetTimeStamp(DateTime.Now)}.log"; //Parse command line SeqLabelOptions opts = new SeqLabelOptions(); ArgParser argParser = new ArgParser(args, opts); if (!opts.ConfigFilePath.IsNullOrEmpty()) { Logger.WriteLine($"Loading config file from '{opts.ConfigFilePath}'"); opts = JsonConvert.DeserializeObject<SeqLabelOptions>(File.ReadAllText(opts.ConfigFilePath)); } DecodingOptions decodingOptions = opts.CreateDecodingOptions(); SeqLabel sl = null; //Parse device ids from options int[] deviceIds = opts.DeviceIds.Split(',').Select(x => int.Parse(x)).ToArray(); if ( opts.Task == ModeEnums.Train ) { // Load train corpus SeqLabelingCorpus trainCorpus = new SeqLabelingCorpus(opts.TrainCorpusPath, opts.BatchSize, opts.ShuffleBlockSize, maxSentLength: opts.MaxTrainSentLength); // Load valid corpus List<SeqLabelingCorpus> validCorpusList = new List<SeqLabelingCorpus>(); if (!opts.ValidCorpusPaths.IsNullOrEmpty()) { string[] validCorpusPathList = opts.ValidCorpusPaths.Split(';'); foreach (var validCorpusPath in validCorpusPathList) { validCorpusList.Add(new SeqLabelingCorpus(opts.ValidCorpusPaths, opts.BatchSize, opts.ShuffleBlockSize, maxSentLength: opts.MaxTestSentLength)); } } // Load or build vocabulary Vocab srcVocab = null; Vocab tgtVocab = null; if (!opts.SrcVocab.IsNullOrEmpty() && !opts.TgtVocab.IsNullOrEmpty() ) { // Vocabulary files are specified, so we load them srcVocab = new Vocab(opts.SrcVocab); tgtVocab = new Vocab(opts.TgtVocab); } else { // We don't specify vocabulary, so we build it from train corpus (srcVocab, tgtVocab) = trainCorpus.BuildVocabs(opts.SrcVocabSize); } // Create learning rate ILearningRate learningRate = new DecayLearningRate(opts.StartLearningRate, opts.WarmUpSteps, opts.WeightsUpdateCount); // Create optimizer IOptimizer optimizer = Misc.CreateOptimizer(opts); // Create metrics List<IMetric> metrics = new List<IMetric>(); foreach (string word in tgtVocab.Items) { if (BuildInTokens.IsPreDefinedToken(word) == false) { metrics.Add(new SequenceLabelFscoreMetric(word)); } } if (File.Exists(opts.ModelFilePath) == false) { //New training sl = new SeqLabel(opts, srcVocab: srcVocab, clsVocab: tgtVocab); } else { //Incremental training Logger.WriteLine($"Loading model from '{opts.ModelFilePath}'..."); sl = new SeqLabel(opts); } // Add event handler for monitoring sl.StatusUpdateWatcher += Misc.Ss_StatusUpdateWatcher; // Kick off training sl.Train(maxTrainingEpoch: opts.MaxEpochNum, trainCorpus: trainCorpus, validCorpusList: validCorpusList.ToArray(), learningRate: learningRate, optimizer: optimizer, metrics: metrics, decodingOptions: decodingOptions); } else if ( opts.Task == ModeEnums.Valid ) { Logger.WriteLine($"Evaluate model '{opts.ModelFilePath}' by valid corpus '{opts.ValidCorpusPaths}'"); // Load valid corpus SeqLabelingCorpus validCorpus = new SeqLabelingCorpus(opts.ValidCorpusPaths, opts.BatchSize, opts.ShuffleBlockSize, opts.MaxTestSentLength); (Vocab srcVocab, Vocab tgtVocab) = validCorpus.BuildVocabs(); // Create metrics List<IMetric> metrics = new List<IMetric>(); foreach (string word in tgtVocab.Items) { if (BuildInTokens.IsPreDefinedToken(word) == false) { metrics.Add(new SequenceLabelFscoreMetric(word)); } } sl = new SeqLabel(opts); sl.Valid(validCorpus: validCorpus, metrics: metrics, decodingOptions: decodingOptions); } else if ( opts.Task == ModeEnums.Test ) { Logger.WriteLine($"Test model '{opts.ModelFilePath}' by input corpus '{opts.InputTestFile}'"); //Test trained model sl = new SeqLabel(opts); List<string> outputLines = new List<string>(); string[] data_sents_raw1 = File.ReadAllLines(opts.InputTestFile); foreach (string line in data_sents_raw1) { var nrs = sl.Test<SeqLabelingCorpusBatch>(ConstructInputTokens(line.Trim().Split(' ').ToList()), null, decodingOptions: decodingOptions); outputLines.AddRange(nrs[0].Output[0].Select(x => string.Join(" ", x))); } File.WriteAllLines(opts.OutputFile, outputLines); } else { argParser.Usage(); } }
/// <summary> /// Run forward part on given single device /// </summary> /// <param name="computeGraph">The computing graph for current device. It gets created and passed by the framework</param> /// <param name="srcSnts">A batch of input tokenized sentences in source side</param> /// <param name="tgtSnts">A batch of output tokenized sentences in target side</param> /// <param name="deviceIdIdx">The index of current device</param> /// <returns>The cost of forward part</returns> public override List <NetworkResult> RunForwardOnSingleDevice(IComputeGraph computeGraph, ISntPairBatch sntPairBatch, int deviceIdIdx, bool isTraining, DecodingOptions decodingOptions) { List <NetworkResult> nrs = new List <NetworkResult>(); (IEncoder encoder, IWeightTensor srcEmbedding, List <IFeedForwardLayer> encoderFFLayer, IWeightTensor posEmbedding, IWeightTensor segmentEmbedding) = GetNetworksOnDeviceAt(deviceIdIdx); var srcSnts = sntPairBatch.GetSrcTokens(0); var originalSrcLengths = BuildInTokens.PadSentences(srcSnts); var srcTokensList = m_modelMetaData.SrcVocab.GetWordIndex(srcSnts); IWeightTensor encOutput = Encoder.Run(computeGraph, sntPairBatch, encoder, m_modelMetaData, m_shuffleType, srcEmbedding, posEmbedding, segmentEmbedding, srcTokensList, originalSrcLengths); int srcSeqPaddedLen = srcSnts[0].Count; int batchSize = srcSnts.Count; float[] clsIdxs = new float[batchSize]; for (int i = 0; i < batchSize; i++) { for (int j = 0; j < srcSnts[i].Count; j++) { if (srcSnts[i][j] == BuildInTokens.CLS) { clsIdxs[i] = i * srcSeqPaddedLen + j; break; } } } IWeightTensor clsWeightTensor = computeGraph.IndexSelect(encOutput, clsIdxs); for (int i = 0; i < m_encoderFFLayer.Length; i++) { float cost = 0.0f; NetworkResult nr = new NetworkResult { Output = new List <List <List <string> > >() }; IWeightTensor ffLayer = encoderFFLayer[i].Process(clsWeightTensor, batchSize, computeGraph); using (IWeightTensor probs = computeGraph.Softmax(ffLayer, runGradients: false, inPlace: true)) { if (isTraining) { var tgtSnts = sntPairBatch.GetTgtTokens(i); for (int k = 0; k < batchSize; k++) { int ix_targets_k_j = m_modelMetaData.ClsVocabs[i].GetWordIndex(tgtSnts[k][0]); float score_k = probs.GetWeightAt(new long[] { k, ix_targets_k_j }); cost += (float)-Math.Log(score_k); probs.SetWeightAt(score_k - 1, new long[] { k, ix_targets_k_j }); } ffLayer.CopyWeightsToGradients(probs); nr.Cost = cost / batchSize; } else { // Output "i"th target word using var targetIdxTensor = computeGraph.Argmax(probs, 1); float[] targetIdx = targetIdxTensor.ToWeightArray(); List <string> targetWords = m_modelMetaData.ClsVocabs[i].ConvertIdsToString(targetIdx.ToList()); nr.Output.Add(new List <List <string> >()); for (int k = 0; k < batchSize; k++) { nr.Output[0].Add(new List <string>()); nr.Output[0][k].Add(targetWords[k]); } } } nrs.Add(nr); } return(nrs); }
/// <summary> /// Run forward part on given single device /// </summary> /// <param name="computeGraph">The computing graph for current device. It gets created and passed by the framework</param> /// <param name="srcSnts">A batch of input tokenized sentences in source side</param> /// <param name="tgtSnts">A batch of output tokenized sentences in target side</param> /// <param name="deviceIdIdx">The index of current device</param> /// <returns>The cost of forward part</returns> public override List <NetworkResult> RunForwardOnSingleDevice(IComputeGraph computeGraph, ISntPairBatch sntPairBatch, int deviceIdIdx, bool isTraining, DecodingOptions decodingOptions) { (var encoder, var decoder, var decoderFFLayer, var srcEmbedding, var tgtEmbedding, var posEmbedding, var segmentEmbedding, var pointerGenerator) = GetNetworksOnDeviceAt(deviceIdIdx); var srcSnts = sntPairBatch.GetSrcTokens(0); var originalSrcLengths = BuildInTokens.PadSentences(srcSnts); var srcTokensList = m_modelMetaData.SrcVocab.GetWordIndex(srcSnts); if (isTraining && srcSnts[0].Count > m_options.MaxTrainSrcSentLength + 2) { throw new InvalidDataException($"The source sentence is too long. Its length = '{srcSnts[0].Count}', but MaxTrainSrcSentLength is '{m_options.MaxTrainSrcSentLength}'. The sentence is '{string.Join(" ", srcSnts[0])}'"); } IWeightTensor encOutput; if (!isTraining && (m_options.ProcessorType == ProcessorTypeEnums.CPU)) { // Try to get src tensor from cache string cacheKey = GenerateCacheKey(srcSnts); if (!m_memoryCache.TryGetValue(cacheKey, out encOutput)) { encOutput = Encoder.Run(computeGraph, sntPairBatch, encoder, m_modelMetaData, m_shuffleType, srcEmbedding, posEmbedding, segmentEmbedding, srcTokensList, originalSrcLengths); var cacheEntryOptions = new MemoryCacheEntryOptions().SetSize(1); m_memoryCache.Set(cacheKey, encOutput.CopyWeightsRef($"cache_{encOutput.Name}", false), cacheEntryOptions); } } else { // Compute src tensor encOutput = Encoder.Run(computeGraph, sntPairBatch, encoder, m_modelMetaData, m_shuffleType, srcEmbedding, posEmbedding, segmentEmbedding, srcTokensList, originalSrcLengths); } List <NetworkResult> nrs = new List <NetworkResult>(); // Generate output decoder sentences int batchSize = srcSnts.Count; var tgtSnts = sntPairBatch.GetTgtTokens(0); var tgtTokensList = m_modelMetaData.TgtVocab.GetWordIndex(tgtSnts); NetworkResult nr = new NetworkResult(); decoder.Reset(computeGraph.GetWeightFactory(), srcSnts.Count); if (decoder is AttentionDecoder) { nr.Cost = Decoder.DecodeAttentionLSTM(tgtTokensList, computeGraph, encOutput, decoder as AttentionDecoder, decoderFFLayer, tgtEmbedding, m_modelMetaData.TgtVocab, srcSnts.Count, isTraining); nr.Output = new List <List <List <string> > > { m_modelMetaData.TgtVocab.ConvertIdsToString(tgtTokensList) }; } else { if (isTraining) { (var c, _) = Decoder.DecodeTransformer(tgtTokensList, computeGraph, encOutput, decoder as TransformerDecoder, decoderFFLayer, tgtEmbedding, posEmbedding, originalSrcLengths, m_modelMetaData.TgtVocab, m_shuffleType, m_options.DropoutRatio, null, isTraining, pointerGenerator: pointerGenerator, srcSeqs: srcTokensList); nr.Cost = c; nr.Output = null; } else { Dictionary <string, IWeightTensor> cachedTensors = new Dictionary <string, IWeightTensor>(); List <List <BeamSearchStatus> > beam2batchStatus = Decoder.InitBeamSearchStatusListList(batchSize, tgtTokensList); for (int i = tgtTokensList[0].Count; i < decodingOptions.MaxTgtSentLength; i++) { List <List <BeamSearchStatus> > batch2beam2seq = null; //(batch_size, beam_search_size) try { foreach (var batchStatus in beam2batchStatus) { var batch2tgtTokens = Decoder.ExtractBatchTokens(batchStatus); using var g = computeGraph.CreateSubGraph($"TransformerDecoder_Step_{i}"); (var cost2, var bssSeqList) = Decoder.DecodeTransformer(batch2tgtTokens, g, encOutput, decoder as TransformerDecoder, decoderFFLayer, tgtEmbedding, posEmbedding, originalSrcLengths, m_modelMetaData.TgtVocab, m_shuffleType, 0.0f, decodingOptions, isTraining, outputSentScore: decodingOptions.BeamSearchSize > 1, previousBeamSearchResults: batchStatus, pointerGenerator: pointerGenerator, srcSeqs: srcTokensList, cachedTensors: cachedTensors); bssSeqList = Decoder.SwapBeamAndBatch(bssSeqList); // Swap shape: (beam_search_size, batch_size) -> (batch_size, beam_search_size) batch2beam2seq = Decoder.CombineBeamSearchResults(batch2beam2seq, bssSeqList); } } catch (OutOfMemoryException) { GC.Collect(); Logger.WriteLine(Logger.Level.warn, $"We have out of memory while generating '{i}th' tokens, so terminate decoding for current sequences."); break; } if (decodingOptions.BeamSearchSize > 1) { // Keep top N result and drop all others for (int k = 0; k < batchSize; k++) { batch2beam2seq[k] = BeamSearch.GetTopNBSS(batch2beam2seq[k], decodingOptions.BeamSearchSize); } } beam2batchStatus = Decoder.SwapBeamAndBatch(batch2beam2seq); if (Decoder.AreAllSentsCompleted(beam2batchStatus)) { break; } } nr.Cost = 0.0f; nr.Output = m_modelMetaData.TgtVocab.ExtractTokens(beam2batchStatus); if (cachedTensors != null) { foreach (var pair in cachedTensors) { pair.Value.Dispose(); } } } } nr.RemoveDuplicatedEOS(); nrs.Add(nr); return(nrs); }
/// <summary> /// Run forward part on given single device /// </summary> /// <param name="computeGraph">The computing graph for current device. It gets created and passed by the framework</param> /// <param name="srcSnts">A batch of input tokenized sentences in source side</param> /// <param name="tgtSnts">A batch of output tokenized sentences in target side</param> /// <param name="deviceIdIdx">The index of current device</param> /// <returns>The cost of forward part</returns> public override List <NetworkResult> RunForwardOnSingleDevice(IComputeGraph computeGraph, ISntPairBatch sntPairBatch, int deviceIdIdx, bool isTraining, DecodingOptions decodingOptions) { (IEncoder encoder, IDecoder decoder, IFeedForwardLayer encoderFFLayer, IFeedForwardLayer decoderFFLayer, IWeightTensor srcEmbedding, IWeightTensor tgtEmbedding, IWeightTensor posEmbedding, IWeightTensor segmentEmbedding) = GetNetworksOnDeviceAt(deviceIdIdx); var srcSnts = sntPairBatch.GetSrcTokens(0); var originalSrcLengths = BuildInTokens.PadSentences(srcSnts); var srcTokensList = m_modelMetaData.SrcVocab.GetWordIndex(srcSnts); IWeightTensor encOutput = Encoder.Run(computeGraph, sntPairBatch, encoder, m_modelMetaData, m_shuffleType, srcEmbedding, posEmbedding, segmentEmbedding, srcTokensList, originalSrcLengths); List <NetworkResult> nrs = new List <NetworkResult>(); int srcSeqPaddedLen = srcSnts[0].Count; int batchSize = srcSnts.Count; float[] clsIdxs = new float[batchSize]; for (int i = 0; i < batchSize; i++) { for (int j = 0; j < srcSnts[i].Count; j++) { if (srcSnts[i][j] == BuildInTokens.CLS) { clsIdxs[i] = i * srcSeqPaddedLen + j; break; } } } IWeightTensor clsWeightTensor = computeGraph.IndexSelect(encOutput, clsIdxs); float cost = 0.0f; NetworkResult nrCLS = new NetworkResult { Output = new List <List <List <string> > >() }; IWeightTensor ffLayer = encoderFFLayer.Process(clsWeightTensor, batchSize, computeGraph); using (IWeightTensor probs = computeGraph.Softmax(ffLayer, runGradients: false, inPlace: true)) { if (isTraining) { var clsSnts = sntPairBatch.GetTgtTokens(0); for (int k = 0; k < batchSize; k++) { int ix_targets_k_j = m_modelMetaData.ClsVocab.GetWordIndex(clsSnts[k][0]); float score_k = probs.GetWeightAt(new long[] { k, ix_targets_k_j }); cost += (float)-Math.Log(score_k); probs.SetWeightAt(score_k - 1, new long[] { k, ix_targets_k_j }); } ffLayer.CopyWeightsToGradients(probs); nrCLS.Cost = cost / batchSize; } else { // Output "i"th target word using var targetIdxTensor = computeGraph.Argmax(probs, 1); float[] targetIdx = targetIdxTensor.ToWeightArray(); List <string> targetWords = m_modelMetaData.ClsVocab.ConvertIdsToString(targetIdx.ToList()); nrCLS.Output.Add(new List <List <string> >()); for (int k = 0; k < batchSize; k++) { nrCLS.Output[0].Add(new List <string>()); nrCLS.Output[0][k].Add(targetWords[k]); } } } // Reset networks decoder.Reset(computeGraph.GetWeightFactory(), srcSnts.Count); // Generate output decoder sentences var tgtSnts = sntPairBatch.GetTgtTokens(1); var tgtTokensList = m_modelMetaData.TgtVocab.GetWordIndex(tgtSnts); NetworkResult nr = new NetworkResult(); if (decoder is AttentionDecoder) { nr.Cost = Decoder.DecodeAttentionLSTM(tgtTokensList, computeGraph, encOutput, decoder as AttentionDecoder, decoderFFLayer, tgtEmbedding, m_modelMetaData.TgtVocab, srcSnts.Count, isTraining); nr.Output = new List <List <List <string> > > { m_modelMetaData.TgtVocab.ConvertIdsToString(tgtTokensList) }; } else { if (isTraining) { (var c, _) = Decoder.DecodeTransformer(tgtTokensList, computeGraph, encOutput, decoder as TransformerDecoder, decoderFFLayer, tgtEmbedding, posEmbedding, originalSrcLengths, m_modelMetaData.TgtVocab, m_shuffleType, m_options.DropoutRatio, null, isTraining); nr.Cost = c; nr.Output = null; } else { List <List <BeamSearchStatus> > beam2batchStatus = Decoder.InitBeamSearchStatusListList(batchSize, tgtTokensList); for (int i = 0; i < decodingOptions.MaxTgtSentLength; i++) { List <List <BeamSearchStatus> > batch2beam2seq = null; //(batch_size, beam_search_size) try { foreach (var batchStatus in beam2batchStatus) { var batch2tgtTokens = Decoder.ExtractBatchTokens(batchStatus); using var g = computeGraph.CreateSubGraph($"TransformerDecoder_Step_{i}"); (var cost2, var bssSeqList) = Decoder.DecodeTransformer(batch2tgtTokens, g, encOutput, decoder as TransformerDecoder, decoderFFLayer, tgtEmbedding, posEmbedding, originalSrcLengths, m_modelMetaData.TgtVocab, m_shuffleType, 0.0f, decodingOptions, isTraining, outputSentScore: decodingOptions.BeamSearchSize > 1, previousBeamSearchResults: batchStatus); bssSeqList = Decoder.SwapBeamAndBatch(bssSeqList); batch2beam2seq = Decoder.CombineBeamSearchResults(batch2beam2seq, bssSeqList); } } catch (OutOfMemoryException) { GC.Collect(); Logger.WriteLine(Logger.Level.warn, $"We have out of memory while generating '{i}th' tokens, so terminate decoding for current sequences."); break; } if (decodingOptions.BeamSearchSize > 1) { // Keep top N result and drop all others for (int k = 0; k < batchSize; k++) { batch2beam2seq[k] = BeamSearch.GetTopNBSS(batch2beam2seq[k], decodingOptions.BeamSearchSize); } } beam2batchStatus = Decoder.SwapBeamAndBatch(batch2beam2seq); if (Decoder.AreAllSentsCompleted(beam2batchStatus)) { break; } } nr.Cost = 0.0f; nr.Output = m_modelMetaData.TgtVocab.ExtractTokens(beam2batchStatus); } } nr.RemoveDuplicatedEOS(); nrs.Add(nrCLS); nrs.Add(nr); return(nrs); }