public void TestCreateTrainingModel() { ModelBuilder builder = create(); NetParameter net_param = builder.CreateModel(); RawProto proto = net_param.ToProto("root"); string strNet = proto.ToString(); RawProto proto2 = RawProto.Parse(strNet); NetParameter net_param2 = NetParameter.FromProto(proto2); m_log.CHECK(net_param2.Compare(net_param), "The two net parameters should be the same!"); // verify creating the model. SolverParameter solver = builder.CreateSolver(); RawProto protoSolver = solver.ToProto("root"); string strSolver = protoSolver.ToString(); SettingsCaffe settings = new SettingsCaffe(); CancelEvent evtCancel = new CancelEvent(); MyCaffeControl <T> mycaffe = new MyCaffeControl <T>(settings, m_log, evtCancel); save(strNet, strSolver, false); // mycaffe.LoadLite(Phase.TRAIN, strSolver, strNet, null); mycaffe.Dispose(); }
/// <summary> /// Set the solver testing interval. /// </summary> /// <param name="strSolver">Specifies the solver parameter.</param> /// <returns>The solver description is returned.</returns> private string fixup_solver(string strSolver, int nInterval) { RawProto proto = RawProto.Parse(strSolver); SolverParameter solver_param = SolverParameter.FromProto(proto); // Set the testining interval during training. solver_param.test_interval = nInterval; solver_param.test_initialization = false; return(solver_param.ToProto("root").ToString()); }
public void TestCreateSolver() { ModelBuilder builder = create(); SolverParameter solverParam = builder.CreateSolver(); RawProto proto = solverParam.ToProto("root"); string strSolver = proto.ToString(); RawProto proto2 = RawProto.Parse(strSolver); SolverParameter solverParam2 = SolverParameter.FromProto(proto2); m_log.CHECK(solverParam2.Compare(solverParam), "The two solver parameters should be the same!"); }
/// <summary> /// Train the model. /// </summary> /// <param name="bNewWts">Specifies whether to use new weights or load existing ones (if they exist).</param> public void Train(bool bNewWts) { if (m_mycaffeTrain == null) { return; } byte[] rgWts = null; if (!bNewWts) { rgWts = loadWeights(); } if (rgWts == null) { Console.WriteLine("Starting with new weights..."); } SolverParameter solver = createSolver(); NetParameter model = createModel(); string strModel = model.ToProto("root").ToString(); Console.WriteLine("Using Train Model:"); Console.WriteLine(strModel); Console.WriteLine("Starting training..."); m_mycaffeTrain.LoadLite(Phase.TRAIN, solver.ToProto("root").ToString(), model.ToProto("root").ToString(), rgWts, false, false); m_mycaffeTrain.SetOnTrainingStartOverride(new EventHandler(onTrainingStart)); m_mycaffeTrain.SetOnTestingStartOverride(new EventHandler(onTestingStart)); // Set clockwork weights. if (m_param.LstmEngine != EngineParameter.Engine.CUDNN) { Net <float> net = m_mycaffeTrain.GetInternalNet(Phase.TRAIN); Blob <float> lstm1 = net.parameters[2]; lstm1.SetData(1, m_param.Hidden, m_param.Hidden); } m_mycaffeTrain.Train(m_param.Iterations); saveLstmState(m_mycaffeTrain); Image img = SimpleGraphingControl.QuickRender(m_plots, 1000, 600); showImage(img, "training.png"); saveWeights(m_mycaffeTrain.GetWeights()); }
/// <summary> /// Create the LeNet solver prototxt programmatically. /// </summary> /// <param name="nIterations">Specifies the number of iterations to train.</param> /// <returns>The solver descriptor is returned as text.</returns> private string create_solver_descriptor_programmatically(int nIterations) { SolverParameter solver_param = new SolverParameter(); solver_param.max_iter = nIterations; solver_param.test_iter = new List <int>(); solver_param.test_iter.Add(100); solver_param.test_initialization = false; solver_param.test_interval = 500; solver_param.base_lr = 0.01; solver_param.momentum = 0.9; solver_param.weight_decay = 0.0005; solver_param.LearningRatePolicy = SolverParameter.LearningRatePolicyType.INV; solver_param.gamma = 0.0001; solver_param.power = 0.75; solver_param.display = 100; solver_param.snapshot = 5000; // Convert solver to text descriptor. RawProto proto = solver_param.ToProto("root"); return(proto.ToString()); }
/// <summary> /// The DoWork thread is the main tread used to train or run the model depending on the operation selected. /// </summary> /// <param name="sender">Specifies the sender</param> /// <param name="e">specifies the arguments.</param> private void m_bw_DoWork(object sender, DoWorkEventArgs e) { BackgroundWorker bw = sender as BackgroundWorker; m_input = e.Argument as InputData; SettingsCaffe s = new SettingsCaffe(); s.ImageDbLoadMethod = IMAGEDB_LOAD_METHOD.LOAD_ALL; try { m_model.Batch = m_input.Batch; m_mycaffe = new MyCaffeControl <float>(s, m_log, m_evtCancel); // Train the model. if (m_input.Operation == InputData.OPERATION.TRAIN) { m_model.Iterations = (int)((m_input.Epochs * 7000) / m_model.Batch); m_log.WriteLine("Training for " + m_input.Epochs.ToString() + " epochs (" + m_model.Iterations.ToString("N0") + " iterations).", true); m_log.WriteLine("INFO: " + m_model.Iterations.ToString("N0") + " iterations.", true); m_log.WriteLine("Using hidden = " + m_input.HiddenSize.ToString() + ", and word size = " + m_input.WordSize.ToString() + ".", true); // Load the Seq2Seq training model. NetParameter netParam = m_model.CreateModel(m_input.InputFileName, m_input.TargetFileName, m_input.HiddenSize, m_input.WordSize, m_input.UseSoftmax, m_input.UseExternalIp); string strModel = netParam.ToProto("root").ToString(); SolverParameter solverParam = m_model.CreateSolver(m_input.LearningRate); string strSolver = solverParam.ToProto("root").ToString(); byte[] rgWts = loadWeights("sequence"); m_strModel = strModel; m_strSolver = strSolver; m_mycaffe.OnTrainingIteration += m_mycaffe_OnTrainingIteration; m_mycaffe.OnTestingIteration += m_mycaffe_OnTestingIteration; m_mycaffe.LoadLite(Phase.TRAIN, strSolver, strModel, rgWts, false, false); if (!m_input.UseSoftmax) { MemoryLossLayer <float> lossLayerTraining = m_mycaffe.GetInternalNet(Phase.TRAIN).FindLayer(LayerParameter.LayerType.MEMORY_LOSS, "loss") as MemoryLossLayer <float>; if (lossLayerTraining != null) { lossLayerTraining.OnGetLoss += LossLayer_OnGetLossTraining; } MemoryLossLayer <float> lossLayerTesting = m_mycaffe.GetInternalNet(Phase.TEST).FindLayer(LayerParameter.LayerType.MEMORY_LOSS, "loss") as MemoryLossLayer <float>; if (lossLayerTesting != null) { lossLayerTesting.OnGetLoss += LossLayer_OnGetLossTesting; } } m_blobProbs = new Blob <float>(m_mycaffe.Cuda, m_mycaffe.Log); m_blobScale = new Blob <float>(m_mycaffe.Cuda, m_mycaffe.Log); TextDataLayer <float> dataLayerTraining = m_mycaffe.GetInternalNet(Phase.TRAIN).FindLayer(LayerParameter.LayerType.TEXT_DATA, "data") as TextDataLayer <float>; if (dataLayerTraining != null) { dataLayerTraining.OnGetData += DataLayerTraining_OnGetDataTraining; } // Train the Seq2Seq model. m_plotsSequenceLoss = new PlotCollection("Sequence Loss"); m_plotsSequenceAccuracyTest = new PlotCollection("Sequence Accuracy Test"); m_plotsSequenceAccuracyTrain = new PlotCollection("Sequence Accuracy Train"); m_mycaffe.Train(m_model.Iterations); saveWeights("sequence", m_mycaffe); } // Run a trained model. else { NetParameter netParam = m_model.CreateModel(m_input.InputFileName, m_input.TargetFileName, m_input.HiddenSize, m_input.WordSize, m_input.UseSoftmax, m_input.UseExternalIp, Phase.RUN); string strModel = netParam.ToProto("root").ToString(); byte[] rgWts = loadWeights("sequence"); strModel = m_model.PrependInput(strModel); m_strModelRun = strModel; int nN = m_model.TimeSteps; m_mycaffe.LoadToRun(strModel, rgWts, new BlobShape(new List <int>() { nN, 1, 1, 1 }), null, null, false, false); m_blobProbs = new Blob <float>(m_mycaffe.Cuda, m_mycaffe.Log); m_blobScale = new Blob <float>(m_mycaffe.Cuda, m_mycaffe.Log); runModel(m_mycaffe, bw, m_input.InputText); } } catch (Exception excpt) { throw excpt; } finally { // Cleanup. if (m_mycaffe != null) { m_mycaffe.Dispose(); m_mycaffe = null; } } }
/// <summary> /// The worker thread used to either train or run the models. /// </summary> /// <remarks> /// When training, first the input hand-written image model is trained /// using the LeNet model. /// /// This input mode is then run in the onTrainingStart event to get the /// detected hand written character representation. The outputs of layer /// 'ip1' from the input model are then fed as input to the sequence /// model which is then trained to encode the 'ip1' input data with one /// lstm and then decoded with another which is then trained to detect /// a section of the Sin curve data. /// /// When running, the first input model is run to get its 'ip1' representation, /// which is then fed into the sequence model to detect the section of the /// Sin curve. /// </remarks> /// <param name="sender">Specifies the sender of the event (e.g. the BackgroundWorker)</param> /// <param name="args">Specifies the event args.</param> private void m_bw_DoWork(object sender, DoWorkEventArgs e) { BackgroundWorker bw = sender as BackgroundWorker; OPERATION op = (OPERATION)e.Argument; SettingsCaffe s = new SettingsCaffe(); s.ImageDbLoadMethod = IMAGEDB_LOAD_METHOD.LOAD_ALL; m_operation = op; m_mycaffe = new MyCaffeControl <float>(s, m_log, m_evtCancel); m_mycaffeInput = new MyCaffeControl <float>(s, m_log, m_evtCancel); m_imgDb = new MyCaffeImageDatabase2(m_log); // Load the image database. m_imgDb.InitializeWithDsName1(s, "MNIST"); m_ds = m_imgDb.GetDatasetByName("MNIST"); // Create the MNIST image detection model NetParameter netParamMnist = m_model.CreateMnistModel(m_ds); SolverParameter solverParamMnist = m_model.CreateMnistSolver(); byte[] rgWts = loadWeights("input"); m_mycaffeInput.Load(Phase.TRAIN, solverParamMnist.ToProto("root").ToString(), netParamMnist.ToProto("root").ToString(), rgWts, null, null, false, m_imgDb); Net <float> netTrain = m_mycaffeInput.GetInternalNet(Phase.TRAIN); Blob <float> input_ip = netTrain.FindBlob(m_strInputOutputBlobName); // input model's second to last output (includes relu) // Run the train or run operation. if (op == OPERATION.TRAIN) { // Train the MNIST model first. m_mycaffeInput.OnTrainingIteration += m_mycaffeInput_OnTrainingIteration; m_plotsInputLoss = new PlotCollection("Input Loss"); m_mycaffeInput.Train(2000); saveWeights("input", m_mycaffeInput.GetWeights()); // Load the Seq2Seq training model. NetParameter netParam = m_model.CreateModel(input_ip.channels, 10); string strModel = netParam.ToProto("root").ToString(); SolverParameter solverParam = m_model.CreateSolver(); rgWts = loadWeights("sequence"); m_mycaffe.OnTrainingIteration += m_mycaffe_OnTrainingIteration; m_mycaffe.LoadLite(Phase.TRAIN, solverParam.ToProto("root").ToString(), netParam.ToProto("root").ToString(), rgWts, false, false); m_mycaffe.SetOnTrainingStartOverride(new EventHandler(onTrainingStart)); // Train the Seq2Seq model. m_plotsSequenceLoss = new PlotCollection("Sequence Loss"); m_mycaffe.Train(m_model.Iterations); saveWeights("sequence", m_mycaffe.GetWeights()); } else { NetParameter netParam = m_model.CreateModel(input_ip.channels, 10, 1, 1); string strModel = netParam.ToProto("root").ToString(); rgWts = loadWeights("sequence"); int nN = 1; m_mycaffe.LoadToRun(netParam.ToProto("root").ToString(), rgWts, new BlobShape(new List <int>() { nN, 1, 1, 1 }), null, null, false, false); runModel(m_mycaffe, bw); } // Cleanup. m_mycaffe.Dispose(); m_mycaffe = null; m_mycaffeInput.Dispose(); m_mycaffeInput = null; }