コード例 #1
0
ファイル: Program.cs プロジェクト: benketriel/sknn
        //Neuron Network on twitter
        private static void LearnImage()
        {
            //Bitmap img = new Bitmap(@"C:\SK\Programming\VSProjects\Stuff\Files\icon_smile2.bmp");
            //Bitmap img = new Bitmap(@"C:\SK\Programming\VSProjects\Stuff\Files\twitter.png");
            Bitmap img = new Bitmap(@"C:\SK\Programming\VSProjects\Stuff\Files\chrome.png");

            //var ntwk = new Network(2, 1024*32, 3);
            //var ntwk = new Network(2, 50, 30, 15, 6, 3);
            //var ntwk = new Network(2, 1024, 1024, 1024, 3);
            var ntwk = new Network(2, 200, 100, 50, 25, 15, 10, 3);

            //Load existing network
            //if (File.Exists(ntwkPath))
            //{
            //    using (Stream stream = File.Open(ntwkPath, FileMode.Open))
            //    {
            //        var binaryFormatter = new System.Runtime.Serialization.Formatters.Binary.BinaryFormatter();
            //        ntwk = Serialization.DeSerialize((float[])binaryFormatter.Deserialize(stream));
            //    }
            //}

            //var cuda = new CudaGo(Enumerable.Range(0, ntwk.Layers.Length - 1).Max(i => ntwk.Layers[i].Length * ntwk.Layers[i + 1].Length));
            //cuda.PushToGpu(ntwk, img);

            var xx = EnumerateImgCoords(img).ToArray();
            var yy = EnumerateImgColors(img).ToArray();

            var time = new Stopwatch();
            time.Start();
            int iter = 0;
            while(++iter > 0)
            {
                ntwk.Train(xx, yy, iter);

                //cuda.Learn(alpha);

                Console.SetCursorPosition(0, Console.CursorTop);
                Console.Write("It " + iter + ", " + (time.ElapsedMilliseconds / iter) + " mill/it. Alpha=" + Hyperparams.Alpha(iter));
                if (iter % 100 == 0) //Test
                {
                    //ntwk = cuda.PeekFromGpu();

                    //Guess
                    var resImg = new Bitmap(img.Width, img.Height);
                    var avgError = 0.0f;
                    for (int w = 0; w < img.Width; ++w)
                    {
                        for (int h = 0; h < img.Height; ++h)
                        {
                            var v = ntwk.Test(new float[] {
                                ((float)w) / img.Width,
                                ((float)h) / img.Height });
                            var r = Math.Max(0, Math.Min(255, (int)(v[0] * 255.0)));
                            var g = Math.Max(0, Math.Min(255, (int)(v[1] * 255.0)));
                            var b = Math.Max(0, Math.Min(255, (int)(v[2] * 255.0)));
                            var c = Color.FromArgb(r, g, b);
                            var p = img.GetPixel(w, h);
                            avgError += Math.Abs(p.R - r);
                            avgError += Math.Abs(p.G - g);
                            avgError += Math.Abs(p.B - b);
                            resImg.SetPixel(w, h, c);
                        }
                    }
                    avgError /= img.Width * img.Height * 3;
                    resImg.Save(savePath(iter) + ".bmp");
                    Console.SetCursorPosition(0, Console.CursorTop);
                    Console.WriteLine("It " + iter + ", " + (time.ElapsedMilliseconds / iter) + " mill/it. Alpha=" + Hyperparams.Alpha(iter) + ", err=" + avgError);

                    //int name = 0;
                    //foreach (var layer in ntwk.Layers.Skip(1).Take(ntwk.Layers.Count - 2))
                    //{
                    //    foreach (var neur in layer)
                    //    {
                    //        for (int w = 0; w < img.Width; ++w)
                    //        {
                    //            for (int h = 0; h < img.Height; ++h)
                    //            {
                    //                var v = ntwk.Test(new float[] { ((float)w) / img.Width, ((float)h) / img.Height });
                    //                var a = neur.Activation;
                    //                var c = Color.FromArgb(
                    //                    Math.Max(0, Math.Min(255, (int)(a * 255.0))),
                    //                    Math.Max(0, Math.Min(255, (int)(a * 255.0))),
                    //                    Math.Max(0, Math.Min(255, (int)(a * 255.0))));
                    //                resImg.SetPixel(w, h, c);
                    //            }
                    //        }
                    //        resImg.Save(savePath(iter) + name++ + ".bmp");

                    //    }
                    //}

                    //Save network
                    //{
                    //    using (Stream stream = File.Open(ntwkPath, FileMode.Create))
                    //    {
                    //        var binaryFormatter = new System.Runtime.Serialization.Formatters.Binary.BinaryFormatter();
                    //        binaryFormatter.Serialize(stream, Serialization.Serialize(ntwk));
                    //    }
                    //}

                }
            }

            //Console.ReadKey();
        }
