コード例 #1
0
        public void VisualizeNeuralNetwork(string visNNFilePath)
        {
            (IEncoder encoder, IDecoder decoder, IWeightTensor srcEmbedding, IWeightTensor tgtEmbedding) = GetNetworksOnDeviceAt(-1);
            // Build input sentence
            List <List <string> > inputSeqs = ParallelCorpus.ConstructInputTokens(null);
            int              batchSize      = inputSeqs.Count;
            IComputeGraph    g          = CreateComputGraph(m_defaultDeviceId, needBack: false, visNetwork: true);
            AttentionDecoder rnnDecoder = decoder as AttentionDecoder;

            encoder.Reset(g.GetWeightFactory(), batchSize);
            rnnDecoder.Reset(g.GetWeightFactory(), batchSize);

            // Run encoder
            IWeightTensor encodedWeightMatrix = Encode(g, inputSeqs, encoder, srcEmbedding, null, null);

            // Prepare for attention over encoder-decoder
            AttentionPreProcessResult attPreProcessResult = rnnDecoder.PreProcess(encodedWeightMatrix, batchSize, g);

            // Run decoder
            IWeightTensor x       = g.PeekRow(tgtEmbedding, (int)SENTTAGS.START);
            IWeightTensor eOutput = rnnDecoder.Decode(x, attPreProcessResult, batchSize, g);
            IWeightTensor probs   = g.Softmax(eOutput);

            g.VisualizeNeuralNetToFile(visNNFilePath);
        }
コード例 #2
0
        /// <summary>
        /// Given input sentence and generate output sentence by seq2seq model with beam search
        /// </summary>
        /// <param name="input"></param>
        /// <param name="beamSearchSize"></param>
        /// <param name="maxOutputLength"></param>
        /// <returns></returns>
        public List <List <string> > Predict(List <string> input, int beamSearchSize = 1, int maxOutputLength = 100)
        {
            (IEncoder encoder, IDecoder decoder, IWeightTensor srcEmbedding, IWeightTensor tgtEmbedding) = GetNetworksOnDeviceAt(-1);
            List <List <string> > inputSeqs = ParallelCorpus.ConstructInputTokens(input);
            int batchSize = 1; // For predict with beam search, we currently only supports one sentence per call

            IComputeGraph    g          = CreateComputGraph(m_defaultDeviceId, needBack: false);
            AttentionDecoder rnnDecoder = decoder as AttentionDecoder;

            encoder.Reset(g.GetWeightFactory(), batchSize);
            rnnDecoder.Reset(g.GetWeightFactory(), batchSize);

            // Construct beam search status list
            List <BeamSearchStatus> bssList = new List <BeamSearchStatus>();
            BeamSearchStatus        bss     = new BeamSearchStatus();

            bss.OutputIds.Add((int)SENTTAGS.START);
            bss.CTs = rnnDecoder.GetCTs();
            bss.HTs = rnnDecoder.GetHTs();
            bssList.Add(bss);

            IWeightTensor             encodedWeightMatrix = Encode(g, inputSeqs, encoder, srcEmbedding, null, null);
            AttentionPreProcessResult attPreProcessResult = rnnDecoder.PreProcess(encodedWeightMatrix, batchSize, g);

            List <BeamSearchStatus> newBSSList = new List <BeamSearchStatus>();
            bool finished     = false;
            int  outputLength = 0;

            while (finished == false && outputLength < maxOutputLength)
            {
                finished = true;
                for (int i = 0; i < bssList.Count; i++)
                {
                    bss = bssList[i];
                    if (bss.OutputIds[bss.OutputIds.Count - 1] == (int)SENTTAGS.END)
                    {
                        newBSSList.Add(bss);
                    }
                    else if (bss.OutputIds.Count > maxOutputLength)
                    {
                        newBSSList.Add(bss);
                    }
                    else
                    {
                        finished = false;
                        int ix_input = bss.OutputIds[bss.OutputIds.Count - 1];
                        rnnDecoder.SetCTs(bss.CTs);
                        rnnDecoder.SetHTs(bss.HTs);

                        IWeightTensor x       = g.PeekRow(tgtEmbedding, ix_input);
                        IWeightTensor eOutput = rnnDecoder.Decode(x, attPreProcessResult, batchSize, g);
                        using (IWeightTensor probs = g.Softmax(eOutput))
                        {
                            List <int> preds = probs.GetTopNMaxWeightIdx(beamSearchSize);
                            for (int j = 0; j < preds.Count; j++)
                            {
                                BeamSearchStatus newBSS = new BeamSearchStatus();
                                newBSS.OutputIds.AddRange(bss.OutputIds);
                                newBSS.OutputIds.Add(preds[j]);

                                newBSS.CTs = rnnDecoder.GetCTs();
                                newBSS.HTs = rnnDecoder.GetHTs();

                                float score = probs.GetWeightAt(preds[j]);
                                newBSS.Score  = bss.Score;
                                newBSS.Score += (float)(-Math.Log(score));

                                //var lengthPenalty = Math.Pow((5.0f + newBSS.OutputIds.Count) / 6, 0.6);
                                //newBSS.Score /= (float)lengthPenalty;

                                newBSSList.Add(newBSS);
                            }
                        }
                    }
                }

                bssList = BeamSearch.GetTopNBSS(newBSSList, beamSearchSize);
                newBSSList.Clear();

                outputLength++;
            }

            // Convert output target word ids to real string
            List <List <string> > results = new List <List <string> >();

            for (int i = 0; i < bssList.Count; i++)
            {
                results.Add(m_modelMetaData.Vocab.ConvertTargetIdsToString(bssList[i].OutputIds));
            }

            return(results);
        }
