Ejemplo n.º 1
0
        private static FloatPair get_rnn_data(byte[] text, int[] offsets, int characters, int len, int batch, int steps)
        {
            float[] x = new float[batch * steps * characters];
            float[] y = new float[batch * steps * characters];
            int     i, j;

            for (i = 0; i < batch; ++i)
            {
                for (j = 0; j < steps; ++j)
                {
                    byte curr = text[(offsets[i]) % len];
                    byte next = text[(offsets[i] + 1) % len];

                    x[(j * batch + i) * characters + curr] = 1;
                    y[(j * batch + i) * characters + next] = 1;

                    offsets[i] = (offsets[i] + 1) % len;

                    if (curr > 255 || curr <= 0 || next > 255 || next <= 0)
                    {
                        Utils.Error("Bad char");
                    }
                }
            }
            FloatPair p = new FloatPair();

            p.X = x;
            p.Y = y;
            return(p);
        }
Ejemplo n.º 2
0
        private static FloatPair get_rnn_vid_data(Network net, string[] files, int n, int batch, int steps)
        {
            int   b;
            Image outIm      = Network.get_network_image(net);
            int   outputSize = outIm.W * outIm.H * outIm.C;

            Console.Write($"%d %d %d\n", outIm.W, outIm.H, outIm.C);
            float[] feats = new float[net.Batch * batch * outputSize];
            for (b = 0; b < batch; ++b)
            {
                int     inputSize = net.W * net.H * net.C;
                float[] input     = new float[inputSize * net.Batch];
                string  filename  = files[Utils.Rand.Next() % n];
                using (VideoCapture cap = new VideoCapture(filename))
                {
                    int frames = (int)cap.GetCaptureProperty(CapProp.FrameCount);
                    int index  = Utils.Rand.Next() % (frames - steps - 2);
                    if (frames < (steps + 4))
                    {
                        --b;
                        continue;
                    }

                    Console.Write($"frames: %d, index: %d\n", frames, index);
                    cap.SetCaptureProperty(CapProp.PosFrames, index);

                    int i;
                    for (i = 0; i < net.Batch; ++i)
                    {
                        using (Mat src = cap.QueryFrame())
                        {
                            Image im = new Image(src);

                            LoadArgs.rgbgr_image(im);
                            Image re = LoadArgs.resize_image(im, net.W, net.H);
                            Array.Copy(re.Data, 0, input, i * inputSize, inputSize);
                        }
                    }

                    float[] output = Network.network_predict(net, input);

                    for (i = 0; i < net.Batch; ++i)
                    {
                        Array.Copy(output, i * outputSize, feats, (b + i * batch) * outputSize, outputSize);
                    }
                }
            }

            FloatPair p = new FloatPair();

            p.X = feats;
            p.Y = new float[feats.Length - outputSize * batch];
            Array.Copy(feats, outputSize * batch, p.Y, 0, p.Y.Length);

            return(p);
        }
Ejemplo n.º 3
0
        private static void train_vid_rnn(string cfgfile, string weightfile)
        {
            string trainVideos     = "Data.Data/vid/train.txt";
            string backupDirectory = "/home/pjreddie/backup/";

            string basec = Utils.Basecfg(cfgfile);

            Console.Write($"%s\n", basec);
            float   avgLoss = -1;
            Network net     = Parser.parse_network_cfg(cfgfile);

            if (string.IsNullOrEmpty(weightfile))
            {
                Parser.load_weights(net, weightfile);
            }
            Console.Write($"Learning Rate: %g, Momentum: %g, Decay: %g\n", net.LearningRate, net.Momentum, net.Decay);
            int imgs = net.Batch * net.Subdivisions;
            int i    = net.Seen / imgs;

            string[] paths = Data.Data.GetPaths(trainVideos);
            int      n     = paths.Length;
            var      sw    = new Stopwatch();
            int      steps = net.TimeSteps;
            int      batch = net.Batch / net.TimeSteps;

            Network extractor = Parser.parse_network_cfg("cfg/extractor.cfg");

            Parser.load_weights(extractor, "/home/pjreddie/trained/yolo-coco.conv");

            while (Network.get_current_batch(net) < net.MaxBatches)
            {
                i += 1;
                sw.Reset();
                sw.Start();
                FloatPair p = get_rnn_vid_data(extractor, paths, n, batch, steps);

                float loss = Network.train_network_datum(net, p.X, p.Y) / (net.Batch);

                if (avgLoss < 0)
                {
                    avgLoss = loss;
                }
                avgLoss = avgLoss * .9f + loss * .1f;

                sw.Stop();
                Console.Error.Write($"%d: %f, %f avg, %f rate, %lf seconds\n", i, loss, avgLoss, Network.get_current_rate(net), sw.Elapsed.Seconds);
                if (i % 100 == 0)
                {
                    string buff = $"{backupDirectory}/{basec}_{i}.Weights";
                    Parser.save_weights(net, buff);
                }
                if (i % 10 == 0)
                {
                    string buff = $"{backupDirectory}/{basec}.backup";
                    Parser.save_weights(net, buff);
                }
            }

            string buff2 = $"{backupDirectory}/{basec}_final.Weights";

            Parser.save_weights(net, buff2);
        }