Ejemplo n.º 1
0
        public void Train(string checkpoint, string run, int?counter, CancellationToken cancellation)
        {
            new Session().UseSelf(session => {
                var context   = tf.placeholder(tf.int32, new TensorShape(this.batchSize, null));
                var output    = Gpt2Model.Model(this.hParams, input: context);
                Tensor labels = context[Range.All, Range.StartAt(1)];
                Tensor logits = output["logits"][Range.All, Range.EndAt(new Index(1, fromEnd: true))];
                var loss      = tf.reduce_mean(
                    tf.nn.sparse_softmax_cross_entropy_with_logits_dyn(
                        labels: labels,
                        logits: logits));

                var sample = Gpt2Sampler.SampleSequence(
                    this.hParams,
                    length: this.sampleLength,
                    context: context,
                    batchSize: this.batchSize,
                    temperature: 1.0f,
                    topK: 40);

                var trainVars = tf.trainable_variables().Where((dynamic var) => var.name.Contains("model"));
                var optimizer = new AdamOptimizer(learning_rate: 0.0002).minimize(loss, var_list: trainVars);

                var saver = new Saver(
                    var_list: trainVars,
                    max_to_keep: 5,
                    keep_checkpoint_every_n_hours: 1);

                session.run(tf.global_variables_initializer());

                Console.WriteLine("Loading checkpoint " + checkpoint);
                saver.restore(session, checkpoint);

                Console.WriteLine("Loading dataset...");
                var sampler = new TrainingSampler(this.dataset, this.random);
                Console.WriteLine($"Dataset has {sampler.TokenCount} tokens");

                string counterFile = Path.Combine(Gpt2Checkpoints.CheckpointDir, run, "counter");
                if (counter == null && File.Exists(counterFile))
                {
                    counter = int.Parse(File.ReadAllText(counterFile), CultureInfo.InvariantCulture) + 1;
                }
                counter = counter ?? 1;

                string runCheckpointDir = Path.Combine(Gpt2Checkpoints.CheckpointDir, run);
                string runSampleDir     = Path.Combine(SampleDir, run);

                void Save()
                {
                    Directory.CreateDirectory(runCheckpointDir);
                    Console.WriteLine("Saving " + Path.Combine(runCheckpointDir, Invariant($"model-{counter}")));
                    saver.save(session,
                               Path.Combine(runCheckpointDir, "model"),
                               global_step: counter.Value);
                    File.WriteAllText(path: counterFile, contents: Invariant($"{counter}"));
                }

                void GenerateSamples()
                {
                    var contextTokens = np.array(new[] { this.encoder.EncodedEndOfText });
                    var allText       = new List <string>();
                    int index         = 0;
                    string text       = null;
                    while (index < this.SampleNum)
                    {
                        var @out = session.run(sample, feed_dict: new PythonDict <object, object> {
                            [context] = Enumerable.Repeat(contextTokens, this.batchSize),
                        });
                        foreach (int i in Enumerable.Range(0, Math.Min(this.SampleNum - index, this.batchSize)))
                        {
                            text = this.encoder.Decode(@out[i]);
                            text = Invariant($"======== SAMPLE {index + 1} ========\n{text}\n");
                            allText.Add(text);
                            index++;
                        }
                    }
                    Console.WriteLine(text);
                    Directory.CreateDirectory(runSampleDir);
                    File.WriteAllLines(
                        path: Path.Combine(runSampleDir, Invariant($"samples-{counter}")),
                        contents: allText);
                }

                var avgLoss   = (0.0, 0.0);
                var startTime = DateTime.Now;

                while (!cancellation.IsCancellationRequested)
                {
                    if (counter % this.SaveEvery == 0)
                    {
                        Save();
                    }
                    if (counter % this.SampleEvery == 0)
                    {
                        GenerateSamples();
                    }

                    var batch = Enumerable.Range(0, this.batchSize)
                                .Select(_ => sampler.Sample(1024))
                                .ToArray();

                    var placeholderValues = new PythonDict <object, object> {
                        [context] = batch.ToPythonList(),
                    };
                    var tuple = session.run_dyn((optimizer, loss), feed_dict: placeholderValues);

                    var lv = tuple.Item2;

                    avgLoss = (avgLoss.Item1 * 0.99 + lv, avgLoss.Item2 * 0.99 + 1);

                    Console.WriteLine($"[{counter} | {DateTime.Now-startTime}] loss={lv} avg={avgLoss.Item1/avgLoss.Item2}");

                    counter++;
                }

                Console.WriteLine("Interrupted");
                Save();
            });
        }
        internal void TrainOneEpoch(int ep, ParallelCorpus trainCorpus, ParallelCorpus validCorpus, ILearningRate learningRate, AdamOptimizer solver, List <IMetric> metrics, IModelMetaData modelMetaData,
                                    Func <IComputeGraph, List <List <string> >, List <List <string> >, int, bool, float> ForwardOnSingleDevice)
        {
            int      processedLineInTotal   = 0;
            DateTime startDateTime          = DateTime.Now;
            DateTime lastCheckPointDateTime = DateTime.Now;
            double   costInTotal            = 0.0;
            long     srcWordCnts            = 0;
            long     tgtWordCnts            = 0;
            double   avgCostPerWordInTotal  = 0.0;

            TensorAllocator.FreeMemoryAllDevices();

            Logger.WriteLine($"Start to process training corpus.");
            List <SntPairBatch> sntPairBatchs = new List <SntPairBatch>();

            foreach (SntPairBatch sntPairBatch in trainCorpus)
            {
                sntPairBatchs.Add(sntPairBatch);
                if (sntPairBatchs.Count == m_deviceIds.Length)
                {
                    float cost          = 0.0f;
                    int   tlen          = 0;
                    int   processedLine = 0;

                    // Copy weights from weights kept in default device to all other devices
                    CopyWeightsFromDefaultDeviceToAllOtherDevices();

                    // Run forward and backward on all available processors
                    Parallel.For(0, m_deviceIds.Length, i =>
                    {
                        SntPairBatch sntPairBatch_i = sntPairBatchs[i];
                        // Construct sentences for encoding and decoding
                        List <List <string> > srcTkns = new List <List <string> >();
                        List <List <string> > tgtTkns = new List <List <string> >();
                        int sLenInBatch = 0;
                        int tLenInBatch = 0;
                        for (int j = 0; j < sntPairBatch_i.BatchSize; j++)
                        {
                            srcTkns.Add(sntPairBatch_i.SntPairs[j].SrcSnt.ToList());
                            sLenInBatch += sntPairBatch_i.SntPairs[j].SrcSnt.Length;

                            tgtTkns.Add(sntPairBatch_i.SntPairs[j].TgtSnt.ToList());
                            tLenInBatch += sntPairBatch_i.SntPairs[j].TgtSnt.Length;
                        }

                        float lcost = 0.0f;
                        // Create a new computing graph instance
                        using (IComputeGraph computeGraph_i = CreateComputGraph(i))
                        {
                            // Run forward part
                            lcost = ForwardOnSingleDevice(computeGraph_i, srcTkns, tgtTkns, i, true);
                            // Run backward part and compute gradients
                            computeGraph_i.Backward();
                        }

                        lock (locker)
                        {
                            cost                 += lcost;
                            srcWordCnts          += sLenInBatch;
                            tgtWordCnts          += tLenInBatch;
                            tlen                 += tLenInBatch;
                            processedLineInTotal += sntPairBatch_i.BatchSize;
                            processedLine        += sntPairBatch_i.BatchSize;
                        }
                    });

                    //Sum up gradients in all devices, and kept it in default device for parameters optmization
                    SumGradientsToTensorsInDefaultDevice();

                    //Optmize parameters
                    float lr = learningRate.GetCurrentLearningRate();
                    List <IWeightTensor> models = GetParametersFromDefaultDevice();
                    solver.UpdateWeights(models, processedLine, lr, m_regc, m_weightsUpdateCount + 1);

                    //Clear gradient over all devices
                    ZeroGradientOnAllDevices();

                    costInTotal          += cost;
                    avgCostPerWordInTotal = costInTotal / tgtWordCnts;
                    m_weightsUpdateCount++;
                    if (IterationDone != null && m_weightsUpdateCount % 100 == 0)
                    {
                        IterationDone(this, new CostEventArg()
                        {
                            LearningRate              = lr,
                            CostPerWord               = cost / tlen,
                            AvgCostInTotal            = avgCostPerWordInTotal,
                            Epoch                     = ep,
                            Update                    = m_weightsUpdateCount,
                            ProcessedSentencesInTotal = processedLineInTotal,
                            ProcessedWordsInTotal     = srcWordCnts + tgtWordCnts,
                            StartDateTime             = startDateTime
                        });
                    }

                    // Evaluate model every hour and save it if we could get a better one.
                    TimeSpan ts = DateTime.Now - lastCheckPointDateTime;
                    if (ts.TotalHours > 1.0)
                    {
                        CreateCheckPoint(validCorpus, metrics, modelMetaData, ForwardOnSingleDevice, avgCostPerWordInTotal);
                        lastCheckPointDateTime = DateTime.Now;
                    }

                    sntPairBatchs.Clear();
                }
            }

            Logger.WriteLine(Logger.Level.info, ConsoleColor.Green, $"Epoch '{ep}' took '{DateTime.Now - startDateTime}' time to finish. AvgCost = {avgCostPerWordInTotal.ToString("F6")}, AvgCostInLastEpoch = {m_avgCostPerWordInTotalInLastEpoch.ToString("F6")}");

            CreateCheckPoint(validCorpus, metrics, modelMetaData, ForwardOnSingleDevice, avgCostPerWordInTotal);
            m_avgCostPerWordInTotalInLastEpoch = avgCostPerWordInTotal;
        }
