public static string DownloadCheckpoint(string root, string modelName, string runName) { string targetDirectory = Path.Combine(root, "checkpoint", runName); Directory.CreateDirectory(targetDirectory); using (var zipStream = new WebClient().OpenRead(CheckpointRootUri)) using (var zip = new ZipArchive(zipStream, ZipArchiveMode.Read)) zip.ExtractToDirectory(targetDirectory); string checkpoint = Gpt2Checkpoints.GetLatestCheckpoint(root, modelName, runName); if (checkpoint == null) { throw new IOException("Can't find checkpoint file after downloading"); } return(checkpoint); }
private ILyricsGenerator CreateGradientLyrics() { string condaEnvName = this.Configuration.GetValue <string>("PYTHON_CONDA_ENV_NAME", null); if (!string.IsNullOrEmpty(condaEnvName)) { GradientSetup.UseCondaEnvironment(condaEnvName); } var logger = this.LoggerFactory.CreateLogger <Startup>(); bool download = this.Configuration.GetValue("Model:Download", defaultValue: true); string gpt2Root = this.Configuration.GetValue("GPT2_ROOT", Environment.CurrentDirectory); string checkpointName = this.Configuration.GetValue("Model:Checkpoint", "latest"); string modelName = this.Configuration.GetValue <string>("Model:Type", null) ?? throw new ArgumentNullException("Model:Type"); string modelRoot = Path.Combine(gpt2Root, "models", modelName); logger.LogInformation($"Using model from {modelRoot}"); if (!File.Exists(Path.Combine(modelRoot, "encoder.json"))) { if (download) { logger.LogInformation($"downloading {modelName} parameters"); ModelDownloader.DownloadModelParameters(gpt2Root, modelName); logger.LogInformation($"downloaded {modelName} parameters"); } else { throw new FileNotFoundException($"Can't find GPT-2 model in " + modelRoot); } } string runName = this.Configuration.GetValue <string>("Model:Run", null) ?? throw new ArgumentNullException("Model:Run"); string checkpoint = Gpt2Checkpoints.ProcessCheckpointConfig(gpt2Root, checkpointName, modelName: modelName, runName: runName); logger.LogInformation($"Using model checkpoint: {checkpoint}"); if (checkpoint == null || !File.Exists(checkpoint + ".index")) { if (download && checkpointName == "latest") { logger.LogInformation($"downloading the latest checkpoint for {modelName}, run {runName}"); checkpoint = ModelDownloader.DownloadCheckpoint( root: gpt2Root, modelName: modelName, runName: runName); logger.LogInformation("download successful"); } else { if (!download) { logger.LogWarning("Model downloading is disabled. See corresponding appsettings file."); } else if (checkpointName != "latest") { logger.LogWarning("Only the 'latest' model can be downloaded. You wanted: " + checkpointName); } throw new FileNotFoundException("Can't find checkpoint " + checkpoint + ".index"); } } return(new Gpt2LyricsGenerator( gpt2Root: gpt2Root, modelName: modelName, checkpoint: checkpoint, logger: this.LoggerFactory.CreateLogger <Gpt2LyricsGenerator>(), condaEnv: condaEnvName)); }