public void VisualizeNeuralNetwork(string visNNFilePath) { (IEncoder encoder, IDecoder decoder, IWeightTensor srcEmbedding, IWeightTensor tgtEmbedding) = GetNetworksOnDeviceAt(-1); // Build input sentence List <List <string> > inputSeqs = ParallelCorpus.ConstructInputTokens(null); int batchSize = inputSeqs.Count; IComputeGraph g = CreateComputGraph(m_defaultDeviceId, needBack: false, visNetwork: true); AttentionDecoder rnnDecoder = decoder as AttentionDecoder; encoder.Reset(g.GetWeightFactory(), batchSize); rnnDecoder.Reset(g.GetWeightFactory(), batchSize); // Run encoder IWeightTensor encodedWeightMatrix = Encode(g, inputSeqs, encoder, srcEmbedding, null, null); // Prepare for attention over encoder-decoder AttentionPreProcessResult attPreProcessResult = rnnDecoder.PreProcess(encodedWeightMatrix, batchSize, g); // Run decoder IWeightTensor x = g.PeekRow(tgtEmbedding, (int)SENTTAGS.START); IWeightTensor eOutput = rnnDecoder.Decode(x, attPreProcessResult, batchSize, g); IWeightTensor probs = g.Softmax(eOutput); g.VisualizeNeuralNetToFile(visNNFilePath); }
/// <summary> /// Given input sentence and generate output sentence by seq2seq model with beam search /// </summary> /// <param name="input"></param> /// <param name="beamSearchSize"></param> /// <param name="maxOutputLength"></param> /// <returns></returns> public List <List <string> > Predict(List <string> input, int beamSearchSize = 1, int maxOutputLength = 100) { (IEncoder encoder, IDecoder decoder, IWeightTensor srcEmbedding, IWeightTensor tgtEmbedding) = GetNetworksOnDeviceAt(-1); List <List <string> > inputSeqs = ParallelCorpus.ConstructInputTokens(input); int batchSize = 1; // For predict with beam search, we currently only supports one sentence per call IComputeGraph g = CreateComputGraph(m_defaultDeviceId, needBack: false); AttentionDecoder rnnDecoder = decoder as AttentionDecoder; encoder.Reset(g.GetWeightFactory(), batchSize); rnnDecoder.Reset(g.GetWeightFactory(), batchSize); // Construct beam search status list List <BeamSearchStatus> bssList = new List <BeamSearchStatus>(); BeamSearchStatus bss = new BeamSearchStatus(); bss.OutputIds.Add((int)SENTTAGS.START); bss.CTs = rnnDecoder.GetCTs(); bss.HTs = rnnDecoder.GetHTs(); bssList.Add(bss); IWeightTensor encodedWeightMatrix = Encode(g, inputSeqs, encoder, srcEmbedding, null, null); AttentionPreProcessResult attPreProcessResult = rnnDecoder.PreProcess(encodedWeightMatrix, batchSize, g); List <BeamSearchStatus> newBSSList = new List <BeamSearchStatus>(); bool finished = false; int outputLength = 0; while (finished == false && outputLength < maxOutputLength) { finished = true; for (int i = 0; i < bssList.Count; i++) { bss = bssList[i]; if (bss.OutputIds[bss.OutputIds.Count - 1] == (int)SENTTAGS.END) { newBSSList.Add(bss); } else if (bss.OutputIds.Count > maxOutputLength) { newBSSList.Add(bss); } else { finished = false; int ix_input = bss.OutputIds[bss.OutputIds.Count - 1]; rnnDecoder.SetCTs(bss.CTs); rnnDecoder.SetHTs(bss.HTs); IWeightTensor x = g.PeekRow(tgtEmbedding, ix_input); IWeightTensor eOutput = rnnDecoder.Decode(x, attPreProcessResult, batchSize, g); using (IWeightTensor probs = g.Softmax(eOutput)) { List <int> preds = probs.GetTopNMaxWeightIdx(beamSearchSize); for (int j = 0; j < preds.Count; j++) { BeamSearchStatus newBSS = new BeamSearchStatus(); newBSS.OutputIds.AddRange(bss.OutputIds); newBSS.OutputIds.Add(preds[j]); newBSS.CTs = rnnDecoder.GetCTs(); newBSS.HTs = rnnDecoder.GetHTs(); float score = probs.GetWeightAt(preds[j]); newBSS.Score = bss.Score; newBSS.Score += (float)(-Math.Log(score)); //var lengthPenalty = Math.Pow((5.0f + newBSS.OutputIds.Count) / 6, 0.6); //newBSS.Score /= (float)lengthPenalty; newBSSList.Add(newBSS); } } } } bssList = BeamSearch.GetTopNBSS(newBSSList, beamSearchSize); newBSSList.Clear(); outputLength++; } // Convert output target word ids to real string List <List <string> > results = new List <List <string> >(); for (int i = 0; i < bssList.Count; i++) { results.Add(m_modelMetaData.Vocab.ConvertTargetIdsToString(bssList[i].OutputIds)); } return(results); }
private void Reset(Encoder encoder, Encoder reversEncoder, AttentionDecoder decoder) { encoder.Reset(); reversEncoder.Reset(); decoder.Reset(); }
public List <string> Predict(List <string> input) { reversEncoder.Reset(); encoder.Reset(); decoder.Reset(); List <string> result = new List <string>(); #if MKL var G2 = new ComputeGraphMKL(false); #else var G2 = new ComputeGraph(false); #endif List <string> inputSeq = new List <string>(); // inputSeq.Add(m_START); inputSeq.AddRange(input); // inputSeq.Add(m_END); List <string> revseq = inputSeq.ToList(); revseq.Reverse(); List <WeightMatrix> forwardEncoded = new List <WeightMatrix>(); List <WeightMatrix> backwardEncoded = new List <WeightMatrix>(); List <WeightMatrix> encoded = new List <WeightMatrix>(); SparseWeightMatrix sparseInput = new SparseWeightMatrix(1, s_Embedding.Columns); Parallel.Invoke( () => { for (int i = 0; i < inputSeq.Count; i++) { int ix = (int)SENTTAGS.UNK; if (s_wordToIndex.ContainsKey(inputSeq[i]) == false) { Logger.WriteLine($"Unknow input word: {inputSeq[i]}"); } else { ix = s_wordToIndex[inputSeq[i]]; } var x2 = G2.PeekRow(s_Embedding, ix); var o = encoder.Encode(x2, G2); forwardEncoded.Add(o); sparseInput.AddWeight(0, ix, 1.0f); } }, () => { for (int i = 0; i < inputSeq.Count; i++) { int ix = (int)SENTTAGS.UNK; if (s_wordToIndex.ContainsKey(revseq[i]) == false) { Logger.WriteLine($"Unknow input word: {revseq[i]}"); } else { ix = s_wordToIndex[revseq[i]]; } var x2 = G2.PeekRow(s_Embedding, ix); var o = reversEncoder.Encode(x2, G2); backwardEncoded.Add(o); } }); backwardEncoded.Reverse(); for (int i = 0; i < inputSeq.Count; i++) { //encoded.Add(G2.concatColumns(forwardEncoded[i], backwardEncoded[i])); encoded.Add(G2.add(forwardEncoded[i], backwardEncoded[i])); } //if (UseDropout) //{ // for (int i = 0; i < encoded.Weight.Length; i++) // { // encoded.Weight[i] *= 0.2; // } //} var ix_input = (int)SENTTAGS.START; while (true) { var x = G2.PeekRow(t_Embedding, ix_input); var eOutput = decoder.Decode(sparseInput, x, encoded, G2); if (UseDropout) { for (int i = 0; i < eOutput.Weight.Length; i++) { eOutput.Weight[i] *= 0.2f; } } var o = G2.muladd(eOutput, this.Whd, this.bd); if (UseDropout) { for (int i = 0; i < o.Weight.Length; i++) { o.Weight[i] *= 0.2f; } } var probs = G2.SoftmaxWithCrossEntropy(o); var maxv = probs.Weight[0]; var maxi = 0; for (int i = 1; i < probs.Weight.Length; i++) { if (probs.Weight[i] > maxv) { maxv = probs.Weight[i]; maxi = i; } } var pred = maxi; if (pred == (int)SENTTAGS.END) { break; // END token predicted, break out } if (result.Count > max_word) { break; } // something is wrong var letter2 = m_UNK; if (t_indexToWord.ContainsKey(pred)) { letter2 = t_indexToWord[pred]; } result.Add(letter2); ix_input = pred; } return(result); }
private void Reset(IWeightFactory weightFactory, Encoder encoder, Encoder reversEncoder, AttentionDecoder decoder) { encoder.Reset(weightFactory); reversEncoder.Reset(weightFactory); decoder.Reset(weightFactory); }