void train_network() { var mnist = new Datasets.MNIST(); var trainSteps = mnist.train_images.Length / batch_size; var validationSteps = mnist.test_images.Length / batch_size; var wrapped_trainGenerator = new InMemoryMiniBatchGenerator( new float[][][] { mnist.train_images, mnist.train_labels }, new C.Variable[] { imageVariable, categoricalVariable }, batch_size, true, false, "training"); var trainGenerator = new MultiThreadedGenerator(4, wrapped_trainGenerator); var wrapped_testGenerator = new InMemoryMiniBatchGenerator( new float[][][] { mnist.test_images, mnist.test_labels }, new C.Variable[] { imageVariable, categoricalVariable }, batch_size, true, false, "testing"); var validationGenerator = new MultiThreadedGenerator(4, wrapped_testGenerator); var epochs = 6; Model.fit_generator( network, learner, trainer, evaluator, batch_size, epochs, trainGenerator, trainSteps, validationGenerator, validationSteps, computeDevice, "mnist_"); trainGenerator.Dispose(); trainGenerator = null; validationGenerator.Dispose(); validationGenerator = null; }
bool show_translations(bool use_network) { if (use_network) { network = load_best_network(); } var mnist = new Datasets.MNIST(); var image_size = new int[] { input_shape[0], input_shape[1] }; var allData = new Data(mnist.train_images, image_size).All(); for (int index = 0; index < 1000; index++) { var img = allData[0][index]; var T = allData[1][index]; var img_t = allData[2][index]; var label = mnist.train_labels[index]; var digit = Util.argmax(label); var mat = convertFloatRowTo2D(img, 28, 28); var window_name = $"mnist_{index}_Digit_{digit}"; // allow the window to be resized OpenCvSharp.Cv2.NamedWindow(window_name, OpenCvSharp.WindowMode.Normal); OpenCvSharp.Cv2.ImShow(window_name, mat); var mat_t = convertFloatRowTo2D(img_t, 28, 28); var window_name_T = $"{T[0]}_{T[1]}"; OpenCvSharp.Cv2.NamedWindow(window_name_T, OpenCvSharp.WindowMode.Normal); OpenCvSharp.Cv2.ImShow(window_name_T, mat_t); var predicted_mat_t = new OpenCvSharp.Mat(); if (network != null) { var prediction = Model.predict(network, new float[][] { img, T }, computeDevice); var predicted_img_t = prediction[0].ToArray(); predicted_mat_t = convertFloatRowTo2D(predicted_img_t, 28, 28); var window_name_predicted_T = $"predicted_{T[0]}_{T[1]}"; OpenCvSharp.Cv2.NamedWindow(window_name_predicted_T, OpenCvSharp.WindowMode.Normal); OpenCvSharp.Cv2.ImShow(window_name_predicted_T, predicted_mat_t); } OpenCvSharp.Cv2.WaitKey(); OpenCvSharp.Cv2.DestroyAllWindows(); //OpenCvSharp.Cv2.DestroyWindow(window_name); //OpenCvSharp.Cv2.DestroyWindow(window_name_T); GC.KeepAlive(mat); GC.KeepAlive(mat_t); GC.KeepAlive(predicted_mat_t); } return(false); }
void run() { //if ( show_translations(use_network: true)==false ) { return; } Console.Title = "Transforming AutoEncoders"; Console.WriteLine("Using: " + computeDevice.AsString()); create_network(); var mnist = new Datasets.MNIST(); var image_size = new int[] { input_shape[0], input_shape[1] }; var train_data = new Data(mnist.train_images, image_size); var test_data = new Data(mnist.test_images, image_size); train_network(train_data, test_data); }
void train_network() { var mnist = new Datasets.MNIST(); var trainSteps = mnist.train_images.Length / batch_size; var validationSteps = mnist.test_images.Length / batch_size; var trainGenerator = new InMemoryMiniBatchGenerator( new float[][][] { mnist.train_images, mnist.train_labels }, new C.Variable[] { imageVariable, categoricalLabel }, batch_size, shuffle: true, only_once: false, name: "train"); var mtTrainGenerator = new MultiThreadedGenerator(workers: 4, generator: trainGenerator); var validationGenerator = new InMemoryMiniBatchGenerator( new float[][][] { mnist.test_images, mnist.test_labels }, new C.Variable[] { imageVariable, categoricalLabel }, batch_size, shuffle: false, only_once: false, name: "validation"); var mtValidationGenerator = new MultiThreadedGenerator(workers: 4, generator: validationGenerator); var epochs = 6; Model.fit_generator( network, learner, trainer, evaluator, batch_size, epochs, mtTrainGenerator, trainSteps, mtValidationGenerator, validationSteps, computeDevice, prefix: "capsules_", trainingLossMetricName: "Training Loss", trainingEvaluationMetricName: "Training Error", validationMetricName: "Validation Error"); mtTrainGenerator.Dispose(); mtValidationGenerator.Dispose(); }