Ejemplo n.º 1
0
 private void Reset(Encoder encoder, Encoder reversEncoder, AttentionDecoder decoder)
 {
     encoder.Reset();
     reversEncoder.Reset();
     decoder.Reset();
 }
Ejemplo n.º 2
0
        /// <summary>
        /// Encode source sentences and output encoded weights
        /// </summary>
        /// <param name="g"></param>
        /// <param name="inputSentences"></param>
        /// <param name="encoder"></param>
        /// <param name="reversEncoder"></param>
        /// <param name="Embedding"></param>
        /// <returns></returns>
        private IWeightMatrix Encode(IComputeGraph g, List <List <string> > inputSentences, Encoder encoder, Encoder reversEncoder, IWeightMatrix Embedding)
        {
            PadSentences(inputSentences);
            List <IWeightMatrix> forwardOutputs  = new List <IWeightMatrix>();
            List <IWeightMatrix> backwardOutputs = new List <IWeightMatrix>();

            int seqLen = inputSentences[0].Count;
            List <IWeightMatrix> forwardInput = new List <IWeightMatrix>();

            for (int i = 0; i < seqLen; i++)
            {
                for (int j = 0; j < inputSentences.Count; j++)
                {
                    var inputSentence = inputSentences[j];
                    int ix_source     = (int)SENTTAGS.UNK;
                    if (m_srcWordToIndex.ContainsKey(inputSentence[i]))
                    {
                        ix_source = m_srcWordToIndex[inputSentence[i]];
                    }
                    var x = g.PeekRow(Embedding, ix_source);
                    forwardInput.Add(x);
                }
            }

            var forwardInputsM = g.ConcatRows(forwardInput);
            List <IWeightMatrix> attResults = new List <IWeightMatrix>();

            for (int i = 0; i < seqLen; i++)
            {
                var emb_i = g.PeekRow(forwardInputsM, i * inputSentences.Count, inputSentences.Count);
                attResults.Add(emb_i);
            }

            for (int i = 0; i < seqLen; i++)
            {
                var eOutput = encoder.Encode(attResults[i], g);
                forwardOutputs.Add(eOutput);

                var eOutput2 = reversEncoder.Encode(attResults[seqLen - i - 1], g);
                backwardOutputs.Add(eOutput2);
            }

            backwardOutputs.Reverse();

            var encodedOutput = g.ConcatRowColumn(forwardOutputs, backwardOutputs);

            return(encodedOutput);
        }
Ejemplo n.º 3
0
        private IComputeGraph Encode(List <string> inputSentence, out float cost, out SparseWeightMatrix sWM, List <WeightMatrix> encoded, Encoder encoder, Encoder reversEncoder, WeightMatrix Embedding)
        {
            var reversSentence = inputSentence.ToList();

            reversSentence.Reverse();

#if MKL
            IComputeGraph g = new ComputeGraphMKL();
#else
            IComputeGraph g = new ComputeGraph();
#endif


            cost = 0.0f;
            SparseWeightMatrix  tmpSWM          = new SparseWeightMatrix(1, Embedding.Columns);
            List <WeightMatrix> forwardOutputs  = new List <WeightMatrix>();
            List <WeightMatrix> backwardOutputs = new List <WeightMatrix>();

            Parallel.Invoke(
                () =>
            {
                for (int i = 0; i < inputSentence.Count; i++)
                {
                    int ix_source = (int)SENTTAGS.UNK;

                    if (s_wordToIndex.ContainsKey(inputSentence[i]))
                    {
                        ix_source = s_wordToIndex[inputSentence[i]];
                    }
                    var x       = g.PeekRow(Embedding, ix_source);
                    var eOutput = encoder.Encode(x, g);
                    forwardOutputs.Add(eOutput);

                    tmpSWM.AddWeight(0, ix_source, 1.0f);
                }
            },
                () =>
            {
                for (int i = 0; i < inputSentence.Count; i++)
                {
                    int ix_source2 = (int)SENTTAGS.UNK;

                    if (s_wordToIndex.ContainsKey(reversSentence[i]))
                    {
                        ix_source2 = s_wordToIndex[reversSentence[i]];
                    }

                    var x2       = g.PeekRow(Embedding, ix_source2);
                    var eOutput2 = reversEncoder.Encode(x2, g);
                    backwardOutputs.Add(eOutput2);
                }
            });

            backwardOutputs.Reverse();

            for (int i = 0; i < inputSentence.Count; i++)
            {
                //encoded.Add(g.concatColumns(forwardOutputs[i], backwardOutputs[i]));
                encoded.Add(g.add(forwardOutputs[i], backwardOutputs[i]));
            }

            sWM = tmpSWM;

            return(g);
        }
Ejemplo n.º 4
0
 private void Reset(IWeightFactory weightFactory, Encoder encoder, Encoder reversEncoder, AttentionDecoder decoder)
 {
     encoder.Reset(weightFactory);
     reversEncoder.Reset(weightFactory);
     decoder.Reset(weightFactory);
 }