Ejemplo n.º 1
0
        /// <summary>
        /// The constructor.
        /// </summary>
        /// <param name="mycaffe">Specifies the instance of MyCaffe assoiated with the open project - when using more than one Brain, this is the master project.</param>
        /// <param name="properties">Specifies the properties passed into the trainer.</param>
        /// <param name="random">Specifies the random number generator used.</param>
        /// <param name="phase">Specifies the phase under which to run.</param>
        public Brain(MyCaffeControl <T> mycaffe, PropertySet properties, CryptoRandom random, Phase phase)
        {
            m_mycaffe    = mycaffe;
            m_solver     = mycaffe.GetInternalSolver();
            m_netOutput  = mycaffe.GetInternalNet(phase);
            m_netTarget  = new Net <T>(m_mycaffe.Cuda, m_mycaffe.Log, m_netOutput.net_param, m_mycaffe.CancelEvent, null, phase);
            m_properties = properties;
            m_random     = random;

            Blob <T> data = m_netOutput.blob_by_name("data");

            if (data == null)
            {
                m_mycaffe.Log.FAIL("Missing the expected input 'data' blob!");
            }

            m_nBatchSize = data.num;

            Blob <T> logits = m_netOutput.blob_by_name("logits");

            if (logits == null)
            {
                m_mycaffe.Log.FAIL("Missing the expected input 'logits' blob!");
            }

            m_nActionCount = logits.channels;

            m_transformer        = m_mycaffe.DataTransformer;
            m_blobActions        = new Blob <T>(m_mycaffe.Cuda, m_mycaffe.Log, false);
            m_blobQValue         = new Blob <T>(m_mycaffe.Cuda, m_mycaffe.Log);
            m_blobNextQValue     = new Blob <T>(m_mycaffe.Cuda, m_mycaffe.Log);
            m_blobExpectedQValue = new Blob <T>(m_mycaffe.Cuda, m_mycaffe.Log);
            m_blobDone           = new Blob <T>(m_mycaffe.Cuda, m_mycaffe.Log, false);
            m_blobLoss           = new Blob <T>(m_mycaffe.Cuda, m_mycaffe.Log);
            m_blobWeights        = new Blob <T>(m_mycaffe.Cuda, m_mycaffe.Log, false);

            m_fGamma = (float)properties.GetPropertyAsDouble("Gamma", m_fGamma);

            m_memLoss = m_netOutput.FindLastLayer(LayerParameter.LayerType.MEMORY_LOSS) as MemoryLossLayer <T>;
            if (m_memLoss == null)
            {
                m_mycaffe.Log.FAIL("Missing the expected MEMORY_LOSS layer!");
            }

            double?dfRate = mycaffe.CurrentProject.GetSolverSettingAsNumeric("base_lr");

            if (dfRate.HasValue)
            {
                m_dfLearningRate = dfRate.Value;
            }

            m_nMiniBatch = m_properties.GetPropertyAsInt("MiniBatch", m_nMiniBatch);
            m_bUseAcceleratedTraining = properties.GetPropertyAsBool("UseAcceleratedTraining", false);

            if (m_nMiniBatch > 1)
            {
                m_colAccumulatedGradients = m_netOutput.learnable_parameters.Clone();
                m_colAccumulatedGradients.SetDiff(0);
            }
        }
Ejemplo n.º 2
0
        public void Train(int nIteration, TRAIN_STEP step)
        {
            m_mycaffe.Log.Enable = false;

            // Run data/clip groups > 1 in non batch mode.
            if (m_nRecurrentSequenceLength != 1 && m_rgData != null && m_rgData.Count > 1 && m_rgClip != null)
            {
                prepareBlob(m_blobActionOneHot1, m_blobActionOneHot);
                prepareBlob(m_blobDiscountedR1, m_blobDiscountedR);
                prepareBlob(m_blobPolicyGradient1, m_blobPolicyGradient);

                for (int i = 0; i < m_rgData.Count; i++)
                {
                    copyBlob(i, m_blobActionOneHot1, m_blobActionOneHot);
                    copyBlob(i, m_blobDiscountedR1, m_blobDiscountedR);
                    copyBlob(i, m_blobPolicyGradient1, m_blobPolicyGradient);

                    List <Datum> rgData1 = new List <Datum>()
                    {
                        m_rgData[i]
                    };
                    List <Datum> rgClip1 = new List <Datum>()
                    {
                        m_rgClip[i]
                    };

                    m_memData.AddDatumVector(rgData1, rgClip1, 1, true, true);

                    m_solver.Step(1, step, true, false, true, true);
                }

                m_blobActionOneHot.ReshapeLike(m_blobActionOneHot1);
                m_blobDiscountedR.ReshapeLike(m_blobDiscountedR1);
                m_blobPolicyGradient.ReshapeLike(m_blobPolicyGradient1);

                m_rgData = null;
                m_rgClip = null;
            }
            else
            {
                m_solver.Step(1, step, true, false, true, true);
            }

            m_colAccumulatedGradients.Accumulate(m_mycaffe.Cuda, m_net.learnable_parameters, true);

            if (nIteration % m_nMiniBatch == 0 || step == TRAIN_STEP.BACKWARD || step == TRAIN_STEP.BOTH)
            {
                m_net.learnable_parameters.CopyFrom(m_colAccumulatedGradients, true);
                m_colAccumulatedGradients.SetDiff(0);
                m_solver.ApplyUpdate(nIteration);
                m_net.ClearParamDiffs();
            }

            m_mycaffe.Log.Enable = true;
        }
