Ejemplo n.º 1
0
        public Brain(MyCaffeControl <T> mycaffe, PropertySet properties, CryptoRandom random, Phase phase)
        {
            m_mycaffe    = mycaffe;
            m_net        = mycaffe.GetInternalNet(phase);
            m_solver     = mycaffe.GetInternalSolver();
            m_properties = properties;
            m_random     = random;

            m_memData = m_net.FindLayer(LayerParameter.LayerType.MEMORYDATA, null) as MemoryDataLayer <T>;
            m_memLoss = m_net.FindLayer(LayerParameter.LayerType.MEMORY_LOSS, null) as MemoryLossLayer <T>;
            m_softmax = m_net.FindLayer(LayerParameter.LayerType.SOFTMAX, null) as SoftmaxLayer <T>;

            if (m_memData == null)
            {
                throw new Exception("Could not find the MemoryData Layer!");
            }

            if (m_memLoss == null)
            {
                throw new Exception("Could not find the MemoryLoss Layer!");
            }

            m_memData.OnDataPack += memData_OnDataPack;
            m_memLoss.OnGetLoss  += memLoss_OnGetLoss;

            m_blobDiscountedR     = new Blob <T>(mycaffe.Cuda, mycaffe.Log);
            m_blobPolicyGradient  = new Blob <T>(mycaffe.Cuda, mycaffe.Log);
            m_blobActionOneHot    = new Blob <T>(mycaffe.Cuda, mycaffe.Log);
            m_blobDiscountedR1    = new Blob <T>(mycaffe.Cuda, mycaffe.Log);
            m_blobPolicyGradient1 = new Blob <T>(mycaffe.Cuda, mycaffe.Log);
            m_blobActionOneHot1   = new Blob <T>(mycaffe.Cuda, mycaffe.Log);
            m_blobLoss            = new Blob <T>(mycaffe.Cuda, mycaffe.Log);
            m_blobAprobLogit      = new Blob <T>(mycaffe.Cuda, mycaffe.Log);

            if (m_softmax != null)
            {
                LayerParameter p = new LayerParameter(LayerParameter.LayerType.SOFTMAXCROSSENTROPY_LOSS);
                p.loss_weight.Add(1);
                p.loss_weight.Add(0);
                p.loss_param.normalization = LossParameter.NormalizationMode.NONE;
                m_softmaxCe = new SoftmaxCrossEntropyLossLayer <T>(mycaffe.Cuda, mycaffe.Log, p);
            }

            m_colAccumulatedGradients = m_net.learnable_parameters.Clone();
            m_colAccumulatedGradients.SetDiff(0);

            int nMiniBatch = mycaffe.CurrentProject.GetBatchSize(phase);

            if (nMiniBatch != 0)
            {
                m_nMiniBatch = nMiniBatch;
            }

            m_nMiniBatch = m_properties.GetPropertyAsInt("MiniBatch", m_nMiniBatch);
        }