Пример #1
0
        public override int Run(string[] remainingArguments)
        {
            this.CheckRequiredArguments();
            if (remainingArguments.Length < 1)
            {
                throw new ArgumentNullException("dataset");
            }
            string datasetName = remainingArguments[0];
            string checkpoint  = Gpt2Checkpoints.ProcessCheckpointConfig(
                gpt2Root: Environment.CurrentDirectory,
                checkpoint: this.Checkpoint,
                modelName: this.ModelName,
                runName: this.RunName);

            var    encoder       = Gpt2Encoder.LoadEncoder(this.ModelName);
            string searchPattern = this.Include ?? "*";
            var    dataset       = searchPattern.EndsWith("*.csv")
                ? LoadCsv(encoder, root: datasetName, field: this.ColumnName)
                : Gpt2Dataset.LoadDataset(encoder, path: datasetName, pattern: searchPattern);

            if (dataset.Count == 0)
            {
                Console.Error.WriteLine("The dataset is empty!");
                return(-1);
            }
            var hParams = Gpt2Model.LoadHParams(this.ModelName);
            var random  = this.Seed is null ? new Random() : new Random(this.Seed.Value);
            var stop    = new CancellationTokenSource();

            Console.CancelKeyPress += delegate { stop.Cancel(); };
            dynamic config = config_pb2.ConfigProto.CreateInstance();

            config.gpu_options.allow_growth = true;
            new Gpt2Trainer(dataset, encoder, hParams, this.BatchSize, this.SampleLength, random)
            {
                SaveEvery   = this.SaveEvery,
                SampleNum   = this.SampleNum,
                SampleEvery = this.SampleEvery,
            }
            .Train(checkpoint, this.RunName,
                   sessionConfig: config,
                   counter: checkpoint == "fresh" ? 1 : (int?)null,
                   cancellation: stop.Token);

            return(0);
        }
        public override int Run(string[] remainingArguments)
        {
            string checkpoint = Gpt2Checkpoints.ProcessCheckpointConfig(
                gpt2Root: Environment.CurrentDirectory,
                checkpoint: this.Checkpoint,
                modelName: this.ModelName,
                runName: this.RunName);

            Run(modelName: this.ModelName,
                checkpoint: checkpoint,
                seed: this.Seed,
                sampleCount: this.SampleCount,
                batchSize: this.BatchSize,
                length: this.Length,
                temperature: this.Temperature,
                topK: this.TopK);

            return(0);
        }