示例#1
0
文件: GptTests.cs 项目: losttech/GPT
        public void Tune()
        {
            var hyperparams = new GptHParams(
                embeddingDim: 16,
                attentionHeads: 2,
                encoderLayers: 2,
                contextTokens: 16,
                vocabularySize: TestEncoder.Count);
            var encoder = new Gpt2Encoder(TestEncoder, TestBPE);
            var dataset = Gpt2Dataset.FromTexts(encoder, new[] { EncoderJson });

            var session = new Session();

            using var _ = session.StartUsing();

            int batchSize = 4;
            var input     = tf.placeholder(tf.int32, new TensorShape(batchSize, null));
            var outputs   = Gpt2Model.Model(hyperparams, input);
            var tuner     = new Gpt2Tuner(hyperparams, session,
                                          inputPlaceholder: input,
                                          outputs,
                                          new GptTrainingSampler(dataset, new Random()),
                                          batchSize: batchSize);

            session.run(tf.global_variables_initializer());

            float loss0 = tuner.FineTuneOnBatch();
            float loss1 = tuner.FineTuneOnBatch();

            Assert.True(loss1 < loss0);
        }
示例#2
0
        public override int Run(string[] remainingArguments)
        {
            this.CheckRequiredArguments();
            if (remainingArguments.Length < 1)
            {
                throw new ArgumentNullException("dataset");
            }

            string modelPath = CommonCommandOptions.ExpandModelNameToPathOrExit(this.ModelName);

            string checkpoint = Gpt2Checkpoints.ProcessCheckpointConfig(
                modelPath: modelPath,
                checkpoint: this.Checkpoint,
                runName: this.RunName);

            var encoder = Gpt2Encoder.LoadEncoder(modelPath);

            string searchPattern = this.Include ?? "*";
            string datasetName   = remainingArguments[0];
            var    dataset       = searchPattern.EndsWith("*.csv")
                ? LoadCsv(encoder, root: datasetName, field: this.ColumnName ?? throw new ArgumentException("column must be specified for training on .csv files"))
                : Gpt2Dataset.LoadDataset(encoder, path: datasetName, pattern: searchPattern);

            if (dataset.Count == 0)
            {
                Console.Error.WriteLine("The dataset is empty!");
                return(-1);
            }

            var hParams = Gpt2Model.LoadHParams(modelPath);

            var random = this.Seed is null ? new Random() : new Random(this.Seed.Value);

            tf.random.set_seed(this.Seed);

            var stop = new CancellationTokenSource();

            Console.CancelKeyPress += delegate { stop.Cancel(); };

            dynamic config = config_pb2.ConfigProto.CreateInstance();

            config.gpu_options.allow_growth = true;
            var trainer = new Gpt2TunerLegacy(dataset, encoder, hParams, this.BatchSize, this.SampleLength, random)
            {
                SaveEvery   = this.SaveEvery,
                SampleNum   = this.SampleNum,
                SampleEvery = this.SampleEvery,
            };
            string checkpointOutputDirectory = Path.Combine(modelPath, Gpt2Checkpoints.CheckpointDir);

            trainer.FineTune(
                checkpointsDir: checkpointOutputDirectory, checkpoint: checkpoint,
                run: this.RunName,
                counter: checkpoint == "fresh" ? 1 : (int?)null,
                sessionConfig: config,
                cancellation: stop.Token);

            return(0);
        }
