示例#1
0
        public void Train(int nIteration, TRAIN_STEP step)
        {
            m_mycaffe.Log.Enable = false;

            // Run data/clip groups > 1 in non batch mode.
            if (m_nRecurrentSequenceLength != 1 && m_rgData != null && m_rgData.Count > 1 && m_rgClip != null)
            {
                prepareBlob(m_blobActionOneHot1, m_blobActionOneHot);
                prepareBlob(m_blobDiscountedR1, m_blobDiscountedR);
                prepareBlob(m_blobPolicyGradient1, m_blobPolicyGradient);

                for (int i = 0; i < m_rgData.Count; i++)
                {
                    copyBlob(i, m_blobActionOneHot1, m_blobActionOneHot);
                    copyBlob(i, m_blobDiscountedR1, m_blobDiscountedR);
                    copyBlob(i, m_blobPolicyGradient1, m_blobPolicyGradient);

                    List <Datum> rgData1 = new List <Datum>()
                    {
                        m_rgData[i]
                    };
                    List <Datum> rgClip1 = new List <Datum>()
                    {
                        m_rgClip[i]
                    };

                    m_memData.AddDatumVector(rgData1, rgClip1, 1, true, true);

                    m_solver.Step(1, step, true, false, true, true);
                }

                m_blobActionOneHot.ReshapeLike(m_blobActionOneHot1);
                m_blobDiscountedR.ReshapeLike(m_blobDiscountedR1);
                m_blobPolicyGradient.ReshapeLike(m_blobPolicyGradient1);

                m_rgData = null;
                m_rgClip = null;
            }
            else
            {
                m_solver.Step(1, step, true, false, true, true);
            }

            m_colAccumulatedGradients.Accumulate(m_mycaffe.Cuda, m_net.learnable_parameters, true);

            if (nIteration % m_nMiniBatch == 0 || step == TRAIN_STEP.BACKWARD || step == TRAIN_STEP.BOTH)
            {
                m_net.learnable_parameters.CopyFrom(m_colAccumulatedGradients, true);
                m_colAccumulatedGradients.SetDiff(0);
                m_solver.ApplyUpdate(nIteration);
                m_net.ClearParamDiffs();
            }

            m_mycaffe.Log.Enable = true;
        }
示例#2
0
        public void Train(int nIteration)
        {
            m_mycaffe.Log.Enable = false;
            m_solver.Step(1, TRAIN_STEP.NONE, false, false, true);  // accumulate grad over batch

            if (nIteration % m_nMiniBatch == 0)
            {
                m_solver.ApplyUpdate(nIteration);
                m_net.ClearParamDiffs();
            }

            m_mycaffe.Log.Enable = true;
        }
示例#3
0
        public void Train(int nIteration)
        {
            m_mycaffe.Log.Enable = false;
            m_solver.Step(1, TRAIN_STEP.NONE, true, false, true);
            m_colAccumulatedGradients.Accumulate(m_mycaffe.Cuda, m_net.learnable_parameters, true);

            if (nIteration % m_nMiniBatch == 0)
            {
                m_net.learnable_parameters.CopyFrom(m_colAccumulatedGradients, true);
                m_colAccumulatedGradients.SetDiff(0);
                m_solver.ApplyUpdate(nIteration);
                m_net.ClearParamDiffs();
            }

            m_mycaffe.Log.Enable = true;
        }
