protected DeepLearningRunner(DeviceDescriptor device, TrainingSessionConfiguration configuration, IMessagePrinter printer)
        {
            Device = device ?? throw new System.ArgumentNullException(nameof(device));

            //TODO: User more detailed validation, ie. check all required props/fields for runner to work
            Configuration       = configuration ?? throw new System.ArgumentNullException(nameof(configuration));
            this.MessagePrinter = printer ?? throw new ArgumentNullException(nameof(printer));
        }
示例#2
0
        private static void RunEmnistTraining(string choice)
        {
            ITrainingDatasetDefinition datasetDefinition = null;

            switch (choice)
            {
            case LETTERS_CHOICE:
                datasetDefinition = new EMNISTLetterDataset();
                break;

            case DIGITS_CHOICE:
                datasetDefinition = new EMNISTDigitDataset();
                break;

            case UPPERCASE_LETTERS_CHOICE:
                datasetDefinition = new EMNISTUppercaseLetterDataset();
                break;

            default:
                SharedConsoleCommands.InvalidCommand(choice);
                return;
            }

            TrainingSessionStart(choice);
            var msgPrinter = new ConsolePrinter();

            var outputDir             = $"./{DateTime.Now.ToString("yyyyMMddHHmmss", CultureInfo.InvariantCulture)}/";
            var device                = DeviceDescriptor.GPUDevice(0);
            var trainingConfiguration = new TrainingSessionConfiguration
            {
                Epochs = 200,
                DumpModelSnapshotPerEpoch  = true,
                ProgressEvaluationSeverity = EvaluationSeverity.PerEpoch,
                MinibatchConfig            = new MinibatchConfiguration
                {
                    MinibatchSize = 64,
                    HowManyMinibatchesPerSnapshot      = (60000 / 32),
                    HowManyMinibatchesPerProgressPrint = 500,
                    DumpModelSnapshotPerMinibatch      = false,
                    AsyncMinibatchSnapshot             = false
                },
                PersistenceConfig = TrainingModelPersistenceConfiguration.CreateWithAllLocationsSetTo(outputDir)
            };

            msgPrinter.PrintMessage("\n" + trainingConfiguration + "\n");

            using (var runner = new ConvolutionalNeuralNetworkRunner(device, trainingConfiguration, msgPrinter))
            {
                runner.RunUsing(datasetDefinition);
            }

            EmnistTrainingDone(choice);
        }
 public ConvolutionalNeuralNetworkRunner(DeviceDescriptor device, TrainingSessionConfiguration configuration, IMessagePrinter printer)
     : base(device, configuration, printer)
 {
 }