Esempio n. 1
0
        /// <summary>
        /// 编码
        /// </summary>
        /// <param name="sentIndex">对话索引</param>
        /// <param name="OutputSentence">输出的样句</param>
        /// <param name="g"></param>
        /// <param name="cost"></param>
        /// <param name="encoded"></param>
        private void Encode(int sentIndex, out List <string> OutputSentence, out ComputeGraph g, out double cost, ref List <WeightMatrix> encoded)
        {
            //var sentIndex = r.Next(0, InputSequences.Count);
            var inputSentence  = InputSequences[sentIndex];
            var reversSentence = InputSequences[sentIndex].ToList();

            reversSentence.Reverse();
            OutputSentence = OutputSequences[sentIndex];
            g = new ComputeGraph();

            cost = 0.0;
            for (int i = 0; i < inputSentence.Count; i++)
            {
                int ix_source  = wordToIndex[inputSentence[i]].i;  //顺
                int ix_source2 = wordToIndex[reversSentence[i]].i; //逆
                var x          = g.PeekRow(Embedding, ix_source);  //查询指定行数据
                var eOutput    = encoder.Encode(x, g);
                var x2         = g.PeekRow(Embedding, ix_source2);
                var eOutput2   = ReversEncoder.Encode(x2, g);
                encoded.Add(g.concatColumns(eOutput, eOutput2));
            }


            //if (UseDropout)
            //{
            //    encoded = g.Dropout(encoded, 0.2);
            //}
        }
Esempio n. 2
0
        public WeightMatrix Perform(WeightMatrix input, WeightMatrix state, ComputeGraph g)
        {
            WeightMatrix        context;
            List <WeightMatrix> atten = new List <WeightMatrix>();

            var stateRepeat = g.RepeatRows(state, input.Rows);
            var baiseInput  = new WeightMatrix(input.Rows, 1, 1);
            var inputb      = g.concatColumns(input, baiseInput);


            var uh = g.mul(inputb, Ua);


            baiseInput  = new WeightMatrix(stateRepeat.Rows, 1, 1);
            stateRepeat = g.concatColumns(stateRepeat, baiseInput);


            var wc = g.mul(stateRepeat, Wa);
            var gg = g.tanh(g.add(uh, wc));
            var aa = g.mul(gg, V);


            var res = g.Softmax(aa);


            var weighted = g.weightRows(input, res);;

            context = g.sumColumns(weighted);

            return(context);
        }
Esempio n. 3
0
 /// <summary>
 ///
 /// </summary>
 /// <param name="V"></param>
 /// <param name="g"></param>
 /// <returns></returns>
 public WeightMatrix Encode(WeightMatrix V, ComputeGraph g)
 {
     foreach (var encoder in encoders)
     {
         var e = encoder.Step(V, g);
         V = e;
     }
     return(V);
 }
Esempio n. 4
0
        /// <summary>
        ///
        /// </summary>
        /// <param name="input"></param>
        /// <param name="innerGraph"></param>
        /// <returns></returns>
        public WeightMatrix Step(WeightMatrix input, ComputeGraph innerGraph)
        {
            var hidden_prev = ht;
            var cell_prev   = ct;

            var cell       = this;
            var h0         = innerGraph.mul(input, cell.Wix);
            var h1         = innerGraph.mul(hidden_prev, cell.Wih);
            var input_gate = innerGraph.sigmoid(
                innerGraph.add(
                    innerGraph.add(h0, h1),
                    cell.bi
                    )
                );

            var h2          = innerGraph.mul(input, cell.Wfx);
            var h3          = innerGraph.mul(hidden_prev, cell.Wfh);
            var forget_gate = innerGraph.sigmoid(
                innerGraph.add(
                    innerGraph.add(h2, h3),
                    cell.bf
                    )
                );

            var h4          = innerGraph.mul(input, cell.Wox);
            var h5          = innerGraph.mul(hidden_prev, cell.Woh);
            var output_gate = innerGraph.sigmoid(
                innerGraph.add(
                    innerGraph.add(h4, h5),
                    cell.bo
                    )
                );

            var h6         = innerGraph.mul(input, cell.Wcx);
            var h7         = innerGraph.mul(hidden_prev, cell.Wch);
            var cell_write = innerGraph.tanh(
                innerGraph.add(
                    innerGraph.add(h6, h7),
                    cell.bc
                    )
                );

            // compute new cell activation
            var retain_cell = innerGraph.eltmul(forget_gate, cell_prev); // what do we keep from cell
            var write_cell  = innerGraph.eltmul(input_gate, cell_write); // what do we write to cell
            var cell_d      = innerGraph.add(retain_cell, write_cell);   // new cell contents

            // compute hidden state as gated, saturated cell activations
            var hidden_d = innerGraph.eltmul(output_gate, innerGraph.tanh(cell_d));

            this.ht = hidden_d;
            this.ct = cell_d;
            return(ht);
        }
