/// <summary> /// The constructor. /// </summary> /// <param name="icallback">Specifies the callback used for update notifications sent to the parent.</param> /// <param name="mycaffe">Specifies the instance of MyCaffe with the open project.</param> /// <param name="properties">Specifies the properties passed into the trainer.</param> /// <param name="random">Specifies the random number generator used.</param> /// <param name="phase">Specifies the phase of the internal network to use.</param> public DqnAgent(IxTrainerCallback icallback, MyCaffeControl <T> mycaffe, PropertySet properties, CryptoRandom random, Phase phase) { m_icallback = icallback; m_brain = new Brain <T>(mycaffe, properties, random, phase); m_properties = properties; m_random = random; m_fGamma = (float)properties.GetPropertyAsDouble("Gamma", m_fGamma); m_bUseRawInput = properties.GetPropertyAsBool("UseRawInput", m_bUseRawInput); m_nMaxMemory = properties.GetPropertyAsInt("MaxMemory", m_nMaxMemory); m_nTrainingUpdateFreq = properties.GetPropertyAsInt("TrainingUpdateFreq", m_nTrainingUpdateFreq); m_nExplorationNum = properties.GetPropertyAsInt("ExplorationNum", m_nExplorationNum); m_nEpsSteps = properties.GetPropertyAsInt("EpsSteps", m_nEpsSteps); m_dfEpsStart = properties.GetPropertyAsDouble("EpsStart", m_dfEpsStart); m_dfEpsEnd = properties.GetPropertyAsDouble("EpsEnd", m_dfEpsEnd); m_dfEpsDelta = (m_dfEpsStart - m_dfEpsEnd) / m_nEpsSteps; m_dfExplorationRate = m_dfEpsStart; if (m_dfEpsStart < 0 || m_dfEpsStart > 1) { throw new Exception("The 'EpsStart' is out of range - please specify a real number in the range [0,1]"); } if (m_dfEpsEnd < 0 || m_dfEpsEnd > 1) { throw new Exception("The 'EpsEnd' is out of range - please specify a real number in the range [0,1]"); } if (m_dfEpsEnd > m_dfEpsStart) { throw new Exception("The 'EpsEnd' must be less than the 'EpsStart' value."); } }
/// <summary> /// The constructor. /// </summary> /// <param name="mycaffe">Specifies the instance of MyCaffe assoiated with the open project - when using more than one Brain, this is the master project.</param> /// <param name="properties">Specifies the properties passed into the trainer.</param> /// <param name="random">Specifies the random number generator used.</param> /// <param name="phase">Specifies the phase under which to run.</param> public Brain(MyCaffeControl <T> mycaffe, PropertySet properties, CryptoRandom random, Phase phase) { m_mycaffe = mycaffe; m_solver = mycaffe.GetInternalSolver(); m_netOutput = mycaffe.GetInternalNet(phase); m_netTarget = new Net <T>(m_mycaffe.Cuda, m_mycaffe.Log, m_netOutput.net_param, m_mycaffe.CancelEvent, null, phase); m_properties = properties; m_random = random; Blob <T> data = m_netOutput.blob_by_name("data"); if (data == null) { m_mycaffe.Log.FAIL("Missing the expected input 'data' blob!"); } m_nBatchSize = data.num; Blob <T> logits = m_netOutput.blob_by_name("logits"); if (logits == null) { m_mycaffe.Log.FAIL("Missing the expected input 'logits' blob!"); } m_nActionCount = logits.channels; m_transformer = m_mycaffe.DataTransformer; m_blobActions = new Blob <T>(m_mycaffe.Cuda, m_mycaffe.Log, false); m_blobQValue = new Blob <T>(m_mycaffe.Cuda, m_mycaffe.Log); m_blobNextQValue = new Blob <T>(m_mycaffe.Cuda, m_mycaffe.Log); m_blobExpectedQValue = new Blob <T>(m_mycaffe.Cuda, m_mycaffe.Log); m_blobDone = new Blob <T>(m_mycaffe.Cuda, m_mycaffe.Log, false); m_blobLoss = new Blob <T>(m_mycaffe.Cuda, m_mycaffe.Log); m_blobWeights = new Blob <T>(m_mycaffe.Cuda, m_mycaffe.Log, false); m_fGamma = (float)properties.GetPropertyAsDouble("Gamma", m_fGamma); m_memLoss = m_netOutput.FindLastLayer(LayerParameter.LayerType.MEMORY_LOSS) as MemoryLossLayer <T>; if (m_memLoss == null) { m_mycaffe.Log.FAIL("Missing the expected MEMORY_LOSS layer!"); } double?dfRate = mycaffe.CurrentProject.GetSolverSettingAsNumeric("base_lr"); if (dfRate.HasValue) { m_dfLearningRate = dfRate.Value; } m_nMiniBatch = m_properties.GetPropertyAsInt("MiniBatch", m_nMiniBatch); m_bUseAcceleratedTraining = properties.GetPropertyAsBool("UseAcceleratedTraining", false); if (m_nMiniBatch > 1) { m_colAccumulatedGradients = m_netOutput.learnable_parameters.Clone(); m_colAccumulatedGradients.SetDiff(0); } }
/// <summary> /// The Initialize method initializes the streaming database component, preparing it for data queries. /// </summary> /// <param name="qt">Specifies the query type to use (see remarks).</param> /// <param name="strSchema">Specifies the query schema to use.</param> /// <remarks> /// Additional settings for each query type are specified in the 'strSettings' parameter as a set /// of key=value pairs for each of the settings. The following are the query specific settings /// that are expected for each QUERY_TYPE. /// /// qt = TIME: /// 'QueryCount' - Specifies the number of items to include in each query. /// 'Start' - Specifies the start date of the stream. /// 'TimeSpanInMs' - Specifies the time increment between data items in the stream in milliseconds. /// 'SegmentSize' - Specifies the segment size of data queried from the database. /// 'MaxCount' - Specifies the maximum number of items to load into memory for each custom query. /// /// qt = GENERAL: /// none at this time. /// /// The database schema defines the number of custom queries to use along with their names. A simple key=value; list /// defines the streaming database schema using the following format: /// /// "ConnectionCount=2; /// Connection0_CustomQueryName=Test1; /// Connection0_CustomQueryParam=param_string1 /// Connection1_CustomQueryName=Test2; /// Connection1_CustomQueryParam=param_string2" /// /// Each param_string specifies the parameters of the custom query and may include the database connection string, database /// table, and database fields to query. /// </remarks> public void Initialize(QUERY_TYPE qt, string strSchema) { if (qt == QUERY_TYPE.SYNCHRONIZED) { PropertySet ps = new PropertySet(strSchema); int nQueryCount = ps.GetPropertyAsInt("QueryCount", 0); DateTime dtStart = ps.GetPropertyAsDateTime("Start"); int nTimeSpanInMs = ps.GetPropertyAsInt("TimeSpanInMs"); int nSegmentSize = ps.GetPropertyAsInt("SegmentSize"); int nMaxCount = ps.GetPropertyAsInt("MaxCount"); m_iquery = new MgrQueryTime(nQueryCount, dtStart, nTimeSpanInMs, nSegmentSize, nMaxCount, strSchema, m_rgCustomQueryToAdd); } else { m_iquery = new MgrQueryGeneral(strSchema, m_rgCustomQueryToAdd); } }
/// <summary> /// Initialize the Data Processor. /// </summary> /// <param name="imycaffe">Specifies the instance of MyCaffe to use.</param> /// <param name="idb">Specifies the instance of the streaming database to use.</param> /// <param name="strPreProcessorDLLPath">Specifies the path to the preprocessing DLL to use.</param> public void Initialize(IXMyCaffe <T> imycaffe, IXStreamDatabase idb, string strPreProcessorDLLPath, PropertySet properties) { m_mgrPreprocessor = new MgrPreprocessor <T>(imycaffe, idb); int nFields = properties.GetPropertyAsInt("Fields", 0); int nDepth = properties.GetPropertyAsInt("Depth", 0); if (nFields == 0) { throw new Exception("You must specify the 'Fields' property with a value greater than 0."); } if (nDepth == 0) { throw new Exception("You must specify the 'Depth' property with a value greater than 0."); } m_mgrPreprocessor.Initialize(strPreProcessorDLLPath, nFields, nDepth); }
public Brain(MyCaffeControl <T> mycaffe, PropertySet properties, CryptoRandom random, Phase phase) { m_mycaffe = mycaffe; m_net = mycaffe.GetInternalNet(phase); m_solver = mycaffe.GetInternalSolver(); m_properties = properties; m_random = random; m_memData = m_net.FindLayer(LayerParameter.LayerType.MEMORYDATA, null) as MemoryDataLayer <T>; m_memLoss = m_net.FindLayer(LayerParameter.LayerType.MEMORY_LOSS, null) as MemoryLossLayer <T>; m_softmax = m_net.FindLayer(LayerParameter.LayerType.SOFTMAX, null) as SoftmaxLayer <T>; if (m_memData == null) { throw new Exception("Could not find the MemoryData Layer!"); } if (m_memLoss == null) { throw new Exception("Could not find the MemoryLoss Layer!"); } m_memData.OnDataPack += memData_OnDataPack; m_memLoss.OnGetLoss += memLoss_OnGetLoss; m_blobDiscountedR = new Blob <T>(mycaffe.Cuda, mycaffe.Log); m_blobPolicyGradient = new Blob <T>(mycaffe.Cuda, mycaffe.Log); m_blobActionOneHot = new Blob <T>(mycaffe.Cuda, mycaffe.Log); m_blobDiscountedR1 = new Blob <T>(mycaffe.Cuda, mycaffe.Log); m_blobPolicyGradient1 = new Blob <T>(mycaffe.Cuda, mycaffe.Log); m_blobActionOneHot1 = new Blob <T>(mycaffe.Cuda, mycaffe.Log); m_blobLoss = new Blob <T>(mycaffe.Cuda, mycaffe.Log); m_blobAprobLogit = new Blob <T>(mycaffe.Cuda, mycaffe.Log); if (m_softmax != null) { LayerParameter p = new LayerParameter(LayerParameter.LayerType.SOFTMAXCROSSENTROPY_LOSS); p.loss_weight.Add(1); p.loss_weight.Add(0); p.loss_param.normalization = LossParameter.NormalizationMode.NONE; m_softmaxCe = new SoftmaxCrossEntropyLossLayer <T>(mycaffe.Cuda, mycaffe.Log, p); } m_colAccumulatedGradients = m_net.learnable_parameters.Clone(); m_colAccumulatedGradients.SetDiff(0); int nMiniBatch = mycaffe.CurrentProject.GetBatchSize(phase); if (nMiniBatch != 0) { m_nMiniBatch = nMiniBatch; } m_nMiniBatch = m_properties.GetPropertyAsInt("MiniBatch", m_nMiniBatch); }
/// <summary> /// Initializes a new custom trainer by loading the key-value pair of properties into the property set. /// </summary> /// <param name="strProperties">Specifies the key-value pair of properties each separated by ';'. For example the expected /// format is 'key1'='value1';'key2'='value2';...</param> /// <param name="icallback">Specifies the parent callback.</param> public void Initialize(string strProperties, IXMyCaffeCustomTrainerCallback icallback) { m_icallback = icallback; m_properties = new PropertySet(strProperties); m_nThreads = m_properties.GetPropertyAsInt("Threads", 1); string strRewardType = m_properties.GetProperty("RewardType", false); if (strRewardType == null) { strRewardType = "VAL"; } else { strRewardType = strRewardType.ToUpper(); } if (strRewardType == "VAL" || strRewardType == "VALUE") { m_rewardType = REWARD_TYPE.VALUE; } else if (strRewardType == "AVE" || strRewardType == "AVERAGE") { m_rewardType = REWARD_TYPE.AVERAGE; } string strTrainerType = m_properties.GetProperty("TrainerType"); switch (strTrainerType) { case "PG.SIMPLE": // bare bones model (Sigmoid only) m_trainerType = TRAINER_TYPE.PG_SIMPLE; m_stage = Stage.RL; break; case "PG.ST": // single thread (Sigmoid and Softmax) m_trainerType = TRAINER_TYPE.PG_ST; m_stage = Stage.RL; break; case "PG": case "PG.MT": // multi-thread (Sigmoid and Softmax) m_trainerType = TRAINER_TYPE.PG_MT; m_stage = Stage.RL; break; case "RNN.SIMPLE": m_trainerType = TRAINER_TYPE.RNN_SIMPLE; m_stage = Stage.RNN; break; default: throw new Exception("Unknown trainer type '" + strTrainerType + "'!"); } }
public CustomQuery2(string strParam = null) { if (strParam != null) { strParam = ParamPacker.UnPack(strParam); PropertySet ps = new PropertySet(strParam); m_strConnection = ps.GetProperty("Connection"); m_strTable = ps.GetProperty("Table"); m_strField = ps.GetProperty("Field"); m_nEndIdx = ps.GetPropertyAsInt("EndIdx", int.MaxValue); } }
/// <summary> /// Initialize the gym with the specified properties. /// </summary> /// <param name="log">Specifies the output log to use.</param> /// <param name="properties">Specifies the properties containing Gym specific initialization parameters.</param> /// <remarks> /// The ModelGym uses the following initialization properties. /// /// 'GpuID' - the GPU to run on. /// 'ModelDescription' - the model description of the model to use. /// 'Dataset' - the name of the dataset to use. /// 'Weights' - the model trained weights. /// 'CudaPath' - the path of the CudaDnnDLL to use. /// 'BatchSize' - the batch size used when running images through the model (default = 16). /// 'RecreateData' - when 'True' the data is re-run through the model, otherwise if already run the data is loaded from file (faster). /// </remarks> public void Initialize(Log log, PropertySet properties) { m_nGpuID = properties.GetPropertyAsInt("GpuID"); m_strModelDesc = properties.GetProperty("ModelDescription"); m_strDataset = properties.GetProperty("Dataset"); m_rgWeights = properties.GetPropertyBlob("Weights"); m_nBatchSize = properties.GetPropertyAsInt("BatchSize", 16); m_bRecreateData = properties.GetPropertyAsBool("RecreateData", false); m_strProject = properties.GetProperty("ProjectName"); if (string.IsNullOrEmpty(m_strProject)) { m_strProject = "default"; } string strCudaPath = properties.GetProperty("CudaPath"); SettingsCaffe s = new SettingsCaffe(); s.GpuIds = m_nGpuID.ToString(); s.ImageDbLoadMethod = IMAGEDB_LOAD_METHOD.LOAD_ON_DEMAND_BACKGROUND; m_imgdb = new MyCaffeImageDatabase2(log); m_imgdb.InitializeWithDsName1(s, m_strDataset); m_ds = m_imgdb.GetDatasetByName(m_strDataset); SimpleDatum sd = m_imgdb.QueryImage(m_ds.TrainingSource.ID, 0, IMGDB_LABEL_SELECTION_METHOD.NONE, IMGDB_IMAGE_SELECTION_METHOD.NONE); BlobShape shape = new BlobShape(1, sd.Channels, sd.Height, sd.Width); if (m_evtCancel == null) { m_evtCancel = new CancelEvent(); } m_mycaffe = new MyCaffeControl <float>(s, log, m_evtCancel, null, null, null, null, strCudaPath); m_mycaffe.LoadToRun(m_strModelDesc, m_rgWeights, shape); m_log = log; }
/// <summary> /// The constructor. /// </summary> /// <param name="nQueryCount">Specifies the size of each query.</param> /// <param name="dtStart">Specifies the state date used for data collection.</param> /// <param name="nTimeSpanInMs">Specifies the time increment used between each data item.</param> /// <param name="nSegmentSize">Specifies the amount of data to query on the back-end from each custom query.</param> /// <param name="nMaxCount">Specifies the maximum number of items to allow in memory.</param> /// <param name="strSchema">Specifies the database schema.</param> /// <param name="rgCustomQueries">Optionally, specifies any custom queries to add directly.</param> /// <remarks> /// The database schema defines the number of custom queries to use along with their names. A simple key=value; list /// defines the streaming database schema using the following format: /// \code{.cpp} /// "ConnectionCount=2; /// Connection0_CustomQueryName=Test1; /// Connection0_CustomQueryParam=param_string1 /// Connection1_CustomQueryName=Test2; /// Connection1_CustomQueryParam=param_string2" /// \endcode /// Each param_string specifies the parameters of the custom query and may include the database connection string, database /// table, and database fields to query. /// </remarks> public MgrQueryTime(int nQueryCount, DateTime dtStart, int nTimeSpanInMs, int nSegmentSize, int nMaxCount, string strSchema, List <IXCustomQuery> rgCustomQueries) { m_colCustomQuery.Load(); m_colData = new DataItemCollection(nQueryCount); m_nQueryCount = nQueryCount; m_nSegmentSize = nSegmentSize; m_schema = new PropertySet(strSchema); foreach (IXCustomQuery icustomquery in rgCustomQueries) { m_colCustomQuery.Add(icustomquery); } int nConnections = m_schema.GetPropertyAsInt("ConnectionCount"); for (int i = 0; i < nConnections; i++) { string strConTag = "Connection" + i.ToString(); string strCustomQuery = m_schema.GetProperty(strConTag + "_CustomQueryName"); string strCustomQueryParam = m_schema.GetProperty(strConTag + "_CustomQueryParam"); IXCustomQuery iqry = m_colCustomQuery.Find(strCustomQuery); if (iqry == null) { throw new Exception("Could not find the custom query '" + strCustomQuery + "'!"); } if (iqry.QueryType != CUSTOM_QUERY_TYPE.TIME) { throw new Exception("The custom query '" + iqry.Name + "' does not support the 'CUSTOM_QUERY_TYPE.TIME'!"); } DataQuery dq = new DataQuery(iqry.Clone(strCustomQueryParam), dtStart, TimeSpan.FromMilliseconds(nTimeSpanInMs), nSegmentSize, nMaxCount); m_colDataQuery.Add(dq); m_nFieldCount += (dq.FieldCount - 1); // subtract each sync field. } m_nFieldCount += 1; // add the sync field m_colDataQuery.Start(); m_evtCancel.Reset(); m_taskConsolidate = Task.Factory.StartNew(new Action(consolidateThread)); m_evtEnabled.Set(); m_colData.WaitData(10000); }
/// <summary> /// The constructor. /// </summary> /// <param name="strSchema">Specifies the database schema.</param> /// <param name="rgCustomQueries">Optionally, specifies any custom queries to diretly add.</param> /// <remarks> /// The database schema defines the number of custom queries to use along with their names. A simple key=value; list /// defines the streaming database schema using the following format: /// \code{.cpp} /// "ConnectionCount=1; /// Connection0_CustomQueryName=Test1; /// Connection0_CustomQueryParam=param_string1 /// \endcode /// Each param_string specifies the parameters of the custom query and may include the database connection string, database /// table, and database fields to query. /// </remarks> public MgrQueryGeneral(string strSchema, List <IXCustomQuery> rgCustomQueries) { m_colCustomQuery.Load(); m_schema = new PropertySet(strSchema); m_colCustomQuery.Add(new StandardQueryTextFile()); m_colCustomQuery.Add(new StandardQueryWAVFile()); foreach (IXCustomQuery icustomquery in rgCustomQueries) { m_colCustomQuery.Add(icustomquery); } int nConnections = m_schema.GetPropertyAsInt("ConnectionCount"); if (nConnections != 1) { throw new Exception("Currently, the general query type only supports 1 connection."); } string strConTag = "Connection0"; string strCustomQuery = m_schema.GetProperty(strConTag + "_CustomQueryName"); string strCustomQueryParam = m_schema.GetProperty(strConTag + "_CustomQueryParam"); IXCustomQuery iqry = m_colCustomQuery.Find(strCustomQuery); if (iqry == null) { throw new Exception("Could not find the custom query '" + strCustomQuery + "'!"); } if (iqry.QueryType != CUSTOM_QUERY_TYPE.BYTE && iqry.QueryType != CUSTOM_QUERY_TYPE.REAL_FLOAT && iqry.QueryType != CUSTOM_QUERY_TYPE.REAL_DOUBLE) { throw new Exception("The custom query '" + iqry.Name + "' must support the 'CUSTOM_QUERY_TYPE.BYTE' or 'CUSTOM_QUERY_TYPE.REAL_FLOAT' or 'CUSTOM_QUERY_TYPE.REAL_DOUBLE' query type!"); } string strParam = ParamPacker.UnPack(strCustomQueryParam); m_iquery = iqry.Clone(strParam); m_iquery.Open(); }
public Brain(MyCaffeControl <T> mycaffe, PropertySet properties, CryptoRandom random, Phase phase) { m_mycaffe = mycaffe; m_net = mycaffe.GetInternalNet(phase); m_solver = mycaffe.GetInternalSolver(); m_properties = properties; m_random = random; m_memData = m_net.FindLayer(LayerParameter.LayerType.MEMORYDATA, null) as MemoryDataLayer <T>; m_memLoss = m_net.FindLayer(LayerParameter.LayerType.MEMORY_LOSS, null) as MemoryLossLayer <T>; SoftmaxLayer <T> softmax = m_net.FindLayer(LayerParameter.LayerType.SOFTMAX, null) as SoftmaxLayer <T>; if (softmax != null) { throw new Exception("The PG.SIMPLE trainer does not support the Softmax layer, use the 'PG.ST' or 'PG.MT' trainer instead."); } if (m_memData == null) { throw new Exception("Could not find the MemoryData Layer!"); } if (m_memLoss == null) { throw new Exception("Could not find the MemoryLoss Layer!"); } m_memLoss.OnGetLoss += memLoss_OnGetLoss; m_blobDiscountedR = new Blob <T>(mycaffe.Cuda, mycaffe.Log); m_blobPolicyGradient = new Blob <T>(mycaffe.Cuda, mycaffe.Log); int nMiniBatch = mycaffe.CurrentProject.GetBatchSize(phase); if (nMiniBatch != 0) { m_nMiniBatch = nMiniBatch; } m_nMiniBatch = m_properties.GetPropertyAsInt("MiniBatch", m_nMiniBatch); }
public Brain(MyCaffeControl <T> mycaffe, PropertySet properties, CryptoRandom random, IxTrainerCallbackRNN icallback, Phase phase, BucketCollection rgVocabulary, bool bUsePreloadData, string strRunProperties = null) { string strOutputBlob = null; if (strRunProperties != null) { m_runProperties = new PropertySet(strRunProperties); } m_icallback = icallback; m_mycaffe = mycaffe; m_properties = properties; m_random = random; m_rgVocabulary = rgVocabulary; m_bUsePreloadData = bUsePreloadData; m_nSolverSequenceLength = m_properties.GetPropertyAsInt("SequenceLength", -1); m_bDisableVocabulary = m_properties.GetPropertyAsBool("DisableVocabulary", false); m_nThreads = m_properties.GetPropertyAsInt("Threads", 1); m_dfScale = m_properties.GetPropertyAsDouble("Scale", 1.0); if (m_nThreads > 1) { m_dataPool.Initialize(m_nThreads, icallback); } if (m_runProperties != null) { m_dfTemperature = Math.Abs(m_runProperties.GetPropertyAsDouble("Temperature", 0)); if (m_dfTemperature > 1.0) { m_dfTemperature = 1.0; } string strPhaseOnRun = m_runProperties.GetProperty("PhaseOnRun", false); switch (strPhaseOnRun) { case "RUN": m_phaseOnRun = Phase.RUN; break; case "TEST": m_phaseOnRun = Phase.TEST; break; case "TRAIN": m_phaseOnRun = Phase.TRAIN; break; } if (phase == Phase.RUN && m_phaseOnRun != Phase.NONE) { if (m_phaseOnRun != Phase.RUN) { m_mycaffe.Log.WriteLine("Warning: Running on the '" + m_phaseOnRun.ToString() + "' network."); } strOutputBlob = m_runProperties.GetProperty("OutputBlob", false); if (strOutputBlob == null) { throw new Exception("You must specify the 'OutputBlob' when Running with a phase other than RUN."); } strOutputBlob = Utility.Replace(strOutputBlob, '~', ';'); phase = m_phaseOnRun; } } m_net = mycaffe.GetInternalNet(phase); if (m_net == null) { mycaffe.Log.WriteLine("WARNING: Test net does not exist, set test_iteration > 0. Using TRAIN phase instead."); m_net = mycaffe.GetInternalNet(Phase.TRAIN); } // Find the first LSTM layer to determine how to load the data. // NOTE: Only LSTM has a special loading order, other layers use the standard N, C, H, W ordering. LSTMLayer <T> lstmLayer = null; LSTMSimpleLayer <T> lstmSimpleLayer = null; foreach (Layer <T> layer1 in m_net.layers) { if (layer1.layer_param.type == LayerParameter.LayerType.LSTM) { lstmLayer = layer1 as LSTMLayer <T>; m_lstmType = LayerParameter.LayerType.LSTM; break; } else if (layer1.layer_param.type == LayerParameter.LayerType.LSTM_SIMPLE) { lstmSimpleLayer = layer1 as LSTMSimpleLayer <T>; m_lstmType = LayerParameter.LayerType.LSTM_SIMPLE; break; } } if (lstmLayer == null && lstmSimpleLayer == null) { throw new Exception("Could not find the required LSTM or LSTM_SIMPLE layer!"); } if (m_phaseOnRun != Phase.NONE && m_phaseOnRun != Phase.RUN && strOutputBlob != null) { if ((m_blobOutput = m_net.FindBlob(strOutputBlob)) == null) { throw new Exception("Could not find the 'Output' layer top named '" + strOutputBlob + "'!"); } } if ((m_blobData = m_net.FindBlob("data")) == null) { throw new Exception("Could not find the 'Input' layer top named 'data'!"); } if ((m_blobClip = m_net.FindBlob("clip")) == null) { throw new Exception("Could not find the 'Input' layer top named 'clip'!"); } Layer <T> layer = m_net.FindLastLayer(LayerParameter.LayerType.INNERPRODUCT); m_mycaffe.Log.CHECK(layer != null, "Could not find an ending INNERPRODUCT layer!"); if (!m_bDisableVocabulary) { m_nVocabSize = (int)layer.layer_param.inner_product_param.num_output; if (rgVocabulary != null) { m_mycaffe.Log.CHECK_EQ(m_nVocabSize, rgVocabulary.Count, "The vocabulary count = '" + rgVocabulary.Count.ToString() + "' and last inner product output count = '" + m_nVocabSize.ToString() + "' - these do not match but they should!"); } } if (m_lstmType == LayerParameter.LayerType.LSTM) { m_nSequenceLength = m_blobData.shape(0); m_nBatchSize = m_blobData.shape(1); } else { m_nBatchSize = (int)lstmSimpleLayer.layer_param.lstm_simple_param.batch_size; m_nSequenceLength = m_blobData.shape(0) / m_nBatchSize; if (phase == Phase.RUN) { m_nBatchSize = 1; List <int> rgNewShape = new List <int>() { m_nSequenceLength, 1 }; m_blobData.Reshape(rgNewShape); m_blobClip.Reshape(rgNewShape); m_net.Reshape(); } } m_mycaffe.Log.CHECK_EQ(m_nSequenceLength, m_blobData.num, "The data num must equal the sequence lengh of " + m_nSequenceLength.ToString()); m_rgDataInput = new T[m_nSequenceLength * m_nBatchSize]; T[] rgClipInput = new T[m_nSequenceLength * m_nBatchSize]; m_mycaffe.Log.CHECK_EQ(rgClipInput.Length, m_blobClip.count(), "The clip count must equal the sequence length * batch size: " + rgClipInput.Length.ToString()); m_tZero = (T)Convert.ChangeType(0, typeof(T)); m_tOne = (T)Convert.ChangeType(1, typeof(T)); for (int i = 0; i < rgClipInput.Length; i++) { if (m_lstmType == LayerParameter.LayerType.LSTM) { rgClipInput[i] = (i < m_nBatchSize) ? m_tZero : m_tOne; } else { rgClipInput[i] = (i % m_nSequenceLength == 0) ? m_tZero : m_tOne; } } m_blobClip.mutable_cpu_data = rgClipInput; if (phase != Phase.RUN) { m_solver = mycaffe.GetInternalSolver(); m_solver.OnStart += m_solver_OnStart; m_solver.OnTestStart += m_solver_OnTestStart; m_solver.OnTestingIteration += m_solver_OnTestingIteration; m_solver.OnTrainingIteration += m_solver_OnTrainingIteration; if ((m_blobLabel = m_net.FindBlob("label")) == null) { throw new Exception("Could not find the 'Input' layer top named 'label'!"); } m_nSequenceLengthLabel = m_blobLabel.count(0, 2); m_rgLabelInput = new T[m_nSequenceLengthLabel]; m_mycaffe.Log.CHECK_EQ(m_rgLabelInput.Length, m_blobLabel.count(), "The label count must equal the label sequence length * batch size: " + m_rgLabelInput.Length.ToString()); m_mycaffe.Log.CHECK(m_nSequenceLengthLabel == m_nSequenceLength * m_nBatchSize || m_nSequenceLengthLabel == 1, "The label sqeuence length must be 1 or equal the length of the sequence: " + m_nSequenceLength.ToString()); } }
/// <summary> /// Initialize the gym with the specified properties. /// </summary> /// <param name="log">Specifies the output log to use.</param> /// <param name="properties">Specifies the properties containing Gym specific initialization parameters.</param> /// <remarks> /// The AtariGym uses the following initialization properties. /// GameRom='path to .rom file' /// </remarks> public void Initialize(Log log, PropertySet properties) { m_log = log; if (m_ale != null) { m_ale.Shutdown(); m_ale = null; } m_ale = new ALE(); m_ale.Initialize(); m_ale.EnableDisplayScreen = false; m_ale.EnableSound = false; m_ale.EnableColorData = properties.GetPropertyAsBool("EnableColor", false); m_ale.EnableRestrictedActionSet = true; m_ale.EnableColorAveraging = true; m_ale.AllowNegativeRewards = properties.GetPropertyAsBool("AllowNegativeRewards", false); m_ale.EnableTerminateOnRallyEnd = properties.GetPropertyAsBool("TerminateOnRallyEnd", false); m_ale.RandomSeed = (int)DateTime.Now.Ticks; m_ale.RepeatActionProbability = 0.0f; // disable action repeatability if (properties == null) { throw new Exception("The properties must be specified with the 'GameROM' set the the Game ROM file path."); } string strROM = properties.GetProperty("GameROM"); if (strROM.Contains('~')) { strROM = Utility.Replace(strROM, '~', ' '); } else { strROM = Utility.Replace(strROM, "[sp]", ' '); } if (!File.Exists(strROM)) { throw new Exception("Could not find the game ROM file specified '" + strROM + "'!"); } if (properties.GetPropertyAsBool("UseGrayscale", false)) { m_ct = COLORTYPE.CT_GRAYSCALE; } m_bPreprocess = properties.GetPropertyAsBool("Preprocess", true); m_bForceGray = properties.GetPropertyAsBool("ActionForceGray", false); m_bEnableNumSkip = properties.GetPropertyAsBool("EnableNumSkip", true); m_nFrameSkip = properties.GetPropertyAsInt("FrameSkip", -1); m_ale.Load(strROM); m_rgActionsRaw = m_ale.ActionSpace; m_random = new CryptoRandom(); m_rgFrameSkip = new List <int>(); if (m_nFrameSkip < 0) { for (int i = 2; i < 5; i++) { m_rgFrameSkip.Add(i); } } else { m_rgFrameSkip.Add(m_nFrameSkip); } m_rgActions.Add(ACTION.ACT_PLAYER_A_LEFT.ToString(), (int)ACTION.ACT_PLAYER_A_LEFT); m_rgActions.Add(ACTION.ACT_PLAYER_A_RIGHT.ToString(), (int)ACTION.ACT_PLAYER_A_RIGHT); if (!properties.GetPropertyAsBool("EnableBinaryActions", false)) { m_rgActions.Add(ACTION.ACT_PLAYER_A_FIRE.ToString(), (int)ACTION.ACT_PLAYER_A_FIRE); } m_rgActionSet = m_rgActions.ToList(); Reset(false); }
/// <summary> /// The constructor. /// </summary> /// <param name="mycaffe">Specifies the instance of MyCaffe assoiated with the open project - when using more than one Brain, this is the master project.</param> /// <param name="properties">Specifies the properties passed into the trainer.</param> /// <param name="random">Specifies the random number generator used.</param> /// <param name="phase">Specifies the phase under which to run.</param> public Brain(MyCaffeControl <T> mycaffe, PropertySet properties, CryptoRandom random, Phase phase) { m_mycaffe = mycaffe; m_solver = mycaffe.GetInternalSolver(); m_netOutput = mycaffe.GetInternalNet(phase); m_netTarget = new Net <T>(m_mycaffe.Cuda, m_mycaffe.Log, m_netOutput.net_param, m_mycaffe.CancelEvent, null, phase); m_properties = properties; m_random = random; Blob <T> data = m_netOutput.blob_by_name("data"); if (data == null) { m_mycaffe.Log.FAIL("Missing the expected input 'data' blob!"); } m_nFramesPerX = data.channels; m_nBatchSize = data.num; Blob <T> logits = m_netOutput.blob_by_name("logits"); if (logits == null) { m_mycaffe.Log.FAIL("Missing the expected input 'logits' blob!"); } m_nActionCount = logits.channels; m_transformer = m_mycaffe.DataTransformer; if (m_transformer == null) { TransformationParameter trans_param = new TransformationParameter(); int nC = m_mycaffe.CurrentProject.Dataset.TrainingSource.ImageChannels; int nH = m_mycaffe.CurrentProject.Dataset.TrainingSource.ImageHeight; int nW = m_mycaffe.CurrentProject.Dataset.TrainingSource.ImageWidth; m_transformer = new DataTransformer <T>(m_mycaffe.Cuda, m_mycaffe.Log, trans_param, phase, nC, nH, nW); } for (int i = 0; i < m_nFramesPerX; i++) { m_transformer.param.mean_value.Add(255 / 2); // center each frame } m_transformer.param.scale = 1.0 / 255; // normalize m_transformer.Update(); m_blobActions = new Blob <T>(m_mycaffe.Cuda, m_mycaffe.Log, false); m_blobQValue = new Blob <T>(m_mycaffe.Cuda, m_mycaffe.Log); m_blobNextQValue = new Blob <T>(m_mycaffe.Cuda, m_mycaffe.Log); m_blobExpectedQValue = new Blob <T>(m_mycaffe.Cuda, m_mycaffe.Log); m_blobDone = new Blob <T>(m_mycaffe.Cuda, m_mycaffe.Log, false); m_blobLoss = new Blob <T>(m_mycaffe.Cuda, m_mycaffe.Log); m_blobWeights = new Blob <T>(m_mycaffe.Cuda, m_mycaffe.Log, false); m_fGamma = (float)properties.GetPropertyAsDouble("Gamma", m_fGamma); m_memLoss = m_netOutput.FindLastLayer(LayerParameter.LayerType.MEMORY_LOSS) as MemoryLossLayer <T>; if (m_memLoss == null) { m_mycaffe.Log.FAIL("Missing the expected MEMORY_LOSS layer!"); } double?dfRate = mycaffe.CurrentProject.GetSolverSettingAsNumeric("base_lr"); if (dfRate.HasValue) { m_dfLearningRate = dfRate.Value; } m_nMiniBatch = m_properties.GetPropertyAsInt("MiniBatch", m_nMiniBatch); m_bUseAcceleratedTraining = properties.GetPropertyAsBool("UseAcceleratedTraining", false); if (m_nMiniBatch > 1) { m_colAccumulatedGradients = m_netOutput.learnable_parameters.Clone(); m_colAccumulatedGradients.SetDiff(0); } }
/// <summary> /// Step the gym one step in the data. /// </summary> /// <param name="nAction">Specifies the action to run on the gym.</param> /// <param name="bGetLabel">Not used.</param> /// <param name="extraProp">Optionally, specifies extra properties.</param> /// <returns>A tuple containing state data, the reward, and the done state is returned.</returns> public Tuple <State, double, bool> Step(int nAction, bool bGetLabel = false, PropertySet extraProp = null) { DataState data = new DataState(); ScoreCollection scores = null; if (ActivePhase == Phase.RUN) { if (extraProp == null) { throw new Exception("The extra properties are needed when querying data during the RUN phase."); } int nDataCount = extraProp.GetPropertyAsInt("DataCountRequested"); string strStartTime = extraProp.GetProperty("SeedTime"); int nStartIdx = m_scores.Count - nDataCount; DateTime dt; if (DateTime.TryParse(strStartTime, out dt)) { nStartIdx = m_scores.FindIndexAt(dt, nDataCount); } scores = m_scores.CopyFrom(nStartIdx, nDataCount); } else { int nCount = 0; m_scores = load(out m_nDim, out m_nWidth); if (m_bRecreateData || m_scores.Count != m_ds.TrainingSource.ImageCount) { Stopwatch sw = new Stopwatch(); sw.Start(); m_scores = new ScoreCollection(); while (m_nCurrentIdx < m_ds.TrainingSource.ImageCount) { // Query images sequentially by index in batches List <SimpleDatum> rgSd = new List <SimpleDatum>(); for (int i = 0; i < m_nBatchSize; i++) { SimpleDatum sd = m_imgdb.QueryImage(m_ds.TrainingSource.ID, m_nCurrentIdx + i, IMGDB_LABEL_SELECTION_METHOD.NONE, IMGDB_IMAGE_SELECTION_METHOD.NONE); rgSd.Add(sd); nCount++; if (nCount == m_ds.TrainingSource.ImageCount) { break; } } List <ResultCollection> rgRes = m_mycaffe.Run(rgSd, ref m_blobWork); if (m_nWidth == 0) { m_nWidth = rgRes[0].ResultsOriginal.Count; m_nDim = rgRes[0].ResultsOriginal.Count * 2; } // Fill SimpleDatum with the ordered label,score pairs starting with the detected label. for (int i = 0; i < rgRes.Count; i++) { m_scores.Add(new Score(rgSd[i].TimeStamp, rgSd[i].Index, rgRes[i])); m_nCurrentIdx++; } if (sw.Elapsed.TotalMilliseconds > 1000) { m_log.Progress = (double)m_nCurrentIdx / (double)m_ds.TrainingSource.ImageCount; m_log.WriteLine("Running model on image " + m_nCurrentIdx.ToString() + " of " + m_ds.TrainingSource.ImageCount.ToString() + " of '" + m_strDataset + "' dataset."); if (m_evtCancel.WaitOne(0)) { return(null); } } } save(m_nDim, m_nWidth, m_scores); } else { m_nCurrentIdx = m_scores.Count; } scores = m_scores; } float[] rgfRes = scores.Data; SimpleDatum sdRes = new SimpleDatum(scores.Count, m_nWidth, 2, rgfRes, 0, rgfRes.Length); data.SetData(sdRes); m_nCurrentIdx = 0; return(new Tuple <State, double, bool>(data, 0, false)); }