public override int Run(string[] remainingArguments) { Console.OutputEncoding = Encoding.UTF8; GradientLog.WarningWriter = GradientLog.OutputWriter = Console.Error; if (!string.IsNullOrEmpty(this.CondaEnv)) { GradientSetup.UsePythonEnvironment(PythonEnvironment.EnumerateCondaEnvironments() .Single(env => Path.GetFileName(env.Home) == this.CondaEnv)); } var generator = new Gpt2TextGenerator( modelName: this.ModelName, checkpoint: this.Checkpoint, sampleLength: this.MaxLength); uint seed = this.Seed ?? GetRandomSeed(); string text = generator.GenerateSample(seed); while (text.StartsWith(generator.EndOfText)) { text = text.Substring(generator.EndOfText.Length); } int end = text.IndexOf(generator.EndOfText, StringComparison.Ordinal); if (end < 0) { Console.Error.WriteLine("Text generated from this seed is longer than max-length."); Console.WriteLine(text); return(-2); } Console.Write(text.Substring(0, end)); return(0); }
private ILyricsGenerator CreateGradientLyrics() { string condaEnvName = this.Configuration.GetValue <string>("PYTHON_CONDA_ENV_NAME", null); if (!string.IsNullOrEmpty(condaEnvName)) { GradientSetup.UsePythonEnvironment(PythonEnvironment.EnumerateCondaEnvironments() .Single(env => Path.GetFileName(env.Home) == condaEnvName)); } var logger = this.LoggerFactory.CreateLogger <Startup>(); bool download = this.Configuration.GetValue("Model:Download", defaultValue: true); string gpt2Root = this.Configuration.GetValue("GPT2_ROOT", Environment.CurrentDirectory); string checkpointName = this.Configuration.GetValue("Model:Checkpoint", "latest"); string modelName = this.Configuration.GetValue <string>("Model:Type", null) ?? throw new ArgumentNullException("Model:Type"); string modelRoot = Path.Combine(gpt2Root, "models", modelName); logger.LogInformation($"Using model from {modelRoot}"); if (!File.Exists(Path.Combine(modelRoot, "encoder.json"))) { if (download) { logger.LogInformation($"downloading {modelName} parameters"); ModelDownloader.DownloadModelParameters(gpt2Root, modelName); logger.LogInformation($"downloaded {modelName} parameters"); } else { throw new FileNotFoundException($"Can't find GPT-2 model in " + modelRoot); } } string runName = this.Configuration.GetValue <string>("Model:Run", null) ?? throw new ArgumentNullException("Model:Run"); string checkpoint = Gpt2Checkpoints.ProcessCheckpointConfig(gpt2Root, checkpointName, modelName: modelName, runName: runName); logger.LogInformation($"Using model checkpoint: {checkpoint}"); if (checkpoint == null || !File.Exists(checkpoint + ".index")) { if (download && checkpointName == "latest") { logger.LogInformation($"downloading the latest checkpoint for {modelName}, run {runName}"); checkpoint = ModelDownloader.DownloadCheckpoint( root: gpt2Root, modelName: modelName, runName: runName); logger.LogInformation("download successful"); } else { if (!download) { logger.LogWarning("Model downloading is disabled. See corresponding appsettings file."); } else if (checkpointName != "latest") { logger.LogWarning("Only the 'latest' model can be downloaded. You wanted: " + checkpointName); } throw new FileNotFoundException("Can't find checkpoint " + checkpoint + ".index"); } } return(new Gpt2LyricsGenerator( gpt2Root: gpt2Root, modelName: modelName, checkpoint: checkpoint, logger: this.LoggerFactory.CreateLogger <Gpt2LyricsGenerator>(), condaEnv: condaEnvName)); }