Esempio n. 1
0
        public static Mnist Load()
        {
            var x = new Mnist();

            x.ReadDataSets(@"D:\人工智能\C#代码\MNISTTensorFlowSharp\MNISTTensorFlowSharp\data");
            return(x);
        }
Esempio n. 2
0
        static void KNN()
        {
            //取得数据
            var mnist = Mnist.Load();

            //拿5000个训练数据,200个测试数据
            const int trainCount = 5000;
            const int testCount  = 200;

            //获得的数据有两个
            //一个是图片,它们都是28*28的
            //一个是one-hot的标签,它们都是1*10的
            (var trainingImages, var trainingLabels) = mnist.GetTrainReader().NextBatch(trainCount);
            (var testImages, var testLabels)         = mnist.GetTestReader().NextBatch(testCount);

            Console.WriteLine($"MNIST 1NN");

            //建立一个图表示计算任务
            using (var graph = new TFGraph())
            {
                var session = new TFSession(graph);

                //用来feed数据的占位符。trainingInput表示N张用来进行训练的图片,N是一个变量,所以这里使用-1
                TFOutput trainingInput = graph.Placeholder(TFDataType.Float, new TFShape(-1, 784));

                //xte表示一张用来测试的图片
                TFOutput xte = graph.Placeholder(TFDataType.Float, new TFShape(784));

                //计算这两张图片的L1距离。这很简单,实际上就是把784个数字逐对相减,然后取绝对值,最后加起来变成一个总和
                var distance = graph.ReduceSum(graph.Abs(graph.Sub(trainingInput, xte)), axis: graph.Const(1));

                //这里只是用了最近的那个数据
                //也就是说,最近的那个数据是什么,那pred(预测值)就是什么
                TFOutput pred = graph.ArgMin(distance, graph.Const(0));

                var accuracy = 0f;

                //开始循环进行计算,循环trainCount次
                for (int i = 0; i < testCount; i++)
                {
                    var runner = session.GetRunner();

                    //每次,对一张新的测试图,计算它和trainCount张训练图的距离,并获得最近的那张
                    var result = runner.Fetch(pred).Fetch(distance)
                                 //trainCount张训练图(数据是trainingImages)
                                 .AddInput(trainingInput, trainingImages)
                                 //testCount张测试图(数据是从testImages中拿出来的)
                                 .AddInput(xte, Extract(testImages, i))
                                 .Run();

                    //最近的点的序号
                    var nn_index = (int)(long)result[0].GetValue();

                    //从trainingLabels中找到答案(这是预测值)
                    var prediction = ArgMax(trainingLabels, nn_index);

                    //正确答案位于testLabels[i]中
                    var real = ArgMax(testLabels, i);

                    //PrintImage(testImages, i);

                    Console.WriteLine($"测试 {i}: " +
                                      $"预测: {prediction} " +
                                      $"正确答案: {real} (最近的点的序号={nn_index})");
                    //Console.WriteLine(testImages);

                    if (prediction == real)
                    {
                        accuracy += 1f / testCount;
                    }
                }
                Console.WriteLine("准确率: " + accuracy);

                session.CloseSession();
            }
        }