static DataSet LoadCsv(Gpt2Encoder encoder, string root, string field) { var result = new List <string>(); foreach (string file in Directory .EnumerateFiles(root, "*.csv", SearchOption.AllDirectories)) { using (var reader = new CsvHelper.CsvReader(new StreamReader(file, Encoding.UTF8), new CsvHelper.Configuration.Configuration { Delimiter = ",", HasHeaderRecord = true, })) { reader.Read(); reader.ReadHeader(); while (reader.Read()) { string entry = reader.GetField(field); System.Diagnostics.Debug.Assert(reader.GetField(0).Length < 300); if (!string.IsNullOrWhiteSpace(entry)) { result.Add(entry); } } } } return(Load(encoder, result)); }
public override int Run(string[] remainingArguments) { this.CheckRequiredArguments(); if (remainingArguments.Length < 1) { throw new ArgumentNullException("dataset"); } string datasetName = remainingArguments[0]; string checkpoint = Gpt2Trainer.ProcessCheckpointConfig(this.Checkpoint, modelName: this.ModelName, runName: this.RunName); var encoder = Gpt2Encoder.LoadEncoder(this.ModelName); string searchPattern = this.Include ?? "*"; var dataset = searchPattern.EndsWith("*.csv") ? LoadCsv(encoder, root: datasetName, field: this.ColumnName) : Gpt2Trainer.LoadDataset(encoder, path: datasetName, pattern: searchPattern); var hParams = Gpt2Model.LoadHParams(this.ModelName); var random = this.Seed == null ? new Random() : new Random(this.Seed.Value); var stop = new CancellationTokenSource(); Console.CancelKeyPress += delegate { stop.Cancel(); }; new Gpt2Trainer(dataset, encoder, hParams, this.BatchSize, this.SampleLength, random) .Train(checkpoint, this.RunName, stop.Token); return(0); }
public Gpt2Trainer(DataSet dataset, Gpt2Encoder encoder, HParams hParams, int batchSize, int sampleLength, Random random) { this.dataset = dataset ?? throw new ArgumentNullException(nameof(dataset)); this.encoder = encoder ?? throw new ArgumentNullException(nameof(encoder)); this.hParams = hParams ?? throw new ArgumentNullException(nameof(hParams)); this.batchSize = batchSize; this.sampleLength = sampleLength; this.random = random ?? throw new ArgumentNullException(nameof(random)); }
public override int Run(string[] remainingArguments) { this.CheckRequiredArguments(); if (remainingArguments.Length < 1) { throw new ArgumentNullException("dataset"); } string datasetName = remainingArguments[0]; string checkpoint = Gpt2Checkpoints.ProcessCheckpointConfig( gpt2Root: Environment.CurrentDirectory, checkpoint: this.Checkpoint, modelName: this.ModelName, runName: this.RunName); var encoder = Gpt2Encoder.LoadEncoder(this.ModelName); string searchPattern = this.Include ?? "*"; var dataset = searchPattern.EndsWith("*.csv") ? LoadCsv(encoder, root: datasetName, field: this.ColumnName) : Gpt2Dataset.LoadDataset(encoder, path: datasetName, pattern: searchPattern); if (dataset.Count == 0) { Console.Error.WriteLine("The dataset is empty!"); return(-1); } var hParams = Gpt2Model.LoadHParams(this.ModelName); var random = this.Seed == null ? new Random() : new Random(this.Seed.Value); var stop = new CancellationTokenSource(); Console.CancelKeyPress += delegate { stop.Cancel(); }; dynamic config = config_pb2.ConfigProto(); config.gpu_options.allow_growth = true; new Gpt2Trainer(dataset, encoder, hParams, this.BatchSize, this.SampleLength, random) { SaveEvery = this.SaveEvery, SampleNum = this.SampleNum, SampleEvery = this.SampleEvery, } .Train(checkpoint, this.RunName, sessionConfig: config, counter: checkpoint == "fresh" ? 1 : (int?)null, cancellation: stop.Token); return(0); }
internal static List <ndarray> LoadDataset(Gpt2Encoder encoder, string path, string pattern = "*") { if (string.IsNullOrEmpty(path)) { throw new ArgumentNullException(nameof(path)); } var paths = new List <string>(); if (Directory.Exists(path)) { paths.AddRange(Directory.EnumerateFiles(path, searchPattern: pattern, SearchOption.AllDirectories)); } else { paths.Add(path); } return(LoadDataset(encoder, paths)); }
static DataSet Load(Gpt2Encoder encoder, IEnumerable <string> texts) { dynamic numpy = Py.Import("numpy"); var result = new DataSet(); string encodedEndOfText = encoder.EncodedEndOfText; var chunk = new List <string>(); int chunkSize = 0; void AddChunk() { PyObject tokens = numpy.stack(chunk); chunk.Clear(); chunkSize = 0; result.Add(ndarray.Wrap(tokens)); } foreach (string text in texts) { if (string.IsNullOrWhiteSpace(text)) { continue; } if (chunkSize + text.Length + encodedEndOfText.Length >= TrimAfter) { AddChunk(); } else { chunkSize += text.Length + encodedEndOfText.Length; var encoded = encoder.Encode(text); chunk.AddRange(encoded); chunk.Add(encodedEndOfText); } } if (chunk.Count > 0) { AddChunk(); } return(result); }
internal static List <ndarray> LoadDataset(Gpt2Encoder encoder, List <string> fileNames) { if (encoder == null) { throw new ArgumentNullException(nameof(encoder)); } var tokenChunks = new List <ndarray>(); foreach (string file in fileNames) { Debug.WriteLine($"Reading {file}"); if (Path.GetExtension(file) == ".npz") { // pre-encoded dynamic npzObject = np.load(file); var npz = npzObject.__enter__(); foreach (var item in npz.files) { tokenChunks.Add(npz[item]); } npzObject.__exit__(); } else { string rawText = File.ReadAllText(file); if (String.IsNullOrWhiteSpace(rawText)) { continue; } dynamic numpy = Py.Import("numpy"); PyObject tokens = numpy.stack(encoder.Encode(rawText)); tokenChunks.Add(ndarray.Wrap(tokens)); } } return(tokenChunks); }
/// <summary> /// Interactively run the model /// </summary> /// <param name="modelName">Which model to use</param> /// <param name="checkpoint">Which checkpoint to load</param> /// <param name="seed">Seed for random number generators, fix seed to reproduce results</param> /// <param name="sampleCount">Number of samples to return total</param> /// <param name="batchSize">Number of batches (only affects speed/memory). Must divide sampleCount.</param> /// <param name="length">Number of tokens in generated text, if null (default), is /// determined by model hyperparameters</param> /// <param name="temperature">randomness in boltzmann distribution. /// Lower temperature results in less random completions. As the /// temperature approaches zero, the model will become deterministic and /// repetitive. Higher temperature results in more random completions.</param> /// <param name="topK">Controls diversity. 1 means only 1 word is /// considered for each step (token), resulting in deterministic completions, /// while 40 means 40 words are considered at each step. 0 (default) is a /// special setting meaning no restrictions. 40 generally is a good value. /// </param> public static void Run(string modelName = "117M", string checkpoint = null, int?seed = null, int sampleCount = 1, int batchSize = 1, int?length = null, float temperature = 1, int topK = 0) { if (sampleCount % batchSize != 0) { throw new ArgumentException(); } var encoder = Gpt2Encoder.LoadEncoder(modelName); var hParams = Gpt2Model.LoadHParams(modelName); int nCtx = ((dynamic)hParams).n_ctx; if (length is null) { length = nCtx; } else if (length > nCtx) { throw new ArgumentException("Can't get samples longer than window size: " + hParams.get("n_ctx")); } new Session(graph: new Graph()).UseSelf(sess => { var context = tf.placeholder(tf.int32, new TensorShape(batchSize, null)); tf.set_random_seed(seed); var output = Gpt2Sampler.SampleSequence( hParams: hParams, length: length.Value, context: context, batchSize: batchSize, temperature: temperature, topK: topK); var saver = new Saver(); checkpoint = checkpoint ?? tf.train.latest_checkpoint(Path.Combine("models", modelName)); saver.restore(sess, checkpoint); while (true) { string text; do { Console.Write("Model prompt >>> "); text = Console.ReadLine(); if (string.IsNullOrEmpty(text)) { Console.WriteLine("Prompt should not be empty"); } } while (string.IsNullOrEmpty(text)); var contextTokens = encoder.Encode(text); int generated = 0; foreach (var _ in Enumerable.Range(0, sampleCount / batchSize)) { var @out = sess.run(output, feed_dict: new PythonDict <object, object> { [context] = Enumerable.Repeat(contextTokens, batchSize), })[Range.All, Range.StartAt(contextTokens.Count)]; foreach (int i in Enumerable.Range(0, batchSize)) { generated++; ndarray part = @out[i]; text = encoder.Decode(part); Console.WriteLine($"{Delimiter} SAMPLE {generated} {Delimiter}"); Console.WriteLine(text); } } Console.Write(Delimiter); Console.WriteLine(Delimiter); } }); }