Esempio n. 5
0
        public WeightMatrix Decode(WeightMatrix input, ComputeGraph g)
        {
            var V = new WeightMatrix();

            foreach (var encoder in decoders)
            {
                var e = encoder.Step(input, g);
                V = e;
            }

            return(V);
        }
Esempio n. 6
0
        /// <summary>
        ///
        /// </summary>
        /// <param name="V"></param>
        /// <param name="g"></param>
        /// <returns></returns>
        public List <WeightMatrix> Encode2(WeightMatrix V, ComputeGraph g)
        {
            List <WeightMatrix> res = new List <WeightMatrix>();

            foreach (var encoder in encoders)
            {
                var e = encoder.Step(V, g);
                V = e;
                res.Add(e);
            }
            return(res);
        }
Esempio n. 7
0
        public WeightMatrix Decode(WeightMatrix input, WeightMatrix encoderOutput, ComputeGraph g)
        {
            var V          = input;
            var lastStatus = this.decoders.FirstOrDefault().ct;
            var context    = Attention.Perform(encoderOutput, lastStatus, g);

            foreach (var encoder in decoders)
            {
                var e = encoder.Step(context, V, g);
                V = e;
            }

            return(V);
        }
Esempio n. 8
0
        /// <summary>
        ///
        /// </summary>
        /// <param name="OutputSentence"></param>
        /// <param name="g"></param>
        /// <param name="cost"></param>
        /// <param name="encoded"></param>
        /// <returns></returns>
        private double DecodeOutput(List <string> OutputSentence, ComputeGraph g, double cost, List <WeightMatrix> encoded)
        {
            int ix_input = 1;

            for (int i = 0; i < OutputSentence.Count + 1; i++)
            {
                int ix_target = 0;
                if (i == OutputSentence.Count)
                {
                    ix_target = 0;
                }
                else
                {
                    ix_target = wordToIndex[OutputSentence[i]].i;
                }

                var x       = g.PeekRow(Embedding, ix_input);
                var eOutput = decoder.Decode(x, encoded, g);
                if (UseDropout)
                {
                    eOutput = g.Dropout(eOutput, 0.2);
                }
                var o = g.add(
                    g.mul(eOutput, this.Whd), this.bd);
                if (UseDropout)
                {
                    o = g.Dropout(o, 0.2);
                }

                var probs = g.SoftmaxWithCrossEntropy(o);
                cost += -Math.Log(probs.Weight[ix_target]);

                o.Gradient             = probs.Weight;
                o.Gradient[ix_target] -= 1;
                ix_input = ix_target;
            }
            return(cost);
        }
Esempio n. 9
0
        public WeightMatrix Perform(List <WeightMatrix> input, WeightMatrix state, ComputeGraph g)
        {
            WeightMatrix        context;
            List <WeightMatrix> atten = new List <WeightMatrix>();

            foreach (var h_j in input)
            {
                var uh = g.add(g.mul(h_j, Ua), bUa);
                var wc = g.add(g.mul(state, Wa), bWa);
                var gg = g.tanh(g.add(uh, wc));
                var aa = g.mul(gg, V);
                atten.Add(aa);
            }
            var res = g.Softmax(atten);


            var cmax   = res[0].Weight[0];
            int maxAtt = 0;

            for (int i = 1; i < res.Count; i++)
            {
                if (res[i].Weight[0] > cmax)
                {
                    cmax   = res[i].Weight[0];
                    maxAtt = i;
                }
            }
            this.MaxIndex = maxAtt;


            context = g.scalemul(input[0], res[0]);
            for (int hj = 1; hj < input.Count; hj++)
            {
                context = g.add(context, g.scalemul(input[hj], res[hj]));
            }
            return(context);
        }
