public override int Run(string[] remainingArguments) { this.CheckRequiredArguments(); if (remainingArguments.Length < 1) { throw new ArgumentNullException("dataset"); } string datasetName = remainingArguments[0]; string checkpoint = Gpt2Trainer.ProcessCheckpointConfig(this.Checkpoint, modelName: this.ModelName, runName: this.RunName); var encoder = Gpt2Encoder.LoadEncoder(this.ModelName); string searchPattern = this.Include ?? "*"; var dataset = searchPattern.EndsWith("*.csv") ? LoadCsv(encoder, root: datasetName, field: this.ColumnName) : Gpt2Trainer.LoadDataset(encoder, path: datasetName, pattern: searchPattern); var hParams = Gpt2Model.LoadHParams(this.ModelName); var random = this.Seed == null ? new Random() : new Random(this.Seed.Value); var stop = new CancellationTokenSource(); Console.CancelKeyPress += delegate { stop.Cancel(); }; new Gpt2Trainer(dataset, encoder, hParams, this.BatchSize, this.SampleLength, random) .Train(checkpoint, this.RunName, stop.Token); return(0); }
public override int Run(string[] remainingArguments) { this.CheckRequiredArguments(); if (remainingArguments.Length < 1) { throw new ArgumentNullException("dataset"); } string datasetName = remainingArguments[0]; string checkpoint = Gpt2Checkpoints.ProcessCheckpointConfig( gpt2Root: Environment.CurrentDirectory, checkpoint: this.Checkpoint, modelName: this.ModelName, runName: this.RunName); var encoder = Gpt2Encoder.LoadEncoder(this.ModelName); string searchPattern = this.Include ?? "*"; var dataset = searchPattern.EndsWith("*.csv") ? LoadCsv(encoder, root: datasetName, field: this.ColumnName) : Gpt2Dataset.LoadDataset(encoder, path: datasetName, pattern: searchPattern); if (dataset.Count == 0) { Console.Error.WriteLine("The dataset is empty!"); return(-1); } var hParams = Gpt2Model.LoadHParams(this.ModelName); var random = this.Seed == null ? new Random() : new Random(this.Seed.Value); var stop = new CancellationTokenSource(); Console.CancelKeyPress += delegate { stop.Cancel(); }; dynamic config = config_pb2.ConfigProto(); config.gpu_options.allow_growth = true; new Gpt2Trainer(dataset, encoder, hParams, this.BatchSize, this.SampleLength, random) { SaveEvery = this.SaveEvery, SampleNum = this.SampleNum, SampleEvery = this.SampleEvery, } .Train(checkpoint, this.RunName, sessionConfig: config, counter: checkpoint == "fresh" ? 1 : (int?)null, cancellation: stop.Token); return(0); }
/// <summary> /// Interactively run the model /// </summary> /// <param name="modelName">Which model to use</param> /// <param name="checkpoint">Which checkpoint to load</param> /// <param name="seed">Seed for random number generators, fix seed to reproduce results</param> /// <param name="sampleCount">Number of samples to return total</param> /// <param name="batchSize">Number of batches (only affects speed/memory). Must divide sampleCount.</param> /// <param name="length">Number of tokens in generated text, if null (default), is /// determined by model hyperparameters</param> /// <param name="temperature">randomness in boltzmann distribution. /// Lower temperature results in less random completions. As the /// temperature approaches zero, the model will become deterministic and /// repetitive. Higher temperature results in more random completions.</param> /// <param name="topK">Controls diversity. 1 means only 1 word is /// considered for each step (token), resulting in deterministic completions, /// while 40 means 40 words are considered at each step. 0 (default) is a /// special setting meaning no restrictions. 40 generally is a good value. /// </param> public static void Run(string modelName = "117M", string checkpoint = null, int?seed = null, int sampleCount = 1, int batchSize = 1, int?length = null, float temperature = 1, int topK = 0) { if (sampleCount % batchSize != 0) { throw new ArgumentException(); } var encoder = Gpt2Encoder.LoadEncoder(modelName); var hParams = Gpt2Model.LoadHParams(modelName); int nCtx = ((dynamic)hParams).n_ctx; if (length is null) { length = nCtx; } else if (length > nCtx) { throw new ArgumentException("Can't get samples longer than window size: " + hParams.get("n_ctx")); } new Session(graph: new Graph()).UseSelf(sess => { var context = tf.placeholder(tf.int32, new TensorShape(batchSize, null)); tf.set_random_seed(seed); var output = Gpt2Sampler.SampleSequence( hParams: hParams, length: length.Value, context: context, batchSize: batchSize, temperature: temperature, topK: topK); var saver = new Saver(); checkpoint = checkpoint ?? tf.train.latest_checkpoint(Path.Combine("models", modelName)); saver.restore(sess, checkpoint); while (true) { string text; do { Console.Write("Model prompt >>> "); text = Console.ReadLine(); if (string.IsNullOrEmpty(text)) { Console.WriteLine("Prompt should not be empty"); } } while (string.IsNullOrEmpty(text)); var contextTokens = encoder.Encode(text); int generated = 0; foreach (var _ in Enumerable.Range(0, sampleCount / batchSize)) { var @out = sess.run(output, feed_dict: new PythonDict <object, object> { [context] = Enumerable.Repeat(contextTokens, batchSize), })[Range.All, Range.StartAt(contextTokens.Count)]; foreach (int i in Enumerable.Range(0, batchSize)) { generated++; ndarray part = @out[i]; text = encoder.Decode(part); Console.WriteLine($"{Delimiter} SAMPLE {generated} {Delimiter}"); Console.WriteLine(text); } } Console.Write(Delimiter); Console.WriteLine(Delimiter); } }); }
public void Train(string checkpoint, string run, int?counter, dynamic sessionConfig = null, CancellationToken cancellation = default) { Session sess = sessionConfig == null ? Session.NewDyn(config : sessionConfig) : new Session(); sess.UseSelf(session => { var context = tf.placeholder(tf.int32, new TensorShape(this.batchSize, null)); var output = Gpt2Model.Model(this.hParams, input: context); Tensor labels = context[Range.All, Range.StartAt(1)]; Tensor logits = output["logits"][Range.All, Range.EndAt(new Index(1, fromEnd: true))]; var loss = tf.reduce_mean( tf.nn.sparse_softmax_cross_entropy_with_logits_dyn( labels: labels, logits: logits)); var sample = Gpt2Sampler.SampleSequence( this.hParams, length: this.sampleLength, context: context, batchSize: this.batchSize, temperature: 1.0f, topK: 40); var trainVars = tf.trainable_variables().Where((dynamic var) => var.name.Contains("model")); var optimizer = new AdamOptimizer(learning_rate: 0.0002).minimize(loss, var_list: trainVars); var saver = new Saver( var_list: trainVars, max_to_keep: 5, keep_checkpoint_every_n_hours: 1); session.run(tf.global_variables_initializer()); Console.WriteLine("Loading checkpoint " + checkpoint); saver.restore(session, checkpoint); Console.WriteLine("Loading dataset..."); var sampler = new TrainingSampler(this.dataset, this.random); Console.WriteLine($"Dataset has {sampler.TokenCount} tokens"); string counterFile = Path.Combine(Gpt2Checkpoints.CheckpointDir, run, "counter"); if (counter == null && File.Exists(counterFile)) { counter = int.Parse(File.ReadAllText(counterFile), CultureInfo.InvariantCulture) + 1; } counter = counter ?? 1; string runCheckpointDir = Path.Combine(Gpt2Checkpoints.CheckpointDir, run); string runSampleDir = Path.Combine(SampleDir, run); void Save() { Directory.CreateDirectory(runCheckpointDir); Console.WriteLine("Saving " + Path.Combine(runCheckpointDir, Invariant($"model-{counter}"))); saver.save(session, Path.Combine(runCheckpointDir, "model"), global_step: counter.Value); File.WriteAllText(path: counterFile, contents: Invariant($"{counter}")); } void GenerateSamples() { var contextTokens = np.array(new[] { this.encoder.EncodedEndOfText }); var allText = new List <string>(); int index = 0; string text = null; while (index < this.SampleNum) { var @out = session.run(sample, feed_dict: new PythonDict <object, object> { [context] = Enumerable.Repeat(contextTokens, this.batchSize), }); foreach (int i in Enumerable.Range(0, Math.Min(this.SampleNum - index, this.batchSize))) { text = this.encoder.Decode(@out[i]); text = Invariant($"======== SAMPLE {index + 1} ========\n{text}\n"); allText.Add(text); index++; } } Console.WriteLine(text); Directory.CreateDirectory(runSampleDir); File.WriteAllLines( path: Path.Combine(runSampleDir, Invariant($"samples-{counter}")), contents: allText); } var avgLoss = (0.0, 0.0); var startTime = DateTime.Now; while (!cancellation.IsCancellationRequested) { if (counter % this.SaveEvery == 0) { Save(); } if (counter % this.SampleEvery == 0) { GenerateSamples(); } var batch = Enumerable.Range(0, this.batchSize) .Select(_ => sampler.Sample(1024)) .ToArray(); var placeholderValues = new PythonDict <object, object> { [context] = batch, }; var tuple = session.run_dyn((optimizer, loss), feed_dict: placeholderValues); var lv = tuple.Item2; avgLoss = (avgLoss.Item1 * 0.99 + lv, avgLoss.Item2 * 0.99 + 1); Console.WriteLine($"[{counter} | {DateTime.Now-startTime}] loss={lv} avg={avgLoss.Item1/avgLoss.Item2}"); counter++; } Console.WriteLine("Interrupted"); Save(); }); }
public static Tensor SampleSequence(HParams hParams, int length, string startToken = null, int?batchSize = null, dynamic context = null, float temperature = 1, int topK = 0) { if (((startToken == null) ^ (context == null)) == false) { throw new ArgumentException($"Exactly one of {nameof(startToken)} or {nameof(context)} has to be specified"); } SortedDictionary <string, dynamic> Step(HParams @params, Tensor tokens, dynamic past = null) { var lmOutput = Gpt2Model.Model(hParams: @params, input: tokens, past: past, reuse: _ReuseMode.AUTO_REUSE); var logits = lmOutput["logits"][Range.All, Range.All, Range.EndAt((int)@params.get("n_vocab"))]; Tensor presents = lmOutput["present"]; int?[] pastShape = Gpt2Model.PastShape(hParams: @params, batchSize: batchSize); presents.set_shape_(pastShape.Cast <object>()); return(new SortedDictionary <string, object> { ["logits"] = logits, ["presents"] = presents, }); } Tensor result = null; new name_scope("sample_sequence").Use(_ => { // Don't feed the last context token -- leave that to the loop below // TODO: Would be slightly faster if we called step on the entire context, // rather than leaving the last token transformer calculation to the while loop. var contextOutput = Step(hParams, context[Range.All, Range.EndAt(new Index(1, fromEnd: true))]); Tensor[] Body(object past, dynamic prev, object output) { var nextOutputs = Step(hParams, prev[Range.All, tf.newaxis], past: past); Tensor logits = nextOutputs["logits"][Range.All, -1, Range.All] / tf.to_float(temperature); logits = TopLogits(logits, topK: topK); var samples = tf.multinomial_dyn(logits, num_samples: 1, output_dtype: tf.int32); return(new Tensor[] { tf.concat(new [] { past, nextOutputs["presents"] }, axis: -2), tf.squeeze(samples, axis: new[] { 1 }), tf.concat(new [] { output, samples }, axis: 1), }); } bool True(object _a, object _b, object _c) => true; dynamic[] loopVars = new[] { contextOutput["presents"], context[Range.All, -1], context, }; TensorShape[] shapeInvariants = new[] { new TensorShape(Gpt2Model.PastShape(hParams: hParams, batchSize: batchSize)), new TensorShape(batchSize), new TensorShape((int?)batchSize, (int?)null), }; result = tf.while_loop( cond: PythonFunctionContainer.Of <object, object, object, bool>(True), body: PythonFunctionContainer.Of(new Func <object, object, object, Tensor[]>(Body)), parallel_iterations: 10, swap_memory: false, name: null, maximum_iterations: tf.constant(length), loop_vars: loopVars, shape_invariants: shapeInvariants, back_prop: false) [2]; }); return(result); }