Ejemplo n.º 3
0
        public Brain(MyCaffeControl <T> mycaffe, PropertySet properties, CryptoRandom random, Phase phase)
        {
            m_mycaffe    = mycaffe;
            m_net        = mycaffe.GetInternalNet(phase);
            m_solver     = mycaffe.GetInternalSolver();
            m_properties = properties;
            m_random     = random;

            m_memData = m_net.FindLayer(LayerParameter.LayerType.MEMORYDATA, null) as MemoryDataLayer <T>;
            m_memLoss = m_net.FindLayer(LayerParameter.LayerType.MEMORY_LOSS, null) as MemoryLossLayer <T>;
            m_softmax = m_net.FindLayer(LayerParameter.LayerType.SOFTMAX, null) as SoftmaxLayer <T>;

            if (m_memData == null)
            {
                throw new Exception("Could not find the MemoryData Layer!");
            }

            if (m_memLoss == null)
            {
                throw new Exception("Could not find the MemoryLoss Layer!");
            }

            m_memData.OnDataPack += memData_OnDataPack;
            m_memLoss.OnGetLoss  += memLoss_OnGetLoss;

            m_blobDiscountedR     = new Blob <T>(mycaffe.Cuda, mycaffe.Log);
            m_blobPolicyGradient  = new Blob <T>(mycaffe.Cuda, mycaffe.Log);
            m_blobActionOneHot    = new Blob <T>(mycaffe.Cuda, mycaffe.Log);
            m_blobDiscountedR1    = new Blob <T>(mycaffe.Cuda, mycaffe.Log);
            m_blobPolicyGradient1 = new Blob <T>(mycaffe.Cuda, mycaffe.Log);
            m_blobActionOneHot1   = new Blob <T>(mycaffe.Cuda, mycaffe.Log);
            m_blobLoss            = new Blob <T>(mycaffe.Cuda, mycaffe.Log);
            m_blobAprobLogit      = new Blob <T>(mycaffe.Cuda, mycaffe.Log);

            if (m_softmax != null)
            {
                LayerParameter p = new LayerParameter(LayerParameter.LayerType.SOFTMAXCROSSENTROPY_LOSS);
                p.loss_weight.Add(1);
                p.loss_weight.Add(0);
                p.loss_param.normalization = LossParameter.NormalizationMode.NONE;
                m_softmaxCe = new SoftmaxCrossEntropyLossLayer <T>(mycaffe.Cuda, mycaffe.Log, p);
            }

            m_colAccumulatedGradients = m_net.learnable_parameters.Clone();
            m_colAccumulatedGradients.SetDiff(0);

            int nMiniBatch = mycaffe.CurrentProject.GetBatchSize(phase);

            if (nMiniBatch != 0)
            {
                m_nMiniBatch = nMiniBatch;
            }

            m_nMiniBatch = m_properties.GetPropertyAsInt("MiniBatch", m_nMiniBatch);
        }
Ejemplo n.º 4
0
        public void Train(int nIteration)
        {
            m_mycaffe.Log.Enable = false;
            m_solver.Step(1, TRAIN_STEP.NONE, true, false, true);
            m_colAccumulatedGradients.Accumulate(m_mycaffe.Cuda, m_net.learnable_parameters, true);

            if (nIteration % m_nMiniBatch == 0)
            {
                m_net.learnable_parameters.CopyFrom(m_colAccumulatedGradients, true);
                m_colAccumulatedGradients.SetDiff(0);
                m_solver.ApplyUpdate(nIteration);
                m_net.ClearParamDiffs();
            }

            m_mycaffe.Log.Enable = true;
        }
Ejemplo n.º 5
0
        /// <summary>
        /// Train the model at the current iteration.
        /// </summary>
        /// <param name="nIteration">Specifies the current iteration.</param>
        /// <param name="rgSamples">Contains the samples to train the model with along with the priorities associated with the samples.</param>
        /// <param name="nActionCount">Specifies the number of actions in the action set.</param>
        public void Train(int nIteration, MemoryCollection rgSamples, int nActionCount)
        {
            m_rgSamples = rgSamples;

            if (m_nActionCount != nActionCount)
            {
                throw new Exception("The logit output of '" + m_nActionCount.ToString() + "' does not match the action count of '" + nActionCount.ToString() + "'!");
            }

            // Get next_q_values
            m_mycaffe.Log.Enable = false;
            setNextStateData(m_netTarget, rgSamples);
            m_netTarget.ForwardFromTo(0, m_netTarget.layers.Count - 2);

            setCurrentStateData(m_netOutput, rgSamples);
            m_memLoss.OnGetLoss += m_memLoss_ComputeTdLoss;

            if (m_nMiniBatch == 1)
            {
                m_solver.Step(1);
            }
            else
            {
                m_solver.Step(1, TRAIN_STEP.NONE, true, m_bUseAcceleratedTraining, true, true);
                m_colAccumulatedGradients.Accumulate(m_mycaffe.Cuda, m_netOutput.learnable_parameters, true);

                if (nIteration % m_nMiniBatch == 0)
                {
                    m_netOutput.learnable_parameters.CopyFrom(m_colAccumulatedGradients, true);
                    m_colAccumulatedGradients.SetDiff(0);
                    m_dfLearningRate = m_solver.ApplyUpdate(nIteration);
                    m_netOutput.ClearParamDiffs();
                }
            }

            m_memLoss.OnGetLoss -= m_memLoss_ComputeTdLoss;
            m_mycaffe.Log.Enable = true;

            resetNoise(m_netOutput);
            resetNoise(m_netTarget);
        }
