コード例 #1
0
        public override int Run(string[] remainingArguments)
        {
            this.CheckRequiredArguments();
            if (remainingArguments.Length < 1)
            {
                throw new ArgumentNullException("dataset");
            }
            string datasetName = remainingArguments[0];
            string checkpoint  = Gpt2Trainer.ProcessCheckpointConfig(this.Checkpoint,
                                                                     modelName: this.ModelName,
                                                                     runName: this.RunName);

            var    encoder       = Gpt2Encoder.LoadEncoder(this.ModelName);
            string searchPattern = this.Include ?? "*";
            var    dataset       = searchPattern.EndsWith("*.csv")
                ? LoadCsv(encoder, root: datasetName, field: this.ColumnName)
                : Gpt2Trainer.LoadDataset(encoder, path: datasetName, pattern: searchPattern);
            var hParams = Gpt2Model.LoadHParams(this.ModelName);
            var random  = this.Seed == null ? new Random() : new Random(this.Seed.Value);
            var stop    = new CancellationTokenSource();

            Console.CancelKeyPress += delegate { stop.Cancel(); };
            new Gpt2Trainer(dataset, encoder, hParams, this.BatchSize, this.SampleLength, random)
            .Train(checkpoint, this.RunName, stop.Token);

            return(0);
        }
コード例 #2
0
        public static string ProcessCheckpointConfig(string checkpoint, string modelName, string runName)
        {
            switch (checkpoint)
            {
            case "latest":
                checkpoint = Gpt2Trainer.GetLatestCheckpoint(modelName, runName);
                break;

            case "fresh":
                checkpoint = Gpt2Trainer.GetOriginalCheckpoint(modelName);
                break;
            }

            return(checkpoint);
        }