public void Train(string checkpoint, string run, int?counter, CancellationToken cancellation) { new Session().UseSelf(session => { var context = tf.placeholder(tf.int32, new TensorShape(this.batchSize, null)); var output = Gpt2Model.Model(this.hParams, input: context); Tensor labels = context[Range.All, Range.StartAt(1)]; Tensor logits = output["logits"][Range.All, Range.EndAt(new Index(1, fromEnd: true))]; var loss = tf.reduce_mean( tf.nn.sparse_softmax_cross_entropy_with_logits_dyn( labels: labels, logits: logits)); var sample = Gpt2Sampler.SampleSequence( this.hParams, length: this.sampleLength, context: context, batchSize: this.batchSize, temperature: 1.0f, topK: 40); var trainVars = tf.trainable_variables().Where((dynamic var) => var.name.Contains("model")); var optimizer = new AdamOptimizer(learning_rate: 0.0002).minimize(loss, var_list: trainVars); var saver = new Saver( var_list: trainVars, max_to_keep: 5, keep_checkpoint_every_n_hours: 1); session.run(tf.global_variables_initializer()); Console.WriteLine("Loading checkpoint " + checkpoint); saver.restore(session, checkpoint); Console.WriteLine("Loading dataset..."); var sampler = new TrainingSampler(this.dataset, this.random); Console.WriteLine($"Dataset has {sampler.TokenCount} tokens"); string counterFile = Path.Combine(Gpt2Checkpoints.CheckpointDir, run, "counter"); if (counter == null && File.Exists(counterFile)) { counter = int.Parse(File.ReadAllText(counterFile), CultureInfo.InvariantCulture) + 1; } counter = counter ?? 1; string runCheckpointDir = Path.Combine(Gpt2Checkpoints.CheckpointDir, run); string runSampleDir = Path.Combine(SampleDir, run); void Save() { Directory.CreateDirectory(runCheckpointDir); Console.WriteLine("Saving " + Path.Combine(runCheckpointDir, Invariant($"model-{counter}"))); saver.save(session, Path.Combine(runCheckpointDir, "model"), global_step: counter.Value); File.WriteAllText(path: counterFile, contents: Invariant($"{counter}")); } void GenerateSamples() { var contextTokens = np.array(new[] { this.encoder.EncodedEndOfText }); var allText = new List <string>(); int index = 0; string text = null; while (index < this.SampleNum) { var @out = session.run(sample, feed_dict: new PythonDict <object, object> { [context] = Enumerable.Repeat(contextTokens, this.batchSize), }); foreach (int i in Enumerable.Range(0, Math.Min(this.SampleNum - index, this.batchSize))) { text = this.encoder.Decode(@out[i]); text = Invariant($"======== SAMPLE {index + 1} ========\n{text}\n"); allText.Add(text); index++; } } Console.WriteLine(text); Directory.CreateDirectory(runSampleDir); File.WriteAllLines( path: Path.Combine(runSampleDir, Invariant($"samples-{counter}")), contents: allText); } var avgLoss = (0.0, 0.0); var startTime = DateTime.Now; while (!cancellation.IsCancellationRequested) { if (counter % this.SaveEvery == 0) { Save(); } if (counter % this.SampleEvery == 0) { GenerateSamples(); } var batch = Enumerable.Range(0, this.batchSize) .Select(_ => sampler.Sample(1024)) .ToArray(); var placeholderValues = new PythonDict <object, object> { [context] = batch.ToPythonList(), }; var tuple = session.run_dyn((optimizer, loss), feed_dict: placeholderValues); var lv = tuple.Item2; avgLoss = (avgLoss.Item1 * 0.99 + lv, avgLoss.Item2 * 0.99 + 1); Console.WriteLine($"[{counter} | {DateTime.Now-startTime}] loss={lv} avg={avgLoss.Item1/avgLoss.Item2}"); counter++; } Console.WriteLine("Interrupted"); Save(); }); }
internal void TrainOneEpoch(int ep, ParallelCorpus trainCorpus, ParallelCorpus validCorpus, ILearningRate learningRate, AdamOptimizer solver, List <IMetric> metrics, IModelMetaData modelMetaData, Func <IComputeGraph, List <List <string> >, List <List <string> >, int, bool, float> ForwardOnSingleDevice) { int processedLineInTotal = 0; DateTime startDateTime = DateTime.Now; DateTime lastCheckPointDateTime = DateTime.Now; double costInTotal = 0.0; long srcWordCnts = 0; long tgtWordCnts = 0; double avgCostPerWordInTotal = 0.0; TensorAllocator.FreeMemoryAllDevices(); Logger.WriteLine($"Start to process training corpus."); List <SntPairBatch> sntPairBatchs = new List <SntPairBatch>(); foreach (SntPairBatch sntPairBatch in trainCorpus) { sntPairBatchs.Add(sntPairBatch); if (sntPairBatchs.Count == m_deviceIds.Length) { float cost = 0.0f; int tlen = 0; int processedLine = 0; // Copy weights from weights kept in default device to all other devices CopyWeightsFromDefaultDeviceToAllOtherDevices(); // Run forward and backward on all available processors Parallel.For(0, m_deviceIds.Length, i => { SntPairBatch sntPairBatch_i = sntPairBatchs[i]; // Construct sentences for encoding and decoding List <List <string> > srcTkns = new List <List <string> >(); List <List <string> > tgtTkns = new List <List <string> >(); int sLenInBatch = 0; int tLenInBatch = 0; for (int j = 0; j < sntPairBatch_i.BatchSize; j++) { srcTkns.Add(sntPairBatch_i.SntPairs[j].SrcSnt.ToList()); sLenInBatch += sntPairBatch_i.SntPairs[j].SrcSnt.Length; tgtTkns.Add(sntPairBatch_i.SntPairs[j].TgtSnt.ToList()); tLenInBatch += sntPairBatch_i.SntPairs[j].TgtSnt.Length; } float lcost = 0.0f; // Create a new computing graph instance using (IComputeGraph computeGraph_i = CreateComputGraph(i)) { // Run forward part lcost = ForwardOnSingleDevice(computeGraph_i, srcTkns, tgtTkns, i, true); // Run backward part and compute gradients computeGraph_i.Backward(); } lock (locker) { cost += lcost; srcWordCnts += sLenInBatch; tgtWordCnts += tLenInBatch; tlen += tLenInBatch; processedLineInTotal += sntPairBatch_i.BatchSize; processedLine += sntPairBatch_i.BatchSize; } }); //Sum up gradients in all devices, and kept it in default device for parameters optmization SumGradientsToTensorsInDefaultDevice(); //Optmize parameters float lr = learningRate.GetCurrentLearningRate(); List <IWeightTensor> models = GetParametersFromDefaultDevice(); solver.UpdateWeights(models, processedLine, lr, m_regc, m_weightsUpdateCount + 1); //Clear gradient over all devices ZeroGradientOnAllDevices(); costInTotal += cost; avgCostPerWordInTotal = costInTotal / tgtWordCnts; m_weightsUpdateCount++; if (IterationDone != null && m_weightsUpdateCount % 100 == 0) { IterationDone(this, new CostEventArg() { LearningRate = lr, CostPerWord = cost / tlen, AvgCostInTotal = avgCostPerWordInTotal, Epoch = ep, Update = m_weightsUpdateCount, ProcessedSentencesInTotal = processedLineInTotal, ProcessedWordsInTotal = srcWordCnts + tgtWordCnts, StartDateTime = startDateTime }); } // Evaluate model every hour and save it if we could get a better one. TimeSpan ts = DateTime.Now - lastCheckPointDateTime; if (ts.TotalHours > 1.0) { CreateCheckPoint(validCorpus, metrics, modelMetaData, ForwardOnSingleDevice, avgCostPerWordInTotal); lastCheckPointDateTime = DateTime.Now; } sntPairBatchs.Clear(); } } Logger.WriteLine(Logger.Level.info, ConsoleColor.Green, $"Epoch '{ep}' took '{DateTime.Now - startDateTime}' time to finish. AvgCost = {avgCostPerWordInTotal.ToString("F6")}, AvgCostInLastEpoch = {m_avgCostPerWordInTotalInLastEpoch.ToString("F6")}"); CreateCheckPoint(validCorpus, metrics, modelMetaData, ForwardOnSingleDevice, avgCostPerWordInTotal); m_avgCostPerWordInTotalInLastEpoch = avgCostPerWordInTotal; }
internal void TrainOneEpoch(int ep, IEnumerable <SntPairBatch> trainCorpus, IEnumerable <SntPairBatch> validCorpus, ILearningRate learningRate, AdamOptimizer solver, List <IMetric> metrics, IModelMetaData modelMetaData, Func <IComputeGraph, List <List <string> >, List <List <string> >, int, bool, float> ForwardOnSingleDevice) { int processedLineInTotal = 0; DateTime startDateTime = DateTime.Now; double costInTotal = 0.0; long srcWordCntsInTotal = 0; long tgtWordCntsInTotal = 0; double avgCostPerWordInTotal = 0.0; Logger.WriteLine($"Start to process training corpus."); List <SntPairBatch> sntPairBatchs = new List <SntPairBatch>(); foreach (SntPairBatch sntPairBatch in trainCorpus) { sntPairBatchs.Add(sntPairBatch); if (sntPairBatchs.Count == m_deviceIds.Length) { // Copy weights from weights kept in default device to all other devices CopyWeightsFromDefaultDeviceToAllOtherDevices(); int batchSplitFactor = 1; bool runNetwordSuccssed = false; while (runNetwordSuccssed == false) { try { (float cost, int sWordCnt, int tWordCnt, int processedLine) = RunNetwork(ForwardOnSingleDevice, sntPairBatchs, batchSplitFactor); processedLineInTotal += processedLine; srcWordCntsInTotal += sWordCnt; tgtWordCntsInTotal += tWordCnt; //Sum up gradients in all devices, and kept it in default device for parameters optmization SumGradientsToTensorsInDefaultDevice(); //Optmize parameters float lr = learningRate.GetCurrentLearningRate(); List <IWeightTensor> models = GetParametersFromDefaultDevice(); solver.UpdateWeights(models, processedLine, lr, m_regc, m_weightsUpdateCount + 1); costInTotal += cost; avgCostPerWordInTotal = costInTotal / tgtWordCntsInTotal; m_weightsUpdateCount++; if (IterationDone != null && m_weightsUpdateCount % 100 == 0) { IterationDone(this, new CostEventArg() { LearningRate = lr, CostPerWord = cost / tWordCnt, AvgCostInTotal = avgCostPerWordInTotal, Epoch = ep, Update = m_weightsUpdateCount, ProcessedSentencesInTotal = processedLineInTotal, ProcessedWordsInTotal = srcWordCntsInTotal + tgtWordCntsInTotal, StartDateTime = startDateTime }); } runNetwordSuccssed = true; } catch (AggregateException err) { if (err.InnerExceptions != null) { string oomMessage = String.Empty; bool isOutOfMemException = false; bool isArithmeticException = false; foreach (var excep in err.InnerExceptions) { if (excep is OutOfMemoryException) { isOutOfMemException = true; oomMessage = excep.Message; break; } else if (excep is ArithmeticException) { isArithmeticException = true; oomMessage = excep.Message; break; } } if (isOutOfMemException) { batchSplitFactor *= 2; Logger.WriteLine($"Got an exception ('{oomMessage}'), so we increase batch split factor to {batchSplitFactor}, and retry it."); if (batchSplitFactor >= sntPairBatchs[0].BatchSize) { Logger.WriteLine($"Batch split factor is larger than batch size, so ignore current mini-batch."); break; } } else if (isArithmeticException) { Logger.WriteLine($"Arithmetic exception: '{err.Message}'"); break; } else { Logger.WriteLine(Logger.Level.err, ConsoleColor.Red, $"Exception: {err.Message}, Call stack: {err.StackTrace}"); throw err; } } else { Logger.WriteLine(Logger.Level.err, ConsoleColor.Red, $"Exception: {err.Message}, Call stack: {err.StackTrace}"); throw err; } } catch (OutOfMemoryException err) { batchSplitFactor *= 2; Logger.WriteLine($"Got an exception ('{err.Message}'), so we increase batch split factor to {batchSplitFactor}, and retry it."); if (batchSplitFactor >= sntPairBatchs[0].BatchSize) { Logger.WriteLine($"Batch split factor is larger than batch size, so ignore current mini-batch."); break; } } catch (ArithmeticException err) { Logger.WriteLine($"Arithmetic exception: '{err.Message}'"); break; } catch (Exception err) { Logger.WriteLine(Logger.Level.err, ConsoleColor.Red, $"Exception: {err.Message}, Call stack: {err.StackTrace}"); throw err; } } // Evaluate model every hour and save it if we could get a better one. TimeSpan ts = DateTime.Now - m_lastCheckPointDateTime; if (ts.TotalHours > 1.0) { CreateCheckPoint(validCorpus, metrics, modelMetaData, ForwardOnSingleDevice, avgCostPerWordInTotal); m_lastCheckPointDateTime = DateTime.Now; } sntPairBatchs.Clear(); } } Logger.WriteLine(Logger.Level.info, ConsoleColor.Green, $"Epoch '{ep}' took '{DateTime.Now - startDateTime}' time to finish. AvgCost = {avgCostPerWordInTotal.ToString("F6")}, AvgCostInLastEpoch = {m_avgCostPerWordInTotalInLastEpoch.ToString("F6")}"); // CreateCheckPoint(validCorpus, metrics, modelMetaData, ForwardOnSingleDevice, avgCostPerWordInTotal); m_avgCostPerWordInTotalInLastEpoch = avgCostPerWordInTotal; }
static void Main(string[] args) { Logger.LogFile = $"{nameof(Seq2SeqConsole)}_{GetTimeStamp(DateTime.Now)}.log"; ShowOptions(args); //Parse command line Options opts = new Options(); ArgParser argParser = new ArgParser(args, opts); if (String.IsNullOrEmpty(opts.ConfigFilePath) == false) { Logger.WriteLine($"Loading config file from '{opts.ConfigFilePath}'"); opts = JsonConvert.DeserializeObject <Options>(File.ReadAllText(opts.ConfigFilePath)); } AttentionSeq2Seq ss = null; ProcessorTypeEnums processorType = (ProcessorTypeEnums)Enum.Parse(typeof(ProcessorTypeEnums), opts.ProcessorType); EncoderTypeEnums encoderType = (EncoderTypeEnums)Enum.Parse(typeof(EncoderTypeEnums), opts.EncoderType); ModeEnums mode = (ModeEnums)Enum.Parse(typeof(ModeEnums), opts.TaskName); //Parse device ids from options int[] deviceIds = opts.DeviceIds.Split(',').Select(x => int.Parse(x)).ToArray(); if (mode == ModeEnums.Train) { // Load train corpus ParallelCorpus trainCorpus = new ParallelCorpus(opts.TrainCorpusPath, opts.SrcLang, opts.TgtLang, opts.BatchSize, opts.ShuffleBlockSize, opts.MaxSentLength); // Load valid corpus ParallelCorpus validCorpus = String.IsNullOrEmpty(opts.ValidCorpusPath) ? null : new ParallelCorpus(opts.ValidCorpusPath, opts.SrcLang, opts.TgtLang, opts.BatchSize, opts.ShuffleBlockSize, opts.MaxSentLength); // Load or build vocabulary Vocab vocab = null; if (!String.IsNullOrEmpty(opts.SrcVocab) && !String.IsNullOrEmpty(opts.TgtVocab)) { // Vocabulary files are specified, so we load them vocab = new Vocab(opts.SrcVocab, opts.TgtVocab); } else { // We don't specify vocabulary, so we build it from train corpus vocab = new Vocab(trainCorpus); } // Create learning rate ILearningRate learningRate = new DecayLearningRate(opts.StartLearningRate, opts.WarmUpSteps, opts.WeightsUpdateCount); // Create optimizer AdamOptimizer optimizer = new AdamOptimizer(opts.GradClip, opts.Beta1, opts.Beta2); // Create metrics List <IMetric> metrics = new List <IMetric>(); metrics.Add(new BleuMetric()); metrics.Add(new LengthRatioMetric()); if (File.Exists(opts.ModelFilePath) == false) { //New training ss = new AttentionSeq2Seq(embeddingDim: opts.WordVectorSize, hiddenDim: opts.HiddenSize, encoderLayerDepth: opts.EncoderLayerDepth, decoderLayerDepth: opts.DecoderLayerDepth, srcEmbeddingFilePath: opts.SrcEmbeddingModelFilePath, tgtEmbeddingFilePath: opts.TgtEmbeddingModelFilePath, vocab: vocab, modelFilePath: opts.ModelFilePath, dropoutRatio: opts.DropoutRatio, processorType: processorType, deviceIds: deviceIds, multiHeadNum: opts.MultiHeadNum, encoderType: encoderType); } else { //Incremental training Logger.WriteLine($"Loading model from '{opts.ModelFilePath}'..."); ss = new AttentionSeq2Seq(modelFilePath: opts.ModelFilePath, processorType: processorType, dropoutRatio: opts.DropoutRatio, deviceIds: deviceIds); } // Add event handler for monitoring ss.IterationDone += ss_IterationDone; // Kick off training ss.Train(maxTrainingEpoch: opts.MaxEpochNum, trainCorpus: trainCorpus, validCorpus: validCorpus, learningRate: learningRate, optimizer: optimizer, metrics: metrics); } else if (mode == ModeEnums.Valid) { Logger.WriteLine($"Evaluate model '{opts.ModelFilePath}' by valid corpus '{opts.ValidCorpusPath}'"); // Create metrics List <IMetric> metrics = new List <IMetric>(); metrics.Add(new BleuMetric()); metrics.Add(new LengthRatioMetric()); // Load valid corpus ParallelCorpus validCorpus = new ParallelCorpus(opts.ValidCorpusPath, opts.SrcLang, opts.TgtLang, opts.BatchSize, opts.ShuffleBlockSize, opts.MaxSentLength); ss = new AttentionSeq2Seq(modelFilePath: opts.ModelFilePath, processorType: processorType, deviceIds: deviceIds); ss.Valid(validCorpus: validCorpus, metrics: metrics); } else if (mode == ModeEnums.Test) { Logger.WriteLine($"Test model '{opts.ModelFilePath}' by input corpus '{opts.InputTestFile}'"); //Test trained model ss = new AttentionSeq2Seq(modelFilePath: opts.ModelFilePath, processorType: processorType, deviceIds: deviceIds); List <string> outputLines = new List <string>(); var data_sents_raw1 = File.ReadAllLines(opts.InputTestFile); foreach (string line in data_sents_raw1) { //// Below support beam search //List<List<string>> outputWordsList = ss.Predict(line.ToLower().Trim().Split(' ').ToList(), opts.BeamSearch); //outputLines.AddRange(outputWordsList.Select(x => String.Join(" ", x))); var outputTokensBatch = ss.Test(ParallelCorpus.ConstructInputTokens(line.ToLower().Trim().Split(' ').ToList())); outputLines.AddRange(outputTokensBatch.Select(x => String.Join(" ", x))); } File.WriteAllLines(opts.OutputTestFile, outputLines); } else if (mode == ModeEnums.VisualizeNetwork) { ss = new AttentionSeq2Seq(embeddingDim: opts.WordVectorSize, hiddenDim: opts.HiddenSize, encoderLayerDepth: opts.EncoderLayerDepth, decoderLayerDepth: opts.DecoderLayerDepth, vocab: new Vocab(), srcEmbeddingFilePath: null, tgtEmbeddingFilePath: null, modelFilePath: opts.ModelFilePath, dropoutRatio: opts.DropoutRatio, processorType: processorType, deviceIds: new int[1] { 0 }, multiHeadNum: opts.MultiHeadNum, encoderType: encoderType); ss.VisualizeNeuralNetwork(opts.VisualizeNNFilePath); } else { argParser.Usage(); } }
/// <summary> /// Solves y = x * W + b (CPU single version) /// for y = 1 and x = -2 /// /// This also demonstrates how to save and load a graph /// </summary> public static void Example1() { var cns = new ConvNetSharp <float>(); // Graph creation Op <float> cost; Op <float> fun; if (File.Exists("test.graphml")) { Console.WriteLine("Loading graph from disk."); var ops = SerializationExtensions.Load <float>("test", true); fun = ops[0]; cost = ops[1]; } else { var x = cns.PlaceHolder("x"); var y = cns.PlaceHolder("y"); var W = cns.Variable(1.0f, "W", true); var b = cns.Variable(2.0f, "b", true); fun = x * W + b; cost = (fun - y) * (fun - y); } var optimizer = new AdamOptimizer <float>(cns, 0.01f, 0.9f, 0.999f, 1e-08f); using (var session = new Session <float>()) { session.Differentiate(cost); // computes dCost/dW at every node of the graph float currentCost; do { var dico = new Dictionary <string, Volume <float> > { { "x", -2.0f }, { "y", 1.0f } }; currentCost = session.Run(cost, dico); Console.WriteLine($"cost: {currentCost}"); var result = session.Run(fun, dico); session.Run(optimizer, dico); } while (currentCost > 1e-5); float finalW = session.GetVariableByName(fun, "W").Result; float finalb = session.GetVariableByName(fun, "b").Result; Console.WriteLine($"fun = x * {finalW} + {finalb}"); fun.Save("test", cost); //// Display graph //var vm = new ViewModel<float>(cost); //var app = new Application(); //app.Run(new GraphControl { DataContext = vm }); } Console.ReadKey(); }
public CharRNNModel(CharRNNModelParameters parameters, bool training = true) { this.parameters = parameters ?? throw new ArgumentNullException(nameof(parameters)); if (!training) { this.parameters.BatchSize = 1; this.parameters.SeqLength = 1; } if (!ModelTypeToCellFunction.TryGetValue(parameters.ModelType, out this.cellFactory)) { throw new NotSupportedException(parameters.ModelType.ToString()); } for (int i = 0; i < parameters.LayerCount; i++) { RNNCell cell = this.cellFactory(parameters.RNNSize); if (training && (parameters.KeepOutputProbability < 1 || parameters.KeepInputProbability < 1)) { cell = new DropoutWrapper(cell, input_keep_prob: parameters.KeepInputProbability, output_keep_prob: parameters.KeepOutputProbability); } this.cells.Add(cell); } this.rnn = new MultiRNNCell(this.cells, state_is_tuple: true); this.inputData = tf.placeholder(tf.int32, new TensorShape(parameters.BatchSize, parameters.SeqLength)); this.targets = tf.placeholder(tf.int32, new TensorShape(parameters.BatchSize, parameters.SeqLength)); this.initialState = this.rnn.zero_state(parameters.BatchSize, tf.float32); Variable softmax_W = null, softmax_b = null; new variable_scope("rnnlm").UseSelf(_ => { softmax_W = tf.get_variable("softmax_w", new TensorShape(parameters.RNNSize, parameters.VocabularySize)); softmax_b = tf.get_variable("softmax_b", new TensorShape(parameters.VocabularySize)); }); Variable embedding = tf.get_variable("embedding", new TensorShape(parameters.VocabularySize, parameters.RNNSize)); Tensor input = tf.nn.embedding_lookup(embedding, this.inputData); // dropout beta testing: double check which one should affect next line if (training && parameters.KeepOutputProbability < 1) { input = tf.nn.dropout(input, parameters.KeepOutputProbability); } PythonList <Tensor> inputs = tf.split(input, parameters.SeqLength, axis: 1); inputs = inputs.Select(i => (Tensor)tf.squeeze(i, axis: 1)).ToPythonList(); dynamic Loop(dynamic prev, dynamic _) { prev = tf.matmul(prev, softmax_W) + softmax_b; var prevSymbol = tf.stop_gradient(tf.argmax(prev, 1)); return(tf.nn.embedding_lookup(embedding, prevSymbol)); } var decoder = tensorflow.contrib.legacy_seq2seq.legacy_seq2seq.rnn_decoder(inputs, initialState.Items().Cast <object>(), this.rnn, loop_function: training ? null : PythonFunctionContainer.Of(new Func <dynamic, dynamic, dynamic>(Loop)), scope: "rnnlm"); var outputs = decoder.Item1; var lastState = (seq2seqState)decoder.Item2; dynamic contatenatedOutputs = tf.concat(outputs, 1); var output = tensorflow.tf.reshape(contatenatedOutputs, new[] { -1, parameters.RNNSize }); this.logits = tf.matmul(output, softmax_W) + softmax_b; this.probs = tf.nn.softmax(new[] { this.logits }); this.loss = tensorflow.contrib.legacy_seq2seq.legacy_seq2seq.sequence_loss_by_example( new[] { this.logits }, new[] { tf.reshape(targets, new[] { -1 }) }, new[] { tf.ones(new[] { parameters.BatchSize *parameters.SeqLength }) }); Tensor cost = null; new name_scope("cost").UseSelf(_ => { cost = tf.reduce_sum(this.loss) / parameters.BatchSize / parameters.SeqLength; }); this.cost = cost; this.finalState = lastState; this.learningRate = new Variable(0.0, trainable: false); var tvars = tf.trainable_variables(); IEnumerable <object> grads = tf.clip_by_global_norm(tf.gradients(this.cost, tvars), parameters.GradientClip).Item1; AdamOptimizer optimizer = null; new name_scope("optimizer").UseSelf(_ => optimizer = new AdamOptimizer(this.learningRate)); this.trainOp = optimizer.apply_gradients(grads.Zip(tvars, (grad, @var) => (dynamic)(grad, @var))); tf.summary.histogram("logits", new[] { this.logits }); tf.summary.histogram("loss", new[] { this.loss }); tf.summary.histogram("train_loss", new[] { this.cost }); }