Exemplo n.º 1
0
        public static void TrainModel()
        {
            //模型建立部分
            Model MNIST_CNN = new Model(new CrossEntropyLoss(), "MNIST_CNN_v1",
                                        new ConvolutionalLayer(new Direct(1), false, new Tuple <int, int>(28, 28), 5, 1, 6, ConvolutionMode.Narrow, new Normal(0, 1)),
                                        new PoolingLayer(3, false, 6, new Tuple <int, int>(24, 24), PoolingMode.Max),
                                        new ConvolutionalLayer(new Direct(1), false, new Tuple <int, int>(8, 8), 5, 6, 16, ConvolutionMode.Narrow, new Normal(0, 1)),
                                        new PoolingLayer(2, true, 16, new Tuple <int, int>(4, 4), PoolingMode.Average),
                                        new FullConnectLayer(new Direct(1), 64, 64, true, new Normal(0, 1)),
                                        new FullConnectLayer(new Sigmoid(), 64, 32, false, new Normal(0, 2)),
                                        new SoftMaxLayer(32, 10, new Normal(0, 2))
                                        );
            //声明评估器
            MNISTEvaluater evaluater = new MNISTEvaluater(new MNISTDataConverter(new MNISTClassifier()));
            //声明分类器
            IClassifier classifier = new MNISTClassifier();

            //建立环境
            Context.Context context = new Context.Context(MNIST_CNN, evaluater, classifier);
            //准备数据集
            //string trainDataPath = System.IO.Path.Combine(Application.StartupPath, "MNIST", "train-images.idx3-ubyte");
            //string trainLabelPath = System.IO.Path.Combine(Application.StartupPath, "MNIST", "train-labels.idx1-ubyte");
            //string testDataPath = System.IO.Path.Combine(Application.StartupPath, "MNIST", "t10k-images.idx3-ubyte");
            //string testLabelPath = System.IO.Path.Combine(Application.StartupPath, "MNIST", "t10k-labels.idx1-ubyte");

            var dataConverter = new MNISTDataConverter(classifier);

            var dataView = new MNISTDataView(dataConverter);

            //dataView.GetInputData()

            context.AddProcess(new Trainer(1000, false, 0, new AdaGradOptimizer(0.5)), 10);
            context.AddProcess(new Trainer(600, false, 0, new ExponentialDelayOptimizer(0.2, 0.9, 0.05)), 10);

            //context.Train(dataView.trainDataSet,dataView.verifyDataSet,)
        }