示例#1
0
        static DataSet LoadCsv(Gpt2Encoder encoder, string root, string field)
        {
            var texts            = new List <string>();
            var csvConfiguration = new CsvHelper.Configuration.Configuration {
                Delimiter       = ",",
                HasHeaderRecord = true,
            };

            foreach (string file in Directory.EnumerateFiles(root, "*.csv", SearchOption.AllDirectories))
            {
                using var reader = new CsvHelper.CsvReader(new StreamReader(file, Encoding.UTF8), csvConfiguration);
                reader.Read();
                reader.ReadHeader();
                while (reader.Read())
                {
                    string entry = reader.GetField(field);
                    System.Diagnostics.Debug.Assert(reader.GetField(0).Length < 300);
                    if (!string.IsNullOrWhiteSpace(entry))
                    {
                        texts.Add(entry);
                    }
                }
            }
            return(Gpt2Dataset.FromTexts(encoder, texts));
        }
示例#2
0
文件: GptTests.cs 项目: losttech/GPT
        public void Tune()
        {
            var hyperparams = new GptHParams(
                embeddingDim: 16,
                attentionHeads: 2,
                encoderLayers: 2,
                contextTokens: 16,
                vocabularySize: TestEncoder.Count);
            var encoder = new Gpt2Encoder(TestEncoder, TestBPE);
            var dataset = Gpt2Dataset.FromTexts(encoder, new[] { EncoderJson });

            var session = new Session();

            using var _ = session.StartUsing();

            int batchSize = 4;
            var input     = tf.placeholder(tf.int32, new TensorShape(batchSize, null));
            var outputs   = Gpt2Model.Model(hyperparams, input);
            var tuner     = new Gpt2Tuner(hyperparams, session,
                                          inputPlaceholder: input,
                                          outputs,
                                          new GptTrainingSampler(dataset, new Random()),
                                          batchSize: batchSize);

            session.run(tf.global_variables_initializer());

            float loss0 = tuner.FineTuneOnBatch();
            float loss1 = tuner.FineTuneOnBatch();

            Assert.True(loss1 < loss0);
        }
示例#3
0
        public override int Run(string[] remainingArguments)
        {
            this.CheckRequiredArguments();
            if (remainingArguments.Length < 1)
            {
                throw new ArgumentNullException("dataset");
            }

            string modelPath = CommonCommandOptions.ExpandModelNameToPathOrExit(this.ModelName);

            string checkpoint = Gpt2Checkpoints.ProcessCheckpointConfig(
                modelPath: modelPath,
                checkpoint: this.Checkpoint,
                runName: this.RunName);

            var encoder = Gpt2Encoder.LoadEncoder(modelPath);

            string searchPattern = this.Include ?? "*";
            string datasetName   = remainingArguments[0];
            var    dataset       = searchPattern.EndsWith("*.csv")
                ? LoadCsv(encoder, root: datasetName, field: this.ColumnName ?? throw new ArgumentException("column must be specified for training on .csv files"))
                : Gpt2Dataset.LoadDataset(encoder, path: datasetName, pattern: searchPattern);

            if (dataset.Count == 0)
            {
                Console.Error.WriteLine("The dataset is empty!");
                return(-1);
            }

            var hParams = Gpt2Model.LoadHParams(modelPath);

            var random = this.Seed is null ? new Random() : new Random(this.Seed.Value);

            tf.random.set_seed(this.Seed);

            var stop = new CancellationTokenSource();

            Console.CancelKeyPress += delegate { stop.Cancel(); };

            dynamic config = config_pb2.ConfigProto.CreateInstance();

            config.gpu_options.allow_growth = true;
            var trainer = new Gpt2TunerLegacy(dataset, encoder, hParams, this.BatchSize, this.SampleLength, random)
            {
                SaveEvery   = this.SaveEvery,
                SampleNum   = this.SampleNum,
                SampleEvery = this.SampleEvery,
            };
            string checkpointOutputDirectory = Path.Combine(modelPath, Gpt2Checkpoints.CheckpointDir);

            trainer.FineTune(
                checkpointsDir: checkpointOutputDirectory, checkpoint: checkpoint,
                run: this.RunName,
                counter: checkpoint == "fresh" ? 1 : (int?)null,
                sessionConfig: config,
                cancellation: stop.Token);

            return(0);
        }