public WeightMatrix Perform(WeightMatrix input, WeightMatrix state, ComputeGraph g)
        {
            WeightMatrix        context;
            List <WeightMatrix> atten = new List <WeightMatrix>();

            var stateRepeat = g.RepeatRows(state, input.Rows);
            var baiseInput  = new WeightMatrix(input.Rows, 1, 1);
            var inputb      = g.concatColumns(input, baiseInput);


            var uh = g.mul(inputb, Ua);


            baiseInput  = new WeightMatrix(stateRepeat.Rows, 1, 1);
            stateRepeat = g.concatColumns(stateRepeat, baiseInput);


            var wc = g.mul(stateRepeat, Wa);
            var gg = g.tanh(g.add(uh, wc));
            var aa = g.mul(gg, V);


            var res = g.Softmax(aa);


            var weighted = g.weightRows(input, res);;

            context = g.sumColumns(weighted);

            return(context);
        }
        public WeightMatrix Perform(List <WeightMatrix> input, WeightMatrix state, ComputeGraph g)
        {
            WeightMatrix        context;
            List <WeightMatrix> atten = new List <WeightMatrix>();

            foreach (var h_j in input)
            {
                var uh = g.add(g.mul(h_j, Ua), bUa);
                var wc = g.add(g.mul(state, Wa), bWa);
                var gg = g.tanh(g.add(uh, wc));
                var aa = g.mul(gg, V);
                atten.Add(aa);
            }
            var res = g.Softmax(atten);


            var cmax   = res[0].Weight[0];
            int maxAtt = 0;

            for (int i = 1; i < res.Count; i++)
            {
                if (res[i].Weight[0] > cmax)
                {
                    cmax   = res[i].Weight[0];
                    maxAtt = i;
                }
            }
            this.MaxIndex = maxAtt;


            context = g.scalemul(input[0], res[0]);
            for (int hj = 1; hj < input.Count; hj++)
            {
                context = g.add(context, g.scalemul(input[hj], res[hj]));
            }
            return(context);
        }
Beispiel #3
0
        public WeightMatrix Step(WeightMatrix input, ComputeGraph innerGraph)
        {
            var hidden_prev = ht;
            var cell_prev   = ct;


            var cell       = this;
            var h0         = innerGraph.mul(input, cell.Wix);
            var h1         = innerGraph.mul(hidden_prev, cell.Wih);
            var input_gate = innerGraph.sigmoid(
                innerGraph.add(
                    innerGraph.add(h0, h1),
                    cell.bi
                    )
                );


            var h2          = innerGraph.mul(input, cell.Wfx);
            var h3          = innerGraph.mul(hidden_prev, cell.Wfh);
            var forget_gate = innerGraph.sigmoid(
                innerGraph.add(
                    innerGraph.add(h2, h3),
                    cell.bf
                    )
                );



            var h4          = innerGraph.mul(input, cell.Wox);
            var h5          = innerGraph.mul(hidden_prev, cell.Woh);
            var output_gate = innerGraph.sigmoid(
                innerGraph.add(
                    innerGraph.add(h4, h5),
                    cell.bo
                    )
                );



            var h6         = innerGraph.mul(input, cell.Wcx);
            var h7         = innerGraph.mul(hidden_prev, cell.Wch);
            var cell_write = innerGraph.tanh(
                innerGraph.add(
                    innerGraph.add(h6, h7),
                    cell.bc
                    )
                );

            // compute new cell activation
            var retain_cell = innerGraph.eltmul(forget_gate, cell_prev); // what do we keep from cell
            var write_cell  = innerGraph.eltmul(input_gate, cell_write); // what do we write to cell
            var cell_d      = innerGraph.add(retain_cell, write_cell);   // new cell contents



            // compute hidden state as gated, saturated cell activations
            var hidden_d = innerGraph.eltmul(output_gate, innerGraph.tanh(cell_d));

            this.ht = hidden_d;
            this.ct = cell_d;
            return(ht);
        }