public override int Run(string[] remainingArguments) { this.CheckRequiredArguments(); if (remainingArguments.Length < 1) { throw new ArgumentNullException("dataset"); } string modelPath = CommonCommandOptions.ExpandModelNameToPathOrExit(this.ModelName); string checkpoint = Gpt2Checkpoints.ProcessCheckpointConfig( modelPath: modelPath, checkpoint: this.Checkpoint, runName: this.RunName); var encoder = Gpt2Encoder.LoadEncoder(modelPath); string searchPattern = this.Include ?? "*"; string datasetName = remainingArguments[0]; var dataset = searchPattern.EndsWith("*.csv") ? LoadCsv(encoder, root: datasetName, field: this.ColumnName ?? throw new ArgumentException("column must be specified for training on .csv files")) : Gpt2Dataset.LoadDataset(encoder, path: datasetName, pattern: searchPattern); if (dataset.Count == 0) { Console.Error.WriteLine("The dataset is empty!"); return(-1); } var hParams = Gpt2Model.LoadHParams(modelPath); var random = this.Seed is null ? new Random() : new Random(this.Seed.Value); tf.random.set_seed(this.Seed); var stop = new CancellationTokenSource(); Console.CancelKeyPress += delegate { stop.Cancel(); }; dynamic config = config_pb2.ConfigProto.CreateInstance(); config.gpu_options.allow_growth = true; var trainer = new Gpt2TunerLegacy(dataset, encoder, hParams, this.BatchSize, this.SampleLength, random) { SaveEvery = this.SaveEvery, SampleNum = this.SampleNum, SampleEvery = this.SampleEvery, }; string checkpointOutputDirectory = Path.Combine(modelPath, Gpt2Checkpoints.CheckpointDir); trainer.FineTune( checkpointsDir: checkpointOutputDirectory, checkpoint: checkpoint, run: this.RunName, counter: checkpoint == "fresh" ? 1 : (int?)null, sessionConfig: config, cancellation: stop.Token); return(0); }
public override int Run(string[] remainingArguments) { string modelPath = CommonCommandOptions.ExpandModelNameToPathOrExit(this.ModelName); string checkpoint = Gpt2Checkpoints.ProcessCheckpointConfig( modelPath: modelPath, checkpoint: this.Checkpoint, runName: this.RunName); return(Run( modelName: this.ModelName, checkpoint: checkpoint, seed: this.Seed, sampleCount: this.SampleCount, batchSize: this.BatchSize, length: this.Length, temperature: this.Temperature, topK: this.TopK)); }