Exemple #1
0
        void train_network()
        {
            var mnist           = new Datasets.MNIST();
            var trainSteps      = mnist.train_images.Length / batch_size;
            var validationSteps = mnist.test_images.Length / batch_size;

            var wrapped_trainGenerator = new InMemoryMiniBatchGenerator(
                new float[][][] { mnist.train_images, mnist.train_labels },
                new C.Variable[] { imageVariable, categoricalVariable },
                batch_size, true, false, "training");
            var trainGenerator = new MultiThreadedGenerator(4, wrapped_trainGenerator);

            var wrapped_testGenerator = new InMemoryMiniBatchGenerator(
                new float[][][] { mnist.test_images, mnist.test_labels },
                new C.Variable[] { imageVariable, categoricalVariable },
                batch_size, true, false, "testing");
            var validationGenerator = new MultiThreadedGenerator(4, wrapped_testGenerator);

            var epochs = 6;

            Model.fit_generator(
                network,
                learner,
                trainer,
                evaluator,
                batch_size,
                epochs,
                trainGenerator, trainSteps,
                validationGenerator, validationSteps,
                computeDevice, "mnist_");

            trainGenerator.Dispose(); trainGenerator           = null;
            validationGenerator.Dispose(); validationGenerator = null;
        }
Exemple #2
0
        bool show_translations(bool use_network)
        {
            if (use_network)
            {
                network = load_best_network();
            }
            var mnist      = new Datasets.MNIST();
            var image_size = new int[] { input_shape[0], input_shape[1] };
            var allData    = new Data(mnist.train_images, image_size).All();

            for (int index = 0; index < 1000; index++)
            {
                var img   = allData[0][index];
                var T     = allData[1][index];
                var img_t = allData[2][index];

                var label = mnist.train_labels[index];
                var digit = Util.argmax(label);
                var mat   = convertFloatRowTo2D(img, 28, 28);

                var window_name = $"mnist_{index}_Digit_{digit}";
                // allow the window to be resized
                OpenCvSharp.Cv2.NamedWindow(window_name, OpenCvSharp.WindowMode.Normal);
                OpenCvSharp.Cv2.ImShow(window_name, mat);

                var mat_t         = convertFloatRowTo2D(img_t, 28, 28);
                var window_name_T = $"{T[0]}_{T[1]}";
                OpenCvSharp.Cv2.NamedWindow(window_name_T, OpenCvSharp.WindowMode.Normal);
                OpenCvSharp.Cv2.ImShow(window_name_T, mat_t);

                var predicted_mat_t = new OpenCvSharp.Mat();
                if (network != null)
                {
                    var prediction      = Model.predict(network, new float[][] { img, T }, computeDevice);
                    var predicted_img_t = prediction[0].ToArray();
                    predicted_mat_t = convertFloatRowTo2D(predicted_img_t, 28, 28);
                    var window_name_predicted_T = $"predicted_{T[0]}_{T[1]}";
                    OpenCvSharp.Cv2.NamedWindow(window_name_predicted_T, OpenCvSharp.WindowMode.Normal);
                    OpenCvSharp.Cv2.ImShow(window_name_predicted_T, predicted_mat_t);
                }

                OpenCvSharp.Cv2.WaitKey();
                OpenCvSharp.Cv2.DestroyAllWindows();
                //OpenCvSharp.Cv2.DestroyWindow(window_name);
                //OpenCvSharp.Cv2.DestroyWindow(window_name_T);

                GC.KeepAlive(mat);
                GC.KeepAlive(mat_t);
                GC.KeepAlive(predicted_mat_t);
            }
            return(false);
        }
Exemple #3
0
        void run()
        {
            //if ( show_translations(use_network: true)==false ) { return; }

            Console.Title = "Transforming AutoEncoders";
            Console.WriteLine("Using: " + computeDevice.AsString());

            create_network();

            var mnist      = new Datasets.MNIST();
            var image_size = new int[] { input_shape[0], input_shape[1] };
            var train_data = new Data(mnist.train_images, image_size);
            var test_data  = new Data(mnist.test_images, image_size);

            train_network(train_data, test_data);
        }
Exemple #4
0
        void train_network()
        {
            var mnist           = new Datasets.MNIST();
            var trainSteps      = mnist.train_images.Length / batch_size;
            var validationSteps = mnist.test_images.Length / batch_size;

            var trainGenerator = new InMemoryMiniBatchGenerator(
                new float[][][] { mnist.train_images, mnist.train_labels },
                new C.Variable[] { imageVariable, categoricalLabel },
                batch_size, shuffle: true, only_once: false, name: "train");
            var mtTrainGenerator = new MultiThreadedGenerator(workers: 4, generator: trainGenerator);

            var validationGenerator = new InMemoryMiniBatchGenerator(
                new float[][][] { mnist.test_images, mnist.test_labels },
                new C.Variable[] { imageVariable, categoricalLabel },
                batch_size, shuffle: false, only_once: false, name: "validation");
            var mtValidationGenerator = new MultiThreadedGenerator(workers: 4, generator: validationGenerator);

            var epochs = 6;

            Model.fit_generator(
                network,
                learner,
                trainer,
                evaluator,
                batch_size,
                epochs,
                mtTrainGenerator, trainSteps,
                mtValidationGenerator, validationSteps,
                computeDevice,
                prefix: "capsules_",
                trainingLossMetricName: "Training Loss",
                trainingEvaluationMetricName: "Training Error",
                validationMetricName: "Validation Error");

            mtTrainGenerator.Dispose();
            mtValidationGenerator.Dispose();
        }