Ejemplo n.º 1
0
        /// <summary>
        /// backward action
        /// </summary>
        /// <param name="lr"> lerning rate</param>
        /// <param name="gradTns"> grad error tensor</param>
        /// <returns> true - ok</returns>
        public bool backward(float lr, Tensor gradTns)
        {
            if ((net_ == null) && !createNet())
            {
                return(false);
            }

            return(snBackward(net_, lr, gradTns.size(), gradTns.data()));
        }
Ejemplo n.º 2
0
        /// <summary>
        /// forward action
        /// </summary>
        /// <param name="isLern"> is lerning ?</param>
        /// <param name="inTns"> in tensor</param>
        /// <param name="outTns"> out result tensor</param>
        /// <returns> true - ok</returns>
        public bool forward(bool isLern, Tensor inTns, Tensor outTns)
        {
            if ((net_ == null) && !createNet())
            {
                return(false);
            }

            return(snForward(net_, isLern, inTns.size(), inTns.data(), outTns.size(), outTns.data()));
        }
Ejemplo n.º 3
0
        /// <summary>
        /// set weight of node
        /// </summary>
        /// <param name="name"> name node in architecture of net</param>
        /// <param name="weight"> set weight tensor</param>
        /// <returns> true - ok</returns>
        public bool setWeightNode(string name, Tensor weight)
        {
            if (net_ == null)
            {
                return(false);
            }

            IntPtr cname = Marshal.StringToHGlobalAnsi(name);

            bool ok = snSetWeightNode(net_, cname, weight.size(), weight.data());

            Marshal.FreeHGlobal(cname);

            return(ok);
        }
Ejemplo n.º 4
0
        /// <summary>
        /// cycle forward-backward
        /// </summary>
        /// <param name="lr"> lerning rate</param>
        /// <param name="inTns"> in tensor</param>
        /// <param name="outTns"> out tensor</param>
        /// <param name="targetTns"> target tensor</param>
        /// <param name="outAccurate"> accurate error</param>
        /// <returns> true - ok</returns>
        public bool training(float lr, Tensor inTns, Tensor outTns, Tensor targetTns, ref float outAccurate)
        {
            if ((net_ == null) && !createNet())
            {
                return(false);
            }

            float accurate = 0;

            bool ok = snTraining(net_, lr, inTns.size(), inTns.data(),
                                 outTns.size(), outTns.data(), targetTns.data(), &accurate);

            outAccurate = accurate;

            return(ok);
        }