Ejemplo n.º 6
0
        /// <summary>
        /// The constructor.
        /// </summary>
        /// <param name="mycaffe">Specifies the instance of MyCaffe assoiated with the open project - when using more than one Brain, this is the master project.</param>
        /// <param name="properties">Specifies the properties passed into the trainer.</param>
        /// <param name="random">Specifies the random number generator used.</param>
        /// <param name="phase">Specifies the phase under which to run.</param>
        public Brain(MyCaffeControl <T> mycaffe, PropertySet properties, CryptoRandom random, Phase phase)
        {
            m_mycaffe    = mycaffe;
            m_solver     = mycaffe.GetInternalSolver();
            m_netOutput  = mycaffe.GetInternalNet(phase);
            m_netTarget  = new Net <T>(m_mycaffe.Cuda, m_mycaffe.Log, m_netOutput.net_param, m_mycaffe.CancelEvent, null, phase);
            m_properties = properties;
            m_random     = random;

            Blob <T> data = m_netOutput.blob_by_name("data");

            if (data == null)
            {
                m_mycaffe.Log.FAIL("Missing the expected input 'data' blob!");
            }

            m_nFramesPerX = data.channels;
            m_nBatchSize  = data.num;

            Blob <T> logits = m_netOutput.blob_by_name("logits");

            if (logits == null)
            {
                m_mycaffe.Log.FAIL("Missing the expected input 'logits' blob!");
            }

            m_nActionCount = logits.channels;

            m_transformer = m_mycaffe.DataTransformer;
            if (m_transformer == null)
            {
                TransformationParameter trans_param = new TransformationParameter();
                int nC = m_mycaffe.CurrentProject.Dataset.TrainingSource.ImageChannels;
                int nH = m_mycaffe.CurrentProject.Dataset.TrainingSource.ImageHeight;
                int nW = m_mycaffe.CurrentProject.Dataset.TrainingSource.ImageWidth;
                m_transformer = new DataTransformer <T>(m_mycaffe.Cuda, m_mycaffe.Log, trans_param, phase, nC, nH, nW);
            }

            for (int i = 0; i < m_nFramesPerX; i++)
            {
                m_transformer.param.mean_value.Add(255 / 2); // center each frame
            }

            m_transformer.param.scale = 1.0 / 255; // normalize
            m_transformer.Update();

            m_blobActions        = new Blob <T>(m_mycaffe.Cuda, m_mycaffe.Log, false);
            m_blobQValue         = new Blob <T>(m_mycaffe.Cuda, m_mycaffe.Log);
            m_blobNextQValue     = new Blob <T>(m_mycaffe.Cuda, m_mycaffe.Log);
            m_blobExpectedQValue = new Blob <T>(m_mycaffe.Cuda, m_mycaffe.Log);
            m_blobDone           = new Blob <T>(m_mycaffe.Cuda, m_mycaffe.Log, false);
            m_blobLoss           = new Blob <T>(m_mycaffe.Cuda, m_mycaffe.Log);
            m_blobWeights        = new Blob <T>(m_mycaffe.Cuda, m_mycaffe.Log, false);

            m_fGamma = (float)properties.GetPropertyAsDouble("Gamma", m_fGamma);

            m_memLoss = m_netOutput.FindLastLayer(LayerParameter.LayerType.MEMORY_LOSS) as MemoryLossLayer <T>;
            if (m_memLoss == null)
            {
                m_mycaffe.Log.FAIL("Missing the expected MEMORY_LOSS layer!");
            }

            double?dfRate = mycaffe.CurrentProject.GetSolverSettingAsNumeric("base_lr");

            if (dfRate.HasValue)
            {
                m_dfLearningRate = dfRate.Value;
            }

            m_nMiniBatch = m_properties.GetPropertyAsInt("MiniBatch", m_nMiniBatch);
            m_bUseAcceleratedTraining = properties.GetPropertyAsBool("UseAcceleratedTraining", false);

            if (m_nMiniBatch > 1)
            {
                m_colAccumulatedGradients = m_netOutput.learnable_parameters.Clone();
                m_colAccumulatedGradients.SetDiff(0);
            }
        }