Пример #1
0
        public void Tune()
        {
            var hyperparams = new GptHParams(
                embeddingDim: 16,
                attentionHeads: 2,
                encoderLayers: 2,
                contextTokens: 16,
                vocabularySize: TestEncoder.Count);
            var encoder = new Gpt2Encoder(TestEncoder, TestBPE);
            var dataset = Gpt2Dataset.FromTexts(encoder, new[] { EncoderJson });

            var session = new Session();

            using var _ = session.StartUsing();

            int batchSize = 4;
            var input     = tf.placeholder(tf.int32, new TensorShape(batchSize, null));
            var outputs   = Gpt2Model.Model(hyperparams, input);
            var tuner     = new Gpt2Tuner(hyperparams, session,
                                          inputPlaceholder: input,
                                          outputs,
                                          new GptTrainingSampler(dataset, new Random()),
                                          batchSize: batchSize);

            session.run(tf.global_variables_initializer());

            float loss0 = tuner.FineTuneOnBatch();
            float loss1 = tuner.FineTuneOnBatch();

            Assert.True(loss1 < loss0);
        }
Пример #2
0
 public Gpt2TunerLegacy(DataSet dataset, Gpt2Encoder encoder, GptHParams hParams,
                        int batchSize, int sampleLength, Random random)
 {
     this.dataset      = dataset ?? throw new ArgumentNullException(nameof(dataset));
     this.encoder      = encoder ?? throw new ArgumentNullException(nameof(encoder));
     this.hParams      = hParams ?? throw new ArgumentNullException(nameof(hParams));
     this.batchSize    = batchSize;
     this.sampleLength = sampleLength;
     this.random       = random ?? throw new ArgumentNullException(nameof(random));
 }
Пример #3
0
        public Gpt2Tuner(GptHParams hyperparams, ISession session,
                         Tensor inputPlaceholder, Dictionary <string, Tensor> outputs,
                         IGptTrainingSampleGenerator sampler,
                         int batchSize,
                         IOptimizer?optimizer = null)
        {
            this.Hyperparams      = hyperparams ?? throw new ArgumentNullException(nameof(hyperparams));
            this.session          = session ?? throw new ArgumentNullException(nameof(session));
            this.InputPlaceholder = inputPlaceholder ?? throw new ArgumentNullException(nameof(inputPlaceholder));
            this.outputs          = outputs ?? throw new ArgumentNullException(nameof(outputs));
            this.Sampler          = sampler ?? throw new ArgumentNullException(nameof(sampler));
            this.BatchSize        = batchSize;

            this.Optimizer = optimizer ?? new AdamOptimizer(learning_rate: 0.0002);

            this.ModelVariables = Enumerable.Where((PythonList <Variable>)v1.trainable_variables(), var => var.name.Contains("model")).ToArray();

            Tensor labels = inputPlaceholder[.., 1..];