Ejemplo n.º 5
0
        static void Main(string[] args)
        {
            sn.Net snet = new sn.Net();

            string ver = snet.versionLib();

            Console.WriteLine("Version snlib " + ver);

            snet.addNode("Input", new sn.Input(), "C1")
            .addNode("C1", new sn.Convolution(15, 0, sn.calcMode.type.CUDA), "C2")
            .addNode("C2", new sn.Convolution(15, 0, sn.calcMode.type.CUDA), "P1")
            .addNode("P1", new sn.Pooling(sn.calcMode.type.CUDA), "FC1")
            .addNode("FC1", new sn.FullyConnected(128, sn.calcMode.type.CUDA), "FC2")
            .addNode("FC2", new sn.FullyConnected(10, sn::calcMode.type.CUDA), "LS")
            .addNode("LS", new sn.LossFunction(sn.lossType.type.softMaxToCrossEntropy), "Output");

            string imgPath = "c://C++//skyNet//example//mnist//images//";


            uint batchSz = 100, classCnt = 10, w = 28, h = 28; float lr = 0.001F;
            List <List <string> >       imgName   = new List <List <string> >();
            List <int>                  imgCntDir = new List <int>(10);
            Dictionary <string, Bitmap> images    = new Dictionary <string, Bitmap>();

            if (!loadImage(imgPath, classCnt, imgName, imgCntDir))
            {
                Console.WriteLine("Error 'loadImage' path: " + imgPath);
                Console.ReadKey();
                return;
            }

            string wpath = "c:/C++/w.dat";

            if (snet.loadAllWeightFromFile(wpath))
            {
                Console.WriteLine("Load weight ok path: " + wpath);
            }
            else
            {
                Console.WriteLine("Load weight err path: " + wpath);
            }


            sn.Tensor inLayer     = new sn.Tensor(new sn.snLSize(w, h, 1, batchSz));
            sn.Tensor targetLayer = new sn.Tensor(new sn.snLSize(classCnt, 1, 1, batchSz));
            sn.Tensor outLayer    = new sn.Tensor(new sn.snLSize(classCnt, 1, 1, batchSz));

            float accuratSumm = 0;

            for (int k = 0; k < 1000; ++k)
            {
                targetLayer.reset();
                Random rnd = new Random();

                for (int i = 0; i < batchSz; ++i)
                {
                    // directory
                    int ndir = rnd.Next(0, (int)classCnt);
                    while (imgCntDir[ndir] == 0)
                    {
                        ndir = rnd.Next(0, (int)classCnt);
                    }

                    // image
                    int nimg = rnd.Next(0, imgCntDir[ndir]);

                    // read
                    Bitmap img;
                    string fn = imgName[ndir][nimg];
                    if (images.ContainsKey(fn))
                    {
                        img = images[fn];
                    }
                    else
                    {
                        img = new Bitmap(fn);
                        images.Add(fn, img);
                    }

                    unsafe
                    {
                        float *refData = inLayer.data() + i * w * h;
                        int    nr = img.Height, nc = img.Width;
                        System.Drawing.Imaging.BitmapData bmd = img.LockBits(new Rectangle(0, 0, img.Width, img.Height),
                                                                             System.Drawing.Imaging.ImageLockMode.ReadWrite, img.PixelFormat);
                        IntPtr pt = bmd.Scan0;
                        for (int r = 0; r < nr; ++r)
                        {
                            for (int c = 0; c < nc; ++c)
                            {
                                refData[r * nc + c] = Marshal.ReadByte(pt);
                                pt += 4;
                            }
                        }
                        img.UnlockBits(bmd);

                        float *tarData = targetLayer.data() + classCnt * i;
                        tarData[ndir] = 1;
                    }
                }

                // training
                float accurat = 0;
                snet.training(lr, inLayer, outLayer, targetLayer, ref accurat);

                // calc error
                int accCnt = 0;
                unsafe
                {
                    float *targetData = targetLayer.data();
                    float *outData    = outLayer.data();
                    int    bsz        = (int)batchSz;
                    for (int i = 0; i < bsz; ++i)
                    {
                        float *refOutput = outData + i * classCnt;

                        float maxval    = refOutput[0];
                        int   maxOutInx = 0;
                        for (int j = 1; j < classCnt; ++j)
                        {
                            if (refOutput[j] > maxval)
                            {
                                maxval    = refOutput[j];
                                maxOutInx = j;
                            }
                        }

                        float *refTarget = targetData + i * classCnt;

                        maxval = refTarget[0];
                        int maxTargInx = 0;
                        for (int j = 1; j < classCnt; ++j)
                        {
                            if (refTarget[j] > maxval)
                            {
                                maxval     = refTarget[j];
                                maxTargInx = j;
                            }
                        }

                        if (maxTargInx == maxOutInx)
                        {
                            ++accCnt;
                        }
                    }
                }

                accuratSumm += (float)accCnt / batchSz;

                Console.WriteLine(k.ToString() + " accurate " + (accuratSumm / (k + 1)).ToString() + " " +
                                  snet.getLastErrorStr());
            }

            if (snet.saveAllWeightToFile(wpath))
            {
                Console.WriteLine("Save weight ok path: " + wpath);
            }
            else
            {
                Console.WriteLine("Save weight err path: " + wpath);
            }

            Console.ReadKey();
            return;
        }