Esempio n. 10
0
        /// <summary>
        /// 预测
        /// </summary>
        /// <param name="inputSeq">输入分词过的语句</param>
        /// <returns></returns>
        public ExecuteResult <List <string> > Predict(List <string> inputSeq)
        {
            ExecuteResult <List <string> > eresult = new ExecuteResult <List <string> >();

            Reset();
            List <string> result = new List <string>();
            var           G2     = new ComputeGraph(false);
            //反序组
            List <string> revseq = inputSeq.ToList();

            revseq.Reverse();
            List <WeightMatrix> encoded = new List <WeightMatrix>();

            //
            //Console.WriteLine($"keys>{string.Join(",", wordToIndex.Keys.ToArray())}");
            for (int i = 0; i < inputSeq.Count; i++)
            {
                //索引
                if (!wordToIndex.ContainsKey(inputSeq[i]))
                {
                    return(eresult.SetFail($"抱歉,未能理解 \"{inputSeq[i]}\"  的含义, 请重新训练我吧!"));
                    //return $"抱歉,未能理解 \"{inputSeq[i]}\"  的含义, 请重新训练我吧!".Split(' ').ToList();
                    //return $"I'm sorry, I can't understand \"{inputSeq[i]}\"  the meaning of the word, please you to retrain me!".Split(' ').ToList();
                }

                if (!wordToIndex.ContainsKey(revseq[i]))
                {
                    return(eresult.SetFail($"抱歉,未能理解 \"{revseq[i]}\"  的含义, 请重新训练我吧!"));
                    //return $"抱歉,未能理解 \"{inputSeq[i]}\"  的含义, 请重新训练我吧!".Split(' ').ToList();
                    //return $"I'm sorry, I can't understand \"{revseq[i]}\"  the meaning of the word, please you to retrain me!".Split(' ').ToList();
                }

                int ix  = wordToIndex[inputSeq[i]].i;
                int ix2 = wordToIndex[revseq[i]].i;

                var x2       = G2.PeekRow(Embedding, ix);
                var o        = encoder.Encode(x2, G2);
                var x3       = G2.PeekRow(Embedding, ix2);
                var eOutput2 = ReversEncoder.Encode(x3, G2);
                var d        = G2.concatColumns(o, eOutput2);
                encoded.Add(d);
            }

            //if (UseDropout)
            //{
            //    for (int i = 0; i < encoded.Weight.Length; i++)
            //    {
            //        encoded.Weight[i] *= 0.2;
            //    }
            //}
            var ix_input = 1;

            while (true)
            {
                var x       = G2.PeekRow(Embedding, ix_input);
                var eOutput = decoder.Decode(x, encoded, G2);
                if (UseDropout)
                {
                    for (int i = 0; i < eOutput.Weight.Length; i++)
                    {
                        eOutput.Weight[i] *= 0.2;
                    }
                }
                var o = G2.add(
                    G2.mul(eOutput, this.Whd), this.bd);
                if (UseDropout)
                {
                    for (int i = 0; i < o.Weight.Length; i++)
                    {
                        o.Weight[i] *= 0.2;
                    }
                }
                var probs = G2.SoftmaxWithCrossEntropy(o);
                var maxv  = probs.Weight[0];
                var maxi  = 0;
                for (int i = 1; i < probs.Weight.Length; i++)
                {
                    if (probs.Weight[i] > maxv)
                    {
                        maxv = probs.Weight[i];
                        maxi = i;
                    }
                }
                var pred = maxi;

                if (pred == 0)
                {
                    break;            // END token predicted, break out
                }
                if (result.Count > max_word)
                {
                    break;
                }                                       // something is wrong
                var letter2 = indexToWord[pred].w;
                result.Add(letter2);
                ix_input = pred;
            }

            return(eresult.SetData(result).SetOk());
        }
Esempio n. 11
0
 public WeightMatrix Decode(WeightMatrix input, WeightMatrix encoderOutput, ComputeGraph g)
 {
     return((WeightMatrix)base.Decode((Seq2SeqLearn.WeightMatrix)input, (Seq2SeqLearn.WeightMatrix)encoderOutput, (Seq2SeqLearn.ComputeGraph)g));
 }
Esempio n. 12
0
 public WeightMatrix Decode(WeightMatrix input, List <WeightMatrix> encoderOutput, ComputeGraph g)
 {
     return((WeightMatrix)base.Decode((Seq2SeqLearn.WeightMatrix)input, encoderOutput.Select(p => (Seq2SeqLearn.WeightMatrix)p).ToList(), (Seq2SeqLearn.ComputeGraph)g));
 }
Esempio n. 13
0
 public List <WeightMatrix> Encode2(WeightMatrix V, ComputeGraph g)
 {
     return(base.Encode2((Seq2SeqLearn.WeightMatrix)V, (Seq2SeqLearn.ComputeGraph)g).Select(p => (WeightMatrix)p).ToList());
 }
Esempio n. 14
0
 public WeightMatrix Encode(WeightMatrix V, ComputeGraph g)
 {
     return((WeightMatrix)base.Encode((Seq2SeqLearn.WeightMatrix)V, (Seq2SeqLearn.ComputeGraph)g));
 }