public override int Run(string[] remainingArguments)
        {
            this.CheckRequiredArguments();
            if (remainingArguments.Length < 1)
            {
                throw new ArgumentNullException("dataset");
            }
            string datasetName = remainingArguments[0];
            string checkpoint  = Gpt2Trainer.ProcessCheckpointConfig(this.Checkpoint,
                                                                     modelName: this.ModelName,
                                                                     runName: this.RunName);

            var    encoder       = Gpt2Encoder.LoadEncoder(this.ModelName);
            string searchPattern = this.Include ?? "*";
            var    dataset       = searchPattern.EndsWith("*.csv")
                ? LoadCsv(encoder, root: datasetName, field: this.ColumnName)
                : Gpt2Trainer.LoadDataset(encoder, path: datasetName, pattern: searchPattern);
            var hParams = Gpt2Model.LoadHParams(this.ModelName);
            var random  = this.Seed == null ? new Random() : new Random(this.Seed.Value);
            var stop    = new CancellationTokenSource();

            Console.CancelKeyPress += delegate { stop.Cancel(); };
            new Gpt2Trainer(dataset, encoder, hParams, this.BatchSize, this.SampleLength, random)
            .Train(checkpoint, this.RunName, stop.Token);

            return(0);
        }
        public override int Run(string[] remainingArguments)
        {
            this.CheckRequiredArguments();
            if (remainingArguments.Length < 1)
            {
                throw new ArgumentNullException("dataset");
            }
            string datasetName = remainingArguments[0];
            string checkpoint  = Gpt2Checkpoints.ProcessCheckpointConfig(
                gpt2Root: Environment.CurrentDirectory,
                checkpoint: this.Checkpoint,
                modelName: this.ModelName,
                runName: this.RunName);

            var    encoder       = Gpt2Encoder.LoadEncoder(this.ModelName);
            string searchPattern = this.Include ?? "*";
            var    dataset       = searchPattern.EndsWith("*.csv")
                ? LoadCsv(encoder, root: datasetName, field: this.ColumnName)
                : Gpt2Dataset.LoadDataset(encoder, path: datasetName, pattern: searchPattern);

            if (dataset.Count == 0)
            {
                Console.Error.WriteLine("The dataset is empty!");
                return(-1);
            }
            var hParams = Gpt2Model.LoadHParams(this.ModelName);
            var random  = this.Seed == null ? new Random() : new Random(this.Seed.Value);
            var stop    = new CancellationTokenSource();

            Console.CancelKeyPress += delegate { stop.Cancel(); };
            dynamic config = config_pb2.ConfigProto();

            config.gpu_options.allow_growth = true;
            new Gpt2Trainer(dataset, encoder, hParams, this.BatchSize, this.SampleLength, random)
            {
                SaveEvery   = this.SaveEvery,
                SampleNum   = this.SampleNum,
                SampleEvery = this.SampleEvery,
            }
            .Train(checkpoint, this.RunName,
                   sessionConfig: config,
                   counter: checkpoint == "fresh" ? 1 : (int?)null,
                   cancellation: stop.Token);

            return(0);
        }
        /// <summary>
        /// Interactively run the model
        /// </summary>
        /// <param name="modelName">Which model to use</param>
        /// <param name="checkpoint">Which checkpoint to load</param>
        /// <param name="seed">Seed for random number generators, fix seed to reproduce results</param>
        /// <param name="sampleCount">Number of samples to return total</param>
        /// <param name="batchSize">Number of batches (only affects speed/memory).  Must divide sampleCount.</param>
        /// <param name="length">Number of tokens in generated text, if null (default), is
        ///     determined by model hyperparameters</param>
        /// <param name="temperature">randomness in boltzmann distribution.
        ///     Lower temperature results in less random completions. As the
        ///     temperature approaches zero, the model will become deterministic and
        ///     repetitive. Higher temperature results in more random completions.</param>
        /// <param name="topK">Controls diversity. 1 means only 1 word is
        ///     considered for each step (token), resulting in deterministic completions,
        ///     while 40 means 40 words are considered at each step. 0 (default) is a
        ///     special setting meaning no restrictions. 40 generally is a good value.
        /// </param>
        public static void Run(string modelName = "117M", string checkpoint = null, int?seed = null,
                               int sampleCount  = 1,
                               int batchSize    = 1, int?length = null, float temperature = 1, int topK = 0)
        {
            if (sampleCount % batchSize != 0)
            {
                throw new ArgumentException();
            }

            var encoder = Gpt2Encoder.LoadEncoder(modelName);
            var hParams = Gpt2Model.LoadHParams(modelName);

            int nCtx = ((dynamic)hParams).n_ctx;

            if (length is null)
            {
                length = nCtx;
            }
            else if (length > nCtx)
            {
                throw new ArgumentException("Can't get samples longer than window size: " + hParams.get("n_ctx"));
            }

            new Session(graph: new Graph()).UseSelf(sess => {
                var context = tf.placeholder(tf.int32, new TensorShape(batchSize, null));
                tf.set_random_seed(seed);

                var output = Gpt2Sampler.SampleSequence(
                    hParams: hParams,
                    length: length.Value,
                    context: context,
                    batchSize: batchSize,
                    temperature: temperature,
                    topK: topK);

                var saver  = new Saver();
                checkpoint = checkpoint ?? tf.train.latest_checkpoint(Path.Combine("models", modelName));
                saver.restore(sess, checkpoint);

                while (true)
                {
                    string text;
                    do
                    {
                        Console.Write("Model prompt >>> ");
                        text = Console.ReadLine();
                        if (string.IsNullOrEmpty(text))
                        {
                            Console.WriteLine("Prompt should not be empty");
                        }
                    } while (string.IsNullOrEmpty(text));

                    var contextTokens = encoder.Encode(text);
                    int generated     = 0;
                    foreach (var _ in Enumerable.Range(0, sampleCount / batchSize))
                    {
                        var @out = sess.run(output, feed_dict: new PythonDict <object, object> {
                            [context] = Enumerable.Repeat(contextTokens, batchSize),
                        })[Range.All, Range.StartAt(contextTokens.Count)];
                        foreach (int i in Enumerable.Range(0, batchSize))
                        {
                            generated++;
                            ndarray part = @out[i];
                            text         = encoder.Decode(part);
                            Console.WriteLine($"{Delimiter} SAMPLE {generated} {Delimiter}");
                            Console.WriteLine(text);
                        }
                    }
                    Console.Write(Delimiter);
                    Console.WriteLine(Delimiter);
                }
            });
        }
