public void Train(int nIteration, TRAIN_STEP step) { m_mycaffe.Log.Enable = false; // Run data/clip groups > 1 in non batch mode. if (m_nRecurrentSequenceLength != 1 && m_rgData != null && m_rgData.Count > 1 && m_rgClip != null) { prepareBlob(m_blobActionOneHot1, m_blobActionOneHot); prepareBlob(m_blobDiscountedR1, m_blobDiscountedR); prepareBlob(m_blobPolicyGradient1, m_blobPolicyGradient); for (int i = 0; i < m_rgData.Count; i++) { copyBlob(i, m_blobActionOneHot1, m_blobActionOneHot); copyBlob(i, m_blobDiscountedR1, m_blobDiscountedR); copyBlob(i, m_blobPolicyGradient1, m_blobPolicyGradient); List <Datum> rgData1 = new List <Datum>() { m_rgData[i] }; List <Datum> rgClip1 = new List <Datum>() { m_rgClip[i] }; m_memData.AddDatumVector(rgData1, rgClip1, 1, true, true); m_solver.Step(1, step, true, false, true, true); } m_blobActionOneHot.ReshapeLike(m_blobActionOneHot1); m_blobDiscountedR.ReshapeLike(m_blobDiscountedR1); m_blobPolicyGradient.ReshapeLike(m_blobPolicyGradient1); m_rgData = null; m_rgClip = null; } else { m_solver.Step(1, step, true, false, true, true); } m_colAccumulatedGradients.Accumulate(m_mycaffe.Cuda, m_net.learnable_parameters, true); if (nIteration % m_nMiniBatch == 0 || step == TRAIN_STEP.BACKWARD || step == TRAIN_STEP.BOTH) { m_net.learnable_parameters.CopyFrom(m_colAccumulatedGradients, true); m_colAccumulatedGradients.SetDiff(0); m_solver.ApplyUpdate(nIteration); m_net.ClearParamDiffs(); } m_mycaffe.Log.Enable = true; }
public void Train(int nIteration) { m_mycaffe.Log.Enable = false; m_solver.Step(1, TRAIN_STEP.NONE, true, false, true); m_colAccumulatedGradients.Accumulate(m_mycaffe.Cuda, m_net.learnable_parameters, true); if (nIteration % m_nMiniBatch == 0) { m_net.learnable_parameters.CopyFrom(m_colAccumulatedGradients, true); m_colAccumulatedGradients.SetDiff(0); m_solver.ApplyUpdate(nIteration); m_net.ClearParamDiffs(); } m_mycaffe.Log.Enable = true; }
/// <summary> /// Train the model at the current iteration. /// </summary> /// <param name="nIteration">Specifies the current iteration.</param> /// <param name="rgSamples">Contains the samples to train the model with along with the priorities associated with the samples.</param> /// <param name="nActionCount">Specifies the number of actions in the action set.</param> public void Train(int nIteration, MemoryCollection rgSamples, int nActionCount) { m_rgSamples = rgSamples; if (m_nActionCount != nActionCount) { throw new Exception("The logit output of '" + m_nActionCount.ToString() + "' does not match the action count of '" + nActionCount.ToString() + "'!"); } // Get next_q_values m_mycaffe.Log.Enable = false; setNextStateData(m_netTarget, rgSamples); m_netTarget.ForwardFromTo(0, m_netTarget.layers.Count - 2); setCurrentStateData(m_netOutput, rgSamples); m_memLoss.OnGetLoss += m_memLoss_ComputeTdLoss; if (m_nMiniBatch == 1) { m_solver.Step(1); } else { m_solver.Step(1, TRAIN_STEP.NONE, true, m_bUseAcceleratedTraining, true, true); m_colAccumulatedGradients.Accumulate(m_mycaffe.Cuda, m_netOutput.learnable_parameters, true); if (nIteration % m_nMiniBatch == 0) { m_netOutput.learnable_parameters.CopyFrom(m_colAccumulatedGradients, true); m_colAccumulatedGradients.SetDiff(0); m_dfLearningRate = m_solver.ApplyUpdate(nIteration); m_netOutput.ClearParamDiffs(); } } m_memLoss.OnGetLoss -= m_memLoss_ComputeTdLoss; m_mycaffe.Log.Enable = true; resetNoise(m_netOutput); resetNoise(m_netTarget); }