예제 #1
0
        public void TestCreateTrainingModel()
        {
            ModelBuilder builder = create();

            NetParameter net_param = builder.CreateModel();
            RawProto     proto     = net_param.ToProto("root");
            string       strNet    = proto.ToString();

            RawProto     proto2     = RawProto.Parse(strNet);
            NetParameter net_param2 = NetParameter.FromProto(proto2);

            m_log.CHECK(net_param2.Compare(net_param), "The two net parameters should be the same!");

            // verify creating the model.
            SolverParameter solver      = builder.CreateSolver();
            RawProto        protoSolver = solver.ToProto("root");
            string          strSolver   = protoSolver.ToString();

            SettingsCaffe      settings  = new SettingsCaffe();
            CancelEvent        evtCancel = new CancelEvent();
            MyCaffeControl <T> mycaffe   = new MyCaffeControl <T>(settings, m_log, evtCancel);

            save(strNet, strSolver, false);

            //            mycaffe.LoadLite(Phase.TRAIN, strSolver, strNet, null);
            mycaffe.Dispose();
        }
예제 #2
0
        /// <summary>
        /// Set the solver testing interval.
        /// </summary>
        /// <param name="strSolver">Specifies the solver parameter.</param>
        /// <returns>The solver description is returned.</returns>
        private string fixup_solver(string strSolver, int nInterval)
        {
            RawProto        proto        = RawProto.Parse(strSolver);
            SolverParameter solver_param = SolverParameter.FromProto(proto);

            // Set the testining interval during training.
            solver_param.test_interval       = nInterval;
            solver_param.test_initialization = false;

            return(solver_param.ToProto("root").ToString());
        }
예제 #3
0
        public void TestCreateSolver()
        {
            ModelBuilder builder = create();

            SolverParameter solverParam = builder.CreateSolver();
            RawProto        proto       = solverParam.ToProto("root");
            string          strSolver   = proto.ToString();

            RawProto        proto2       = RawProto.Parse(strSolver);
            SolverParameter solverParam2 = SolverParameter.FromProto(proto2);

            m_log.CHECK(solverParam2.Compare(solverParam), "The two solver parameters should be the same!");
        }
예제 #4
0
        /// <summary>
        /// Train the model.
        /// </summary>
        /// <param name="bNewWts">Specifies whether to use new weights or load existing ones (if they exist).</param>
        public void Train(bool bNewWts)
        {
            if (m_mycaffeTrain == null)
            {
                return;
            }

            byte[] rgWts = null;

            if (!bNewWts)
            {
                rgWts = loadWeights();
            }

            if (rgWts == null)
            {
                Console.WriteLine("Starting with new weights...");
            }

            SolverParameter solver = createSolver();
            NetParameter    model  = createModel();

            string strModel = model.ToProto("root").ToString();

            Console.WriteLine("Using Train Model:");
            Console.WriteLine(strModel);
            Console.WriteLine("Starting training...");

            m_mycaffeTrain.LoadLite(Phase.TRAIN, solver.ToProto("root").ToString(), model.ToProto("root").ToString(), rgWts, false, false);
            m_mycaffeTrain.SetOnTrainingStartOverride(new EventHandler(onTrainingStart));
            m_mycaffeTrain.SetOnTestingStartOverride(new EventHandler(onTestingStart));

            // Set clockwork weights.
            if (m_param.LstmEngine != EngineParameter.Engine.CUDNN)
            {
                Net <float>  net   = m_mycaffeTrain.GetInternalNet(Phase.TRAIN);
                Blob <float> lstm1 = net.parameters[2];
                lstm1.SetData(1, m_param.Hidden, m_param.Hidden);
            }

            m_mycaffeTrain.Train(m_param.Iterations);
            saveLstmState(m_mycaffeTrain);

            Image img = SimpleGraphingControl.QuickRender(m_plots, 1000, 600);

            showImage(img, "training.png");
            saveWeights(m_mycaffeTrain.GetWeights());
        }
예제 #5
0
        /// <summary>
        /// Create the LeNet solver prototxt programmatically.
        /// </summary>
        /// <param name="nIterations">Specifies the number of iterations to train.</param>
        /// <returns>The solver descriptor is returned as text.</returns>
        private string create_solver_descriptor_programmatically(int nIterations)
        {
            SolverParameter solver_param = new SolverParameter();

            solver_param.max_iter  = nIterations;
            solver_param.test_iter = new List <int>();
            solver_param.test_iter.Add(100);
            solver_param.test_initialization = false;
            solver_param.test_interval       = 500;
            solver_param.base_lr             = 0.01;
            solver_param.momentum            = 0.9;
            solver_param.weight_decay        = 0.0005;
            solver_param.LearningRatePolicy  = SolverParameter.LearningRatePolicyType.INV;
            solver_param.gamma    = 0.0001;
            solver_param.power    = 0.75;
            solver_param.display  = 100;
            solver_param.snapshot = 5000;

            // Convert solver to text descriptor.
            RawProto proto = solver_param.ToProto("root");

            return(proto.ToString());
        }
