Exemplo n.º 1
0
        static bool CatTrain()
        {
            BitmapCatEnumerator enums   = new BitmapCatEnumerator("Sorted", new System.Drawing.Size(50, 25));
            Network             network = new Network();

            network.AddLayer(new Conv2D(new Relu(), 7, 7, 32));
            network.AddLayer(new MaxPool2D(new Relu(), 2, 2));

            network.AddLayer(new Conv2D(new Relu(), 5, 5, 64));
            network.AddLayer(new MaxPool2D(new Relu(), 2, 2));


            network.AddLayer(new FullyConnLayar(new Relu(), new Size(1, 1, 256)));
            network.AddLayer(new FullyConnLayar(new Sigmoid(), new Size(1, 1, 2)));

            network.Compile(new Size(3, 25, 50), true);

            network.Normalization();

            var pair = enums.GetRandom(ref network);

            OneEnumerator one = new OneEnumerator();

            one.input  = pair.Key;
            one.output = pair.Value;

            MomentumParallel sgd = new MomentumParallel(network, 0.9, 1e-6);

            double[] errors = sgd.TrainBatch(enums, 32, 1000);

            return(errors[0] > errors.Last());
        }
Exemplo n.º 2
0
        static bool TestFullConnect()
        {
            ArrDataEnumerator datas = new ArrDataEnumerator();

            //and test
            double[,,] temp  = new double[1, 1, 2];
            double[,,] otemp = new double[1, 1, 1];
            datas.AddSample(temp, otemp);

            temp  = new double[1, 1, 2];
            otemp = new double[1, 1, 1];

            temp[0, 0, 0] = 1;
            datas.AddSample(temp, otemp);

            temp  = new double[1, 1, 2];
            otemp = new double[1, 1, 1];

            temp[0, 0, 1] = 1;
            datas.AddSample(temp, otemp);

            temp  = new double[1, 1, 2];
            otemp = new double[1, 1, 1];

            temp[0, 0, 0]  = 1;
            temp[0, 0, 1]  = 1;
            otemp[0, 0, 0] = 1;
            datas.AddSample(temp, otemp);

            Network network = new Network();

            network.AddLayer(new TSConv2D(new Sigmoid(), 2, 1, 1, 10, 1));
            network.AddLayer(new FullyConnLayar(new Sigmoid(), new Size(1, 1, 1)));
            //network.AddLayer(new FullyConnLayar(new Sigmoid(), new Size(1, 1, 20)));
            //network.AddLayer(new FullyConnLayar(new Sigmoid(), new Size(1, 1, 1)));

            network.Compile(new Size(1, 1, 2), true);

            network.Normalization();
            //network.Normalization();

            MomentumParallel mom = new MomentumParallel(network, 0.9, 1e-2);

            mom.TrainBatch(datas, 320, 3000);

            return(true);
        }
Exemplo n.º 3
0
        static bool PreTrain()
        {
            BitmapCatEnumerator enums   = new BitmapCatEnumerator("Sorted", new System.Drawing.Size(24, 12));
            BitmapCatEnumerator val     = new BitmapCatEnumerator("Val", new System.Drawing.Size(24, 12));
            Network             network = new Network();

            //network.LoadJSON(System.IO.File.ReadAllText("pretrained_2.neural"));
            //network.CompileOnlyError();
            network.AddLayer(new Conv2D(new Relu(), 3, 3, 10));
            network.AddLayer(new MaxPool2D(new Relu(), 2, 2));

            network.AddLayer(new Conv2D(new Relu(), 5, 5, 30));
            network.AddLayer(new MaxPool2D(new Relu(), 2, 2));

            network.AddLayer(new FullyConnLayar(new Relu(), new Size(1, 1, 256)));
            network.AddLayer(new FullyConnLayar(new Sigmoid(), new Size(1, 1, 2)));

            network.Compile(new Size(3, 12, 24), true);

            network.Normalization();
            network.Normalization();

            MomentumParallel sgd = new MomentumParallel(network, 0.9, 1e-4);

            sgd.need_max = false;

            var pair = PretrainAutoEncoder.Action(network, sgd, enums, val, 2000, 32);

            Console.WriteLine("{0}\n{1}", pair.Key, pair.Value);
            Console.WriteLine("Start train");

            sgd = new MomentumParallel(network, 0, 1e-5);
            DateTime start = DateTime.Now;

            sgd.TrainBatch(enums, 256, 1);
            for (int i = 0; i < 100000; i++)
            {
                double[] errors = sgd.TrainBatchContinue(enums, 256, 1);
                if ((DateTime.Now - start).TotalMinutes > 5)
                {
                    System.IO.File.WriteAllText("train_" + i + ".neural", network.SaveJSON());
                    start = DateTime.Now;
                    Console.WriteLine("Saved at " + "train_" + i + ".neural");
                }
            }
            return(true);
        }