예제 #1
0
            // return Expression of total loss
            public Expression BuildLMGraph(List <int> sent, bool fApplyDropout)
            {
                // Renew the computation graph
                dy.RenewCG();

                // hidden -> word rep parameter
                Expression R = dy.parameter(p_R);
                // word bias
                Expression bias = dy.parameter(p_bias);

                // Build the collection of losses
                List <Expression> errs = new List <Expression>();

                // Start the initial state with the a <s> tag
                RNNState state = builder.GetInitialState().AddInput(lp[d["<s>"]]);

                // Go through all the inputs
                for (int t = 0; t < sent.Count; t++)
                {
                    // Regular softmax
                    Expression u_t = dy.affine_transform(bias, R, state.Output());
                    errs.Add(dy.pickneglogsoftmax(u_t, sent[t]));
                    // Add the next item in
                    state = state.AddInput(dy.lookup(lp, sent[t]));
                }// next t
                // Add the last </s> tag
                Expression u_last = dy.affine_transform(bias, R, state.Output());

                errs.Add(dy.pickneglogsoftmax(u_last, d["</s>"]));

                // Run the sum
                return(dy.esum(errs));
            }
예제 #2
0
        private static string GenerateSentence(string inputSentence, ParameterGroup pg)
        {
            dy.RenewCG();

            List <Expression> embeds    = EmbedSentence(inputSentence, pg);
            List <Expression> encodings = EncodeSentence(embeds, pg);

            // Create the matrix - of all the context vectors
            Expression inputMat = dy.concatenate_cols(encodings);
            // Each attention is an activation layer on top of the sum of w1*inputMat and w2*state
            // Since w1*inputMat is static - calculate it here
            Expression w1dt = pg.attention_w1 * inputMat;

            // Create the initial state of the decoder
            RNNState decState = pg.dec_lstm.GetInitialState();

            // Run the EOS through (attend initial will be zeros)
            decState = decState.AddInput(dy.concatenate(dy.zeros(new[] { STATE_SIZE * 2 }), pg.output_lookup[pg.c2i[pg.EOS]]));

            List <string> output = new List <string>();
            Expression    prev   = pg.output_lookup[pg.c2i[pg.EOS]];

            // Go through and decode
            for (int i = 0; i < inputSentence.Length * 2; i++)
            {
                // Create the input
                Expression inputVec = dy.concatenate(Attend(inputMat, w1dt, decState, pg), prev);
                // Run through LSTM + linear layer
                decState = decState.AddInput(inputVec);
                Expression outputVec = dy.softmax(pg.decoder_W * decState.Output() + pg.decoder_b);
                // Get the predictions
                int max = Argmax(outputVec.VectorValue());
                if (max == pg.c2i[pg.EOS])
                {
                    break;
                }
                output.Add(pg.i2c[max]);
                prev = pg.output_lookup[max];
            }// next output

            return(string.Join("", output));
        }
예제 #3
0
        private static Expression DecodeSentence(List <Expression> encodings, string outputSentence, ParameterGroup pg)
        {
            // Pad the output *only at end* with eos
            List <string> output = outputSentence.Select(c => c.ToString()).ToList();

            output.Add(pg.EOS);

            // Create the matrix - of all the context vectors
            Expression inputMat = dy.concatenate_cols(encodings);
            // Each attention is an activation layer on top of the sum of w1*inputMat and w2*state
            // Since w1*inputMat is static - calculate it here
            Expression w1dt = pg.attention_w1 * inputMat;

            // Create the initial state of the decoder
            RNNState decState = pg.dec_lstm.GetInitialState();

            // Run the EOS through (attend initial will be zeros)
            decState = decState.AddInput(dy.concatenate(dy.zeros(new[] { STATE_SIZE * 2 }), pg.output_lookup[pg.c2i[pg.EOS]]));

            List <Expression> losses = new List <Expression>();
            // Go through and decode
            Expression prev = pg.output_lookup[pg.c2i[pg.EOS]];

            foreach (string outS in output)
            {
                // Create the input
                Expression inputVec = dy.concatenate(Attend(inputMat, w1dt, decState, pg), prev);
                // Run through LSTM + linear layer
                decState = decState.AddInput(inputVec);
                Expression outputVec = dy.softmax(pg.decoder_W * decState.Output() + pg.decoder_b);
                // Loss & next
                losses.Add(-dy.log(dy.pick(outputVec, pg.c2i[outS])));
                prev = pg.output_lookup[pg.c2i[outS]];
            }// next output

            return(dy.sum(losses));
        }