Example #1
0
        public void Train(int nIteration, TRAIN_STEP step)
        {
            m_mycaffe.Log.Enable = false;

            // Run data/clip groups > 1 in non batch mode.
            if (m_nRecurrentSequenceLength != 1 && m_rgData != null && m_rgData.Count > 1 && m_rgClip != null)
            {
                prepareBlob(m_blobActionOneHot1, m_blobActionOneHot);
                prepareBlob(m_blobDiscountedR1, m_blobDiscountedR);
                prepareBlob(m_blobPolicyGradient1, m_blobPolicyGradient);

                for (int i = 0; i < m_rgData.Count; i++)
                {
                    copyBlob(i, m_blobActionOneHot1, m_blobActionOneHot);
                    copyBlob(i, m_blobDiscountedR1, m_blobDiscountedR);
                    copyBlob(i, m_blobPolicyGradient1, m_blobPolicyGradient);

                    List <Datum> rgData1 = new List <Datum>()
                    {
                        m_rgData[i]
                    };
                    List <Datum> rgClip1 = new List <Datum>()
                    {
                        m_rgClip[i]
                    };

                    m_memData.AddDatumVector(rgData1, rgClip1, 1, true, true);

                    m_solver.Step(1, step, true, false, true, true);
                }

                m_blobActionOneHot.ReshapeLike(m_blobActionOneHot1);
                m_blobDiscountedR.ReshapeLike(m_blobDiscountedR1);
                m_blobPolicyGradient.ReshapeLike(m_blobPolicyGradient1);

                m_rgData = null;
                m_rgClip = null;
            }
            else
            {
                m_solver.Step(1, step, true, false, true, true);
            }

            m_colAccumulatedGradients.Accumulate(m_mycaffe.Cuda, m_net.learnable_parameters, true);

            if (nIteration % m_nMiniBatch == 0 || step == TRAIN_STEP.BACKWARD || step == TRAIN_STEP.BOTH)
            {
                m_net.learnable_parameters.CopyFrom(m_colAccumulatedGradients, true);
                m_colAccumulatedGradients.SetDiff(0);
                m_solver.ApplyUpdate(nIteration);
                m_net.ClearParamDiffs();
            }

            m_mycaffe.Log.Enable = true;
        }
Example #2
0
        public void Train(int nIteration)
        {
            m_mycaffe.Log.Enable = false;
            m_solver.Step(1, TRAIN_STEP.NONE, true, false, true);
            m_colAccumulatedGradients.Accumulate(m_mycaffe.Cuda, m_net.learnable_parameters, true);

            if (nIteration % m_nMiniBatch == 0)
            {
                m_net.learnable_parameters.CopyFrom(m_colAccumulatedGradients, true);
                m_colAccumulatedGradients.SetDiff(0);
                m_solver.ApplyUpdate(nIteration);
                m_net.ClearParamDiffs();
            }

            m_mycaffe.Log.Enable = true;
        }
Example #3
0
        /// <summary>
        /// Train the model at the current iteration.
        /// </summary>
        /// <param name="nIteration">Specifies the current iteration.</param>
        /// <param name="rgSamples">Contains the samples to train the model with along with the priorities associated with the samples.</param>
        /// <param name="nActionCount">Specifies the number of actions in the action set.</param>
        public void Train(int nIteration, MemoryCollection rgSamples, int nActionCount)
        {
            m_rgSamples = rgSamples;

            if (m_nActionCount != nActionCount)
            {
                throw new Exception("The logit output of '" + m_nActionCount.ToString() + "' does not match the action count of '" + nActionCount.ToString() + "'!");
            }

            // Get next_q_values
            m_mycaffe.Log.Enable = false;
            setNextStateData(m_netTarget, rgSamples);
            m_netTarget.ForwardFromTo(0, m_netTarget.layers.Count - 2);

            setCurrentStateData(m_netOutput, rgSamples);
            m_memLoss.OnGetLoss += m_memLoss_ComputeTdLoss;

            if (m_nMiniBatch == 1)
            {
                m_solver.Step(1);
            }
            else
            {
                m_solver.Step(1, TRAIN_STEP.NONE, true, m_bUseAcceleratedTraining, true, true);
                m_colAccumulatedGradients.Accumulate(m_mycaffe.Cuda, m_netOutput.learnable_parameters, true);

                if (nIteration % m_nMiniBatch == 0)
                {
                    m_netOutput.learnable_parameters.CopyFrom(m_colAccumulatedGradients, true);
                    m_colAccumulatedGradients.SetDiff(0);
                    m_dfLearningRate = m_solver.ApplyUpdate(nIteration);
                    m_netOutput.ClearParamDiffs();
                }
            }

            m_memLoss.OnGetLoss -= m_memLoss_ComputeTdLoss;
            m_mycaffe.Log.Enable = true;

            resetNoise(m_netOutput);
            resetNoise(m_netTarget);
        }