static DataSet LoadCsv(Gpt2Encoder encoder, string root, string field) { var texts = new List <string>(); var csvConfiguration = new CsvHelper.Configuration.Configuration { Delimiter = ",", HasHeaderRecord = true, }; foreach (string file in Directory.EnumerateFiles(root, "*.csv", SearchOption.AllDirectories)) { using var reader = new CsvHelper.CsvReader(new StreamReader(file, Encoding.UTF8), csvConfiguration); reader.Read(); reader.ReadHeader(); while (reader.Read()) { string entry = reader.GetField(field); System.Diagnostics.Debug.Assert(reader.GetField(0).Length < 300); if (!string.IsNullOrWhiteSpace(entry)) { texts.Add(entry); } } } return(Gpt2Dataset.FromTexts(encoder, texts)); }
public void Tune() { var hyperparams = new GptHParams( embeddingDim: 16, attentionHeads: 2, encoderLayers: 2, contextTokens: 16, vocabularySize: TestEncoder.Count); var encoder = new Gpt2Encoder(TestEncoder, TestBPE); var dataset = Gpt2Dataset.FromTexts(encoder, new[] { EncoderJson }); var session = new Session(); using var _ = session.StartUsing(); int batchSize = 4; var input = tf.placeholder(tf.int32, new TensorShape(batchSize, null)); var outputs = Gpt2Model.Model(hyperparams, input); var tuner = new Gpt2Tuner(hyperparams, session, inputPlaceholder: input, outputs, new GptTrainingSampler(dataset, new Random()), batchSize: batchSize); session.run(tf.global_variables_initializer()); float loss0 = tuner.FineTuneOnBatch(); float loss1 = tuner.FineTuneOnBatch(); Assert.True(loss1 < loss0); }