コード例 #1
0
ファイル: train_rnnlm.cs プロジェクト: zhangfanTJU/dynet
 public RNNLanguageModel(RNNBuilder builder, ParameterCollection model, Dictionary <string, int> vocab, int INPUT_DIM, int HIDDEN_DIM, float dropout)
 {
     this.builder = builder;
     this.dropout = dropout;
     this.d       = vocab;
     this.di2W    = d.ToDictionary(kvp => kvp.Value, kvp => kvp.Key);
     lp           = model.AddLookupParameters(vocab.Count, new[] { INPUT_DIM });
     p_R          = model.AddParameters(new[] { vocab.Count, HIDDEN_DIM });
     p_bias       = model.AddParameters(new[] { vocab.Count });
 }
コード例 #2
0
        static void Main(string[] args)
        {
            DynetParams.FromArgs(args).Initialize();
            // Alternatively, can initialize and it directly, e.g:
            // DynetParams dp = new DynetParams();
            // dp.AutoBatch = true;
            // dp.MemDescriptor = "768";
            // dp.Initialize();

            const string  EOS        = "<EOS>";
            List <string> characters = "abcdefghijklmnopqrstuvwxyz ".Select(c => c.ToString()).ToList();

            characters.Add(EOS);

            // Lookup - dictionary
            Dictionary <string, int> c2i = Enumerable.Range(0, characters.Count).ToDictionary(i => characters[i], i => i);

            // Define the variables
            VOCAB_SIZE         = characters.Count;
            LSTM_NUM_OF_LAYERS = 2;
            EMBEDDINGS_SIZE    = 32;
            STATE_SIZE         = 32;
            ATTENTION_SIZE     = 32;

            // ParameterCollection (all the model parameters).
            ParameterCollection m = new ParameterCollection();
            // A class defined locally used to contain all the parameters to transfer
            // them between functions and avoid global variables
            ParameterGroup pg = new ParameterGroup();

            pg.c2i = c2i;
            pg.i2c = characters;
            pg.EOS = EOS;

            // LSTMs
            pg.enc_fwd_lstm = new LSTMBuilder(LSTM_NUM_OF_LAYERS, EMBEDDINGS_SIZE, STATE_SIZE, m);
            pg.enc_bwd_lstm = new LSTMBuilder(LSTM_NUM_OF_LAYERS, EMBEDDINGS_SIZE, STATE_SIZE, m);

            pg.dec_lstm = new LSTMBuilder(LSTM_NUM_OF_LAYERS, STATE_SIZE * 2 + EMBEDDINGS_SIZE, STATE_SIZE, m);

            // Create the parameters
            pg.input_lookup  = m.AddLookupParameters(VOCAB_SIZE, new[] { EMBEDDINGS_SIZE });
            pg.attention_w1  = m.AddParameters(new[] { ATTENTION_SIZE, STATE_SIZE * 2 });
            pg.attention_w2  = m.AddParameters(new[] { ATTENTION_SIZE, STATE_SIZE * 2 * LSTM_NUM_OF_LAYERS });
            pg.attention_v   = m.AddParameters(new[] { 1, ATTENTION_SIZE });
            pg.decoder_W     = m.AddParameters(new[] { VOCAB_SIZE, STATE_SIZE });
            pg.decoder_b     = m.AddParameters(new[] { VOCAB_SIZE });
            pg.output_lookup = m.AddLookupParameters(VOCAB_SIZE, new[] { EMBEDDINGS_SIZE });

            Trainer trainer = new SimpleSGDTrainer(m);

            // For good practice, renew the computation graph
            dy.RenewCG();

            // Train
            string trainSentence = "it is working";

            // Run 600 epochs
            for (int iEpoch = 0; iEpoch < 600; iEpoch++)
            {
                // Loss
                Expression loss = CalculateLoss(trainSentence, trainSentence, pg);
                // Forward, backward, update
                float lossValue = loss.ScalarValue();
                loss.Backward();
                trainer.Update();
                if (iEpoch % 20 == 0)
                {
                    Console.WriteLine(lossValue);
                    Console.WriteLine(GenerateSentence(trainSentence, pg));
                }
            }// next epoch
        }