Ejemplo n.º 3
0
        internal void TrainOneEpoch(int ep, IEnumerable <SntPairBatch> trainCorpus, IEnumerable <SntPairBatch> validCorpus, ILearningRate learningRate, AdamOptimizer solver, List <IMetric> metrics, IModelMetaData modelMetaData,
                                    Func <IComputeGraph, List <List <string> >, List <List <string> >, int, bool, float> ForwardOnSingleDevice)
        {
            int      processedLineInTotal  = 0;
            DateTime startDateTime         = DateTime.Now;
            double   costInTotal           = 0.0;
            long     srcWordCntsInTotal    = 0;
            long     tgtWordCntsInTotal    = 0;
            double   avgCostPerWordInTotal = 0.0;

            Logger.WriteLine($"Start to process training corpus.");
            List <SntPairBatch> sntPairBatchs = new List <SntPairBatch>();

            foreach (SntPairBatch sntPairBatch in trainCorpus)
            {
                sntPairBatchs.Add(sntPairBatch);
                if (sntPairBatchs.Count == m_deviceIds.Length)
                {
                    // Copy weights from weights kept in default device to all other devices
                    CopyWeightsFromDefaultDeviceToAllOtherDevices();

                    int  batchSplitFactor   = 1;
                    bool runNetwordSuccssed = false;

                    while (runNetwordSuccssed == false)
                    {
                        try
                        {
                            (float cost, int sWordCnt, int tWordCnt, int processedLine) = RunNetwork(ForwardOnSingleDevice, sntPairBatchs, batchSplitFactor);
                            processedLineInTotal += processedLine;
                            srcWordCntsInTotal   += sWordCnt;
                            tgtWordCntsInTotal   += tWordCnt;

                            //Sum up gradients in all devices, and kept it in default device for parameters optmization
                            SumGradientsToTensorsInDefaultDevice();

                            //Optmize parameters
                            float lr = learningRate.GetCurrentLearningRate();
                            List <IWeightTensor> models = GetParametersFromDefaultDevice();
                            solver.UpdateWeights(models, processedLine, lr, m_regc, m_weightsUpdateCount + 1);


                            costInTotal          += cost;
                            avgCostPerWordInTotal = costInTotal / tgtWordCntsInTotal;
                            m_weightsUpdateCount++;
                            if (IterationDone != null && m_weightsUpdateCount % 100 == 0)
                            {
                                IterationDone(this, new CostEventArg()
                                {
                                    LearningRate              = lr,
                                    CostPerWord               = cost / tWordCnt,
                                    AvgCostInTotal            = avgCostPerWordInTotal,
                                    Epoch                     = ep,
                                    Update                    = m_weightsUpdateCount,
                                    ProcessedSentencesInTotal = processedLineInTotal,
                                    ProcessedWordsInTotal     = srcWordCntsInTotal + tgtWordCntsInTotal,
                                    StartDateTime             = startDateTime
                                });
                            }

                            runNetwordSuccssed = true;
                        }
                        catch (AggregateException err)
                        {
                            if (err.InnerExceptions != null)
                            {
                                string oomMessage            = String.Empty;
                                bool   isOutOfMemException   = false;
                                bool   isArithmeticException = false;
                                foreach (var excep in err.InnerExceptions)
                                {
                                    if (excep is OutOfMemoryException)
                                    {
                                        isOutOfMemException = true;
                                        oomMessage          = excep.Message;
                                        break;
                                    }
                                    else if (excep is ArithmeticException)
                                    {
                                        isArithmeticException = true;
                                        oomMessage            = excep.Message;
                                        break;
                                    }
                                }

                                if (isOutOfMemException)
                                {
                                    batchSplitFactor *= 2;
                                    Logger.WriteLine($"Got an exception ('{oomMessage}'), so we increase batch split factor to {batchSplitFactor}, and retry it.");

                                    if (batchSplitFactor >= sntPairBatchs[0].BatchSize)
                                    {
                                        Logger.WriteLine($"Batch split factor is larger than batch size, so ignore current mini-batch.");
                                        break;
                                    }
                                }
                                else if (isArithmeticException)
                                {
                                    Logger.WriteLine($"Arithmetic exception: '{err.Message}'");
                                    break;
                                }
                                else
                                {
                                    Logger.WriteLine(Logger.Level.err, ConsoleColor.Red, $"Exception: {err.Message}, Call stack: {err.StackTrace}");
                                    throw err;
                                }
                            }
                            else
                            {
                                Logger.WriteLine(Logger.Level.err, ConsoleColor.Red, $"Exception: {err.Message}, Call stack: {err.StackTrace}");
                                throw err;
                            }
                        }
                        catch (OutOfMemoryException err)
                        {
                            batchSplitFactor *= 2;
                            Logger.WriteLine($"Got an exception ('{err.Message}'), so we increase batch split factor to {batchSplitFactor}, and retry it.");

                            if (batchSplitFactor >= sntPairBatchs[0].BatchSize)
                            {
                                Logger.WriteLine($"Batch split factor is larger than batch size, so ignore current mini-batch.");
                                break;
                            }
                        }
                        catch (ArithmeticException err)
                        {
                            Logger.WriteLine($"Arithmetic exception: '{err.Message}'");
                            break;
                        }
                        catch (Exception err)
                        {
                            Logger.WriteLine(Logger.Level.err, ConsoleColor.Red, $"Exception: {err.Message}, Call stack: {err.StackTrace}");
                            throw err;
                        }
                    }

                    // Evaluate model every hour and save it if we could get a better one.
                    TimeSpan ts = DateTime.Now - m_lastCheckPointDateTime;
                    if (ts.TotalHours > 1.0)
                    {
                        CreateCheckPoint(validCorpus, metrics, modelMetaData, ForwardOnSingleDevice, avgCostPerWordInTotal);
                        m_lastCheckPointDateTime = DateTime.Now;
                    }

                    sntPairBatchs.Clear();
                }
            }

            Logger.WriteLine(Logger.Level.info, ConsoleColor.Green, $"Epoch '{ep}' took '{DateTime.Now - startDateTime}' time to finish. AvgCost = {avgCostPerWordInTotal.ToString("F6")}, AvgCostInLastEpoch = {m_avgCostPerWordInTotalInLastEpoch.ToString("F6")}");

            //  CreateCheckPoint(validCorpus, metrics, modelMetaData, ForwardOnSingleDevice, avgCostPerWordInTotal);
            m_avgCostPerWordInTotalInLastEpoch = avgCostPerWordInTotal;
        }
