static DataSet LoadCsv(Gpt2Encoder encoder, string root, string field)
        {
            var result = new List <string>();

            foreach (string file in Directory
                     .EnumerateFiles(root, "*.csv", SearchOption.AllDirectories))
            {
                using (var reader = new CsvHelper.CsvReader(new StreamReader(file, Encoding.UTF8),
                                                            new CsvHelper.Configuration.Configuration {
                    Delimiter = ",",
                    HasHeaderRecord = true,
                })) {
                    reader.Read();
                    reader.ReadHeader();
                    while (reader.Read())
                    {
                        string entry = reader.GetField(field);
                        System.Diagnostics.Debug.Assert(reader.GetField(0).Length < 300);
                        if (!string.IsNullOrWhiteSpace(entry))
                        {
                            result.Add(entry);
                        }
                    }
                }
            }
            return(Load(encoder, result));
        }
Exemple #2
0
 public Gpt2Trainer(DataSet dataset, Gpt2Encoder encoder, HParams hParams,
                    int batchSize, int sampleLength, Random random)
 {
     this.dataset      = dataset ?? throw new ArgumentNullException(nameof(dataset));
     this.encoder      = encoder ?? throw new ArgumentNullException(nameof(encoder));
     this.hParams      = hParams ?? throw new ArgumentNullException(nameof(hParams));
     this.batchSize    = batchSize;
     this.sampleLength = sampleLength;
     this.random       = random ?? throw new ArgumentNullException(nameof(random));
 }
        public override int Run(string[] remainingArguments)
        {
            this.CheckRequiredArguments();
            if (remainingArguments.Length < 1)
            {
                throw new ArgumentNullException("dataset");
            }
            string datasetName = remainingArguments[0];
            string checkpoint  = Gpt2Checkpoints.ProcessCheckpointConfig(
                gpt2Root: Environment.CurrentDirectory,
                checkpoint: this.Checkpoint,
                modelName: this.ModelName,
                runName: this.RunName);

            var    encoder       = Gpt2Encoder.LoadEncoder(this.ModelName);
            string searchPattern = this.Include ?? "*";
            var    dataset       = searchPattern.EndsWith("*.csv")
                ? LoadCsv(encoder, root: datasetName, field: this.ColumnName)
                : Gpt2Dataset.LoadDataset(encoder, path: datasetName, pattern: searchPattern);

            if (dataset.Count == 0)
            {
                Console.Error.WriteLine("The dataset is empty!");
                return(-1);
            }
            var hParams = Gpt2Model.LoadHParams(this.ModelName);
            var random  = this.Seed is null ? new Random() : new Random(this.Seed.Value);
            var stop    = new CancellationTokenSource();

            Console.CancelKeyPress += delegate { stop.Cancel(); };
            dynamic config = config_pb2.ConfigProto.CreateInstance();

            config.gpu_options.allow_growth = true;
            new Gpt2Trainer(dataset, encoder, hParams, this.BatchSize, this.SampleLength, random)
            {
                SaveEvery   = this.SaveEvery,
                SampleNum   = this.SampleNum,
                SampleEvery = this.SampleEvery,
            }
            .Train(checkpoint, this.RunName,
                   sessionConfig: config,
                   counter: checkpoint == "fresh" ? 1 : (int?)null,
                   cancellation: stop.Token);

            return(0);
        }
        internal static List <ndarray> LoadDataset(Gpt2Encoder encoder, string path, string pattern = "*")
        {
            if (string.IsNullOrEmpty(path))
            {
                throw new ArgumentNullException(nameof(path));
            }
            var paths = new List <string>();

            if (Directory.Exists(path))
            {
                paths.AddRange(Directory.EnumerateFiles(path, searchPattern: pattern, SearchOption.AllDirectories));
            }
            else
            {
                paths.Add(path);
            }

            return(LoadDataset(encoder, paths));
        }
        static DataSet Load(Gpt2Encoder encoder, IEnumerable <string> texts)
        {
            var    result           = new DataSet();
            string encodedEndOfText = encoder.EncodedEndOfText;
            var    chunk            = new List <string>();
            int    chunkSize        = 0;

            void AddChunk()
            {
                var tokens = np.stack(chunk);

                chunk.Clear();
                chunkSize = 0;
                result.Add(tokens);
            }

            foreach (string text in texts)
            {
                if (string.IsNullOrWhiteSpace(text))
                {
                    continue;
                }

                if (chunkSize + text.Length + encodedEndOfText.Length >= TrimAfter)
                {
                    AddChunk();
                }
                else
                {
                    chunkSize += text.Length + encodedEndOfText.Length;
                    var encoded = encoder.Encode(text);
                    chunk.AddRange(encoded);
                    chunk.Add(encodedEndOfText);
                }
            }
            if (chunk.Count > 0)
            {
                AddChunk();
            }
            return(result);
        }
        internal static List <ndarray> LoadDataset(Gpt2Encoder encoder, List <string> fileNames)
        {
            if (encoder is null)
            {
                throw new ArgumentNullException(nameof(encoder));
            }

            var tokenChunks = new List <ndarray>();

            foreach (string file in fileNames)
            {
                Debug.WriteLine($"Reading {file}");
                if (Path.GetExtension(file) == ".npz")
                {
                    // pre-encoded
                    dynamic npzObject = np.load(file);
                    var     npz       = npzObject.__enter__();
                    foreach (var item in npz.files)
                    {
                        tokenChunks.Add(npz[item]);
                    }
                    npzObject.__exit__();
                }
                else
                {
                    string rawText = File.ReadAllText(file);
                    if (String.IsNullOrWhiteSpace(rawText))
                    {
                        continue;
                    }
                    var tokens = np.stack(encoder.Encode(rawText));
                    tokenChunks.Add(tokens);
                }
            }

            return(tokenChunks);
        }
        /// <summary>
        /// Interactively run the model
        /// </summary>
        /// <param name="modelName">Which model to use</param>
        /// <param name="checkpoint">Which checkpoint to load</param>
        /// <param name="seed">Seed for random number generators, fix seed to reproduce results</param>
        /// <param name="sampleCount">Number of samples to return total</param>
        /// <param name="batchSize">Number of batches (only affects speed/memory).  Must divide sampleCount.</param>
        /// <param name="length">Number of tokens in generated text, if null (default), is
        ///     determined by model hyperparameters</param>
        /// <param name="temperature">randomness in boltzmann distribution.
        ///     Lower temperature results in less random completions. As the
        ///     temperature approaches zero, the model will become deterministic and
        ///     repetitive. Higher temperature results in more random completions.</param>
        /// <param name="topK">Controls diversity. 1 means only 1 word is
        ///     considered for each step (token), resulting in deterministic completions,
        ///     while 40 means 40 words are considered at each step. 0 (default) is a
        ///     special setting meaning no restrictions. 40 generally is a good value.
        /// </param>
        public static void Run(string modelName = "117M", string checkpoint = null, int?seed = null,
                               int sampleCount  = 1,
                               int batchSize    = 1, int?length = null, float temperature = 1, int topK = 0)
        {
            if (sampleCount % batchSize != 0)
            {
                throw new ArgumentException();
            }

            var encoder = Gpt2Encoder.LoadEncoder(modelName);
            var hParams = Gpt2Model.LoadHParams(modelName);

            int nCtx = ((dynamic)hParams).n_ctx;

            if (length is null)
            {
                length = nCtx;
            }
            else if (length > nCtx)
            {
                throw new ArgumentException("Can't get samples longer than window size: " + hParams.get("n_ctx"));
            }

            new Session(graph: new Graph()).UseSelf(sess => {
                var context = tf.placeholder(tf.int32, new TensorShape(batchSize, null));
                tf.set_random_seed(seed);

                var output = Gpt2Sampler.SampleSequence(
                    hParams: hParams,
                    length: length.Value,
                    context: context,
                    batchSize: batchSize,
                    temperature: temperature,
                    topK: topK);

                var saver  = new Saver();
                checkpoint = checkpoint ?? tf.train.latest_checkpoint(Path.Combine("models", modelName));
                saver.restore(sess, checkpoint);

                bool interrupted        = false;
                Console.CancelKeyPress += (object sender, ConsoleCancelEventArgs args) =>
                                          Volatile.Write(ref interrupted, args.Cancel = true);

                while (!interrupted)
                {
                    string text;
                    do
                    {
                        Console.Write("Model prompt >>> ");
                        text = Console.ReadLine();
                        if (Volatile.Read(ref interrupted))
                        {
                            break;
                        }
                        if (string.IsNullOrEmpty(text))
                        {
                            Console.WriteLine("Prompt should not be empty");
                        }
                    } while (!Volatile.Read(ref interrupted) && string.IsNullOrEmpty(text));

                    if (Volatile.Read(ref interrupted))
                    {
                        break;
                    }

                    var contextTokens = encoder.Encode(text);
                    if (!tf.test.is_gpu_available() && contextTokens.Count >= length.Value)
                    {
                        Console.Error.WriteLine();
                        Console.Error.WriteLine("Prompt is too long.");
                        Console.Error.WriteLine();
                        continue;
                    }
                    int generated = 0;
                    foreach (var _ in Enumerable.Range(0, sampleCount / batchSize))
                    {
                        ndarray <int> @out;
                        try {
                            @out = sess.run(output, feed_dict: new Dictionary <object, object> {
                                [context] = Enumerable.Repeat(contextTokens, batchSize).ToArray(),
                            })[.., contextTokens.Count..];
                        } catch (InvalidArgumentError ex) {
                            throw new ArgumentOutOfRangeException(
                                "Unable to generate sequence of desired length. "
                                + "Try lowering length by passing -l (-sample-length) parameter. "
                                + "Current length: " + length.Value,
                                innerException: ex);
                        }

                        foreach (int i in Enumerable.Range(0, batchSize))
                        {
                            generated++;
                            var part = (ndarray <int>)@out[i];
                            text     = encoder.Decode(part);
                            Console.WriteLine($"{Delimiter} SAMPLE {generated} {Delimiter}");
                            Console.WriteLine(text);
                        }
                    }
                    Console.Write(Delimiter);
                    Console.WriteLine(Delimiter);
                }
            });
        }