Beispiel #1
0
        void train_network()
        {
            var mnist           = new Datasets.MNIST();
            var trainSteps      = mnist.train_images.Length / batch_size;
            var validationSteps = mnist.test_images.Length / batch_size;

            var wrapped_trainGenerator = new InMemoryMiniBatchGenerator(
                new float[][][] { mnist.train_images, mnist.train_labels },
                new C.Variable[] { imageVariable, categoricalVariable },
                batch_size, true, false, "training");
            var trainGenerator = new MultiThreadedGenerator(4, wrapped_trainGenerator);

            var wrapped_testGenerator = new InMemoryMiniBatchGenerator(
                new float[][][] { mnist.test_images, mnist.test_labels },
                new C.Variable[] { imageVariable, categoricalVariable },
                batch_size, true, false, "testing");
            var validationGenerator = new MultiThreadedGenerator(4, wrapped_testGenerator);

            var epochs = 6;

            Model.fit_generator(
                network,
                learner,
                trainer,
                evaluator,
                batch_size,
                epochs,
                trainGenerator, trainSteps,
                validationGenerator, validationSteps,
                computeDevice, "mnist_");

            trainGenerator.Dispose(); trainGenerator           = null;
            validationGenerator.Dispose(); validationGenerator = null;
        }
Beispiel #2
0
        void train_network()
        {
            var mnist           = new Datasets.MNIST();
            var trainSteps      = mnist.train_images.Length / batch_size;
            var validationSteps = mnist.test_images.Length / batch_size;

            var trainGenerator = new InMemoryMiniBatchGenerator(
                new float[][][] { mnist.train_images, mnist.train_labels },
                new C.Variable[] { imageVariable, categoricalLabel },
                batch_size, shuffle: true, only_once: false, name: "train");
            var mtTrainGenerator = new MultiThreadedGenerator(workers: 4, generator: trainGenerator);

            var validationGenerator = new InMemoryMiniBatchGenerator(
                new float[][][] { mnist.test_images, mnist.test_labels },
                new C.Variable[] { imageVariable, categoricalLabel },
                batch_size, shuffle: false, only_once: false, name: "validation");
            var mtValidationGenerator = new MultiThreadedGenerator(workers: 4, generator: validationGenerator);

            var epochs = 6;

            Model.fit_generator(
                network,
                learner,
                trainer,
                evaluator,
                batch_size,
                epochs,
                mtTrainGenerator, trainSteps,
                mtValidationGenerator, validationSteps,
                computeDevice,
                prefix: "capsules_",
                trainingLossMetricName: "Training Loss",
                trainingEvaluationMetricName: "Training Error",
                validationMetricName: "Validation Error");

            mtTrainGenerator.Dispose();
            mtValidationGenerator.Dispose();
        }
Beispiel #3
0
        void train_network(Data train_data, Data test_data)
        {
            // For a discussion on the MNIST training/validation sets,
            // see https://github.com/keras-team/keras/issues/1753
            var trainSteps      = train_data.numSteps(batch_size);
            var validationSteps = test_data.numSteps(batch_size);

            var inner_trainGenerator = new InMemoryMiniBatchGenerator(
                train_data.All(), allInputVariables(),
                batch_size, true, false, "training");
            var trainGenerator = new MultiThreadedGenerator(4, inner_trainGenerator);

            var inner_validationGenerator = new InMemoryMiniBatchGenerator(
                test_data.All(), allInputVariables(),
                batch_size, true, false, "validation");
            var validationGenerator = new MultiThreadedGenerator(4, inner_validationGenerator);

            Model.fit_generator(network, learner, trainer, evaluator, batch_size, epochs,
                                trainGenerator, trainSteps, validationGenerator, validationSteps, computeDevice, prefix: "ta_");

            trainGenerator.Dispose(); trainGenerator           = null;
            validationGenerator.Dispose(); validationGenerator = null;
        }