Exemplo n.º 1
0
        /// <summary>
        /// Called on each testing iteration of the sequence model.
        /// </summary>
        /// <param name="sender">Specifies the event sender.</param>
        /// <param name="e">Specifies the event args.</param>
        private void m_mycaffe_OnTestingIteration(object sender, TestingIterationArgs <float> e)
        {
            float fAccuracy = m_rgAccuracyTraining.Average();

            m_plotsSequenceAccuracyTrain.Add(m_nTotalSequences, fAccuracy * 100);
            if (m_plotsSequenceAccuracyTrain.Count > 100)
            {
                m_plotsSequenceAccuracyTrain.RemoveAt(0);
            }

            m_rgAccuracyTesting.Add((float)e.Accuracy);
            m_rgAccuracyTesting.RemoveAt(0);
            fAccuracy = m_rgAccuracyTesting.Average();

            m_plotsSequenceAccuracyTest.Add(m_nTotalSequences, fAccuracy * 100);
            if (m_plotsSequenceAccuracyTest.Count > 100)
            {
                m_plotsSequenceAccuracyTest.RemoveAt(0);
            }

            PlotCollectionSet set = new PlotCollectionSet();

            set.Add(m_plotsSequenceAccuracyTrain);
            set.Add(m_plotsSequenceAccuracyTest);

            Image img = SimpleGraphingControl.QuickRender(set, pbImageAccuracy.Width, pbImageAccuracy.Height, false, null, null, false, m_rgZeroLine);

            m_bw.ReportProgress(1, new Tuple <Image, int>(img, 1));
        }
Exemplo n.º 2
0
        private void FormPlotCollection2_Load(object sender, EventArgs e)
        {
            if (DesignMode)
            {
                return;
            }

            simpleGraphingControl1      = new SimpleGraphingControl();
            simpleGraphingControl1.Name = "SimpleGraphing";
            this.Controls.Add(simpleGraphingControl1);
            simpleGraphingControl1.Dock = DockStyle.Fill;

            simpleGraphingControl1.Configuration = new Configuration();
            simpleGraphingControl1.Configuration.Frames.Add(new ConfigurationFrame());
            simpleGraphingControl1.EnableCrossHairs = true;
            simpleGraphingControl1.Configuration.Frames[0].XAxis.LabelFont = new Font("Century Gothic", 7.0f);
            simpleGraphingControl1.Configuration.Frames[0].XAxis.Visible   = true;
            simpleGraphingControl1.Configuration.Frames[0].XAxis.Margin    = 100;
            simpleGraphingControl1.Configuration.Frames[0].YAxis.LabelFont = new Font("Century Gothic", 7.0f);
            simpleGraphingControl1.Configuration.Frames[0].YAxis.Decimals  = 3;

            for (int i = 0; i < m_set.Count; i++)
            {
                ConfigurationPlot plotConfig = new ConfigurationPlot();
                plotConfig.DataIndexOnRender = i;

                simpleGraphingControl1.Configuration.Frames[0].Plots.Add(plotConfig);

                if (m_set[0].Count > 0 && m_set[0][0].Y_values.Length == 4)
                {
                    simpleGraphingControl1.Configuration.Frames[0].Plots[0].PlotType = ConfigurationPlot.PLOTTYPE.CANDLE;
                    simpleGraphingControl1.Configuration.Frames[0].XAxis.ValueType   = ConfigurationAxis.VALUE_TYPE.TIME;
                }
                else
                {
                    simpleGraphingControl1.Configuration.Frames[0].Plots[0].PlotType = ConfigurationPlot.PLOTTYPE.LINE;
                    simpleGraphingControl1.Configuration.Frames[0].XAxis.ValueType   = ConfigurationAxis.VALUE_TYPE.NUMBER;
                }

                if (m_set[0].Parameters.ContainsKey("ValueType"))
                {
                    simpleGraphingControl1.Configuration.Frames[0].XAxis.ValueType = (ConfigurationAxis.VALUE_TYPE)m_set[0].Parameters["ValueType"];
                }
            }

            simpleGraphingControl1.Configuration.Frames[0].EnableRelativeScaling(true, true);

            List <PlotCollectionSet> rgSet = new List <PlotCollectionSet>()
            {
                m_set
            };

            simpleGraphingControl1.BuildGraph(rgSet);
            simpleGraphingControl1.Invalidate();
            simpleGraphingControl1.ScrollToEnd(true);
        }