Ejemplo n.º 4
0
        static void Main(string[] args)
        {
            Logger.LogFile = $"{nameof(Seq2SeqConsole)}_{GetTimeStamp(DateTime.Now)}.log";
            ShowOptions(args);

            //Parse command line
            Options   opts      = new Options();
            ArgParser argParser = new ArgParser(args, opts);

            if (String.IsNullOrEmpty(opts.ConfigFilePath) == false)
            {
                Logger.WriteLine($"Loading config file from '{opts.ConfigFilePath}'");
                opts = JsonConvert.DeserializeObject <Options>(File.ReadAllText(opts.ConfigFilePath));
            }

            AttentionSeq2Seq   ss            = null;
            ProcessorTypeEnums processorType = (ProcessorTypeEnums)Enum.Parse(typeof(ProcessorTypeEnums), opts.ProcessorType);
            EncoderTypeEnums   encoderType   = (EncoderTypeEnums)Enum.Parse(typeof(EncoderTypeEnums), opts.EncoderType);
            ModeEnums          mode          = (ModeEnums)Enum.Parse(typeof(ModeEnums), opts.TaskName);

            //Parse device ids from options
            int[] deviceIds = opts.DeviceIds.Split(',').Select(x => int.Parse(x)).ToArray();
            if (mode == ModeEnums.Train)
            {
                // Load train corpus
                ParallelCorpus trainCorpus = new ParallelCorpus(opts.TrainCorpusPath, opts.SrcLang, opts.TgtLang, opts.BatchSize, opts.ShuffleBlockSize, opts.MaxSentLength);
                // Load valid corpus
                ParallelCorpus validCorpus = String.IsNullOrEmpty(opts.ValidCorpusPath) ? null : new ParallelCorpus(opts.ValidCorpusPath, opts.SrcLang, opts.TgtLang, opts.BatchSize, opts.ShuffleBlockSize, opts.MaxSentLength);

                // Load or build vocabulary
                Vocab vocab = null;
                if (!String.IsNullOrEmpty(opts.SrcVocab) && !String.IsNullOrEmpty(opts.TgtVocab))
                {
                    // Vocabulary files are specified, so we load them
                    vocab = new Vocab(opts.SrcVocab, opts.TgtVocab);
                }
                else
                {
                    // We don't specify vocabulary, so we build it from train corpus
                    vocab = new Vocab(trainCorpus);
                }

                // Create learning rate
                ILearningRate learningRate = new DecayLearningRate(opts.StartLearningRate, opts.WarmUpSteps, opts.WeightsUpdateCount);

                // Create optimizer
                AdamOptimizer optimizer = new AdamOptimizer(opts.GradClip, opts.Beta1, opts.Beta2);

                // Create metrics
                List <IMetric> metrics = new List <IMetric>();
                metrics.Add(new BleuMetric());
                metrics.Add(new LengthRatioMetric());

                if (File.Exists(opts.ModelFilePath) == false)
                {
                    //New training
                    ss = new AttentionSeq2Seq(embeddingDim: opts.WordVectorSize, hiddenDim: opts.HiddenSize, encoderLayerDepth: opts.EncoderLayerDepth, decoderLayerDepth: opts.DecoderLayerDepth,
                                              srcEmbeddingFilePath: opts.SrcEmbeddingModelFilePath, tgtEmbeddingFilePath: opts.TgtEmbeddingModelFilePath, vocab: vocab, modelFilePath: opts.ModelFilePath,
                                              dropoutRatio: opts.DropoutRatio, processorType: processorType, deviceIds: deviceIds, multiHeadNum: opts.MultiHeadNum, encoderType: encoderType);
                }
                else
                {
                    //Incremental training
                    Logger.WriteLine($"Loading model from '{opts.ModelFilePath}'...");
                    ss = new AttentionSeq2Seq(modelFilePath: opts.ModelFilePath, processorType: processorType, dropoutRatio: opts.DropoutRatio, deviceIds: deviceIds);
                }

                // Add event handler for monitoring
                ss.IterationDone += ss_IterationDone;

                // Kick off training
                ss.Train(maxTrainingEpoch: opts.MaxEpochNum, trainCorpus: trainCorpus, validCorpus: validCorpus, learningRate: learningRate, optimizer: optimizer, metrics: metrics);
            }
            else if (mode == ModeEnums.Valid)
            {
                Logger.WriteLine($"Evaluate model '{opts.ModelFilePath}' by valid corpus '{opts.ValidCorpusPath}'");

                // Create metrics
                List <IMetric> metrics = new List <IMetric>();
                metrics.Add(new BleuMetric());
                metrics.Add(new LengthRatioMetric());

                // Load valid corpus
                ParallelCorpus validCorpus = new ParallelCorpus(opts.ValidCorpusPath, opts.SrcLang, opts.TgtLang, opts.BatchSize, opts.ShuffleBlockSize, opts.MaxSentLength);

                ss = new AttentionSeq2Seq(modelFilePath: opts.ModelFilePath, processorType: processorType, deviceIds: deviceIds);
                ss.Valid(validCorpus: validCorpus, metrics: metrics);
            }
            else if (mode == ModeEnums.Test)
            {
                Logger.WriteLine($"Test model '{opts.ModelFilePath}' by input corpus '{opts.InputTestFile}'");

                //Test trained model
                ss = new AttentionSeq2Seq(modelFilePath: opts.ModelFilePath, processorType: processorType, deviceIds: deviceIds);

                List <string> outputLines     = new List <string>();
                var           data_sents_raw1 = File.ReadAllLines(opts.InputTestFile);
                foreach (string line in data_sents_raw1)
                {
                    //// Below support beam search
                    //List<List<string>> outputWordsList = ss.Predict(line.ToLower().Trim().Split(' ').ToList(), opts.BeamSearch);
                    //outputLines.AddRange(outputWordsList.Select(x => String.Join(" ", x)));

                    var outputTokensBatch = ss.Test(ParallelCorpus.ConstructInputTokens(line.ToLower().Trim().Split(' ').ToList()));
                    outputLines.AddRange(outputTokensBatch.Select(x => String.Join(" ", x)));
                }

                File.WriteAllLines(opts.OutputTestFile, outputLines);
            }
            else if (mode == ModeEnums.VisualizeNetwork)
            {
                ss = new AttentionSeq2Seq(embeddingDim: opts.WordVectorSize, hiddenDim: opts.HiddenSize, encoderLayerDepth: opts.EncoderLayerDepth, decoderLayerDepth: opts.DecoderLayerDepth,
                                          vocab: new Vocab(), srcEmbeddingFilePath: null, tgtEmbeddingFilePath: null, modelFilePath: opts.ModelFilePath, dropoutRatio: opts.DropoutRatio,
                                          processorType: processorType, deviceIds: new int[1] {
                    0
                }, multiHeadNum: opts.MultiHeadNum, encoderType: encoderType);

                ss.VisualizeNeuralNetwork(opts.VisualizeNNFilePath);
            }
            else
            {
                argParser.Usage();
            }
        }
