public override int Run(string[] remainingArguments) { this.CheckRequiredArguments(); if (remainingArguments.Length < 1) { throw new ArgumentNullException("dataset"); } string modelPath = CommonCommandOptions.ExpandModelNameToPathOrExit(this.ModelName); string checkpoint = Gpt2Checkpoints.ProcessCheckpointConfig( modelPath: modelPath, checkpoint: this.Checkpoint, runName: this.RunName); var encoder = Gpt2Encoder.LoadEncoder(modelPath); string searchPattern = this.Include ?? "*"; string datasetName = remainingArguments[0]; var dataset = searchPattern.EndsWith("*.csv") ? LoadCsv(encoder, root: datasetName, field: this.ColumnName ?? throw new ArgumentException("column must be specified for training on .csv files")) : Gpt2Dataset.LoadDataset(encoder, path: datasetName, pattern: searchPattern); if (dataset.Count == 0) { Console.Error.WriteLine("The dataset is empty!"); return(-1); } var hParams = Gpt2Model.LoadHParams(modelPath); var random = this.Seed is null ? new Random() : new Random(this.Seed.Value); tf.random.set_seed(this.Seed); var stop = new CancellationTokenSource(); Console.CancelKeyPress += delegate { stop.Cancel(); }; dynamic config = config_pb2.ConfigProto.CreateInstance(); config.gpu_options.allow_growth = true; var trainer = new Gpt2TunerLegacy(dataset, encoder, hParams, this.BatchSize, this.SampleLength, random) { SaveEvery = this.SaveEvery, SampleNum = this.SampleNum, SampleEvery = this.SampleEvery, }; string checkpointOutputDirectory = Path.Combine(modelPath, Gpt2Checkpoints.CheckpointDir); trainer.FineTune( checkpointsDir: checkpointOutputDirectory, checkpoint: checkpoint, run: this.RunName, counter: checkpoint == "fresh" ? 1 : (int?)null, sessionConfig: config, cancellation: stop.Token); return(0); }
/// <summary> /// Interactively run the model /// </summary> /// <param name="modelName">Which model to use</param> /// <param name="checkpoint">Which checkpoint to load</param> /// <param name="seed">Seed for random number generators, fix seed to reproduce results</param> /// <param name="sampleCount">Number of samples to return total</param> /// <param name="batchSize">Number of batches (only affects speed/memory). Must divide sampleCount.</param> /// <param name="length">Number of tokens in generated text, if null (default), is /// determined by model hyperparameters</param> /// <param name="temperature">randomness in boltzmann distribution. /// Lower temperature results in less random completions. As the /// temperature approaches zero, the model will become deterministic and /// repetitive. Higher temperature results in more random completions.</param> /// <param name="topK">Controls diversity. 1 means only 1 word is /// considered for each step (token), resulting in deterministic completions, /// while 40 means 40 words are considered at each step. 0 (default) is a /// special setting meaning no restrictions. 40 generally is a good value. /// </param> public static int Run(string modelName = "117M", string?checkpoint = null, int?seed = null, int sampleCount = 1, int batchSize = 1, int?length = null, float temperature = 1, int topK = 0) { if (sampleCount % batchSize != 0) { throw new ArgumentException(); } string modelPath = CommonCommandOptions.ExpandModelNameToPathOrExit(modelName); var encoder = Gpt2Encoder.LoadEncoder(modelPath); var hParams = Gpt2Model.LoadHParams(modelPath); int nCtx = hParams.ContextTokens; if (length is null) { length = nCtx; } else if (length > nCtx) { throw new ArgumentException("Can't get samples longer than window size: " + nCtx); } foreach (var gpu in tf.config.list_physical_devices("gpu")) { tf.config.experimental.set_memory_growth(gpu, true); } var sess = new Session(graph: new Graph()); using var sessionContext = sess.StartUsing(); Tensor context = v1.placeholder(tf.int32, new TensorShape(batchSize, null)); tf.random.set_seed(seed); Tensor output = Gpt2Sampler.SampleSequence( hParams: hParams, length: length.Value, context: context, batchSize: batchSize, temperature: temperature, topK: topK); var saver = new Saver(); checkpoint ??= tf.train.latest_checkpoint(modelPath); saver.restore(sess, checkpoint); bool interrupted = false; Console.CancelKeyPress += (object sender, ConsoleCancelEventArgs args) => Volatile.Write(ref interrupted, args.Cancel = true); while (!interrupted) { string text; do { Console.Write("Model prompt >>> "); text = Console.ReadLine(); if (Volatile.Read(ref interrupted)) { break; } if (string.IsNullOrEmpty(text)) { Console.WriteLine("Prompt should not be empty"); } } while (!Volatile.Read(ref interrupted) && string.IsNullOrEmpty(text)); if (Volatile.Read(ref interrupted)) { break; } var contextTokens = encoder.Encode(text); if (!tf.test.is_gpu_available() && contextTokens.Count >= length.Value) { Console.Error.WriteLine(); Console.Error.WriteLine("Prompt is too long."); Console.Error.WriteLine(); continue; } int generated = 0; foreach (int _ in Enumerable.Range(0, sampleCount / batchSize)) { ndarray <int> @out; try { @out = sess.run(output, feed_dict: new Dictionary <object, object> { [context] = Enumerable.Repeat(contextTokens, batchSize).ToArray(), })[.., contextTokens.Count..]; } catch (InvalidArgumentError ex) { throw new ArgumentOutOfRangeException( "Unable to generate sequence of desired length. " + "Try lowering length by passing -l (-sample-length) parameter. " + "Current length: " + length.Value, innerException: ex); } foreach (int i in Enumerable.Range(0, batchSize)) { generated++; var part = @out[i].AsArray(); text = encoder.Decode(part); Console.WriteLine($"{Delimiter} SAMPLE {generated} {Delimiter}"); Console.WriteLine(text); } } Console.Write(Delimiter); Console.WriteLine(Delimiter); } return(0); }