Exemplo n.º 3
0
        /// <summary>
        /// Called on each training iteration of the input model used to detect each hand written character.
        /// </summary>
        /// <param name="sender">Specifies the event sender.</param>
        /// <param name="e">Specifies the event args.</param>
        private void m_mycaffeInput_OnTrainingIteration(object sender, TrainingIterationArgs <float> e)
        {
            if (m_sw.Elapsed.TotalMilliseconds > 1000)
            {
                m_log.Progress = e.Iteration / (double)m_model.Iterations;
                m_log.WriteLine("MNIST Iteration " + e.Iteration.ToString() + " of " + m_model.Iterations.ToString() + ", loss = " + e.SmoothedLoss.ToString());
                m_sw.Restart();

                m_plotsInputLoss.Add(e.Iteration, e.SmoothedLoss);
                Image img = SimpleGraphingControl.QuickRender(m_plotsInputLoss, pbImage.Width, pbImage.Height, false, null, null, true, m_rgZeroLine);
                m_bw.ReportProgress(1, img);
            }
        }
Exemplo n.º 4
0
        /// <summary>
        /// Train the model.
        /// </summary>
        /// <param name="bNewWts">Specifies whether to use new weights or load existing ones (if they exist).</param>
        public void Train(bool bNewWts)
        {
            if (m_mycaffeTrain == null)
            {
                return;
            }

            byte[] rgWts = null;

            if (!bNewWts)
            {
                rgWts = loadWeights();
            }

            if (rgWts == null)
            {
                Console.WriteLine("Starting with new weights...");
            }

            SolverParameter solver = createSolver();
            NetParameter    model  = createModel();

            string strModel = model.ToProto("root").ToString();

            Console.WriteLine("Using Train Model:");
            Console.WriteLine(strModel);
            Console.WriteLine("Starting training...");

            m_mycaffeTrain.LoadLite(Phase.TRAIN, solver.ToProto("root").ToString(), model.ToProto("root").ToString(), rgWts, false, false);
            m_mycaffeTrain.SetOnTrainingStartOverride(new EventHandler(onTrainingStart));
            m_mycaffeTrain.SetOnTestingStartOverride(new EventHandler(onTestingStart));

            // Set clockwork weights.
            if (m_param.LstmEngine != EngineParameter.Engine.CUDNN)
            {
                Net <float>  net   = m_mycaffeTrain.GetInternalNet(Phase.TRAIN);
                Blob <float> lstm1 = net.parameters[2];
                lstm1.SetData(1, m_param.Hidden, m_param.Hidden);
            }

            m_mycaffeTrain.Train(m_param.Iterations);
            saveLstmState(m_mycaffeTrain);

            Image img = SimpleGraphingControl.QuickRender(m_plots, 1000, 600);

            showImage(img, "training.png");
            saveWeights(m_mycaffeTrain.GetWeights());
        }
Exemplo n.º 5
0
        /// <summary>
        /// Called on each training iteration of the sequence model used to encode each detected hand written character
        /// and then decode the encoding into the proper section of the Sin curve.
        /// </summary>
        /// <param name="sender">Specifies the event sender.</param>
        /// <param name="e">Specifies the event args.</param>
        private void m_mycaffe_OnTrainingIteration(object sender, TrainingIterationArgs <float> e)
        {
            if (m_sw.Elapsed.TotalMilliseconds > 1000)
            {
                m_log.Progress = e.Iteration / (double)m_model.Iterations;
                m_log.WriteLine("Seq2Seq Epoch " + m_nTotalEpochs.ToString() + " Sequence " + m_nTotalSequences.ToString() + " Iteration " + e.Iteration.ToString() + " of " + m_model.Iterations.ToString() + ", loss = " + e.SmoothedLoss.ToString(), true);
                m_sw.Restart();

                m_fTotalCost += (float)e.SmoothedLoss;
                m_nTotalIter1++;
                float fLoss = m_fTotalCost / m_nTotalIter1;

                m_plotsSequenceLoss.Add(m_nTotalSequences, fLoss);
                if (m_plotsSequenceLoss.Count > 2000)
                {
                    m_plotsSequenceLoss.RemoveAt(0);
                }

                Image img = SimpleGraphingControl.QuickRender(m_plotsSequenceLoss, pbImageLoss.Width, pbImageLoss.Height, false, null, null, true, m_rgZeroLine);
                m_bw.ReportProgress(1, new Tuple <Image, int>(img, 0));
            }
        }
