/// <summary> /// The DataTransformer constructor. /// </summary> /// <param name="log">Specifies the Log used for output.</param> /// <param name="p">Specifies the TransformationParameter used to create the DataTransformer.</param> /// <param name="phase">Specifies the Phase under which the DataTransformer is run.</param> /// <param name="nC">Specifies the channels.</param> /// <param name="nH">Specifies the height.</param> /// <param name="nW">Specifies the width.</param> /// <param name="imgMean">Optionally, specifies the image mean to use.</param> public DataTransformer(Log log, TransformationParameter p, Phase phase, int nC, int nH, int nW, SimpleDatum imgMean = null) { m_log = log; if (p.mean_file != null) { m_protoMean = loadProtoMean(p.mean_file); } int nDataSize = nC * nH * nW; if (imgMean != null) { nDataSize = imgMean.Channels * imgMean.Height * imgMean.Width; } m_rgTransformedData = new T[nDataSize]; m_param = p; m_phase = phase; InitRand(); if (p.use_imagedb_mean) { if (m_protoMean == null) { m_imgMean = imgMean; if (m_imgMean != null) { m_rgMeanData = m_imgMean.GetData <double>(); } } else { if (m_protoMean.data.Count > 0) { m_rgMeanData = new double[m_protoMean.data.Count]; Array.Copy(m_protoMean.data.ToArray(), m_rgMeanData, m_rgMeanData.Length); } else { m_rgMeanData = m_protoMean.double_data.ToArray(); } } } if (p.mean_value.Count > 0) { m_log.CHECK(p.use_imagedb_mean == false, "Cannot specify use_image_mean and mean_value at the same time."); for (int c = 0; c < p.mean_value.Count; c++) { m_rgMeanValues.Add(p.mean_value[c]); } } }
/// <summary> /// Set the image mean. /// </summary> /// <param name="d">Specifies the image mean.</param> /// <param name="bSave">Optionally, specifies whether or not to save the image mean in the database (default = false).</param> public void SetImageMean(SimpleDatum d, bool bSave = false) { m_imgMean = d; if (bSave) { m_factory.SaveImageMean(d, true); } }
/// <summary> /// Preprocesses the data. /// </summary> /// <param name="s">Specifies the state and data to use.</param> /// <param name="bUseRawInput">Specifies whether or not to use the raw data <i>true</i>, or a difference of the current and previous data <i>false</i> (default = <i>false</i>).</param> /// <param name="bDifferent">Returns whether or not the current state data is different from the previous - note this is only set when NOT using raw input, otherwise <i>true</i> is always returned.</param> /// <param name="bReset">Optionally, specifies to reset the last sd to null.</param> /// <returns>The preprocessed data is returned.</returns> public SimpleDatum Preprocess(StateBase s, bool bUseRawInput, out bool bDifferent, bool bReset = false) { bDifferent = false; SimpleDatum sd = new SimpleDatum(s.Data, true); if (!bUseRawInput) { if (bReset) { m_sdLast = null; } if (m_sdLast == null) { sd.Zero(); } else { bDifferent = sd.Sub(m_sdLast); } m_sdLast = new SimpleDatum(s.Data, true); } else { bDifferent = true; } sd.Tag = bReset; if (bReset) { m_rgX = new List <SimpleDatum>(); for (int i = 0; i < m_nFramesPerX * m_nStackPerX; i++) { m_rgX.Add(sd); } } else { m_rgX.Add(sd); m_rgX.RemoveAt(0); } SimpleDatum[] rgSd = new SimpleDatum[m_nStackPerX]; for (int i = 0; i < m_nStackPerX; i++) { int nIdx = ((m_nStackPerX - i) * m_nFramesPerX) - 1; rgSd[i] = m_rgX[nIdx]; } return(new SimpleDatum(rgSd.ToList(), true)); }
/// <summary> /// Step the gym one step in the data. /// </summary> /// <param name="nAction">Specifies the action to run on the gym.</param> /// <returns>A tuple containing state data, the reward, and the done state is returned.</returns> public Tuple <State, double, bool> Step(int nAction) { DataState data = new DataState(); SimpleDatum sd = m_db.Query(1000); data.SetData(sd); return(new Tuple <State, double, bool>(data, 0, (sd == null) ? true : false)); }
/// <summary> /// Saves the image mean in a SimpleDatum to the database. /// </summary> /// <param name="nSrcId">Specifies the ID of the data source to use.</param> /// <param name="sd">Specifies the image mean data.</param> /// <param name="bUpdate">Specifies whether or not to update the mean image.</param> /// <returns>If saved successfully, this method returns <i>true</i>, otherwise <i>false</i> is returned.</returns> public bool SaveImageMean(int nSrcId, SimpleDatum sd, bool bUpdate) { if (m_TestingImages.SourceID != nSrcId && m_TrainingImages.SourceID != nSrcId) { return(false); } return(m_factory.SaveImageMean(sd, bUpdate, nSrcId)); }
/// <summary> /// Resync the transformer with changes in its parameter. /// </summary> public void Update(int nDataSize = 0, SimpleDatum imgMean = null) { TransformationParameter p = m_param; if (imgMean != null) { nDataSize = imgMean.Channels * imgMean.Height * imgMean.Width; } if (nDataSize > 0 || (m_rgfTransformedData != null && nDataSize != m_rgfTransformedData.Length)) { m_rgTransformedData = new T[nDataSize]; } if (p.mean_file != null) { m_protoMean = loadProtoMean(p.mean_file); } if (p.use_imagedb_mean) { if (m_protoMean == null) { m_imgMean = imgMean; if (m_imgMean != null) { m_rgMeanData = m_imgMean.GetData <double>(); } } else { if (m_protoMean.data.Count > 0) { m_rgMeanData = new double[m_protoMean.data.Count]; Array.Copy(m_protoMean.data.ToArray(), m_rgMeanData, m_rgMeanData.Length); } else { m_rgMeanData = m_protoMean.double_data.ToArray(); } } } if (p.mean_value.Count > 0) { m_log.CHECK(p.use_imagedb_mean == false, "Cannot specify use_image_mean and mean_value at the same time."); for (int c = 0; c < p.mean_value.Count; c++) { m_rgMeanValues.Add(p.mean_value[c]); } } }
private SimpleDatum getNextPair(bool bMatching, SimpleDatum d, bool bLoadDataCriteria) { int nRetries = 10; int nIdx = 0; SimpleDatum dNew = null; string strType = null; if (bMatching) { int?nLabel = null; if (d.Label == d.OriginalLabel) { nLabel = d.Label; } dNew = m_cursor.GetValue(nLabel, bLoadDataCriteria); while (dNew.Label != d.Label && nIdx < nRetries) { Next(); dNew = m_cursor.GetValue(nLabel, bLoadDataCriteria); nIdx++; } if (dNew.Label != d.Label) { strType = "match"; } } else { dNew = m_cursor.GetValue(null, bLoadDataCriteria); while (dNew.Label == d.Label && nIdx < nRetries) { Next(); dNew = m_cursor.GetValue(null, bLoadDataCriteria); nIdx++; } if (dNew.Label == d.Label) { strType = "non-match"; } } if (strType != null) { m_log.WriteLine("WARNING: The secondary pairing " + strType + " could not be found after " + nRetries.ToString() + " retries!"); } return(dNew); }
/// <summary> /// Saves the image mean in a SimpleDatum to the database for a data source. /// </summary> /// <param name="nSrcID">Specifies the ID of the data source.</param> /// <param name="sd">Specifies the image mean data.</param> /// <param name="bUpdate">Specifies whether or not to update the mean image.</param> /// <returns>Returns <i>true</i> after a successful save, <i>false</i> otherwise.</returns> public bool SaveImageMean(int nSrcID, SimpleDatum sd, bool bUpdate) { foreach (DatasetEx ds in m_rgDatasets) { if (ds.SaveImageMean(nSrcID, sd, bUpdate)) { return(true); } } return(false); }
/// <summary> /// The constructor. /// </summary> /// <param name="currentState">Specifies the current state.</param> /// <param name="currentData">Specifies the current data.</param> /// <param name="nAction">Specifies the action.</param> /// <param name="nextState">Specifies the next state.</param> /// <param name="nextData">Specifies the next data.</param> /// <param name="dfReward">Specifies the reward.</param> /// <param name="bTerminated">Specifies whether or not this is a termination state or not.</param> /// <param name="nIteration">Specifies the iteration.</param> /// <param name="nEpisode">Specifies the episode.</param> public MemoryItem(StateBase currentState, SimpleDatum currentData, int nAction, StateBase nextState, SimpleDatum nextData, double dfReward, bool bTerminated, int nIteration, int nEpisode) { m_state0 = currentState; m_state1 = nextState; m_x0 = currentData; m_x1 = nextData; m_nAction = nAction; m_bTerminated = bTerminated; m_dfReward = dfReward; m_nIteration = nIteration; m_nEpisode = nEpisode; }
/// <summary> /// Returns the image mean for the ImageSet. /// </summary> /// <param name="log">Specifies the Log used to output status.</param> /// <param name="rgAbort">Specifies a set of wait handles for aborting the operation.</param> /// <param name="bQueryOnly">Specifies whether or not to only query for the mean and not calculate if missing.</param> /// <returns>The SimpleDatum with the image mean is returned.</returns> public SimpleDatum GetImageMean(Log log, WaitHandle[] rgAbort, bool bQueryOnly) { if (m_imgMean != null || bQueryOnly) { return(m_imgMean); } int nLoadedCount = GetLoadedCount(); int nTotalCount = GetTotalCount(); if (nLoadedCount < nTotalCount) { double dfPct = (double)nLoadedCount / (double)nTotalCount; if (log != null) { log.WriteLine("WARNING: Cannot create the image mean until all images have loaded - the data is currently " + dfPct.ToString("P") + " loaded."); } return(null); } if (OnCalculateImageMean != null) { CalculateImageMeanArgs args = new CalculateImageMeanArgs(m_rgImages); OnCalculateImageMean(this, args); if (args.Cancelled) { return(null); } m_imgMean = args.ImageMean; return(m_imgMean); } RawImageMean imgMean = m_factory.GetRawImageMean(); if (m_imgMean != null) { m_imgMean = m_factory.LoadDatum(imgMean); } else { log.WriteLine("Calculating mean..."); m_imgMean = SimpleDatum.CalculateMean(log, m_rgImages, rgAbort); m_factory.PutRawImageMean(m_imgMean, true); } m_imgMean.SetLabel(0); return(m_imgMean); }
/// <summary> /// Transform the data into an array of transformed values. /// </summary> /// <param name="d">Data to transform.</param> /// <param name="rgTransformedAnnoVec">Returns the list of transfomed annoations.</param> /// <param name="bMirror">Returns whether or not a mirror occurred.</param> /// <param name="bResize">Specifies to resize the data.</param> /// <returns>Transformed data.</returns> public T[] Transform(SimpleDatum d, out List <AnnotationGroup> rgTransformedAnnoVec, out bool bMirror, bool bResize = true) { // Transform the datum. NormalizedBBox crop_bbox = new NormalizedBBox(0, 0, 0, 0); T[] rgTrans = Transform(d, out bMirror, crop_bbox); // Transform annoation. rgTransformedAnnoVec = TransformAnnotation(d, crop_bbox, bMirror, bResize); return(rgTrans); }
/// <summary> /// Returns the image mean for a data source. /// </summary> /// <param name="nSrcID">Specifies the ID of the data source.</param> /// <returns>The image mean queried is returned as a SimpleDatum.</returns> public SimpleDatum QueryImageMean(int nSrcID) { foreach (DatasetEx ds in m_rgDatasets) { SimpleDatum sd = ds.QueryImageMean(nSrcID); if (sd != null) { return(sd); } } return(null); }
public void AddLoaded(SimpleDatum sd) { if (!m_rgLoadedIdx.ContainsKey(sd.Label)) { m_rgLoadedIdx.Add(sd.Label, new List <int>()); } DbItem item = m_rgItemsByLabel[sd.Label].Where(p => p.ID == sd.ImageID).First(); m_rgLoadedIdx[sd.Label].Add((int)item.Tag); m_rgNotLoadedIdx[sd.Label].Remove((int)item.Tag); }
private void setData(Net <T> net, SimpleDatum sdData, SimpleDatum sdClip) { SimpleDatum[] rgData = new SimpleDatum[] { sdData }; SimpleDatum[] rgClip = null; if (sdClip != null) { rgClip = new SimpleDatum[] { sdClip } } ; setData(net, rgData, rgClip); }
/// <summary> /// Retrieve the Datum at the current cursor location within the data source. /// </summary> /// <param name="nLabel">Optionally, specifies a label for which the cursor should query from.</param> /// <param name="bLoadDataCriteria">Specifies whether or not to load the data criteria.</param> /// <param name="imgSel">Optionally, specifies the image selection method (default = null).</param> /// <returns>The Datum retrieved is returned.</returns> public SimpleDatum GetValue(int?nLabel = null, bool bLoadDataCriteria = false, IMGDB_IMAGE_SELECTION_METHOD?imgSel = null) { SimpleDatum sd = m_db.QueryImage(m_nSrcID, m_nIdx, null, imgSel, nLabel, bLoadDataCriteria, false); if (m_log != null) { m_log.WriteLine(m_strSrc + ": Idx = " + sd.Index.ToString() + " Label = " + sd.Label.ToString()); } m_transformer.TransformLabel(sd); return(sd); }
public void TestQueryGeneralWAV() { Log log = new Log("Test streaming database with general WAV data"); log.EnableTrace = true; IXStreamDatabase db = new MyCaffeStreamDatabase(log); string strSchema = "ConnectionCount=1;"; string strDataPath = getTestPath("\\MyCaffe\\test_data\\data\\wav", true); string strParam = "FilePath=" + strDataPath + ";"; strParam = ParamPacker.Pack(strParam); strSchema += "Connection0_CustomQueryName=StdWAVFileQuery;"; strSchema += "Connection0_CustomQueryParam=" + strParam + ";"; DateTime dt = DateTime.Today; string strSettings = ""; db.Initialize(QUERY_TYPE.GENERAL, strSchema + strSettings); int[] rgSize = db.QuerySize(); log.CHECK(rgSize != null, "The Query size should not be null."); log.CHECK_EQ(rgSize.Length, 3, "The query size should have 3 items."); log.CHECK_EQ(rgSize[0], 1, "The query size item 1 should be 1 for the number of files."); log.CHECK_GE(rgSize[1], 2, "The query size item 0 (the channel count) should be greater than or equal to 1."); log.CHECK_GE(rgSize[2], 1000, "The query size item 2 should be the number of samples in the query."); int nH = rgSize[1]; int nW = rgSize[2]; int nCount = nH * nW; Stopwatch sw = new Stopwatch(); sw.Start(); SimpleDatum sd = db.Query(int.MaxValue); SimpleDatum sdEnd = db.Query(int.MaxValue); sw.Stop(); double dfMs = sw.Elapsed.TotalMilliseconds; log.WriteLine("Total Time = " + dfMs.ToString() + " ms."); log.CHECK(sdEnd == null, "The last query should be null to show no more data exists."); log.CHECK_EQ(sd.ItemCount, nCount, "There should be the same number in the data as the size[1] * size[2] returned by QuerySize."); log.CHECK(sd.IsRealData, "The data should be real data, not byte."); db.Shutdown(); }
public static Blob <float> CuSca(CudaDnn <float> cuda, Log log, float fInput) { float[] rgInput = new float[1]; rgInput[0] = fInput; // Load a simple datum. SimpleDatum myData = new SimpleDatum(1, 1, 1, rgInput, 0, 1); // Load the blob, which transfers the cpu data to the gpu. // NOTE: Alternatively, the data can be transferred to the gpu // by using a call to Reshape (which allocates the GPU memory) // and then calling blob.mutable_cpu_data = rgInput; return(new Blob <float>(cuda, log, myData, true, true, false)); }
public int act(SimpleDatum sd, out float[] rgfAprob) { List <Datum> rgData = new List <Datum>(); rgData.Add(new Datum(sd)); double dfLoss; float fRandom = (float)m_random.NextDouble(); // Roll the dice. m_memData.AddDatumVector(rgData, 1, true, true); m_bSkipLoss = true; BlobCollection <T> res = m_net.Forward(out dfLoss); m_bSkipLoss = false; rgfAprob = null; for (int i = 0; i < res.Count; i++) { if (res[i].type != Blob <T> .BLOB_TYPE.LOSS) { rgfAprob = Utility.ConvertVecF <T>(res[i].update_cpu_data()); break; } } if (rgfAprob == null) { throw new Exception("Could not find a non-loss output! Your model should output the loss and the action probabilities."); } // Select the action from the probability distribution. float fSum = 0; for (int i = 0; i < rgfAprob.Length; i++) { fSum += rgfAprob[i]; if (fRandom < fSum) { return(i); } } if (rgfAprob.Length == 1) { return(1); } return(rgfAprob.Length - 1); }
/// <summary> /// Returns the action from running the model. The action returned is either randomly selected (when using Exploration), /// or calculated via a forward pass (when using Exploitation). /// </summary> /// <param name="sd">Specifies the data to run the model on.</param> /// <param name="sdClip">Specifies the clip data (if any exits).</param> /// <param name="nActionCount">Returns the number of actions in the action set.</param> /// <returns>The action value is returned.</returns> public int act(SimpleDatum sd, SimpleDatum sdClip, int nActionCount) { setData(m_netOutput, sd, sdClip); m_netOutput.ForwardFromTo(0, m_netOutput.layers.Count - 2); Blob <T> output = m_netOutput.blob_by_name("logits"); if (output == null) { throw new Exception("Missing expected 'logits' blob!"); } // Choose greedy action return(argmax(Utility.ConvertVecF <T>(output.mutable_cpu_data))); }
/// <summary> /// The ApplyDistortEx method applies the distortion policy to the simple datum. /// </summary> /// <param name="sd">Specifies the SimpleDatum to distort.</param> /// <param name="p">Specifies the distortion parameters that define the distortion policy.</param> /// <returns>The distorted SimpleDatum is returned.</returns> public SimpleDatum ApplyDistortEx(SimpleDatum sd, DistortionParameter p) { double dfProb = m_random.NextDouble(); if (dfProb > 0.5) { randomBrightnessContrastSaturation(sd, p.brightness_prob, p.brightness_delta, p.contrast_prob, p.contrast_lower, p.contrast_upper, p.saturation_prob, p.saturation_lower, p.saturation_upper, ImageTools.ADJUSTCONTRAST_ORDERING.BRIGHTNESS_CONTRAST_GAMMA); } else { randomBrightnessContrastSaturation(sd, p.brightness_prob, p.brightness_delta, p.contrast_prob, p.contrast_lower, p.contrast_upper, p.saturation_prob, p.saturation_lower, p.saturation_upper, ImageTools.ADJUSTCONTRAST_ORDERING.BRIGHTNESS_GAMMA_CONTRAST); } return(randomChannelOrder(sd, p.random_order_prob)); }
/// <summary> /// Get the image with a specific image index. /// </summary> /// <param name="nIdx">Specifies the image index.</param> /// <param name="bLoadDataCriteria">Specifies whether or not to load the data criteria along with the image.</param> /// <param name="bLoadDebugData">Specifies whether or not to load the debug data with the image.</param> /// <param name="loadMethod">Specifies the image loading method used.</param> /// <returns>If found, the image is returned.</returns> public SimpleDatum GetImage(int nIdx, bool bLoadDataCriteria, bool bLoadDebugData, IMAGEDB_LOAD_METHOD loadMethod) { SimpleDatum sd = m_rgImages[nIdx]; if (sd == null) { if (m_refreshManager != null) { while (nIdx > 0 && m_rgImages[nIdx] == null) { nIdx--; } sd = m_rgImages[nIdx]; if (sd == null) { throw new Exception("No images should be null when using LoadLimit loading!"); } } else { if (!m_evtRunning.WaitOne(0) && (loadMethod != IMAGEDB_LOAD_METHOD.LOAD_ON_DEMAND && loadMethod != IMAGEDB_LOAD_METHOD.LOAD_ON_DEMAND_NOCACHE)) { Load((loadMethod == IMAGEDB_LOAD_METHOD.LOAD_ON_DEMAND_BACKGROUND) ? true : false); } sd = directLoadImage(nIdx); if (sd == null) { throw new Exception("The image is still null yet should have loaded!"); } if (loadMethod == IMAGEDB_LOAD_METHOD.LOAD_ON_DEMAND) { m_rgImages[nIdx] = sd; } } } // Double check that the conditional data has loaded (if needed). if (bLoadDataCriteria || bLoadDebugData) { m_factory.LoadRawData(sd, bLoadDataCriteria, bLoadDebugData); } return(sd); }
/// <summary> /// This event is called by the Solver to get training data for this next training Step. /// Within this event, the data is loaded for the next training step. /// </summary> /// <param name="sender">Specifies the sender of the event (e.g. the solver)</param> /// <param name="args">n/a</param> private void onTrainingStart(object sender, EventArgs args) { Blob <float> blobData = m_mycaffe.GetInternalNet(Phase.TRAIN).FindBlob("data"); Blob <float> blobLabel = m_mycaffe.GetInternalNet(Phase.TRAIN).FindBlob("label"); Blob <float> blobClip1 = m_mycaffe.GetInternalNet(Phase.TRAIN).FindBlob("clip1"); // Load a batch of data where: // 'data' contains a batch of 10D detected images in sequence by label. // 'label' contains a batch of 1D future signals in sequence. List <float> rgYb = new List <float>(); List <float> rgFYb = new List <float>(); for (int i = 0; i < m_model.Batch; i++) { for (int t = 0; t < m_model.TimeSteps; t++) { // Get images one number at a time, in order by label, but randomly selected. SimpleDatum sd = m_imgDb.QueryImage(m_ds.TrainingSource.ID, 0, null, IMGDB_IMAGE_SELECTION_METHOD.RANDOM, m_nLabelSeq); m_mycaffeInput.Run(sd); Net <float> inputNet = m_mycaffeInput.GetInternalNet(Phase.RUN); Blob <float> input_ip = inputNet.FindBlob(m_strInputOutputBlobName); float[] rgY1 = input_ip.mutable_cpu_data; rgYb.AddRange(rgY1); Dictionary <string, float[]> data = Signal.GenerateSample(1, m_nLabelSeq / 10.0f, 1, m_model.InputLabel, m_model.TimeSteps); float[] rgFY1 = data["FY"]; // Add future steps corresponding to m_nLabelSeq time step; rgFYb.AddRange(rgFY1); m_nLabelSeq++; if (m_nLabelSeq > 9) { m_nLabelSeq = 0; } } } float[] rgY = SimpleDatum.Transpose(rgYb.ToArray(), blobData.channels, blobData.num, blobData.count(2)); // Transpose for Sequence Major ordering. blobData.mutable_cpu_data = rgY; float[] rgFY = SimpleDatum.Transpose(rgFYb.ToArray(), blobLabel.channels, blobLabel.num, blobLabel.count(2)); // Transpose for Sequence Major ordering. blobLabel.mutable_cpu_data = rgFY; blobClip1.SetData(1); blobClip1.SetData(0, 0, m_model.Batch); }
/// <summary> /// Distort the SimpleDatum. /// </summary> /// <param name="d">Specifies the SimpleDatum to distort.</param> /// <returns>The distorted SimpleDatum is returned.</returns> public SimpleDatum DistortImage(SimpleDatum d) { if (m_param.distortion_param == null) { return(d); } if (m_param.distortion_param.brightness_prob == 0 && m_param.distortion_param.contrast_prob == 0 && m_param.distortion_param.saturation_prob == 0) { return(d); } return(m_imgTransforms.ApplyDistort(d, m_param.distortion_param)); }
public SimpleDatum Preprocess(StateBase s, bool bUseRawInput) { SimpleDatum sd = new SimpleDatum(s.Data, true); if (bUseRawInput) return sd; if (m_sdLast == null) sd.Zero(); else sd.Sub(m_sdLast); m_sdLast = s.Data; return sd; }
/// <summary> /// Generate samples from the annotated Datum using the list of BatchSamplers. /// </summary> /// <param name="anno_datum"></param> /// <param name="rgBatchSamplers"></param> /// <returns>All samples bboxes that satisfy the constraints defined in the BatchSampler are returned.</returns> public List <NormalizedBBox> GenerateBatchSamples(SimpleDatum anno_datum, List <BatchSampler> rgBatchSamplers) { List <NormalizedBBox> rgSampledBBoxes = new List <NormalizedBBox>(); List <NormalizedBBox> rgObjectBBoxes = GroupObjectBBoxes(anno_datum); for (int i = 0; i < rgBatchSamplers.Count; i++) { if (rgBatchSamplers[i].use_original_image) { NormalizedBBox unitBbox = new NormalizedBBox(0, 0, 1, 1); rgSampledBBoxes.AddRange(GenerateSamples(unitBbox, rgObjectBBoxes, rgBatchSamplers[i])); } } return(rgSampledBBoxes); }
/// <summary> /// Initialize the DatasetEx by loading the training and testing data sources into memory. /// </summary> /// <param name="ds">Specifies the dataset to load.</param> /// <param name="rgAbort">Specifies a set of wait handles used to cancel the load.</param> /// <param name="nPadW">Optionally, specifies a pad to apply to the width of each item (default = 0).</param> /// <param name="nPadH">Optionally, specifies a pad to apply to the height of each item (default = 0).</param> /// <param name="log">Optionally, specifies an external Log to output status (default = null).</param> /// <param name="loadMethod">Optionally, specifies the load method to use (default = LOAD_ALL).</param> /// <param name="nImageDbLoadLimit">Optionally, specifies the load limit (default = 0).</param> /// <param name="bSkipMeanCheck">Optionally, specifies to skip the mean check (default = false).</param> /// <returns>Upon loading the dataset <i>true</i> is returned, otherwise on failure or abort <i>false</i> is returned.</returns> public bool Initialize(DatasetDescriptor ds, WaitHandle[] rgAbort, int nPadW = 0, int nPadH = 0, Log log = null, IMAGEDB_LOAD_METHOD loadMethod = IMAGEDB_LOAD_METHOD.LOAD_ALL, int nImageDbLoadLimit = 0, bool bSkipMeanCheck = false) { lock (m_syncObj) { if (loadMethod != IMAGEDB_LOAD_METHOD.LOAD_ALL && nImageDbLoadLimit > 0) { throw new Exception("Currently the load-limit only works with the LOAD_ALLL image loading method."); } SimpleDatum imgMean = null; if (ds != null) { m_ds = ds; } if (m_ds.TrainingSource.ImageWidth == -1 || m_ds.TrainingSource.ImageHeight == -1) { log.WriteLine("WARNING: Cannot create a mean image for data sources that contain variable sized images. The mean check will be skipped."); bSkipMeanCheck = true; } m_TrainingImages = loadImageset("Training", m_ds.TrainingSource, rgAbort, ref imgMean, out m_nLastTrainingImageIdx, nPadW, nPadH, log, loadMethod, nImageDbLoadLimit, m_nLastTrainingImageIdx, (ds == null) ? true : false, bSkipMeanCheck); if (m_nLastTrainingImageIdx >= m_ds.TrainingSource.ImageCount) { m_nLastTrainingImageIdx = 0; } if (EventWaitHandle.WaitAny(rgAbort, 0) != EventWaitHandle.WaitTimeout) { return(false); } m_TestingImages = loadImageset("Testing", m_ds.TestingSource, rgAbort, ref imgMean, out m_nLastTestingImageIdx, nPadW, nPadH, log, loadMethod, nImageDbLoadLimit, m_nLastTestingImageIdx, (ds == null) ? true : false, bSkipMeanCheck); if (m_nLastTestingImageIdx >= m_ds.TestingSource.ImageCount) { m_nLastTestingImageIdx = 0; } if (EventWaitHandle.WaitAny(rgAbort, 0) != EventWaitHandle.WaitTimeout) { return(false); } return(true); } }
public void TestMean() { List <string> rgDs = new List <string>() { "MNIST", "CIFAR-10", "MNIST" }; IXImageDatabase db = new MyCaffeImageDatabase(); foreach (string strDs in rgDs) { SettingsCaffe settings = new SettingsCaffe(); Stopwatch sw = new Stopwatch(); sw.Start(); db.InitializeWithDsName(settings, strDs); string str = sw.ElapsedMilliseconds.ToString(); Trace.WriteLine(strDs + " Initialization Time: " + str + " ms."); DatasetDescriptor ds = db.GetDatasetByName(strDs); SimpleDatum d1 = db.QueryImageMean(ds.TrainingSource.ID); SimpleDatum d2 = db.QueryImageMeanFromDataset(ds.ID); SimpleDatum d3 = db.GetImageMean(ds.TrainingSource.ID); byte[] rgB1 = d1.ByteData; byte[] rgB2 = d2.ByteData; byte[] rgB3 = d3.ByteData; Assert.AreEqual(rgB1.Length, rgB2.Length); Assert.AreEqual(rgB2.Length, rgB3.Length); for (int i = 0; i < rgB1.Length; i++) { Assert.AreEqual(rgB1[i], rgB2[i]); Assert.AreEqual(rgB2[i], rgB3[i]); } } db.CleanUp(); IDisposable idisp = db as IDisposable; if (idisp != null) { idisp.Dispose(); } }
/// <summary> /// Returns the image mean for the ImageSet. /// </summary> /// <param name="log">Specifies the Log used to output status.</param> /// <param name="rgAbort">Specifies a set of wait handles for aborting the operation.</param> /// <returns>The SimpleDatum with the image mean is returned.</returns> public SimpleDatum GetImageMean(Log log, WaitHandle[] rgAbort) { if (m_imgMean != null) { return(m_imgMean); } if (m_rgImages.Length == 0) { if (log != null) { log.WriteLine("WARNING: Cannot create image mean with no images!"); } return(null); } if (m_loadMethod != IMAGEDB_LOAD_METHOD.LOAD_ALL) { throw new Exception("Can only create image mean when using LOAD_ALL."); } if (m_nLoadLimit != 0) { throw new Exception("Can only create image mean when LoadLimit = 0."); } if (OnCalculateImageMean != null) { CalculateImageMeanArgs args = new CalculateImageMeanArgs(m_rgImages); OnCalculateImageMean(this, args); if (args.Cancelled) { return(null); } m_imgMean = args.ImageMean; return(m_imgMean); } m_imgMean = SimpleDatum.CalculateMean(log, m_rgImages, rgAbort); m_imgMean.SetLabel(0); return(m_imgMean); }
public void TestPutRawImage(bool bSaveImagesToFile) { DatasetFactory factory = new DatasetFactory(); factory.DeleteSources("Test123"); int nSrcId = factory.AddSource("Test123", 1, 10, 10, false, 0, bSaveImagesToFile); factory.Open(nSrcId, 10); byte[] rgBytes = new byte[10 * 10]; for (int i = 0; i < 20; i++) { rgBytes[i] = (byte)i; SimpleDatum sd = new SimpleDatum(false, 1, 10, 10, i, DateTime.MinValue, rgBytes.ToList(), null, 0, false, i); factory.PutRawImageCache(i, sd); } factory.ClearImageCash(true); List <RawImage> rgImg = factory.GetRawImagesAt(0, 20); for (int i = 0; i < rgImg.Count; i++) { SimpleDatum sd = factory.LoadDatum(rgImg[i]); bool bEncoded = false; byte[] rgData = sd.GetByteData(out bEncoded); for (int j = 0; j < 100; j++) { if (j <= i) { Assert.AreEqual(rgData[j], j); } else { Assert.AreEqual(rgData[j], 0); } } } factory.DeleteSources("Test123"); factory.Close(); }
public int act(SimpleDatum sd, out float fAprob) { List <Datum> rgData = new List <Datum>(); rgData.Add(new Datum(sd)); double dfLoss; m_memData.AddDatumVector(rgData, 1, true, true); m_bSkipLoss = true; BlobCollection <T> res = m_net.Forward(out dfLoss); m_bSkipLoss = false; float[] rgfAprob = null; for (int i = 0; i < res.Count; i++) { if (res[i].type != Blob <T> .BLOB_TYPE.LOSS) { rgfAprob = Utility.ConvertVecF <T>(res[i].update_cpu_data()); break; } } if (rgfAprob == null) { throw new Exception("Could not find a non-loss output! Your model should output the loss and the action probabilities."); } if (rgfAprob.Length != 1) { throw new Exception("The simple policy gradient only supports a single data output!"); } fAprob = rgfAprob[0]; // Roll the dice! if (m_random.NextDouble() < (double)fAprob) { return(0); } else { return(1); } }