예제 #6
0
        /// <summary>
        /// The DoWork thread is the main tread used to train or run the model depending on the operation selected.
        /// </summary>
        /// <param name="sender">Specifies the sender</param>
        /// <param name="e">specifies the arguments.</param>
        private void m_bw_DoWork(object sender, DoWorkEventArgs e)
        {
            BackgroundWorker bw = sender as BackgroundWorker;

            m_input = e.Argument as InputData;
            SettingsCaffe s = new SettingsCaffe();

            s.ImageDbLoadMethod = IMAGEDB_LOAD_METHOD.LOAD_ALL;

            try
            {
                m_model.Batch = m_input.Batch;
                m_mycaffe     = new MyCaffeControl <float>(s, m_log, m_evtCancel);

                // Train the model.
                if (m_input.Operation == InputData.OPERATION.TRAIN)
                {
                    m_model.Iterations = (int)((m_input.Epochs * 7000) / m_model.Batch);
                    m_log.WriteLine("Training for " + m_input.Epochs.ToString() + " epochs (" + m_model.Iterations.ToString("N0") + " iterations).", true);
                    m_log.WriteLine("INFO: " + m_model.Iterations.ToString("N0") + " iterations.", true);
                    m_log.WriteLine("Using hidden = " + m_input.HiddenSize.ToString() + ", and word size = " + m_input.WordSize.ToString() + ".", true);

                    // Load the Seq2Seq training model.
                    NetParameter    netParam    = m_model.CreateModel(m_input.InputFileName, m_input.TargetFileName, m_input.HiddenSize, m_input.WordSize, m_input.UseSoftmax, m_input.UseExternalIp);
                    string          strModel    = netParam.ToProto("root").ToString();
                    SolverParameter solverParam = m_model.CreateSolver(m_input.LearningRate);
                    string          strSolver   = solverParam.ToProto("root").ToString();
                    byte[]          rgWts       = loadWeights("sequence");

                    m_strModel  = strModel;
                    m_strSolver = strSolver;

                    m_mycaffe.OnTrainingIteration += m_mycaffe_OnTrainingIteration;
                    m_mycaffe.OnTestingIteration  += m_mycaffe_OnTestingIteration;
                    m_mycaffe.LoadLite(Phase.TRAIN, strSolver, strModel, rgWts, false, false);

                    if (!m_input.UseSoftmax)
                    {
                        MemoryLossLayer <float> lossLayerTraining = m_mycaffe.GetInternalNet(Phase.TRAIN).FindLayer(LayerParameter.LayerType.MEMORY_LOSS, "loss") as MemoryLossLayer <float>;
                        if (lossLayerTraining != null)
                        {
                            lossLayerTraining.OnGetLoss += LossLayer_OnGetLossTraining;
                        }
                        MemoryLossLayer <float> lossLayerTesting = m_mycaffe.GetInternalNet(Phase.TEST).FindLayer(LayerParameter.LayerType.MEMORY_LOSS, "loss") as MemoryLossLayer <float>;
                        if (lossLayerTesting != null)
                        {
                            lossLayerTesting.OnGetLoss += LossLayer_OnGetLossTesting;
                        }
                    }

                    m_blobProbs = new Blob <float>(m_mycaffe.Cuda, m_mycaffe.Log);
                    m_blobScale = new Blob <float>(m_mycaffe.Cuda, m_mycaffe.Log);

                    TextDataLayer <float> dataLayerTraining = m_mycaffe.GetInternalNet(Phase.TRAIN).FindLayer(LayerParameter.LayerType.TEXT_DATA, "data") as TextDataLayer <float>;
                    if (dataLayerTraining != null)
                    {
                        dataLayerTraining.OnGetData += DataLayerTraining_OnGetDataTraining;
                    }

                    // Train the Seq2Seq model.
                    m_plotsSequenceLoss          = new PlotCollection("Sequence Loss");
                    m_plotsSequenceAccuracyTest  = new PlotCollection("Sequence Accuracy Test");
                    m_plotsSequenceAccuracyTrain = new PlotCollection("Sequence Accuracy Train");
                    m_mycaffe.Train(m_model.Iterations);
                    saveWeights("sequence", m_mycaffe);
                }

                // Run a trained model.
                else
                {
                    NetParameter netParam = m_model.CreateModel(m_input.InputFileName, m_input.TargetFileName, m_input.HiddenSize, m_input.WordSize, m_input.UseSoftmax, m_input.UseExternalIp, Phase.RUN);
                    string       strModel = netParam.ToProto("root").ToString();
                    byte[]       rgWts    = loadWeights("sequence");

                    strModel = m_model.PrependInput(strModel);

                    m_strModelRun = strModel;

                    int nN = m_model.TimeSteps;
                    m_mycaffe.LoadToRun(strModel, rgWts, new BlobShape(new List <int>()
                    {
                        nN, 1, 1, 1
                    }), null, null, false, false);

                    m_blobProbs = new Blob <float>(m_mycaffe.Cuda, m_mycaffe.Log);
                    m_blobScale = new Blob <float>(m_mycaffe.Cuda, m_mycaffe.Log);

                    runModel(m_mycaffe, bw, m_input.InputText);
                }
            }
            catch (Exception excpt)
            {
                throw excpt;
            }
            finally
            {
                // Cleanup.
                if (m_mycaffe != null)
                {
                    m_mycaffe.Dispose();
                    m_mycaffe = null;
                }
            }
        }