コード例 #2
0
ファイル: Program.cs プロジェクト: benketriel/sknn
        //Neuron Network on numbers
        private static void LearnMNIST()
        {
            Console.WriteLine("Loading files");
            var trainX = Matrix.Build.DenseOfRows(MNIST.LoadImages(@"C:\SK\Programming\VSProjects\Stuff\Files\data\common\train-images-idx3-ubyte")).Map(x => x / 255.0);
            var trainY = Vector.Build.Dense(MNIST.LoadLabels(@"C:\SK\Programming\VSProjects\Stuff\Files\data\common\train-labels-idx1-ubyte"));
            var testX = Matrix.Build.DenseOfRows(MNIST.LoadImages(@"C:\SK\Programming\VSProjects\Stuff\Files\data\common\t10k-images-idx3-ubyte")).Map(x => x / 255.0);
            var testY = Vector.Build.Dense(MNIST.LoadLabels(@"C:\SK\Programming\VSProjects\Stuff\Files\data\common\t10k-labels-idx1-ubyte"));
            Console.WriteLine("Learning");

            var ntwk = new Network(trainX.ColumnCount, 1024, 10);
            //Load existing network
            //if (File.Exists(ntwkPath))
            //{
            //    using (Stream stream = File.Open(ntwkPath, FileMode.Open))
            //    {
            //        var binaryFormatter = new System.Runtime.Serialization.Formatters.Binary.BinaryFormatter();
            //        ntwk = Serialization.DeSerialize((float[])binaryFormatter.Deserialize(stream));
            //    }
            //}

            var xx = trainX.EnumerateRows().Select(x => x.Select(y => (float)y).ToArray())/*.Take(1000)*/.ToArray();
            var yy = trainY.Select(y => new[] { (float)y })/*.Take(1000)*/.ToArray();

            var testXx = testX.EnumerateRows().Select(r => r.Select(t => (float)t).ToArray()).ToArray();

            var time = new Stopwatch();
            time.Start();
            int iter = 0;

            while (++iter > 0)
            {
                ntwk.TrainSoftMax(xx, yy, iter);

                //Save network
                //{
                //    using (Stream stream = File.Open(ntwkPath, FileMode.Create))
                //    {
                //        var binaryFormatter = new System.Runtime.Serialization.Formatters.Binary.BinaryFormatter();
                //        binaryFormatter.Serialize(stream, Serialization.Serialize(ntwk));
                //    }
                //}
                //Console.SetCursorPosition(0, Console.CursorTop);
                if (iter % 1 == 0)
                {
                    Console.WriteLine("It " + iter + ", " + (time.ElapsedMilliseconds / iter) + " mill/it. Alpha=" + Hyperparams.Alpha(iter));
                    int howManyToTest = 1000/*testY.Count*/;
                    var resImg = new Bitmap(28, 28);
                    int maxShow = 10;
                    int rightCount = 0;
                    for (int testI = 0; testI < howManyToTest; ++testI)
                    {
                        var x = testXx[testI];
                        var v = ntwk.Test(x);

                        var guess = Vector<float>.Build.DenseOfArray(v).MaximumIndex();
                        var right = testY[testI];
                        if (guess == right)
                        {
                            ++rightCount;
                            //resImg.Save(savePath(iter) + (testI) + "RIGHT" + guess + ".bmp");
                        }
                        else if(maxShow-- > 0)
                        {
                            for (int h = 0; h < 28; ++h)
                            {
                                for (int w = 0; w < 28; ++w)
                                {
                                    var c = Color.FromArgb(
                                        (int)(x[h * 28 + w] * 255.0),
                                        (int)(x[h * 28 + w] * 255.0),
                                        (int)(x[h * 28 + w] * 255.0));
                                    resImg.SetPixel(w, h, c);
                                }
                            }
                            resImg.Save(savePath(iter) + (testI) + "WRONG" + guess + ".bmp");
                        }
                    }
                    Console.WriteLine("Hit ratio " + rightCount + "/" + howManyToTest + " (" + ((float)rightCount / howManyToTest) + ")");

                }
            }
        }