Ejemplo n.º 6
0
        static void Main(string[] args)
        {
            sn.Net snet = new sn.Net();

            string ver = snet.versionLib();

            Console.WriteLine("Version snlib " + ver);

            snet.addNode("In", new sn.Input(), "C1")
            .addNode("C1", new sn.Convolution(10, -1), "C2")
            .addNode("C2", new sn.Convolution(10, 0), "P1 Crop1")
            .addNode("Crop1", new sn.Crop(new sn.rect(0, 0, 487, 487)), "Rsz1")
            .addNode("Rsz1", new sn.Resize(new sn.diap(0, 10), new sn.diap(0, 10)), "Conc1")
            .addNode("P1", new sn.Pooling(), "C3")

            .addNode("C3", new sn.Convolution(10, -1), "C4")
            .addNode("C4", new sn.Convolution(10, 0), "P2 Crop2")
            .addNode("Crop2", new sn.Crop(new sn.rect(0, 0, 247, 247)), "Rsz2")
            .addNode("Rsz2", new sn.Resize(new sn.diap(0, 10), new sn.diap(0, 10)), "Conc2")
            .addNode("P2", new sn.Pooling(), "C5")

            .addNode("C5", new sn.Convolution(10, 0), "C6")
            .addNode("C6", new sn.Convolution(10, 0), "DC1")
            .addNode("DC1", new sn.Deconvolution(10, 0), "Rsz3")
            .addNode("Rsz3", new sn.Resize(new sn.diap(0, 10), new sn.diap(10, 20)), "Conc2")

            .addNode("Conc2", new sn.Concat("Rsz2 Rsz3"), "C7")

            .addNode("C7", new sn.Convolution(10, 0), "C8")
            .addNode("C8", new sn.Convolution(10, 0), "DC2")
            .addNode("DC2", new sn.Deconvolution(10, 0), "Rsz4")
            .addNode("Rsz4", new sn.Resize(new sn.diap(0, 10), new sn.diap(10, 20)), "Conc1")

            .addNode("Conc1", new sn.Concat("Rsz1 Rsz4"), "C9")

            .addNode("C9", new sn.Convolution(10, 0), "C10");

            sn.Convolution convOut = new sn.Convolution(1, 0);
            convOut.act = new sn.active(sn.active.type.sigmoid);
            snet.addNode("C10", convOut, "LS")
            .addNode("LS", new sn.LossFunction(sn.lossType.type.binaryCrossEntropy), "Output");

            string imgPath  = "c://cpp//other//sunnet//example//unet//images//";
            string targPath = "c://cpp//other//sunnet//example//unet//labels//";


            uint          batchSz = 3, w = 512, h = 512, wo = 483, ho = 483; float lr = 0.001F;
            List <string> imgName  = new List <string>();
            List <string> targName = new List <string>();

            if (!loadImage(imgPath, imgName) ||
                !loadImage(targPath, targName))
            {
                Console.WriteLine("Error 'loadImage' path: " + imgPath);
                Console.ReadKey();
                return;
            }

            string wpath = "c:/cpp/w.dat";

            //  if (snet.loadAllWeightFromFile(wpath))
            //     Console.WriteLine("Load weight ok path: " + wpath);
            // else
            //     Console.WriteLine("Load weight err path: " + wpath);


            sn.Tensor inLayer     = new sn.Tensor(new sn.snLSize(w, h, 1, batchSz));
            sn.Tensor targetLayer = new sn.Tensor(new sn.snLSize(wo, ho, 1, batchSz));
            sn.Tensor outLayer    = new sn.Tensor(new sn.snLSize(wo, ho, 1, batchSz));

            float accuratSumm = 0;

            for (int k = 0; k < 1000; ++k)
            {
                targetLayer.reset();
                Random rnd = new Random();

                for (int i = 0; i < batchSz; ++i)
                {
                    // image
                    int nimg = rnd.Next(0, imgName.Count);

                    // read
                    Bitmap img = new Bitmap(imgName[nimg]);
                    unsafe
                    {
                        float *refData = inLayer.data() + i * w * h;
                        int    nr = img.Height, nc = img.Width;
                        System.Drawing.Imaging.BitmapData bmd = img.LockBits(new Rectangle(0, 0, img.Width, img.Height),
                                                                             System.Drawing.Imaging.ImageLockMode.ReadWrite, img.PixelFormat);

                        IntPtr ptData = bmd.Scan0;
                        for (int r = 0; r < nr; ++r)
                        {
                            for (int c = 0; c < nc; ++c)
                            {
                                refData[r * nc + c] = Marshal.ReadByte(ptData);

                                ptData += 4;
                            }
                        }
                        img.UnlockBits(bmd);


                        Bitmap imgTrg = new Bitmap(new Bitmap(targName[nimg]), new Size((int)wo, (int)ho));
                        nr = imgTrg.Height; nc = imgTrg.Width;

                        float *targData = targetLayer.data() + i * wo * ho;

                        System.Drawing.Imaging.BitmapData bmdTrg = imgTrg.LockBits(new Rectangle(0, 0, nc, nr),
                                                                                   System.Drawing.Imaging.ImageLockMode.ReadWrite, imgTrg.PixelFormat);

                        IntPtr ptTrg = bmdTrg.Scan0;
                        for (int r = 0; r < nr; ++r)
                        {
                            for (int c = 0; c < nc; ++c)
                            {
                                targData[r * nc + c] = (float)(Marshal.ReadByte(ptTrg) / 255.0);

                                ptTrg += 4;
                            }
                        }
                        imgTrg.UnlockBits(bmdTrg);
                    }
                }

                // training
                float accurat = 0;
                snet.training(lr, inLayer, outLayer, targetLayer, ref accurat);

                // calc error
                accuratSumm += accurat;

                Console.WriteLine(k.ToString() + " accurate " + (accuratSumm / (k + 1)).ToString() + " " +
                                  snet.getLastErrorStr());
            }

            if (snet.saveAllWeightToFile(wpath))
            {
                Console.WriteLine("Save weight ok path: " + wpath);
            }
            else
            {
                Console.WriteLine("Save weight err path: " + wpath);
            }

            Console.ReadKey();
            return;
        }