예제 #7
0
        /// <summary>
        /// The worker thread used to either train or run the models.
        /// </summary>
        /// <remarks>
        /// When training, first the input hand-written image model is trained
        /// using the LeNet model.
        ///
        /// This input mode is then run in the onTrainingStart event to get the
        /// detected hand written character representation.  The outputs of layer
        /// 'ip1' from the input model are then fed as input to the sequence
        /// model which is then trained to encode the 'ip1' input data with one
        /// lstm and then decoded with another which is then trained to detect
        /// a section of the Sin curve data.
        ///
        /// When running, the first input model is run to get its 'ip1' representation,
        /// which is then fed into the sequence model to detect the section of the
        /// Sin curve.
        /// </remarks>
        /// <param name="sender">Specifies the sender of the event (e.g. the BackgroundWorker)</param>
        /// <param name="args">Specifies the event args.</param>
        private void m_bw_DoWork(object sender, DoWorkEventArgs e)
        {
            BackgroundWorker bw = sender as BackgroundWorker;
            OPERATION        op = (OPERATION)e.Argument;
            SettingsCaffe    s  = new SettingsCaffe();

            s.ImageDbLoadMethod = IMAGEDB_LOAD_METHOD.LOAD_ALL;

            m_operation    = op;
            m_mycaffe      = new MyCaffeControl <float>(s, m_log, m_evtCancel);
            m_mycaffeInput = new MyCaffeControl <float>(s, m_log, m_evtCancel);
            m_imgDb        = new MyCaffeImageDatabase2(m_log);

            // Load the image database.
            m_imgDb.InitializeWithDsName1(s, "MNIST");
            m_ds = m_imgDb.GetDatasetByName("MNIST");

            // Create the MNIST image detection model
            NetParameter    netParamMnist    = m_model.CreateMnistModel(m_ds);
            SolverParameter solverParamMnist = m_model.CreateMnistSolver();

            byte[] rgWts = loadWeights("input");
            m_mycaffeInput.Load(Phase.TRAIN, solverParamMnist.ToProto("root").ToString(), netParamMnist.ToProto("root").ToString(), rgWts, null, null, false, m_imgDb);
            Net <float>  netTrain = m_mycaffeInput.GetInternalNet(Phase.TRAIN);
            Blob <float> input_ip = netTrain.FindBlob(m_strInputOutputBlobName); // input model's second to last output (includes relu)

            // Run the train or run operation.
            if (op == OPERATION.TRAIN)
            {
                // Train the MNIST model first.
                m_mycaffeInput.OnTrainingIteration += m_mycaffeInput_OnTrainingIteration;
                m_plotsInputLoss = new PlotCollection("Input Loss");
                m_mycaffeInput.Train(2000);
                saveWeights("input", m_mycaffeInput.GetWeights());

                // Load the Seq2Seq training model.
                NetParameter    netParam    = m_model.CreateModel(input_ip.channels, 10);
                string          strModel    = netParam.ToProto("root").ToString();
                SolverParameter solverParam = m_model.CreateSolver();
                rgWts = loadWeights("sequence");

                m_mycaffe.OnTrainingIteration += m_mycaffe_OnTrainingIteration;
                m_mycaffe.LoadLite(Phase.TRAIN, solverParam.ToProto("root").ToString(), netParam.ToProto("root").ToString(), rgWts, false, false);
                m_mycaffe.SetOnTrainingStartOverride(new EventHandler(onTrainingStart));

                // Train the Seq2Seq model.
                m_plotsSequenceLoss = new PlotCollection("Sequence Loss");
                m_mycaffe.Train(m_model.Iterations);
                saveWeights("sequence", m_mycaffe.GetWeights());
            }
            else
            {
                NetParameter netParam = m_model.CreateModel(input_ip.channels, 10, 1, 1);
                string       strModel = netParam.ToProto("root").ToString();
                rgWts = loadWeights("sequence");

                int nN = 1;
                m_mycaffe.LoadToRun(netParam.ToProto("root").ToString(), rgWts, new BlobShape(new List <int>()
                {
                    nN, 1, 1, 1
                }), null, null, false, false);
                runModel(m_mycaffe, bw);
            }

            // Cleanup.
            m_mycaffe.Dispose();
            m_mycaffe = null;
            m_mycaffeInput.Dispose();
            m_mycaffeInput = null;
        }