示例#3
0
        /// <summary>
        /// Interactively run the model
        /// </summary>
        /// <param name="modelName">Which model to use</param>
        /// <param name="checkpoint">Which checkpoint to load</param>
        /// <param name="seed">Seed for random number generators, fix seed to reproduce results</param>
        /// <param name="sampleCount">Number of samples to return total</param>
        /// <param name="batchSize">Number of batches (only affects speed/memory).  Must divide sampleCount.</param>
        /// <param name="length">Number of tokens in generated text, if null (default), is
        ///     determined by model hyperparameters</param>
        /// <param name="temperature">randomness in boltzmann distribution.
        ///     Lower temperature results in less random completions. As the
        ///     temperature approaches zero, the model will become deterministic and
        ///     repetitive. Higher temperature results in more random completions.</param>
        /// <param name="topK">Controls diversity. 1 means only 1 word is
        ///     considered for each step (token), resulting in deterministic completions,
        ///     while 40 means 40 words are considered at each step. 0 (default) is a
        ///     special setting meaning no restrictions. 40 generally is a good value.
        /// </param>
        public static int Run(string modelName = "117M", string?checkpoint = null, int?seed = null,
                              int sampleCount  = 1,
                              int batchSize    = 1, int?length = null, float temperature = 1, int topK = 0)
        {
            if (sampleCount % batchSize != 0)
            {
                throw new ArgumentException();
            }

            string modelPath = CommonCommandOptions.ExpandModelNameToPathOrExit(modelName);

            var encoder = Gpt2Encoder.LoadEncoder(modelPath);
            var hParams = Gpt2Model.LoadHParams(modelPath);

            int nCtx = hParams.ContextTokens;

            if (length is null)
            {
                length = nCtx;
            }
            else if (length > nCtx)
            {
                throw new ArgumentException("Can't get samples longer than window size: " + nCtx);
            }

            foreach (var gpu in tf.config.list_physical_devices("gpu"))
            {
                tf.config.experimental.set_memory_growth(gpu, true);
            }

            var sess = new Session(graph: new Graph());

            using var sessionContext = sess.StartUsing();

            Tensor context = v1.placeholder(tf.int32, new TensorShape(batchSize, null));

            tf.random.set_seed(seed);

            Tensor output = Gpt2Sampler.SampleSequence(
                hParams: hParams,
                length: length.Value,
                context: context,
                batchSize: batchSize,
                temperature: temperature,
                topK: topK);

            var saver = new Saver();

            checkpoint ??= tf.train.latest_checkpoint(modelPath);
            saver.restore(sess, checkpoint);

            bool interrupted = false;

            Console.CancelKeyPress += (object sender, ConsoleCancelEventArgs args) =>
                                      Volatile.Write(ref interrupted, args.Cancel = true);

            while (!interrupted)
            {
                string text;
                do
                {
                    Console.Write("Model prompt >>> ");
                    text = Console.ReadLine();
                    if (Volatile.Read(ref interrupted))
                    {
                        break;
                    }
                    if (string.IsNullOrEmpty(text))
                    {
                        Console.WriteLine("Prompt should not be empty");
                    }
                } while (!Volatile.Read(ref interrupted) && string.IsNullOrEmpty(text));

                if (Volatile.Read(ref interrupted))
                {
                    break;
                }

                var contextTokens = encoder.Encode(text);
                if (!tf.test.is_gpu_available() && contextTokens.Count >= length.Value)
                {
                    Console.Error.WriteLine();
                    Console.Error.WriteLine("Prompt is too long.");
                    Console.Error.WriteLine();
                    continue;
                }
                int generated = 0;
                foreach (int _ in Enumerable.Range(0, sampleCount / batchSize))
                {
                    ndarray <int> @out;
                    try {
                        @out = sess.run(output, feed_dict: new Dictionary <object, object> {
                            [context] = Enumerable.Repeat(contextTokens, batchSize).ToArray(),
                        })[.., contextTokens.Count..];
                    } catch (InvalidArgumentError ex) {
                        throw new ArgumentOutOfRangeException(
                                  "Unable to generate sequence of desired length. "
                                  + "Try lowering length by passing -l (-sample-length) parameter. "
                                  + "Current length: " + length.Value,
                                  innerException: ex);
                    }

                    foreach (int i in Enumerable.Range(0, batchSize))
                    {
                        generated++;
                        var part = @out[i].AsArray();
                        text = encoder.Decode(part);
                        Console.WriteLine($"{Delimiter} SAMPLE {generated} {Delimiter}");
                        Console.WriteLine(text);
                    }
                }
                Console.Write(Delimiter);
                Console.WriteLine(Delimiter);
            }

            return(0);
        }