Ejemplo n.º 5
0
        /// <summary>
        /// Solves y = x * W + b (CPU single version)
        /// for y = 1 and x = -2
        ///
        /// This also demonstrates how to save and load a graph
        /// </summary>
        public static void Example1()
        {
            var cns = new ConvNetSharp <float>();

            // Graph creation
            Op <float> cost;
            Op <float> fun;

            if (File.Exists("test.graphml"))
            {
                Console.WriteLine("Loading graph from disk.");
                var ops = SerializationExtensions.Load <float>("test", true);

                fun  = ops[0];
                cost = ops[1];
            }
            else
            {
                var x = cns.PlaceHolder("x");
                var y = cns.PlaceHolder("y");

                var W = cns.Variable(1.0f, "W", true);
                var b = cns.Variable(2.0f, "b", true);

                fun = x * W + b;

                cost = (fun - y) * (fun - y);
            }


            var optimizer = new AdamOptimizer <float>(cns, 0.01f, 0.9f, 0.999f, 1e-08f);

            using (var session = new Session <float>())
            {
                session.Differentiate(cost); // computes dCost/dW at every node of the graph

                float currentCost;
                do
                {
                    var dico = new Dictionary <string, Volume <float> > {
                        { "x", -2.0f }, { "y", 1.0f }
                    };

                    currentCost = session.Run(cost, dico);
                    Console.WriteLine($"cost: {currentCost}");

                    var result = session.Run(fun, dico);
                    session.Run(optimizer, dico);
                } while (currentCost > 1e-5);

                float finalW = session.GetVariableByName(fun, "W").Result;
                float finalb = session.GetVariableByName(fun, "b").Result;
                Console.WriteLine($"fun = x * {finalW} + {finalb}");

                fun.Save("test", cost);

                //// Display graph
                //var vm = new ViewModel<float>(cost);
                //var app = new Application();
                //app.Run(new GraphControl { DataContext = vm });
            }

            Console.ReadKey();
        }