示例#4
0
        /// <summary>
        /// Train the model at the current iteration.
        /// </summary>
        /// <param name="nIteration">Specifies the current iteration.</param>
        /// <param name="rgSamples">Contains the samples to train the model with along with the priorities associated with the samples.</param>
        /// <param name="nActionCount">Specifies the number of actions in the action set.</param>
        public void Train(int nIteration, MemoryCollection rgSamples, int nActionCount)
        {
            m_rgSamples = rgSamples;

            if (m_nActionCount != nActionCount)
            {
                throw new Exception("The logit output of '" + m_nActionCount.ToString() + "' does not match the action count of '" + nActionCount.ToString() + "'!");
            }

            // Get next_q_values
            m_mycaffe.Log.Enable = false;
            setNextStateData(m_netTarget, rgSamples);
            m_netTarget.ForwardFromTo(0, m_netTarget.layers.Count - 2);

            setCurrentStateData(m_netOutput, rgSamples);
            m_memLoss.OnGetLoss += m_memLoss_ComputeTdLoss;

            if (m_nMiniBatch == 1)
            {
                m_solver.Step(1);
            }
            else
            {
                m_solver.Step(1, TRAIN_STEP.NONE, true, m_bUseAcceleratedTraining, true, true);
                m_colAccumulatedGradients.Accumulate(m_mycaffe.Cuda, m_netOutput.learnable_parameters, true);

                if (nIteration % m_nMiniBatch == 0)
                {
                    m_netOutput.learnable_parameters.CopyFrom(m_colAccumulatedGradients, true);
                    m_colAccumulatedGradients.SetDiff(0);
                    m_dfLearningRate = m_solver.ApplyUpdate(nIteration);
                    m_netOutput.ClearParamDiffs();
                }
            }

            m_memLoss.OnGetLoss -= m_memLoss_ComputeTdLoss;
            m_mycaffe.Log.Enable = true;

            resetNoise(m_netOutput);
            resetNoise(m_netTarget);
        }
示例#5
0
        //-----------------------------------------------------------------------------------------
        //  Simple Classification (using direct net surgery)
        //-----------------------------------------------------------------------------------------

        /// <summary>
        /// The SimpleClassification sample is designed to show how to manually train the MNIST dataset using raw image data stored
        /// in the \ProgramData\MyCaffe\test_data\images\mnist\training directory (previously loaded with the 'Export Images'
        /// sample above).
        /// </summary>
        /// <remarks>
        /// IMPORTANT: This sample is for demonstration, using the Simplest Classification method is the fastest recommended method that uses the Image Database.
        ///
        /// This sample requires that you have already loaded the MNIST dataset into SQL (or SQLEXPRESS) using the MyCaffe
        /// Test Application by selecting its 'Database | Load MNIST...' menu item.
        /// </remarks>
        /// <param name="sender">Specifies the event sender.</param>
        /// <param name="e">Specifies the event argument.</param>
        private void btnSimpleClassification_Click(object sender, EventArgs e)
        {
            Stopwatch     sw         = new Stopwatch();
            int           nBatchSize = 32;
            SettingsCaffe settings   = new SettingsCaffe();

            settings.GpuIds = "0";

            if (!Directory.Exists(m_strImageDirTesting) || !Directory.Exists(m_strImageDirTraining))
            {
                MessageBox.Show("You must first export the MNIST images by pressing the Export button!", "Export Needed", MessageBoxButtons.OK, MessageBoxIcon.Error);
                return;
            }

            m_rgstrTrainingFiles = Directory.GetFiles(m_strImageDirTraining);
            m_rgstrTrainingFiles = m_rgstrTrainingFiles.Where(p => p.Contains(".png")).ToArray();
            m_rgstrTestingFiles  = Directory.GetFiles(m_strImageDirTesting);
            m_rgstrTestingFiles  = m_rgstrTestingFiles.Where(p => p.Contains(".png")).ToArray();

            string strSolver;
            string strModel;

            load_descriptors("mnist", out strSolver, out strModel); // Load the descriptors from their respective files (installed by MyCaffe Test Application install)
            strModel = fixup_model(strModel, nBatchSize);

            MyCaffeControl <float> mycaffe = new MyCaffeControl <float>(settings, m_log, m_evtCancel);

            mycaffe.Load(Phase.TRAIN,                                  // using the training phase.
                         strSolver,                                    // solver descriptor, that specifies to use the SGD solver.
                         strModel,                                     // simple LENET model descriptor.
                         null, null, null, false, null, false, false); // no weights are loaded and no image database is used.

            // Perform your own training
            Solver <float> solver    = mycaffe.GetInternalSolver();
            Net <float>    net       = mycaffe.GetInternalNet(Phase.TRAIN);
            Blob <float>   dataBlob  = net.blob_by_name("data");
            Blob <float>   labelBlob = net.blob_by_name("label");

            sw.Start();

            int nIterations = 5000;

            for (int i = 0; i < nIterations; i++)
            {
                // Load the data into the data and label blobs.
                loadData(m_rgstrTrainingFiles, nBatchSize, dataBlob, labelBlob);

                // Run the forward and backward passes.
                double dfLoss;
                net.Forward(out dfLoss);
                net.ClearParamDiffs();
                net.Backward();

                // Apply the gradients calculated during Backward.
                solver.ApplyUpdate(i);

                // Output the loss.
                if (sw.Elapsed.TotalMilliseconds > 1000)
                {
                    m_log.Progress = (double)i / (double)nIterations;
                    m_log.WriteLine("Loss = " + dfLoss.ToString());
                    sw.Restart();
                }
            }

            // Run testing using the MyCaffe control (who's internal Run net is already updated
            // for it shares its weight memory with the training net.
            net       = mycaffe.GetInternalNet(Phase.TEST);
            dataBlob  = net.blob_by_name("data");
            labelBlob = net.blob_by_name("label");

            float fTotalAccuracy = 0;

            nIterations = 100;
            for (int i = 0; i < nIterations; i++)
            {
                // Load the data into the data and label blobs.
                loadData(m_rgstrTestingFiles, nBatchSize, dataBlob, labelBlob);

                // Run the forward pass.
                double dfLoss;
                BlobCollection <float> res = net.Forward(out dfLoss);
                fTotalAccuracy += res[0].GetData(0);

                // Output the training progress.
                if (sw.Elapsed.TotalMilliseconds > 1000)
                {
                    m_log.Progress = (double)i / (double)nIterations;
                    m_log.WriteLine("training...");
                    sw.Restart();
                }
            }

            double dfAccuracy = (double)fTotalAccuracy / (double)nIterations;

            m_log.WriteLine("Accuracy = " + dfAccuracy.ToString("P"));

            MessageBox.Show("Average Accuracy = " + dfAccuracy.ToString("P"), "Traing/Test on MNIST Completed", MessageBoxButtons.OK, MessageBoxIcon.Information);
        }