コード例 #3
0
 private void Reset(Encoder encoder, Encoder reversEncoder, AttentionDecoder decoder)
 {
     encoder.Reset();
     reversEncoder.Reset();
     decoder.Reset();
 }
コード例 #4
0
        public List <string> Predict(List <string> input)
        {
            reversEncoder.Reset();
            encoder.Reset();
            decoder.Reset();

            List <string> result = new List <string>();

#if MKL
            var G2 = new ComputeGraphMKL(false);
#else
            var G2 = new ComputeGraph(false);
#endif

            List <string> inputSeq = new List <string>();
            //    inputSeq.Add(m_START);
            inputSeq.AddRange(input);
            //    inputSeq.Add(m_END);

            List <string> revseq = inputSeq.ToList();
            revseq.Reverse();

            List <WeightMatrix> forwardEncoded  = new List <WeightMatrix>();
            List <WeightMatrix> backwardEncoded = new List <WeightMatrix>();
            List <WeightMatrix> encoded         = new List <WeightMatrix>();
            SparseWeightMatrix  sparseInput     = new SparseWeightMatrix(1, s_Embedding.Columns);

            Parallel.Invoke(
                () =>
            {
                for (int i = 0; i < inputSeq.Count; i++)
                {
                    int ix = (int)SENTTAGS.UNK;
                    if (s_wordToIndex.ContainsKey(inputSeq[i]) == false)
                    {
                        Logger.WriteLine($"Unknow input word: {inputSeq[i]}");
                    }
                    else
                    {
                        ix = s_wordToIndex[inputSeq[i]];
                    }

                    var x2 = G2.PeekRow(s_Embedding, ix);
                    var o  = encoder.Encode(x2, G2);
                    forwardEncoded.Add(o);

                    sparseInput.AddWeight(0, ix, 1.0f);
                }
            },
                () =>
            {
                for (int i = 0; i < inputSeq.Count; i++)
                {
                    int ix = (int)SENTTAGS.UNK;
                    if (s_wordToIndex.ContainsKey(revseq[i]) == false)
                    {
                        Logger.WriteLine($"Unknow input word: {revseq[i]}");
                    }
                    else
                    {
                        ix = s_wordToIndex[revseq[i]];
                    }

                    var x2 = G2.PeekRow(s_Embedding, ix);
                    var o  = reversEncoder.Encode(x2, G2);
                    backwardEncoded.Add(o);
                }
            });

            backwardEncoded.Reverse();
            for (int i = 0; i < inputSeq.Count; i++)
            {
                //encoded.Add(G2.concatColumns(forwardEncoded[i], backwardEncoded[i]));
                encoded.Add(G2.add(forwardEncoded[i], backwardEncoded[i]));
            }

            //if (UseDropout)
            //{
            //    for (int i = 0; i < encoded.Weight.Length; i++)
            //    {
            //        encoded.Weight[i] *= 0.2;
            //    }
            //}
            var ix_input = (int)SENTTAGS.START;
            while (true)
            {
                var x       = G2.PeekRow(t_Embedding, ix_input);
                var eOutput = decoder.Decode(sparseInput, x, encoded, G2);
                if (UseDropout)
                {
                    for (int i = 0; i < eOutput.Weight.Length; i++)
                    {
                        eOutput.Weight[i] *= 0.2f;
                    }
                }
                var o = G2.muladd(eOutput, this.Whd, this.bd);
                if (UseDropout)
                {
                    for (int i = 0; i < o.Weight.Length; i++)
                    {
                        o.Weight[i] *= 0.2f;
                    }
                }
                var probs = G2.SoftmaxWithCrossEntropy(o);
                var maxv  = probs.Weight[0];
                var maxi  = 0;
                for (int i = 1; i < probs.Weight.Length; i++)
                {
                    if (probs.Weight[i] > maxv)
                    {
                        maxv = probs.Weight[i];
                        maxi = i;
                    }
                }
                var pred = maxi;

                if (pred == (int)SENTTAGS.END)
                {
                    break;                            // END token predicted, break out
                }
                if (result.Count > max_word)
                {
                    break;
                }                                       // something is wrong

                var letter2 = m_UNK;
                if (t_indexToWord.ContainsKey(pred))
                {
                    letter2 = t_indexToWord[pred];
                }

                result.Add(letter2);
                ix_input = pred;
            }

            return(result);
        }
コード例 #5
0
 private void Reset(IWeightFactory weightFactory, Encoder encoder, Encoder reversEncoder, AttentionDecoder decoder)
 {
     encoder.Reset(weightFactory);
     reversEncoder.Reset(weightFactory);
     decoder.Reset(weightFactory);
 }