void generate_text(TextGeneratingTrainingEngine engine, DataInfo di) { var random = new Random(2018); var start_index = (int)(random.NextDouble() * (di.text.Length - maxlen - 1)); var seed_generated_text = di.text.Substring(start_index, maxlen).Replace('\n', ' '); Console.WriteLine($"\nSeed: {seed_generated_text}"); var temperatures = new double[] { 0.2, 0.5, 1.0, 1.2 }; foreach (var temperature in temperatures) { var generated_text = seed_generated_text; for (int i = 0; i < 400; i++) { var sampled = generated_text.Select(v => (float)(di.char_indices[v])).ToArray(); var preds = engine.evaluate(new float[][] { sampled }, engine.softmaxOutput)[0].Take(di.chars.Length).ToArray(); var next_index = sample(random, preds, temperature); var next_char = di.chars[next_index]; if (next_char == '\n') { next_char = ' '; } generated_text = generated_text.Substring(1) + next_char; } Console.WriteLine($"Randomly generated with temperature {temperature:F1}: {generated_text}"); } }
void run() { var di = new DataInfo(); var engine = new TextGeneratingTrainingEngine() { num_epochs = 32, batch_size = 128, sequence_length = maxlen, lossFunctionType = TrainingEngine.LossFunctionType.Custom, accuracyFunctionType = TrainingEngine.AccuracyFunctionType.SameAsLoss, metricType = TrainingEngine.MetricType.Loss }; engine.setData(di.x, di.y, null, null); engine.train(); generate_text(engine, di); }