public override int Run(string[] remainingArguments) { this.CheckRequiredArguments(); if (remainingArguments.Length < 1) { throw new ArgumentNullException("dataset"); } string datasetName = remainingArguments[0]; string checkpoint = Gpt2Trainer.ProcessCheckpointConfig(this.Checkpoint, modelName: this.ModelName, runName: this.RunName); var encoder = Gpt2Encoder.LoadEncoder(this.ModelName); string searchPattern = this.Include ?? "*"; var dataset = searchPattern.EndsWith("*.csv") ? LoadCsv(encoder, root: datasetName, field: this.ColumnName) : Gpt2Trainer.LoadDataset(encoder, path: datasetName, pattern: searchPattern); var hParams = Gpt2Model.LoadHParams(this.ModelName); var random = this.Seed == null ? new Random() : new Random(this.Seed.Value); var stop = new CancellationTokenSource(); Console.CancelKeyPress += delegate { stop.Cancel(); }; new Gpt2Trainer(dataset, encoder, hParams, this.BatchSize, this.SampleLength, random) .Train(checkpoint, this.RunName, stop.Token); return(0); }
public static string ProcessCheckpointConfig(string checkpoint, string modelName, string runName) { switch (checkpoint) { case "latest": checkpoint = Gpt2Trainer.GetLatestCheckpoint(modelName, runName); break; case "fresh": checkpoint = Gpt2Trainer.GetOriginalCheckpoint(modelName); break; } return(checkpoint); }