示例#4
0
        public void FineTune(string checkpointsDir, string checkpoint, string run, int?counter,
                             int topK = 40, float temperature = 1.0f,
                             dynamic?sessionConfig = null, CancellationToken cancellation = default)
        {
            Session session = sessionConfig is null
                ? Session.NewDyn(config : sessionConfig)
                : new Session();

            using var _ = session.StartUsing();

            Tensor context = v1.placeholder(tf.int32, new TensorShape(this.batchSize, null));
            var    output  = Gpt2Model.Model(this.hParams, input: context);

            var sampler   = new GptTrainingSampler(this.dataset, this.random);
            var optimizer = new AdamOptimizer(learning_rate: 0.0002);
            var tuner     = new Gpt2Tuner(this.hParams, session, context, output, sampler, this.batchSize, optimizer);

            Tensor sample = Gpt2Sampler.SampleSequence(
                this.hParams,
                length: this.sampleLength,
                context: context,
                batchSize: this.batchSize,
                temperature: temperature,
                topK: topK);

            var saver = new Saver(
                var_list: tuner.ModelVariables,
                max_to_keep: 5,
                keep_checkpoint_every_n_hours: 1);

            session.run(v1.global_variables_initializer());

            Console.WriteLine("Loading checkpoint " + checkpoint);
            saver.restore(session, checkpoint);

            Console.WriteLine("Loading dataset...");

            Console.WriteLine($"Dataset has {sampler.TokenCount} tokens");

            string counterFile = Path.Combine(checkpointsDir, run, "counter");

            if (counter is null && File.Exists(counterFile))
            {
                counter = int.Parse(File.ReadAllText(counterFile), CultureInfo.InvariantCulture) + 1;
            }
            counter ??= 1;

            string runCheckpointDir = Path.Combine(checkpointsDir, run);
            string runSampleDir     = Path.Combine(SampleDir, run);

            void Save()
            {
                Directory.CreateDirectory(runCheckpointDir);
                Console.WriteLine("Saving " + Path.Combine(runCheckpointDir, Invariant($"model-{counter}")));
                saver.save(session,
                           Path.Combine(runCheckpointDir, "model"),
                           global_step: counter.Value);
                File.WriteAllText(path: counterFile, contents: Invariant($"{counter}"));
            }

            void GenerateSamples()
            {
                var    contextTokens = np.array(new[] { this.encoder.EncodedEndOfText });
                var    allText       = new List <string>();
                int    index         = 0;
                string?text          = null;

                while (index < this.SampleNum)
                {
                    ndarray <int> @out = session.run(sample, feed_dict: new Dictionary <object, object> {
                        [context] = Enumerable.Repeat(contextTokens, this.batchSize),
                    });
                    foreach (int i in Enumerable.Range(0, Math.Min(this.SampleNum - index, this.batchSize)))
                    {
                        text = this.encoder.Decode((ndarray <int>)@out[i]);
                        text = Invariant($"======== SAMPLE {index + 1} ========\n{text}\n");
                        allText.Add(text);
                        index++;
                    }
                }
                Console.WriteLine(text);
                Directory.CreateDirectory(runSampleDir);
                File.WriteAllLines(
                    path: Path.Combine(runSampleDir, Invariant($"samples-{counter}")),
                    contents: allText);
            }

            var avgLoss   = (0.0, 0.0);
            var startTime = DateTime.Now;

            while (!cancellation.IsCancellationRequested)
            {
                if (counter % this.SaveEvery == 0)
                {
                    Save();
                }
                if (counter % this.SampleEvery == 0)
                {
                    GenerateSamples();
                }

                float lv = tuner.FineTuneOnBatch();

                avgLoss = (avgLoss.Item1 * 0.99 + lv, avgLoss.Item2 * 0.99 + 1);

                Console.WriteLine($"[{counter} | {DateTime.Now - startTime}] loss={lv} avg={avgLoss.Item1 / avgLoss.Item2}");

                counter++;
            }

            Console.WriteLine("Interrupted");
            Save();
        }