示例#6
0
        //-----------------------------------------------------------------------------------------
        //  Simple Classification (using direct net surgery)
        //-----------------------------------------------------------------------------------------

        /// <summary>
        /// The SimpleClassification sample is designed to show how to manually train the MNIST dataset using raw image data stored
        /// in the \ProgramData\MyCaffe\test_data\images\mnist\training directory (previously loaded with the 'Export Images'
        /// sample above).
        /// </summary>
        /// <remarks>
        /// This sample requires that you have already loaded the MNIST dataset into SQL (or SQLEXPRESS) using the MyCaffe
        /// Test Application by selecting its 'Database | Load MNIST...' menu item.
        /// </remarks>
        /// <param name="sender">Specifies the event sender.</param>
        /// <param name="e">Specifies the event argument.</param>
        private void btnSimpleClassification_Click(object sender, EventArgs e)
        {
            Stopwatch     sw         = new Stopwatch();
            int           nBatchSize = 32;
            SettingsCaffe settings   = new SettingsCaffe();

            settings.GpuIds = "0";

            if (!Directory.Exists(m_strImageDirTraining) || !Directory.Exists(m_strImageDirTesting))
            {
                string strMsg = "You must first expand the MNIST dataset into the following directories:" + Environment.NewLine;
                strMsg += "Training Images: '" + m_strImageDirTraining + "'" + Environment.NewLine;
                strMsg += "Testing Images: '" + m_strImageDirTesting + "'" + Environment.NewLine + Environment.NewLine;

                strMsg += "If you have Microsoft SQL or SQL Express installed, selecting the 'Export' button from the 'ImageClassification' project will export these images for you." + Environment.NewLine + Environment.NewLine;

                strMsg += "If you DO NOT have Microsoft SQL or SQL Express, running the MyCaffe Test Application and selecting the 'Database | Load MNIST...' menu item with the 'Export to file only' check box checked, will export the images for you without SQL." + Environment.NewLine + Environment.NewLine;

                strMsg += "To get the MNIST *.gz data files, please see http://yann.lecun.com/exdb/mnist/";

                MessageBox.Show(strMsg, "Images Not Found", MessageBoxButtons.OK, MessageBoxIcon.Error);
                return;
            }

            m_rgstrTrainingFiles = Directory.GetFiles(m_strImageDirTraining);
            m_rgstrTrainingFiles = m_rgstrTrainingFiles.Where(p => p.Contains(".png")).ToArray();
            m_rgstrTestingFiles  = Directory.GetFiles(m_strImageDirTesting);
            m_rgstrTestingFiles  = m_rgstrTestingFiles.Where(p => p.Contains(".png")).ToArray();

            string strSolver;
            string strModel;

            load_descriptors("mnist", out strSolver, out strModel); // Load the descriptors from their respective files (installed by MyCaffe Test Application install)
            strModel = fixup_model(strModel, nBatchSize);

            MyCaffeControl <float> mycaffe = new MyCaffeControl <float>(settings, m_log, m_evtCancel);

            mycaffe.LoadLite(Phase.TRAIN, // using the training phase.
                             strSolver,   // solver descriptor, that specifies to use the SGD solver.
                             strModel,    // simple LENET model descriptor.
                             null);       // no weights are loaded.

            // Perform your own training
            Solver <float> solver    = mycaffe.GetInternalSolver();
            Net <float>    net       = mycaffe.GetInternalNet(Phase.TRAIN);
            Blob <float>   dataBlob  = net.blob_by_name("data");
            Blob <float>   labelBlob = net.blob_by_name("label");

            sw.Start();

            int nIterations = 5000;

            for (int i = 0; i < nIterations; i++)
            {
                // Load the data into the data and label blobs.
                loadData(m_rgstrTrainingFiles, nBatchSize, dataBlob, labelBlob);

                // Run the forward and backward passes.
                double dfLoss;
                net.Forward(out dfLoss);
                net.ClearParamDiffs();
                net.Backward();

                // Apply the gradients calculated during Backward.
                solver.ApplyUpdate(i);

                // Output the loss.
                if (sw.Elapsed.TotalMilliseconds > 1000)
                {
                    m_log.Progress = (double)i / (double)nIterations;
                    m_log.WriteLine("Loss = " + dfLoss.ToString());
                    sw.Restart();
                }
            }

            // Run testing using the MyCaffe control (who's internal Run net is already updated
            // for it shares its weight memory with the training net.
            net       = mycaffe.GetInternalNet(Phase.TEST);
            dataBlob  = net.blob_by_name("data");
            labelBlob = net.blob_by_name("label");

            float fTotalAccuracy = 0;

            nIterations = 100;
            for (int i = 0; i < nIterations; i++)
            {
                // Load the data into the data and label blobs.
                loadData(m_rgstrTestingFiles, nBatchSize, dataBlob, labelBlob);

                // Run the forward pass.
                double dfLoss;
                BlobCollection <float> res = net.Forward(out dfLoss);
                fTotalAccuracy += res[0].GetData(0);

                // Output the training progress.
                if (sw.Elapsed.TotalMilliseconds > 1000)
                {
                    m_log.Progress = (double)i / (double)nIterations;
                    m_log.WriteLine("training...");
                    sw.Restart();
                }
            }

            double dfAccuracy = (double)fTotalAccuracy / (double)nIterations;

            m_log.WriteLine("Accuracy = " + dfAccuracy.ToString("P"));

            MessageBox.Show("Average Accuracy = " + dfAccuracy.ToString("P"), "Traing/Test on MNIST Completed", MessageBoxButtons.OK, MessageBoxIcon.Information);

            // Save the trained weights for use later.
            saveWeights(mycaffe, "my_weights");

            Bitmap           bmp     = new Bitmap(m_rgstrTestingFiles[0]);
            ResultCollection results = mycaffe.Run(bmp);

            MyCaffeControl <float> mycaffe2 = mycaffe.Clone(0);
            ResultCollection       results2 = mycaffe2.Run(bmp);

            // Release resources used.
            mycaffe.Dispose();
            mycaffe2.Dispose();
        }