public void Tune() { var hyperparams = new GptHParams( embeddingDim: 16, attentionHeads: 2, encoderLayers: 2, contextTokens: 16, vocabularySize: TestEncoder.Count); var encoder = new Gpt2Encoder(TestEncoder, TestBPE); var dataset = Gpt2Dataset.FromTexts(encoder, new[] { EncoderJson }); var session = new Session(); using var _ = session.StartUsing(); int batchSize = 4; var input = tf.placeholder(tf.int32, new TensorShape(batchSize, null)); var outputs = Gpt2Model.Model(hyperparams, input); var tuner = new Gpt2Tuner(hyperparams, session, inputPlaceholder: input, outputs, new GptTrainingSampler(dataset, new Random()), batchSize: batchSize); session.run(tf.global_variables_initializer()); float loss0 = tuner.FineTuneOnBatch(); float loss1 = tuner.FineTuneOnBatch(); Assert.True(loss1 < loss0); }
public Gpt2TunerLegacy(DataSet dataset, Gpt2Encoder encoder, GptHParams hParams, int batchSize, int sampleLength, Random random) { this.dataset = dataset ?? throw new ArgumentNullException(nameof(dataset)); this.encoder = encoder ?? throw new ArgumentNullException(nameof(encoder)); this.hParams = hParams ?? throw new ArgumentNullException(nameof(hParams)); this.batchSize = batchSize; this.sampleLength = sampleLength; this.random = random ?? throw new ArgumentNullException(nameof(random)); }
public Gpt2Tuner(GptHParams hyperparams, ISession session, Tensor inputPlaceholder, Dictionary <string, Tensor> outputs, IGptTrainingSampleGenerator sampler, int batchSize, IOptimizer?optimizer = null) { this.Hyperparams = hyperparams ?? throw new ArgumentNullException(nameof(hyperparams)); this.session = session ?? throw new ArgumentNullException(nameof(session)); this.InputPlaceholder = inputPlaceholder ?? throw new ArgumentNullException(nameof(inputPlaceholder)); this.outputs = outputs ?? throw new ArgumentNullException(nameof(outputs)); this.Sampler = sampler ?? throw new ArgumentNullException(nameof(sampler)); this.BatchSize = batchSize; this.Optimizer = optimizer ?? new AdamOptimizer(learning_rate: 0.0002); this.ModelVariables = Enumerable.Where((PythonList <Variable>)v1.trainable_variables(), var => var.name.Contains("model")).ToArray(); Tensor labels = inputPlaceholder[.., 1..];