void train_network() { var mnist = new Datasets.MNIST(); var trainSteps = mnist.train_images.Length / batch_size; var validationSteps = mnist.test_images.Length / batch_size; var wrapped_trainGenerator = new InMemoryMiniBatchGenerator( new float[][][] { mnist.train_images, mnist.train_labels }, new C.Variable[] { imageVariable, categoricalVariable }, batch_size, true, false, "training"); var trainGenerator = new MultiThreadedGenerator(4, wrapped_trainGenerator); var wrapped_testGenerator = new InMemoryMiniBatchGenerator( new float[][][] { mnist.test_images, mnist.test_labels }, new C.Variable[] { imageVariable, categoricalVariable }, batch_size, true, false, "testing"); var validationGenerator = new MultiThreadedGenerator(4, wrapped_testGenerator); var epochs = 6; Model.fit_generator( network, learner, trainer, evaluator, batch_size, epochs, trainGenerator, trainSteps, validationGenerator, validationSteps, computeDevice, "mnist_"); trainGenerator.Dispose(); trainGenerator = null; validationGenerator.Dispose(); validationGenerator = null; }
void train_network() { var mnist = new Datasets.MNIST(); var trainSteps = mnist.train_images.Length / batch_size; var validationSteps = mnist.test_images.Length / batch_size; var trainGenerator = new InMemoryMiniBatchGenerator( new float[][][] { mnist.train_images, mnist.train_labels }, new C.Variable[] { imageVariable, categoricalLabel }, batch_size, shuffle: true, only_once: false, name: "train"); var mtTrainGenerator = new MultiThreadedGenerator(workers: 4, generator: trainGenerator); var validationGenerator = new InMemoryMiniBatchGenerator( new float[][][] { mnist.test_images, mnist.test_labels }, new C.Variable[] { imageVariable, categoricalLabel }, batch_size, shuffle: false, only_once: false, name: "validation"); var mtValidationGenerator = new MultiThreadedGenerator(workers: 4, generator: validationGenerator); var epochs = 6; Model.fit_generator( network, learner, trainer, evaluator, batch_size, epochs, mtTrainGenerator, trainSteps, mtValidationGenerator, validationSteps, computeDevice, prefix: "capsules_", trainingLossMetricName: "Training Loss", trainingEvaluationMetricName: "Training Error", validationMetricName: "Validation Error"); mtTrainGenerator.Dispose(); mtValidationGenerator.Dispose(); }
void train_network(Data train_data, Data test_data) { // For a discussion on the MNIST training/validation sets, // see https://github.com/keras-team/keras/issues/1753 var trainSteps = train_data.numSteps(batch_size); var validationSteps = test_data.numSteps(batch_size); var inner_trainGenerator = new InMemoryMiniBatchGenerator( train_data.All(), allInputVariables(), batch_size, true, false, "training"); var trainGenerator = new MultiThreadedGenerator(4, inner_trainGenerator); var inner_validationGenerator = new InMemoryMiniBatchGenerator( test_data.All(), allInputVariables(), batch_size, true, false, "validation"); var validationGenerator = new MultiThreadedGenerator(4, inner_validationGenerator); Model.fit_generator(network, learner, trainer, evaluator, batch_size, epochs, trainGenerator, trainSteps, validationGenerator, validationSteps, computeDevice, prefix: "ta_"); trainGenerator.Dispose(); trainGenerator = null; validationGenerator.Dispose(); validationGenerator = null; }