示例#4
0
        public void Train(string checkpoint, string run, int?counter, dynamic sessionConfig = null, CancellationToken cancellation = default)
        {
            Session sess = sessionConfig == null
                ? Session.NewDyn(config : sessionConfig)
                : new Session();

            sess.UseSelf(session => {
                var context   = tf.placeholder(tf.int32, new TensorShape(this.batchSize, null));
                var output    = Gpt2Model.Model(this.hParams, input: context);
                Tensor labels = context[Range.All, Range.StartAt(1)];
                Tensor logits = output["logits"][Range.All, Range.EndAt(new Index(1, fromEnd: true))];
                var loss      = tf.reduce_mean(
                    tf.nn.sparse_softmax_cross_entropy_with_logits_dyn(
                        labels: labels,
                        logits: logits));

                var sample = Gpt2Sampler.SampleSequence(
                    this.hParams,
                    length: this.sampleLength,
                    context: context,
                    batchSize: this.batchSize,
                    temperature: 1.0f,
                    topK: 40);

                var trainVars = tf.trainable_variables().Where((dynamic var) => var.name.Contains("model"));
                var optimizer = new AdamOptimizer(learning_rate: 0.0002).minimize(loss, var_list: trainVars);

                var saver = new Saver(
                    var_list: trainVars,
                    max_to_keep: 5,
                    keep_checkpoint_every_n_hours: 1);

                session.run(tf.global_variables_initializer());

                Console.WriteLine("Loading checkpoint " + checkpoint);
                saver.restore(session, checkpoint);

                Console.WriteLine("Loading dataset...");
                var sampler = new TrainingSampler(this.dataset, this.random);
                Console.WriteLine($"Dataset has {sampler.TokenCount} tokens");

                string counterFile = Path.Combine(Gpt2Checkpoints.CheckpointDir, run, "counter");
                if (counter == null && File.Exists(counterFile))
                {
                    counter = int.Parse(File.ReadAllText(counterFile), CultureInfo.InvariantCulture) + 1;
                }
                counter = counter ?? 1;

                string runCheckpointDir = Path.Combine(Gpt2Checkpoints.CheckpointDir, run);
                string runSampleDir     = Path.Combine(SampleDir, run);

                void Save()
                {
                    Directory.CreateDirectory(runCheckpointDir);
                    Console.WriteLine("Saving " + Path.Combine(runCheckpointDir, Invariant($"model-{counter}")));
                    saver.save(session,
                               Path.Combine(runCheckpointDir, "model"),
                               global_step: counter.Value);
                    File.WriteAllText(path: counterFile, contents: Invariant($"{counter}"));
                }

                void GenerateSamples()
                {
                    var contextTokens = np.array(new[] { this.encoder.EncodedEndOfText });
                    var allText       = new List <string>();
                    int index         = 0;
                    string text       = null;
                    while (index < this.SampleNum)
                    {
                        var @out = session.run(sample, feed_dict: new PythonDict <object, object> {
                            [context] = Enumerable.Repeat(contextTokens, this.batchSize),
                        });
                        foreach (int i in Enumerable.Range(0, Math.Min(this.SampleNum - index, this.batchSize)))
                        {
                            text = this.encoder.Decode(@out[i]);
                            text = Invariant($"======== SAMPLE {index + 1} ========\n{text}\n");
                            allText.Add(text);
                            index++;
                        }
                    }
                    Console.WriteLine(text);
                    Directory.CreateDirectory(runSampleDir);
                    File.WriteAllLines(
                        path: Path.Combine(runSampleDir, Invariant($"samples-{counter}")),
                        contents: allText);
                }

                var avgLoss   = (0.0, 0.0);
                var startTime = DateTime.Now;

                while (!cancellation.IsCancellationRequested)
                {
                    if (counter % this.SaveEvery == 0)
                    {
                        Save();
                    }
                    if (counter % this.SampleEvery == 0)
                    {
                        GenerateSamples();
                    }

                    var batch = Enumerable.Range(0, this.batchSize)
                                .Select(_ => sampler.Sample(1024))
                                .ToArray();

                    var placeholderValues = new PythonDict <object, object> {
                        [context] = batch,
                    };
                    var tuple = session.run_dyn((optimizer, loss), feed_dict: placeholderValues);

                    var lv = tuple.Item2;

                    avgLoss = (avgLoss.Item1 * 0.99 + lv, avgLoss.Item2 * 0.99 + 1);

                    Console.WriteLine($"[{counter} | {DateTime.Now-startTime}] loss={lv} avg={avgLoss.Item1/avgLoss.Item2}");

                    counter++;
                }

                Console.WriteLine("Interrupted");
                Save();
            });
        }
        public static Tensor SampleSequence(HParams hParams, int length,
                                            string startToken = null, int?batchSize = null, dynamic context = null,
                                            float temperature = 1, int topK         = 0)
        {
            if (((startToken == null) ^ (context == null)) == false)
            {
                throw new ArgumentException($"Exactly one of {nameof(startToken)} or {nameof(context)} has to be specified");
            }

            SortedDictionary <string, dynamic> Step(HParams @params, Tensor tokens, dynamic past = null)
            {
                var lmOutput = Gpt2Model.Model(hParams: @params, input: tokens, past: past, reuse: _ReuseMode.AUTO_REUSE);

                var    logits   = lmOutput["logits"][Range.All, Range.All, Range.EndAt((int)@params.get("n_vocab"))];
                Tensor presents = lmOutput["present"];

                int?[] pastShape = Gpt2Model.PastShape(hParams: @params, batchSize: batchSize);
                presents.set_shape_(pastShape.Cast <object>());

                return(new SortedDictionary <string, object>
                {
                    ["logits"] = logits,
                    ["presents"] = presents,
                });
            }

            Tensor result = null;

            new name_scope("sample_sequence").Use(_ =>
            {
                // Don't feed the last context token -- leave that to the loop below
                // TODO: Would be slightly faster if we called step on the entire context,
                // rather than leaving the last token transformer calculation to the while loop.
                var contextOutput = Step(hParams, context[Range.All, Range.EndAt(new Index(1, fromEnd: true))]);

                Tensor[] Body(object past, dynamic prev, object output)
                {
                    var nextOutputs = Step(hParams, prev[Range.All, tf.newaxis], past: past);
                    Tensor logits   = nextOutputs["logits"][Range.All, -1, Range.All] / tf.to_float(temperature);
                    logits          = TopLogits(logits, topK: topK);
                    var samples     = tf.multinomial_dyn(logits, num_samples: 1, output_dtype: tf.int32);
                    return(new Tensor[]
                    {
                        tf.concat(new [] { past, nextOutputs["presents"] }, axis: -2),
                        tf.squeeze(samples, axis: new[] { 1 }),
                        tf.concat(new [] { output, samples }, axis: 1),
                    });
                }

                bool True(object _a, object _b, object _c) => true;

                dynamic[] loopVars = new[] {
                    contextOutput["presents"],
                    context[Range.All, -1],
                    context,
                };
                TensorShape[] shapeInvariants = new[] {
                    new TensorShape(Gpt2Model.PastShape(hParams: hParams, batchSize: batchSize)),
                    new TensorShape(batchSize),
                    new TensorShape((int?)batchSize, (int?)null),
                };
                result = tf.while_loop(
                    cond: PythonFunctionContainer.Of <object, object, object, bool>(True),
                    body: PythonFunctionContainer.Of(new Func <object, object, object, Tensor[]>(Body)),
                    parallel_iterations: 10,
                    swap_memory: false,
                    name: null,
                    maximum_iterations: tf.constant(length),
                    loop_vars: loopVars,
                    shape_invariants: shapeInvariants,
                    back_prop: false)
                         [2];
            });
            return(result);
        }