Ejemplo n.º 6
0
        public CharRNNModel(CharRNNModelParameters parameters, bool training = true)
        {
            this.parameters = parameters ?? throw new ArgumentNullException(nameof(parameters));
            if (!training)
            {
                this.parameters.BatchSize = 1;
                this.parameters.SeqLength = 1;
            }

            if (!ModelTypeToCellFunction.TryGetValue(parameters.ModelType, out this.cellFactory))
            {
                throw new NotSupportedException(parameters.ModelType.ToString());
            }

            for (int i = 0; i < parameters.LayerCount; i++)
            {
                RNNCell cell = this.cellFactory(parameters.RNNSize);
                if (training && (parameters.KeepOutputProbability < 1 || parameters.KeepInputProbability < 1))
                {
                    cell = new DropoutWrapper(cell,
                                              input_keep_prob: parameters.KeepInputProbability,
                                              output_keep_prob: parameters.KeepOutputProbability);
                }
                this.cells.Add(cell);
            }
            this.rnn          = new MultiRNNCell(this.cells, state_is_tuple: true);
            this.inputData    = tf.placeholder(tf.int32, new TensorShape(parameters.BatchSize, parameters.SeqLength));
            this.targets      = tf.placeholder(tf.int32, new TensorShape(parameters.BatchSize, parameters.SeqLength));
            this.initialState = this.rnn.zero_state(parameters.BatchSize, tf.float32);

            Variable softmax_W = null, softmax_b = null;

            new variable_scope("rnnlm").UseSelf(_ => {
                softmax_W = tf.get_variable("softmax_w", new TensorShape(parameters.RNNSize, parameters.VocabularySize));
                softmax_b = tf.get_variable("softmax_b", new TensorShape(parameters.VocabularySize));
            });

            Variable embedding = tf.get_variable("embedding", new TensorShape(parameters.VocabularySize, parameters.RNNSize));
            Tensor   input     = tf.nn.embedding_lookup(embedding, this.inputData);

            // dropout beta testing: double check which one should affect next line
            if (training && parameters.KeepOutputProbability < 1)
            {
                input = tf.nn.dropout(input, parameters.KeepOutputProbability);
            }

            PythonList <Tensor> inputs = tf.split(input, parameters.SeqLength, axis: 1);

            inputs = inputs.Select(i => (Tensor)tf.squeeze(i, axis: 1)).ToPythonList();

            dynamic Loop(dynamic prev, dynamic _)
            {
                prev = tf.matmul(prev, softmax_W) + softmax_b;
                var prevSymbol = tf.stop_gradient(tf.argmax(prev, 1));

                return(tf.nn.embedding_lookup(embedding, prevSymbol));
            }

            var decoder = tensorflow.contrib.legacy_seq2seq.legacy_seq2seq.rnn_decoder(inputs, initialState.Items().Cast <object>(), this.rnn,
                                                                                       loop_function: training ? null : PythonFunctionContainer.Of(new Func <dynamic, dynamic, dynamic>(Loop)), scope: "rnnlm");
            var     outputs             = decoder.Item1;
            var     lastState           = (seq2seqState)decoder.Item2;
            dynamic contatenatedOutputs = tf.concat(outputs, 1);
            var     output = tensorflow.tf.reshape(contatenatedOutputs, new[] { -1, parameters.RNNSize });

            this.logits = tf.matmul(output, softmax_W) + softmax_b;
            this.probs  = tf.nn.softmax(new[] { this.logits });
            this.loss   = tensorflow.contrib.legacy_seq2seq.legacy_seq2seq.sequence_loss_by_example(
                new[] { this.logits },
                new[] { tf.reshape(targets, new[] { -1 }) },
                new[] { tf.ones(new[] { parameters.BatchSize *parameters.SeqLength }) });

            Tensor cost = null;

            new name_scope("cost").UseSelf(_ => {
                cost = tf.reduce_sum(this.loss) / parameters.BatchSize / parameters.SeqLength;
            });
            this.cost         = cost;
            this.finalState   = lastState;
            this.learningRate = new Variable(0.0, trainable: false);
            var tvars = tf.trainable_variables();

            IEnumerable <object> grads     = tf.clip_by_global_norm(tf.gradients(this.cost, tvars), parameters.GradientClip).Item1;
            AdamOptimizer        optimizer = null;

            new name_scope("optimizer").UseSelf(_ => optimizer = new AdamOptimizer(this.learningRate));
            this.trainOp = optimizer.apply_gradients(grads.Zip(tvars, (grad, @var) => (dynamic)(grad, @var)));

            tf.summary.histogram("logits", new[] { this.logits });
            tf.summary.histogram("loss", new[] { this.loss });
            tf.summary.histogram("train_loss", new[] { this.cost });
        }