Exemplo n.º 6
0
        /// <summary>
        /// Run the trained model on the generated Sin curve.
        /// </summary>
        /// <returns>Returns <i>false</i> if no trained model found.</returns>
        public bool Run()
        {
            // Load the run net with the previous weights.
            byte[] rgWts = loadWeights();
            if (rgWts == null)
            {
                Console.WriteLine("You must first train the network!");
                return(false);
            }

            // Crate the model used to run indefinitely
            NetParameter model = createModelInfiniteInput();

            string strModel = model.ToProto("root").ToString();

            Console.WriteLine("Using Run Model:");
            Console.WriteLine(strModel);

            // Load the model for running with the trained weights.
            int nN = 1;

            m_mycaffeRun.LoadToRun(strModel, rgWts, new BlobShape(new List <int>()
            {
                nN, 1, 1
            }), null, null, false, false);

            // Load the previously saved LSTM state (hy and cy) along with the previously
            // trained weights.
            loadLstmState(m_mycaffeRun);

            // Get the internal RUN net and associated blobs.
            Net <float>  net      = m_mycaffeRun.GetInternalNet(Phase.RUN);
            Blob <float> blobData = net.FindBlob("data");
            Blob <float> blobClip = net.FindBlob("clip2");
            Blob <float> blobIp1  = net.FindBlob("ip1");

            int nBatch = 1;

            // Run on 3 different, randomly selected Sin curves.
            for (int i = 0; i < 3; i++)
            {
                // Create the Sin data.
                Dictionary <string, float[]> data = generateSample(i + 1.1337f, null, nBatch, m_param.Output, m_param.TimeSteps);
                List <float> rgPrediction         = new List <float>();

                // Set the clip to 1 for we are continuing from the
                // last training session and want start with the last
                // cy and hy states.
                blobClip.SetData(1);
                float[] rgY  = data["Y"];
                float[] rgFY = data["FY"];

                // Run the model on the data up to number of
                // time steps.
                for (int t = 0; t < m_param.TimeSteps; t++)
                {
                    blobData.SetData(rgY[t]);
                    net.Forward();
                    rgPrediction.Add(blobIp1.GetData(0));
                }

                // Run the model on the last prediction for
                // the number of predicted output steps.
                for (int t = 0; t < m_param.Output; t++)
                {
                    blobData.SetData(rgPrediction[rgPrediction.Count - 1]);
                    //blobData.SetData(rgFY[t]);
                    net.Forward();
                    rgPrediction.Add(blobIp1.GetData(0));
                }

                // Graph and show the resupts.
                List <float> rgT2 = new List <float>(data["T"]);
                rgT2.AddRange(data["FT"]);

                // Plot the graph.
                PlotCollection plotsY = createPlots("Y", rgT2.ToArray(), new List <float[]>()
                {
                    data["Y"], data["FY"]
                }, 0);
                PlotCollection plotsTarget = createPlots("Target", rgT2.ToArray(), new List <float[]>()
                {
                    data["Y"], data["FY"]
                }, 1);
                PlotCollection plotsPrediction = createPlots("Predicted", rgT2.ToArray(), new List <float[]>()
                {
                    rgPrediction.ToArray()
                }, 0);
                PlotCollectionSet set = new PlotCollectionSet(new List <PlotCollection>()
                {
                    plotsY, plotsTarget, plotsPrediction
                });

                // Create the graph image and display
                Image img = SimpleGraphingControl.QuickRender(set, 2000, 600);
                showImage(img, "result_" + i.ToString() + ".png");
            }

            return(true);
        }