Ejemplo n.º 7
0
        static void Main(string[] args)
        {
            // using python for create file 'resNet50Weights.dat' as:
            // CMD: cd c:\cpp\other\skyNet\example\resnet50\
            // CMD: python createNet.py

            string arch = File.ReadAllText(@"c:\cpp\other\skyNet\example\resnet50\resNet50Struct.json", Encoding.UTF8);

            sn.Net snet = new sn.Net(arch, @"c:\cpp\other\skyNet\example\resnet50\resNet50Weights.dat");

            if (snet.getLastErrorStr().Count() > 0)
            {
                Console.WriteLine("Error loadAllWeightFromFile: " + snet.getLastErrorStr());
                Console.ReadKey();
                return;
            }

            string imgPath = @"c:\cpp\other\skyNet\example\resnet50\images\elephant.jpg";

            int classCnt = 1000, w = 224, h = 224;

            sn.Tensor inLayer  = new sn.Tensor(new snLSize((UInt64)w, (UInt64)h, 3, 1));
            sn.Tensor outLayer = new sn.Tensor(new snLSize((UInt64)classCnt, 1, 1, 1));

            // read
            Bitmap img = new Bitmap(Image.FromFile(imgPath), new Size(w, h));

            unsafe
            {
                float *refData = inLayer.data();

                System.Drawing.Imaging.BitmapData bmd = img.LockBits(new Rectangle(0, 0, img.Width, img.Height),
                                                                     System.Drawing.Imaging.ImageLockMode.ReadWrite, img.PixelFormat);

                // B
                IntPtr pt = bmd.Scan0;
                for (int r = 0; r < h; ++r)
                {
                    for (int c = 0; c < w; ++c)
                    {
                        refData[r * w + c] = Marshal.ReadByte(pt + 3);
                        pt += 4;
                    }
                }

                // G
                pt       = bmd.Scan0;
                refData += h * w;
                for (int r = 0; r < h; ++r)
                {
                    for (int c = 0; c < w; ++c)
                    {
                        refData[r * w + c] = Marshal.ReadByte(pt + 2);
                        pt += 4;
                    }
                }

                // R
                pt       = bmd.Scan0;
                refData += h * w;
                for (int r = 0; r < h; ++r)
                {
                    for (int c = 0; c < w; ++c)
                    {
                        refData[r * w + c] = Marshal.ReadByte(pt + 1);
                        pt += 4;
                    }
                }

                img.UnlockBits(bmd);
            }

            // training
            snet.forward(false, inLayer, outLayer);

            float maxval    = 0;
            int   maxOutInx = 0;

            unsafe {
                float *refOutput = outLayer.data();

                maxval = refOutput[0];
                for (int j = 1; j < classCnt; ++j)
                {
                    if (refOutput[j] > maxval)
                    {
                        maxval    = refOutput[j];
                        maxOutInx = j;
                    }
                }
            }

            // for check: c:\cpp\other\skyNet\example\resnet50\imagenet_class_index.json

            Console.WriteLine("inx " + maxOutInx.ToString() + " accurate " + maxval.ToString() + " " + snet.getLastErrorStr());
            Console.ReadKey();
            return;
        }