/// <summary> /// Given input sentence and generate output sentence by seq2seq model with beam search /// </summary> /// <param name="input"></param> /// <param name="beamSearchSize"></param> /// <param name="maxOutputLength"></param> /// <returns></returns> public List <List <string> > Predict(List <string> input, int beamSearchSize = 1, int maxOutputLength = 100) { (IEncoder encoder, IDecoder decoder, IWeightTensor srcEmbedding, IWeightTensor tgtEmbedding) = GetNetworksOnDeviceAt(-1); List <List <string> > inputSeqs = ParallelCorpus.ConstructInputTokens(input); int batchSize = 1; // For predict with beam search, we currently only supports one sentence per call IComputeGraph g = CreateComputGraph(m_defaultDeviceId, needBack: false); AttentionDecoder rnnDecoder = decoder as AttentionDecoder; encoder.Reset(g.GetWeightFactory(), batchSize); rnnDecoder.Reset(g.GetWeightFactory(), batchSize); // Construct beam search status list List <BeamSearchStatus> bssList = new List <BeamSearchStatus>(); BeamSearchStatus bss = new BeamSearchStatus(); bss.OutputIds.Add((int)SENTTAGS.START); bss.CTs = rnnDecoder.GetCTs(); bss.HTs = rnnDecoder.GetHTs(); bssList.Add(bss); IWeightTensor encodedWeightMatrix = Encode(g, inputSeqs, encoder, srcEmbedding, null, null); AttentionPreProcessResult attPreProcessResult = rnnDecoder.PreProcess(encodedWeightMatrix, batchSize, g); List <BeamSearchStatus> newBSSList = new List <BeamSearchStatus>(); bool finished = false; int outputLength = 0; while (finished == false && outputLength < maxOutputLength) { finished = true; for (int i = 0; i < bssList.Count; i++) { bss = bssList[i]; if (bss.OutputIds[bss.OutputIds.Count - 1] == (int)SENTTAGS.END) { newBSSList.Add(bss); } else if (bss.OutputIds.Count > maxOutputLength) { newBSSList.Add(bss); } else { finished = false; int ix_input = bss.OutputIds[bss.OutputIds.Count - 1]; rnnDecoder.SetCTs(bss.CTs); rnnDecoder.SetHTs(bss.HTs); IWeightTensor x = g.PeekRow(tgtEmbedding, ix_input); IWeightTensor eOutput = rnnDecoder.Decode(x, attPreProcessResult, batchSize, g); using (IWeightTensor probs = g.Softmax(eOutput)) { List <int> preds = probs.GetTopNMaxWeightIdx(beamSearchSize); for (int j = 0; j < preds.Count; j++) { BeamSearchStatus newBSS = new BeamSearchStatus(); newBSS.OutputIds.AddRange(bss.OutputIds); newBSS.OutputIds.Add(preds[j]); newBSS.CTs = rnnDecoder.GetCTs(); newBSS.HTs = rnnDecoder.GetHTs(); float score = probs.GetWeightAt(preds[j]); newBSS.Score = bss.Score; newBSS.Score += (float)(-Math.Log(score)); //var lengthPenalty = Math.Pow((5.0f + newBSS.OutputIds.Count) / 6, 0.6); //newBSS.Score /= (float)lengthPenalty; newBSSList.Add(newBSS); } } } } bssList = BeamSearch.GetTopNBSS(newBSSList, beamSearchSize); newBSSList.Clear(); outputLength++; } // Convert output target word ids to real string List <List <string> > results = new List <List <string> >(); for (int i = 0; i < bssList.Count; i++) { results.Add(m_modelMetaData.Vocab.ConvertTargetIdsToString(bssList[i].OutputIds)); } return(results); }