static DataSet LoadCsv(Gpt2Encoder encoder, string root, string field) { var texts = new List <string>(); var csvConfiguration = new CsvHelper.Configuration.Configuration { Delimiter = ",", HasHeaderRecord = true, }; foreach (string file in Directory.EnumerateFiles(root, "*.csv", SearchOption.AllDirectories)) { using var reader = new CsvHelper.CsvReader(new StreamReader(file, Encoding.UTF8), csvConfiguration); reader.Read(); reader.ReadHeader(); while (reader.Read()) { string entry = reader.GetField(field); System.Diagnostics.Debug.Assert(reader.GetField(0).Length < 300); if (!string.IsNullOrWhiteSpace(entry)) { texts.Add(entry); } } } return(Gpt2Dataset.FromTexts(encoder, texts)); }
public void Tune() { var hyperparams = new GptHParams( embeddingDim: 16, attentionHeads: 2, encoderLayers: 2, contextTokens: 16, vocabularySize: TestEncoder.Count); var encoder = new Gpt2Encoder(TestEncoder, TestBPE); var dataset = Gpt2Dataset.FromTexts(encoder, new[] { EncoderJson }); var session = new Session(); using var _ = session.StartUsing(); int batchSize = 4; var input = tf.placeholder(tf.int32, new TensorShape(batchSize, null)); var outputs = Gpt2Model.Model(hyperparams, input); var tuner = new Gpt2Tuner(hyperparams, session, inputPlaceholder: input, outputs, new GptTrainingSampler(dataset, new Random()), batchSize: batchSize); session.run(tf.global_variables_initializer()); float loss0 = tuner.FineTuneOnBatch(); float loss1 = tuner.FineTuneOnBatch(); Assert.True(loss1 < loss0); }
public override int Run(string[] remainingArguments) { this.CheckRequiredArguments(); if (remainingArguments.Length < 1) { throw new ArgumentNullException("dataset"); } string modelPath = CommonCommandOptions.ExpandModelNameToPathOrExit(this.ModelName); string checkpoint = Gpt2Checkpoints.ProcessCheckpointConfig( modelPath: modelPath, checkpoint: this.Checkpoint, runName: this.RunName); var encoder = Gpt2Encoder.LoadEncoder(modelPath); string searchPattern = this.Include ?? "*"; string datasetName = remainingArguments[0]; var dataset = searchPattern.EndsWith("*.csv") ? LoadCsv(encoder, root: datasetName, field: this.ColumnName ?? throw new ArgumentException("column must be specified for training on .csv files")) : Gpt2Dataset.LoadDataset(encoder, path: datasetName, pattern: searchPattern); if (dataset.Count == 0) { Console.Error.WriteLine("The dataset is empty!"); return(-1); } var hParams = Gpt2Model.LoadHParams(modelPath); var random = this.Seed is null ? new Random() : new Random(this.Seed.Value); tf.random.set_seed(this.Seed); var stop = new CancellationTokenSource(); Console.CancelKeyPress += delegate { stop.Cancel(); }; dynamic config = config_pb2.ConfigProto.CreateInstance(); config.gpu_options.allow_growth = true; var trainer = new Gpt2TunerLegacy(dataset, encoder, hParams, this.BatchSize, this.SampleLength, random) { SaveEvery = this.SaveEvery, SampleNum = this.SampleNum, SampleEvery = this.SampleEvery, }; string checkpointOutputDirectory = Path.Combine(modelPath, Gpt2Checkpoints.CheckpointDir); trainer.FineTune( checkpointsDir: checkpointOutputDirectory, checkpoint: checkpoint, run: this.RunName, counter: checkpoint == "fresh" ? 1 : (int?)null, sessionConfig: config, cancellation: stop.Token); return(0); }