Exemplo n.º 7
0
        /// <summary>
        /// Run the trained model.  When run each hand-written image is fed in sequence (by label,
        /// e.g. 0,1,2,...,9 through the model, yet images within each label are selected at random.
        /// </summary>
        /// <param name="mycaffe">Specifies the mycaffe instance running the sequence run model.</param>
        /// <param name="bw">Specifies the background worker.</param>
        private void runModel(MyCaffeControl <float> mycaffe, BackgroundWorker bw)
        {
            Random random = new Random((int)DateTime.Now.Ticks);
            // Get the internal RUN net and associated blobs.
            Net <float>  net          = m_mycaffe.GetInternalNet(Phase.RUN);
            Blob <float> blobData     = net.FindBlob("data");
            Blob <float> blobClip1    = net.FindBlob("clip1");
            Blob <float> blobIp1      = net.FindBlob("ip1");
            List <float> rgPrediction = new List <float>();
            List <float> rgTarget     = new List <float>();
            List <float> rgT          = new List <float>();

            m_mycaffeInput.UpdateRunWeights();
            blobClip1.SetData(0);

            bool bForcedError = false;

            for (int i = 0; i < 100; i++)
            {
                if (m_evtCancel.WaitOne(0))
                {
                    return;
                }

                int nLabelSeq = m_nLabelSeq;
                if (m_evtForceError.WaitOne(0))
                {
                    nLabelSeq    = random.Next(10);
                    bForcedError = true;
                }
                else
                {
                    bForcedError = false;
                }

                // Get images one number at a time, in order by label, but randomly selected.
                SimpleDatum      sd  = m_imgDb.QueryImage(m_ds.TrainingSource.ID, 0, null, IMGDB_IMAGE_SELECTION_METHOD.RANDOM, nLabelSeq);
                ResultCollection res = m_mycaffeInput.Run(sd);

                Net <float>  inputNet             = m_mycaffeInput.GetInternalNet(Phase.RUN);
                Blob <float> input_ip             = inputNet.FindBlob(m_strInputOutputBlobName);
                Dictionary <string, float[]> data = Signal.GenerateSample(1, m_nLabelSeq / 10.0f, 1, m_model.InputLabel, m_model.TimeSteps);

                float[] rgFY1 = data["FY"];

                // Run the model.
                blobClip1.SetData(1);
                blobData.mutable_cpu_data = input_ip.mutable_cpu_data;
                net.Forward();
                rgPrediction.AddRange(blobIp1.mutable_cpu_data);

                // Graph and show the results.
                float[] rgFT = data["FT"];
                float[] rgFY = data["FY"];
                for (int j = 0; j < rgFT.Length; j++)
                {
                    rgT.Add(rgFT[j]);
                    rgTarget.Add(rgFY[j]);
                }

                while (rgTarget.Count * 5 > pbImage.Width)
                {
                    rgTarget.RemoveAt(0);
                    rgPrediction.RemoveAt(0);
                }

                // Plot the graph.
                PlotCollection plotsTarget = createPlots("Target", rgT.ToArray(), new List <float[]>()
                {
                    rgTarget.ToArray()
                }, 0);
                PlotCollection plotsPrediction = createPlots("Predicted", rgT.ToArray(), new List <float[]>()
                {
                    rgPrediction.ToArray()
                }, 0);
                PlotCollection    plotsAvePrediction = createPlotsAve("Predicted SMA", plotsPrediction, 10);
                PlotCollectionSet set = new PlotCollectionSet(new List <PlotCollection>()
                {
                    plotsTarget, plotsPrediction, plotsAvePrediction
                });

                // Create the graph image and display
                Image img = SimpleGraphingControl.QuickRender(set, pbImage.Width, pbImage.Height);
                img = drawInput(img, sd, res.DetectedLabel, bForcedError);

                bw.ReportProgress(0, img);
                Thread.Sleep(1000);

                m_nLabelSeq++;
                if (m_nLabelSeq == 10)
                {
                    m_nLabelSeq = 0;
                }
            }
        }