コード例 #1
0
ファイル: MyCaffeTrainerDual.cs プロジェクト: lulzzz/MyCaffe
        /// <summary>
        /// Initializes a new custom trainer by loading the key-value pair of properties into the property set.
        /// </summary>
        /// <param name="strProperties">Specifies the key-value pair of properties each separated by ';'.  For example the expected
        /// format is 'key1'='value1';'key2'='value2';...</param>
        /// <param name="icallback">Specifies the parent callback.</param>
        public void Initialize(string strProperties, IXMyCaffeCustomTrainerCallback icallback)
        {
            m_icallback  = icallback;
            m_properties = new PropertySet(strProperties);
            m_nThreads   = m_properties.GetPropertyAsInt("Threads", 1);

            string strRewardType = m_properties.GetProperty("RewardType", false);

            if (strRewardType == null)
            {
                strRewardType = "VAL";
            }
            else
            {
                strRewardType = strRewardType.ToUpper();
            }

            if (strRewardType == "VAL" || strRewardType == "VALUE")
            {
                m_rewardType = REWARD_TYPE.VALUE;
            }
            else if (strRewardType == "AVE" || strRewardType == "AVERAGE")
            {
                m_rewardType = REWARD_TYPE.AVERAGE;
            }

            string strTrainerType = m_properties.GetProperty("TrainerType");

            switch (strTrainerType)
            {
            case "PG.SIMPLE":       // bare bones model (Sigmoid only)
                m_trainerType = TRAINER_TYPE.PG_SIMPLE;
                m_stage       = Stage.RL;
                break;

            case "PG.ST":           // single thread (Sigmoid and Softmax)
                m_trainerType = TRAINER_TYPE.PG_ST;
                m_stage       = Stage.RL;
                break;

            case "PG":
            case "PG.MT":           // multi-thread (Sigmoid and Softmax)
                m_trainerType = TRAINER_TYPE.PG_MT;
                m_stage       = Stage.RL;
                break;

            case "RNN.SIMPLE":
                m_trainerType = TRAINER_TYPE.RNN_SIMPLE;
                m_stage       = Stage.RNN;
                break;

            default:
                throw new Exception("Unknown trainer type '" + strTrainerType + "'!");
            }
        }
コード例 #2
0
 public CustomQuery2(string strParam = null)
 {
     if (strParam != null)
     {
         strParam = ParamPacker.UnPack(strParam);
         PropertySet ps = new PropertySet(strParam);
         m_strConnection = ps.GetProperty("Connection");
         m_strTable      = ps.GetProperty("Table");
         m_strField      = ps.GetProperty("Field");
         m_nEndIdx       = ps.GetPropertyAsInt("EndIdx", int.MaxValue);
     }
 }
コード例 #3
0
        /// <summary>
        /// The PreprocessInput allows derivative data layers to convert a property set of input
        /// data into the bottom blob collection used as intput.
        /// </summary>
        /// <param name="customInput">Specifies the custom input data.</param>
        /// <param name="colBottom">Optionally, specifies the bottom data to fill.</param>
        /// <returns>The bottom data is returned.</returns>
        /// <remarks>The blobs returned should match the blob descriptions returned in the LayerParameter's
        /// overrides for 'PrepareRunModelInputs' and 'PrepareRunModel'.</remarks>
        public override BlobCollection <T> PreProcessInput(PropertySet customInput, BlobCollection <T> colBottom = null)
        {
            if (colBottom == null)
            {
                string   strInput = m_param.PrepareRunModelInputs();
                RawProto proto    = RawProto.Parse(strInput);
                Dictionary <string, BlobShape> rgInput = NetParameter.InputFromProto(proto);
                colBottom = new BlobCollection <T>();

                foreach (KeyValuePair <string, BlobShape> kv in rgInput)
                {
                    Blob <T> blob = new Blob <T>(m_cuda, m_log);
                    blob.Name = kv.Key;
                    blob.Reshape(kv.Value);
                    colBottom.Add(blob);
                }
            }

            string strEncInput = customInput.GetProperty("InputData");

            if (strEncInput == null)
            {
                throw new Exception("Could not find the expected input property 'InputData'!");
            }

            PreProcessInput(strEncInput, null, colBottom);

            return(colBottom);
        }
コード例 #4
0
        private void OnSelected(object sender, EventArgs e)
        {
            if (used_mouse)
            {
                OpType = this.cbDepartment.SelectedIndex + 1;
                PropertySet.cutlistData.OpType = OpType;

                if (Properties.Settings.Default.Testing)
                {
                    SwProperty oldprop = PropertySet.GetProperty("DEPARTMENT");
                    PropertySet.GetProperty("DEPARTMENT").Value    = cbDepartment.SelectedText;
                    PropertySet.GetProperty("DEPARTMENT").ResValue = cbDepartment.SelectedText;
                }

                int idx = this.OpType - 1; // Don't sort the table, and this works well.
                cbDepartment.SelectedIndex = idx;
                cbDepartment.DisplayMember = "TYPEDESC";
                if (idx != starting_index)
                {
                    PropertySet.ResetOps();
                }

                used_mouse = false;
            }
        }
コード例 #5
0
        /// <summary>
        /// The constructor.
        /// </summary>
        /// <param name="nQueryCount">Specifies the size of each query.</param>
        /// <param name="dtStart">Specifies the state date used for data collection.</param>
        /// <param name="nTimeSpanInMs">Specifies the time increment used between each data item.</param>
        /// <param name="nSegmentSize">Specifies the amount of data to query on the back-end from each custom query.</param>
        /// <param name="nMaxCount">Specifies the maximum number of items to allow in memory.</param>
        /// <param name="strSchema">Specifies the database schema.</param>
        /// <param name="rgCustomQueries">Optionally, specifies any custom queries to add directly.</param>
        /// <remarks>
        /// The database schema defines the number of custom queries to use along with their names.  A simple key=value; list
        /// defines the streaming database schema using the following format:
        /// \code{.cpp}
        ///  "ConnectionCount=2;
        ///   Connection0_CustomQueryName=Test1;
        ///   Connection0_CustomQueryParam=param_string1
        ///   Connection1_CustomQueryName=Test2;
        ///   Connection1_CustomQueryParam=param_string2"
        /// \endcode
        /// Each param_string specifies the parameters of the custom query and may include the database connection string, database
        /// table, and database fields to query.
        /// </remarks>
        public MgrQueryTime(int nQueryCount, DateTime dtStart, int nTimeSpanInMs, int nSegmentSize, int nMaxCount, string strSchema, List <IXCustomQuery> rgCustomQueries)
        {
            m_colCustomQuery.Load();
            m_colData      = new DataItemCollection(nQueryCount);
            m_nQueryCount  = nQueryCount;
            m_nSegmentSize = nSegmentSize;

            m_schema = new PropertySet(strSchema);

            foreach (IXCustomQuery icustomquery in rgCustomQueries)
            {
                m_colCustomQuery.Add(icustomquery);
            }

            int nConnections = m_schema.GetPropertyAsInt("ConnectionCount");

            for (int i = 0; i < nConnections; i++)
            {
                string strConTag           = "Connection" + i.ToString();
                string strCustomQuery      = m_schema.GetProperty(strConTag + "_CustomQueryName");
                string strCustomQueryParam = m_schema.GetProperty(strConTag + "_CustomQueryParam");

                IXCustomQuery iqry = m_colCustomQuery.Find(strCustomQuery);
                if (iqry == null)
                {
                    throw new Exception("Could not find the custom query '" + strCustomQuery + "'!");
                }

                if (iqry.QueryType != CUSTOM_QUERY_TYPE.TIME)
                {
                    throw new Exception("The custom query '" + iqry.Name + "' does not support the 'CUSTOM_QUERY_TYPE.TIME'!");
                }

                DataQuery dq = new DataQuery(iqry.Clone(strCustomQueryParam), dtStart, TimeSpan.FromMilliseconds(nTimeSpanInMs), nSegmentSize, nMaxCount);
                m_colDataQuery.Add(dq);

                m_nFieldCount += (dq.FieldCount - 1);  // subtract each sync field.
            }

            m_nFieldCount += 1; // add the sync field
            m_colDataQuery.Start();

            m_evtCancel.Reset();
            m_taskConsolidate = Task.Factory.StartNew(new Action(consolidateThread));
            m_evtEnabled.Set();
            m_colData.WaitData(10000);
        }
コード例 #6
0
 /// <summary>
 /// The constructor.
 /// </summary>
 /// <param name="strParam">Specifies the parameters which shold contains the 'FilePath'=path key=value pair.</param>
 public StandardQueryWAVFile(string strParam = null)
 {
     if (strParam != null)
     {
         strParam = ParamPacker.UnPack(strParam);
         PropertySet ps = new PropertySet(strParam);
         m_strPath = ps.GetProperty("FilePath");
     }
 }
コード例 #7
0
        /// <summary>
        /// The constructor.
        /// </summary>
        /// <param name="strSchema">Specifies the database schema.</param>
        /// <param name="rgCustomQueries">Optionally, specifies any custom queries to diretly add.</param>
        /// <remarks>
        /// The database schema defines the number of custom queries to use along with their names.  A simple key=value; list
        /// defines the streaming database schema using the following format:
        /// \code{.cpp}
        ///  "ConnectionCount=1;
        ///   Connection0_CustomQueryName=Test1;
        ///   Connection0_CustomQueryParam=param_string1
        /// \endcode
        /// Each param_string specifies the parameters of the custom query and may include the database connection string, database
        /// table, and database fields to query.
        /// </remarks>
        public MgrQueryGeneral(string strSchema, List <IXCustomQuery> rgCustomQueries)
        {
            m_colCustomQuery.Load();
            m_schema = new PropertySet(strSchema);

            m_colCustomQuery.Add(new StandardQueryTextFile());
            m_colCustomQuery.Add(new StandardQueryWAVFile());

            foreach (IXCustomQuery icustomquery in rgCustomQueries)
            {
                m_colCustomQuery.Add(icustomquery);
            }

            int nConnections = m_schema.GetPropertyAsInt("ConnectionCount");

            if (nConnections != 1)
            {
                throw new Exception("Currently, the general query type only supports 1 connection.");
            }

            string strConTag           = "Connection0";
            string strCustomQuery      = m_schema.GetProperty(strConTag + "_CustomQueryName");
            string strCustomQueryParam = m_schema.GetProperty(strConTag + "_CustomQueryParam");

            IXCustomQuery iqry = m_colCustomQuery.Find(strCustomQuery);

            if (iqry == null)
            {
                throw new Exception("Could not find the custom query '" + strCustomQuery + "'!");
            }

            if (iqry.QueryType != CUSTOM_QUERY_TYPE.BYTE && iqry.QueryType != CUSTOM_QUERY_TYPE.REAL_FLOAT && iqry.QueryType != CUSTOM_QUERY_TYPE.REAL_DOUBLE)
            {
                throw new Exception("The custom query '" + iqry.Name + "' must support the 'CUSTOM_QUERY_TYPE.BYTE' or 'CUSTOM_QUERY_TYPE.REAL_FLOAT' or 'CUSTOM_QUERY_TYPE.REAL_DOUBLE' query type!");
            }

            string strParam = ParamPacker.UnPack(strCustomQueryParam);

            m_iquery = iqry.Clone(strParam);
            m_iquery.Open();
        }
コード例 #8
0
        /// <summary>
        /// Populate controls.
        /// </summary>
        private void GetData()
        {
            PropertySet.Read();
            RevSet.Read();

            FillBoxes();

            if (Properties.Settings.Default.RememberLastCustomer && (PropertySet.GetProperty("CUSTOMER").Value == string.Empty))
            {
                cbCustomer.SelectedIndex = Properties.Settings.Default.LastCustomerSelection;
            }
        }
コード例 #9
0
        /// <summary>
        /// Initialize the gym with the specified properties.
        /// </summary>
        /// <param name="log">Specifies the output log to use.</param>
        /// <param name="properties">Specifies the properties containing Gym specific initialization parameters.</param>
        /// <remarks>
        /// The ModelGym uses the following initialization properties.
        ///
        /// 'GpuID' - the GPU to run on.
        /// 'ModelDescription' - the model description of the model to use.
        /// 'Dataset' - the name of the dataset to use.
        /// 'Weights' - the model trained weights.
        /// 'CudaPath' - the path of the CudaDnnDLL to use.
        /// 'BatchSize' - the batch size used when running images through the model (default = 16).
        /// 'RecreateData' - when 'True' the data is re-run through the model, otherwise if already run the data is loaded from file (faster).
        /// </remarks>
        public void Initialize(Log log, PropertySet properties)
        {
            m_nGpuID        = properties.GetPropertyAsInt("GpuID");
            m_strModelDesc  = properties.GetProperty("ModelDescription");
            m_strDataset    = properties.GetProperty("Dataset");
            m_rgWeights     = properties.GetPropertyBlob("Weights");
            m_nBatchSize    = properties.GetPropertyAsInt("BatchSize", 16);
            m_bRecreateData = properties.GetPropertyAsBool("RecreateData", false);
            m_strProject    = properties.GetProperty("ProjectName");
            if (string.IsNullOrEmpty(m_strProject))
            {
                m_strProject = "default";
            }

            string strCudaPath = properties.GetProperty("CudaPath");

            SettingsCaffe s = new SettingsCaffe();

            s.GpuIds            = m_nGpuID.ToString();
            s.ImageDbLoadMethod = IMAGEDB_LOAD_METHOD.LOAD_ON_DEMAND_BACKGROUND;

            m_imgdb = new MyCaffeImageDatabase2(log);
            m_imgdb.InitializeWithDsName1(s, m_strDataset);
            m_ds = m_imgdb.GetDatasetByName(m_strDataset);

            SimpleDatum sd    = m_imgdb.QueryImage(m_ds.TrainingSource.ID, 0, IMGDB_LABEL_SELECTION_METHOD.NONE, IMGDB_IMAGE_SELECTION_METHOD.NONE);
            BlobShape   shape = new BlobShape(1, sd.Channels, sd.Height, sd.Width);

            if (m_evtCancel == null)
            {
                m_evtCancel = new CancelEvent();
            }

            m_mycaffe = new MyCaffeControl <float>(s, log, m_evtCancel, null, null, null, null, strCudaPath);
            m_mycaffe.LoadToRun(m_strModelDesc, m_rgWeights, shape);

            m_log = log;
        }
コード例 #10
0
        public void LinkControls()
        {
            SwProperty newprop = PropertySet.GetProperty("DEPT");

            newprop.Type = SolidWorks.Interop.swconst.swCustomInfoType_e.swCustomInfoNumber;
            PropertySet.LinkControlToProperty("DEPT", true, cbDepartment);

            if (Properties.Settings.Default.Testing)
            {
                SwProperty oldprop = PropertySet.GetProperty("DEPTARTMENT");
                oldprop.Type = SolidWorks.Interop.swconst.swCustomInfoType_e.swCustomInfoText;
                PropertySet.LinkControlToProperty("DEPARTMENT", true, cbDepartment);
            }
        }
コード例 #11
0
 /// <summary>
 /// Run the trained model.
 /// </summary>
 /// <param name="mycaffe">Specifies the mycaffe instance running the sequence run model.</param>
 /// <param name="bw">Specifies the background worker.</param>
 /// <param name="strInput">Specifies the input data to run the model on.</param>
 private void runModel(MyCaffeControl <float> mycaffe, BackgroundWorker bw, string strInput)
 {
     try
     {
         m_log.WriteLine("You: " + strInput);
         int         nK      = (m_input.UseBeamSearch) ? 3 : 1;
         PropertySet input   = new PropertySet("InputData=" + strInput);
         PropertySet results = m_mycaffe.Run(input, nK);
         m_log.WriteLine("Robot: " + results.GetProperty("Results").TrimEnd(' ', '|'), true);
     }
     catch (Exception excpt)
     {
         m_log.WriteLine("Robot: " + excpt.Message);
     }
 }
コード例 #12
0
        /// <summary>
        /// Initializes a new custom trainer by loading the key-value pair of properties into the property set.
        /// </summary>
        /// <param name="strProperties">Specifies the key-value pair of properties each separated by ';'.  For example the expected
        /// format is 'key1'='value1';'key2'='value2';...</param>
        /// <param name="icallback">Specifies the parent callback.</param>
        public void Initialize(string strProperties, IXMyCaffeCustomTrainerCallback icallback)
        {
            m_icallback  = icallback;
            m_properties = new PropertySet(strProperties);

            string strTrainerType = m_properties.GetProperty("TrainerType");

            switch (strTrainerType)
            {
            case "RNN.SIMPLE":       // bare bones model
                m_trainerType = TRAINER_TYPE.RNN_SIMPLE;
                break;

            default:
                throw new Exception("Unknown trainer type '" + strTrainerType + "'!");
            }
        }
コード例 #13
0
ファイル: RenderHelper.cs プロジェクト: secondii/Yutai
        public static IFeatureRenderer ReadRender(string string_0)
        {
            BinaryReader binaryReader = new BinaryReader(new System.IO.FileStream(string_0, FileMode.Open));

            byte[]            numArray = binaryReader.ReadBytes(binaryReader.ReadInt32());
            IMemoryBlobStream memoryBlobStreamClass = new MemoryBlobStream();
            IObjectStream     objectStreamClass     = new ObjectStream()
            {
                Stream = memoryBlobStreamClass
            };

            ((IMemoryBlobStreamVariant)memoryBlobStreamClass).ImportFromVariant(numArray);
            IPropertySet propertySetClass = new PropertySet();

            (propertySetClass as IPersistStream).Load(objectStreamClass);
            return(propertySetClass.GetProperty("Render") as IFeatureRenderer);
        }
コード例 #14
0
        /// <summary>
        /// Write properties to drawing document.
        /// </summary>
        /// <param name="md">The current ModelDoc object.</param>
        public void Write(ModelDoc2 md)
        {
            try {
                dirtTracker.Besmirched -= dirtTracker_Besmirched;
            } catch (Exception) {
                // I don't care.
            }

            if (!check_rev())
            {
                if (System.Windows.Forms.MessageBox.Show(
                        string.Format(@"Make drawing REV ({0}) match filename ({1})?", cbRevision.Text, fileRev), @"REV mismatch.", MessageBoxButtons.YesNo)
                    == DialogResult.Yes)
                {
                    cbRevision.Text = fileRev;
                    PropertySet.GetProperty("REVISION LEVEL").Value = fileRev;
                }
            }

            if (!check_itemnumber())
            {
                string itnu = label4.Text.Trim();
                string cuss = cbCustomer.Text.Split('-')[0].Trim();
                System.Windows.Forms.MessageBox.Show(
                    string.Format("The item number '{0}' doesn't match the customer '{1}'.", itnu, cuss),
                    "Wrong customer?",
                    MessageBoxButtons.OK);
            }
            this.PropertySet.ReadControls();
            this.PropertySet.Write(md);
            this.RevSet.Write(md);
            FillBoxes();
            (md as DrawingDoc).ForceRebuild();
            this.dirtTracker = null;
            dirtTracker      = new DirtTracker(this);

            if (dirtTracker != null)
            {
                IsDirty = false;
                dirtTracker.Besmirched += dirtTracker_Besmirched;
            }
        }
コード例 #15
0
        private void LinkControlToProperty()
        {
            string pn = "DEPARTMENT";

            if (!Properties.Settings.Default.Testing)
            {
                pn = "DEPTID";
            }

            string dept;

            if (PropertySet.Contains(pn))
            {
                PropertySet.GetProperty(pn).Ctl = cbDepartment;
                dept = PropertySet.GetProperty(pn).Value;
                int tp = 1;

                if (int.TryParse(dept, out tp))
                {
                    OpType = tp;
                }
                else
                {
                    OpType = PropertySet.cutlistData.GetOpTypeIDByName(dept);
                }
                dept = tp.ToString();
            }
            else
            {
                SolidWorks.Interop.swconst.swCustomInfoType_e t = SolidWorks.Interop.swconst.swCustomInfoType_e.swCustomInfoNumber;
                SwProperty p = new SwProperty(pn, t, "1", true);
                p.SwApp = SwApp;
                p.Ctl   = cbDepartment;
                PropertySet.Add(p);
                OpType = 1;
            }
        }
コード例 #16
0
    void Start()
    {
        m_sceneController = new SceneController();
        BaseState bs  = new BaseState();
        Property  pt1 = new StringProperty("pt1");
        Property  pt2 = new Vector2Property("vpt");
        Property  pt3 = new Matrix4x4Property("matpt");
        Property  pt4 = new QuaternionProperty("qpt");

        pt2.value = new Vector2(23, 11);
        pt1.value = "test";
        pt3.value = Matrix4x4.identity;
        pt4.value = Quaternion.identity;
        Property pt5 = new ColorProperty("cpt");
        Property pt6 = new IntProperty("ipt");
        Property pt7 = new Vector3Property("v3pt");

        pt5.value = Color.red;
        pt6.value = 612;
        pt7.value = new Vector3(12, 11, 33);

        PropertySet        ps  = new PropertySet("pset", null);
        PrpertySetProperty pts = new PrpertySetProperty("pts");
        PropertySet        ps2 = new PropertySet("pset2", null);

        ps2.SetProperty(pt5);
        ps2.SetProperty(pt6);
        ps2.SetProperty(pt7);
        pts.value = ps2;

        ps.SetProperty(pt1);
        ps.SetProperty(pt2);
        ps.SetProperty(pt3);
        ps.SetProperty(pt4);
        ps.SetProperty(pts);


        //bs.RegisterProperty(new StringProperty("name"));
        //bs.SetProperty("name", "lily");
        //object v = bs.GetPropertyValue("name");
        Property pt = ps.GetProperty("pt1");

        Debug.Log(pt.name + "  " + pt.value);
        pt = ps.GetProperty("vpt");
        Debug.Log(pt.name + "  " + pt.value);
        pt = ps.GetProperty("matpt");
        Debug.Log(pt.name + "  " + pt.value);
        pt = ps.GetProperty("qpt");
        Debug.Log(pt.name + "  " + pt.value);
        pt = ps.GetProperty("pts");
        PropertySet ps3 = pt.value as PropertySet;

        pt = ps3.GetProperty("cpt");
        Debug.Log(pt.name + "  " + pt.value);
        pt = ps3.GetProperty("ipt");
        Debug.Log(pt.name + "  " + pt.value);
        pt = ps3.GetProperty("v3pt");
        Debug.Log(pt.name + "  " + pt.value);
        //Debug.Log(pt2.name + "  " + pt2.value);
        //Debug.Log(pt3.name + "  " + pt3.value);
        //Debug.Log(pt4.name + "  " + pt4.value);
    }
コード例 #17
0
        /// <summary>
        /// Initialize the gym with the specified properties.
        /// </summary>
        /// <param name="log">Specifies the output log to use.</param>
        /// <param name="properties">Specifies the properties containing Gym specific initialization parameters.</param>
        /// <remarks>
        /// The AtariGym uses the following initialization properties.
        ///   GameRom='path to .rom file'
        /// </remarks>
        public void Initialize(Log log, PropertySet properties)
        {
            m_log = log;

            if (m_ale != null)
            {
                m_ale.Shutdown();
                m_ale = null;
            }

            m_ale = new ALE();
            m_ale.Initialize();
            m_ale.EnableDisplayScreen       = false;
            m_ale.EnableSound               = false;
            m_ale.EnableColorData           = properties.GetPropertyAsBool("EnableColor", false);
            m_ale.EnableRestrictedActionSet = true;
            m_ale.EnableColorAveraging      = true;
            m_ale.AllowNegativeRewards      = properties.GetPropertyAsBool("AllowNegativeRewards", false);
            m_ale.EnableTerminateOnRallyEnd = properties.GetPropertyAsBool("TerminateOnRallyEnd", false);
            m_ale.RandomSeed = (int)DateTime.Now.Ticks;
            m_ale.RepeatActionProbability = 0.0f; // disable action repeatability

            if (properties == null)
            {
                throw new Exception("The properties must be specified with the 'GameROM' set the the Game ROM file path.");
            }

            string strROM = properties.GetProperty("GameROM");

            if (strROM.Contains('~'))
            {
                strROM = Utility.Replace(strROM, '~', ' ');
            }
            else
            {
                strROM = Utility.Replace(strROM, "[sp]", ' ');
            }

            if (!File.Exists(strROM))
            {
                throw new Exception("Could not find the game ROM file specified '" + strROM + "'!");
            }

            if (properties.GetPropertyAsBool("UseGrayscale", false))
            {
                m_ct = COLORTYPE.CT_GRAYSCALE;
            }

            m_bPreprocess    = properties.GetPropertyAsBool("Preprocess", true);
            m_bForceGray     = properties.GetPropertyAsBool("ActionForceGray", false);
            m_bEnableNumSkip = properties.GetPropertyAsBool("EnableNumSkip", true);
            m_nFrameSkip     = properties.GetPropertyAsInt("FrameSkip", -1);

            m_ale.Load(strROM);
            m_rgActionsRaw = m_ale.ActionSpace;
            m_random       = new CryptoRandom();
            m_rgFrameSkip  = new List <int>();

            if (m_nFrameSkip < 0)
            {
                for (int i = 2; i < 5; i++)
                {
                    m_rgFrameSkip.Add(i);
                }
            }
            else
            {
                m_rgFrameSkip.Add(m_nFrameSkip);
            }

            m_rgActions.Add(ACTION.ACT_PLAYER_A_LEFT.ToString(), (int)ACTION.ACT_PLAYER_A_LEFT);
            m_rgActions.Add(ACTION.ACT_PLAYER_A_RIGHT.ToString(), (int)ACTION.ACT_PLAYER_A_RIGHT);

            if (!properties.GetPropertyAsBool("EnableBinaryActions", false))
            {
                m_rgActions.Add(ACTION.ACT_PLAYER_A_FIRE.ToString(), (int)ACTION.ACT_PLAYER_A_FIRE);
            }

            m_rgActionSet = m_rgActions.ToList();

            Reset(false);
        }
コード例 #18
0
ファイル: MainForm.cs プロジェクト: batuZ/Samples
        private void MainForm_Load(object sender, System.EventArgs e)
        {
            if (_geoFactory == null)
            {
                _geoFactory = new GeometryFactory();
            }
            if (_gc == null)
            {
                _gc = new GeometryConvertor();
            }

            // 初始化RenderControl控件
            IPropertySet ps = new PropertySet();

            ps.SetProperty("RenderSystem", gviRenderSystem.gviRenderOpenGL);
            this.axRenderControl1.Initialize(true, ps);
            this.axRenderControl1.Camera.FlyTime = 1;

            rootId = this.axRenderControl1.ObjectManager.GetProjectTree().RootID;

            // 设置天空盒

            if (System.IO.Directory.Exists(strMediaPath))
            {
                string  tmpSkyboxPath = strMediaPath + @"\skybox";
                ISkyBox skybox        = this.axRenderControl1.ObjectManager.GetSkyBox(0);
                skybox.SetImagePath(gviSkyboxImageIndex.gviSkyboxImageBack, tmpSkyboxPath + "\\1_BK.jpg");
                skybox.SetImagePath(gviSkyboxImageIndex.gviSkyboxImageBottom, tmpSkyboxPath + "\\1_DN.jpg");
                skybox.SetImagePath(gviSkyboxImageIndex.gviSkyboxImageFront, tmpSkyboxPath + "\\1_FR.jpg");
                skybox.SetImagePath(gviSkyboxImageIndex.gviSkyboxImageLeft, tmpSkyboxPath + "\\1_LF.jpg");
                skybox.SetImagePath(gviSkyboxImageIndex.gviSkyboxImageRight, tmpSkyboxPath + "\\1_RT.jpg");
                skybox.SetImagePath(gviSkyboxImageIndex.gviSkyboxImageTop, tmpSkyboxPath + "\\1_UP.jpg");
            }
            else
            {
                MessageBox.Show("请不要随意更改SDK目录名");
                return;
            }

            #region 创建地面模型
            Array imgNames = null;
            try
            {
                string           osgPath = (strMediaPath + @"\mdb+osg\Ground\00650100agc4001.osg");
                IResourceFactory resFac  = new ResourceFactory();
                IPropertySet     imgs    = new PropertySet();
                resFac.CreateModelAndImageFromFile(osgPath, out imgs, out ModelSrc, out MatrixSrc);
                this.axRenderControl1.ObjectManager.AddModel("test", ModelSrc);

                string[] keys = imgs.GetAllKeys();
                foreach (string imgName in keys)
                {
                    IImage img = imgs.GetProperty(imgName) as IImage;
                    this.axRenderControl1.ObjectManager.AddImage(imgName, img);
                }

                ModelPointSrc = _geoFactory.CreateGeometry(gviGeometryType.gviGeometryModelPoint, gviVertexAttribute.gviVertexAttributeZ) as IModelPoint;
                ModelPointSrc.ModelEnvelope = ModelSrc.Envelope;
                ModelPointSrc.FromMatrix(MatrixSrc);
                ModelPointSrc.ModelName = "test";

                RenderModelPointSrc = this.axRenderControl1.ObjectManager.CreateRenderModelPoint(ModelPointSrc, null, rootId);
                this.axRenderControl1.Camera.FlyToObject(RenderModelPointSrc.Guid, gviActionCode.gviActionFlyTo);
            }
            catch (COMException ex)
            {
                System.Diagnostics.Trace.WriteLine(ex.Message);
                return;
            }
            #endregion

            _geoEditor    = this.axRenderControl1.ObjectEditor;
            _multiPolygon = _geoFactory.CreateGeometry(gviGeometryType.gviGeometryMultiPolygon, gviVertexAttribute.gviVertexAttributeZ) as IMultiPolygon;

            {
                this.helpProvider1.SetShowHelp(this.axRenderControl1, true);
                this.helpProvider1.SetHelpString(this.axRenderControl1, "");
                this.helpProvider1.HelpNamespace = "GeometryConvert3.html";
            }
        }
コード例 #19
0
        /// <summary>
        /// Step the gym one step in the data.
        /// </summary>
        /// <param name="nAction">Specifies the action to run on the gym.</param>
        /// <param name="bGetLabel">Not used.</param>
        /// <param name="extraProp">Optionally, specifies extra properties.</param>
        /// <returns>A tuple containing state data, the reward, and the done state is returned.</returns>
        public Tuple <State, double, bool> Step(int nAction, bool bGetLabel = false, PropertySet extraProp = null)
        {
            DataState       data   = new DataState();
            ScoreCollection scores = null;

            if (ActivePhase == Phase.RUN)
            {
                if (extraProp == null)
                {
                    throw new Exception("The extra properties are needed when querying data during the RUN phase.");
                }

                int    nDataCount   = extraProp.GetPropertyAsInt("DataCountRequested");
                string strStartTime = extraProp.GetProperty("SeedTime");

                int      nStartIdx = m_scores.Count - nDataCount;
                DateTime dt;
                if (DateTime.TryParse(strStartTime, out dt))
                {
                    nStartIdx = m_scores.FindIndexAt(dt, nDataCount);
                }

                scores = m_scores.CopyFrom(nStartIdx, nDataCount);
            }
            else
            {
                int nCount = 0;

                m_scores = load(out m_nDim, out m_nWidth);

                if (m_bRecreateData || m_scores.Count != m_ds.TrainingSource.ImageCount)
                {
                    Stopwatch sw = new Stopwatch();
                    sw.Start();

                    m_scores = new ScoreCollection();

                    while (m_nCurrentIdx < m_ds.TrainingSource.ImageCount)
                    {
                        // Query images sequentially by index in batches
                        List <SimpleDatum> rgSd = new List <SimpleDatum>();

                        for (int i = 0; i < m_nBatchSize; i++)
                        {
                            SimpleDatum sd = m_imgdb.QueryImage(m_ds.TrainingSource.ID, m_nCurrentIdx + i, IMGDB_LABEL_SELECTION_METHOD.NONE, IMGDB_IMAGE_SELECTION_METHOD.NONE);
                            rgSd.Add(sd);
                            nCount++;

                            if (nCount == m_ds.TrainingSource.ImageCount)
                            {
                                break;
                            }
                        }

                        List <ResultCollection> rgRes = m_mycaffe.Run(rgSd, ref m_blobWork);

                        if (m_nWidth == 0)
                        {
                            m_nWidth = rgRes[0].ResultsOriginal.Count;
                            m_nDim   = rgRes[0].ResultsOriginal.Count * 2;
                        }

                        // Fill SimpleDatum with the ordered label,score pairs starting with the detected label.
                        for (int i = 0; i < rgRes.Count; i++)
                        {
                            m_scores.Add(new Score(rgSd[i].TimeStamp, rgSd[i].Index, rgRes[i]));
                            m_nCurrentIdx++;
                        }

                        if (sw.Elapsed.TotalMilliseconds > 1000)
                        {
                            m_log.Progress = (double)m_nCurrentIdx / (double)m_ds.TrainingSource.ImageCount;
                            m_log.WriteLine("Running model on image " + m_nCurrentIdx.ToString() + " of " + m_ds.TrainingSource.ImageCount.ToString() + " of '" + m_strDataset + "' dataset.");

                            if (m_evtCancel.WaitOne(0))
                            {
                                return(null);
                            }
                        }
                    }

                    save(m_nDim, m_nWidth, m_scores);
                }
                else
                {
                    m_nCurrentIdx = m_scores.Count;
                }

                scores = m_scores;
            }

            float[]     rgfRes = scores.Data;
            SimpleDatum sdRes  = new SimpleDatum(scores.Count, m_nWidth, 2, rgfRes, 0, rgfRes.Length);

            data.SetData(sdRes);
            m_nCurrentIdx = 0;

            return(new Tuple <State, double, bool>(data, 0, false));
        }
コード例 #20
0
ファイル: BeamSearch.cs プロジェクト: Pandinosaurus/MyCaffe
        /// <summary>
        /// Perform the beam-search.
        /// </summary>
        /// <param name="input">Specifies the input data (e.g. the encoder input)</param>
        /// <param name="nK">Specifies the beam width for the search.</param>
        /// <param name="dfThreshold">Specifies the threshold where detected items with probabilities less than the threshold are ignored (default = 0.01).</param>
        /// <param name="nMax">Specifies the maximum length to process (default = 80)</param>
        /// <returns>The list of top sequences is returned.</returns>
        /// <remarks>
        /// The beam-search algorithm is inspired by the article
        /// @see [How to Implement a Beam Search Decoder for Natural Language Processing](https://machinelearningmastery.com/beam-search-decoder-natural-language-processing/) by Jason Brownlee, "Machine Learning Mastery", 2018
        /// </remarks>
        public List <Tuple <double, bool, List <Tuple <string, int, double> > > > Search(PropertySet input, int nK, double dfThreshold = 0.01, int nMax = 80)
        {
            List <Tuple <double, bool, List <Tuple <string, int, double> > > > rgSequences = new List <Tuple <double, bool, List <Tuple <string, int, double> > > >();

            rgSequences.Add(new Tuple <double, bool, List <Tuple <string, int, double> > >(0, false, new List <Tuple <string, int, double> >()));

            BlobCollection <T> colBottom = m_layer.PreProcessInput(input, null);
            double             dfLoss;
            string             strInput = input.GetProperty("InputData");
            bool bDone = false;

            BlobCollection <T> colTop = m_net.Forward(colBottom, out dfLoss);
            List <Tuple <string, int, double> > rgRes = m_layer.PostProcessOutput(colTop[0], nK);

            rgRes = rgRes.Where(p => p.Item3 >= dfThreshold).ToList();
            List <List <Tuple <string, int, double> > > rgrgRes = new List <List <Tuple <string, int, double> > >();

            rgrgRes.Add(rgRes);

            while (!bDone && nMax > 0)
            {
                int nProcessedCount = 0;

                List <Tuple <double, bool, List <Tuple <string, int, double> > > > rgCandidates = new List <Tuple <double, bool, List <Tuple <string, int, double> > > >();

                for (int i = 0; i < rgSequences.Count; i++)
                {
                    if (rgrgRes[i].Count > 0)
                    {
                        for (int j = 0; j < rgrgRes[i].Count; j++)
                        {
                            if (rgrgRes[i][j].Item1.Length > 0)
                            {
                                double dfScore = rgSequences[i].Item1 - Math.Log(rgrgRes[i][j].Item3);

                                List <Tuple <string, int, double> > rgSequence1 = new List <Tuple <string, int, double> >();
                                rgSequence1.AddRange(rgSequences[i].Item3);
                                rgSequence1.Add(rgrgRes[i][j]);

                                rgCandidates.Add(new Tuple <double, bool, List <Tuple <string, int, double> > >(dfScore, false, rgSequence1));
                                nProcessedCount++;
                            }
                        }
                    }
                    else
                    {
                        rgCandidates.Add(new Tuple <double, bool, List <Tuple <string, int, double> > >(rgSequences[i].Item1, true, rgSequences[i].Item3));
                    }
                }

                if (nProcessedCount > 0)
                {
                    rgSequences = rgCandidates.OrderBy(p => p.Item1).Take(nK).ToList();
                    rgrgRes     = new List <List <Tuple <string, int, double> > >();

                    for (int i = 0; i < rgSequences.Count; i++)
                    {
                        if (!rgSequences[i].Item2)
                        {
                            rgRes = new List <Tuple <string, int, double> >();

                            // Reset state.
                            m_layer.PreProcessInput(strInput, 1, colBottom);
                            m_net.Forward(colBottom, out dfLoss, true);

                            // Re-run through each branch to get correct state at the leaf
                            for (int j = 0; j < rgSequences[i].Item3.Count; j++)
                            {
                                int nIdx = rgSequences[i].Item3[j].Item2;

                                m_layer.PreProcessInput(strInput, nIdx, colBottom);
                                colTop = m_net.Forward(colBottom, out dfLoss, true);

                                if (j == rgSequences[i].Item3.Count - 1)
                                {
                                    List <Tuple <string, int, double> > rgRes1 = m_layer.PostProcessOutput(colTop[0], nK);
                                    rgRes1 = rgRes1.Where(p => p.Item3 >= dfThreshold).ToList();

                                    for (int k = 0; k < rgRes1.Count; k++)
                                    {
                                        if (rgRes1[k].Item1.Length > 0)
                                        {
                                            rgRes.Add(rgRes1[k]);
                                        }
                                        else
                                        {
                                            Trace.WriteLine("EOS");
                                        }
                                    }

                                    rgrgRes.Add(rgRes);
                                }
                            }
                        }
                        else
                        {
                            rgrgRes.Add(new List <Tuple <string, int, double> >());
                        }
                    }
                }
                else
                {
                    bDone = true;
                }

                nMax--;
            }

            return(rgSequences);
        }
コード例 #21
0
 /// <summary>
 /// Gets the value associated with the provided key.
 /// <para>Returns <paramref name="defaultValue"/> if the key is not found.</para>
 /// </summary>
 /// <typeparam name="I">Type of the value to return.</typeparam>
 /// <param name="key">Key of the value to get.</param>
 /// <param name="defaultValue">Default value to be returned if the key is not found.</param>
 /// <exception cref="Contxt.IncorrectTypeException">Thrown if the property in the set is not of the same type as <typeparamref name="T"/></exception>
 /// <seealso cref="Contxt.PropertySet.GetProperty{T}(string, T)"/>
 public I Get <I>(string key, I defaultValue)
 {
     return(propertySet.GetProperty <I>(key, defaultValue));
 }
コード例 #22
0
        /// <summary>
        /// Select known data in fields, and link controls to properties. I think this should be refactored.
        /// </summary>
        private void FillBoxes()
        {
            SwProperty partNo = this.PropertySet.GetProperty("PartNo");
            SwProperty custo  = this.PropertySet.GetProperty("CUSTOMER");
            SwProperty by     = this.PropertySet.GetProperty("DrawnBy");
            SwProperty d      = this.PropertySet.GetProperty("DATE");
            SwProperty rl     = PropertySet.GetProperty("REVISION LEVEL");

            fileTitle = (PropertySet.SwApp.ActiveDoc as ModelDoc2).GetTitle().Replace(@".SLDDRW", string.Empty);
            fileName  = fileTitle.Split(' ')[0].Trim();
            if (fileTitle.ToUpper().Contains(@" REV"))
            {
                fileRev = fileTitle.Split(new string[] { @" REV", @" " }, StringSplitOptions.RemoveEmptyEntries)[1].Trim();
            }
            else
            {
                if (rl != null)
                {
                    fileRev = rl.ResValue;
                }
                else
                {
                    fileRev = "100";
                }
            }

            if (partNo != null)
            {
                label4.Text = fileName;
                partNo.Ctl  = tbItemNo;
            }
            else
            {
                partNo       = new SwProperty("PartNo", swCustomInfoType_e.swCustomInfoText, "$PRP:\"SW-File Name\"", true);
                partNo.SwApp = SwApp;
                partNo.Ctl   = tbItemNo;
                this.PropertySet.Add(partNo);
            }

            if (custo != null)
            {
                custo.Ctl = cbCustomer;
            }
            else
            {
                custo       = new SwProperty("CUSTOMER", swCustomInfoType_e.swCustomInfoText, string.Empty, true);
                custo.SwApp = SwApp;
                custo.Ctl   = cbCustomer;
                this.PropertySet.Add(custo);
            }

            if (by != null)
            {
                if (by.Value == string.Empty)
                {
                    by.ID       = PropertySet.CutlistData.GetCurrentAuthor().ToString();
                    by.Value    = PropertySet.CutlistData.GetCurrentAuthorInitial();
                    by.ResValue = by.Value;
                }
                by.Ctl = this.cbAuthor;
            }
            else
            {
                by       = new SwProperty("DrawnBy", swCustomInfoType_e.swCustomInfoText, string.Empty, true);
                by.SwApp = SwApp;
                by.Ctl   = cbAuthor;
                this.PropertySet.Add(by);
            }

            if (d != null)
            {
                d.Ctl = this.dpDate;
            }
            else
            {
                d       = new SwProperty("DATE", swCustomInfoType_e.swCustomInfoDate, string.Empty, true);
                d.SwApp = SwApp;
                d.Ctl   = dpDate;
                this.PropertySet.Add(d);
            }

            if (rl != null)
            {
                rl.Ctl = cbRevision;
            }
            else
            {
                rl       = new SwProperty("REVISION LEVEL", swCustomInfoType_e.swCustomInfoText, "100", true);
                rl.SwApp = SwApp;
                rl.Ctl   = cbRevision;
                PropertySet.Add(rl);
            }

            for (int i = 1; i < 6; i++)
            {
                if (PropertySet.Contains("M" + i.ToString()))
                {
                    foreach (Control c in tableLayoutPanel3.Controls)
                    {
                        if (c.Name.ToUpper().Contains("M" + i.ToString()))
                        {
                            PropertySet.GetProperty("M" + i.ToString()).Ctl = c;
                        }


                        if (c.Name.ToUpper().Contains("FINISH" + i.ToString()))
                        {
                            this.PropertySet.GetProperty("FINISH " + i.ToString()).Ctl = c;
                        }
                    }
                }
                else
                {
                    foreach (Control c in tableLayoutPanel3.Controls)
                    {
                        if (c.Name.ToUpper().Contains("M" + i.ToString()))
                        {
                            SwProperty mx = new SwProperty("M" + i.ToString(), swCustomInfoType_e.swCustomInfoText, string.Empty, true);
                            mx.SwApp = SwApp;
                            mx.Ctl   = c;
                            PropertySet.Add(mx);
                        }

                        if (c.Name.ToUpper().Contains("FINISH" + i.ToString()))
                        {
                            SwProperty fx = new SwProperty("FINISH " + i.ToString(), swCustomInfoType_e.swCustomInfoText, string.Empty, true);
                            fx.SwApp = SwApp;
                            fx.Ctl   = c;
                            PropertySet.Add(fx);
                        }
                    }
                }
            }

            DataSet ds = PropertySet.CutlistData.GetCutlistData(fileName.Trim(),
                                                                PropertySet.GetProperty("REVISION LEVEL").Value.Trim());
            int stat = 0;

            if (ds.Tables[0].Rows.Count > 0 && int.TryParse(ds.Tables[0].Rows[0][(int)CutlistData.CutlistDataFields.STATEID].ToString(), out stat))
            {
                cbStatus.Enabled       = true;
                cbStatus.SelectedValue = stat;
            }
            else
            {
                cbStatus.Enabled = false;
            }

            PropertySet.UpdateFields();
            tbItemNoRes.Text = PropertySet.GetProperty("PartNo").ResValue;
        }
コード例 #23
0
        public Brain(MyCaffeControl <T> mycaffe, PropertySet properties, CryptoRandom random, IxTrainerCallbackRNN icallback, Phase phase, BucketCollection rgVocabulary, bool bUsePreloadData, string strRunProperties = null)
        {
            string strOutputBlob = null;

            if (strRunProperties != null)
            {
                m_runProperties = new PropertySet(strRunProperties);
            }

            m_icallback             = icallback;
            m_mycaffe               = mycaffe;
            m_properties            = properties;
            m_random                = random;
            m_rgVocabulary          = rgVocabulary;
            m_bUsePreloadData       = bUsePreloadData;
            m_nSolverSequenceLength = m_properties.GetPropertyAsInt("SequenceLength", -1);
            m_bDisableVocabulary    = m_properties.GetPropertyAsBool("DisableVocabulary", false);
            m_nThreads              = m_properties.GetPropertyAsInt("Threads", 1);
            m_dfScale               = m_properties.GetPropertyAsDouble("Scale", 1.0);

            if (m_nThreads > 1)
            {
                m_dataPool.Initialize(m_nThreads, icallback);
            }

            if (m_runProperties != null)
            {
                m_dfTemperature = Math.Abs(m_runProperties.GetPropertyAsDouble("Temperature", 0));
                if (m_dfTemperature > 1.0)
                {
                    m_dfTemperature = 1.0;
                }

                string strPhaseOnRun = m_runProperties.GetProperty("PhaseOnRun", false);
                switch (strPhaseOnRun)
                {
                case "RUN":
                    m_phaseOnRun = Phase.RUN;
                    break;

                case "TEST":
                    m_phaseOnRun = Phase.TEST;
                    break;

                case "TRAIN":
                    m_phaseOnRun = Phase.TRAIN;
                    break;
                }

                if (phase == Phase.RUN && m_phaseOnRun != Phase.NONE)
                {
                    if (m_phaseOnRun != Phase.RUN)
                    {
                        m_mycaffe.Log.WriteLine("Warning: Running on the '" + m_phaseOnRun.ToString() + "' network.");
                    }

                    strOutputBlob = m_runProperties.GetProperty("OutputBlob", false);
                    if (strOutputBlob == null)
                    {
                        throw new Exception("You must specify the 'OutputBlob' when Running with a phase other than RUN.");
                    }

                    strOutputBlob = Utility.Replace(strOutputBlob, '~', ';');

                    phase = m_phaseOnRun;
                }
            }

            m_net = mycaffe.GetInternalNet(phase);
            if (m_net == null)
            {
                mycaffe.Log.WriteLine("WARNING: Test net does not exist, set test_iteration > 0.  Using TRAIN phase instead.");
                m_net = mycaffe.GetInternalNet(Phase.TRAIN);
            }

            // Find the first LSTM layer to determine how to load the data.
            // NOTE: Only LSTM has a special loading order, other layers use the standard N, C, H, W ordering.
            LSTMLayer <T>       lstmLayer       = null;
            LSTMSimpleLayer <T> lstmSimpleLayer = null;

            foreach (Layer <T> layer1 in m_net.layers)
            {
                if (layer1.layer_param.type == LayerParameter.LayerType.LSTM)
                {
                    lstmLayer  = layer1 as LSTMLayer <T>;
                    m_lstmType = LayerParameter.LayerType.LSTM;
                    break;
                }
                else if (layer1.layer_param.type == LayerParameter.LayerType.LSTM_SIMPLE)
                {
                    lstmSimpleLayer = layer1 as LSTMSimpleLayer <T>;
                    m_lstmType      = LayerParameter.LayerType.LSTM_SIMPLE;
                    break;
                }
            }

            if (lstmLayer == null && lstmSimpleLayer == null)
            {
                throw new Exception("Could not find the required LSTM or LSTM_SIMPLE layer!");
            }

            if (m_phaseOnRun != Phase.NONE && m_phaseOnRun != Phase.RUN && strOutputBlob != null)
            {
                if ((m_blobOutput = m_net.FindBlob(strOutputBlob)) == null)
                {
                    throw new Exception("Could not find the 'Output' layer top named '" + strOutputBlob + "'!");
                }
            }

            if ((m_blobData = m_net.FindBlob("data")) == null)
            {
                throw new Exception("Could not find the 'Input' layer top named 'data'!");
            }

            if ((m_blobClip = m_net.FindBlob("clip")) == null)
            {
                throw new Exception("Could not find the 'Input' layer top named 'clip'!");
            }

            Layer <T> layer = m_net.FindLastLayer(LayerParameter.LayerType.INNERPRODUCT);

            m_mycaffe.Log.CHECK(layer != null, "Could not find an ending INNERPRODUCT layer!");

            if (!m_bDisableVocabulary)
            {
                m_nVocabSize = (int)layer.layer_param.inner_product_param.num_output;
                if (rgVocabulary != null)
                {
                    m_mycaffe.Log.CHECK_EQ(m_nVocabSize, rgVocabulary.Count, "The vocabulary count = '" + rgVocabulary.Count.ToString() + "' and last inner product output count = '" + m_nVocabSize.ToString() + "' - these do not match but they should!");
                }
            }

            if (m_lstmType == LayerParameter.LayerType.LSTM)
            {
                m_nSequenceLength = m_blobData.shape(0);
                m_nBatchSize      = m_blobData.shape(1);
            }
            else
            {
                m_nBatchSize      = (int)lstmSimpleLayer.layer_param.lstm_simple_param.batch_size;
                m_nSequenceLength = m_blobData.shape(0) / m_nBatchSize;

                if (phase == Phase.RUN)
                {
                    m_nBatchSize = 1;

                    List <int> rgNewShape = new List <int>()
                    {
                        m_nSequenceLength, 1
                    };
                    m_blobData.Reshape(rgNewShape);
                    m_blobClip.Reshape(rgNewShape);
                    m_net.Reshape();
                }
            }

            m_mycaffe.Log.CHECK_EQ(m_nSequenceLength, m_blobData.num, "The data num must equal the sequence lengh of " + m_nSequenceLength.ToString());

            m_rgDataInput = new T[m_nSequenceLength * m_nBatchSize];

            T[] rgClipInput = new T[m_nSequenceLength * m_nBatchSize];
            m_mycaffe.Log.CHECK_EQ(rgClipInput.Length, m_blobClip.count(), "The clip count must equal the sequence length * batch size: " + rgClipInput.Length.ToString());
            m_tZero = (T)Convert.ChangeType(0, typeof(T));
            m_tOne  = (T)Convert.ChangeType(1, typeof(T));

            for (int i = 0; i < rgClipInput.Length; i++)
            {
                if (m_lstmType == LayerParameter.LayerType.LSTM)
                {
                    rgClipInput[i] = (i < m_nBatchSize) ? m_tZero : m_tOne;
                }
                else
                {
                    rgClipInput[i] = (i % m_nSequenceLength == 0) ? m_tZero : m_tOne;
                }
            }

            m_blobClip.mutable_cpu_data = rgClipInput;

            if (phase != Phase.RUN)
            {
                m_solver                      = mycaffe.GetInternalSolver();
                m_solver.OnStart             += m_solver_OnStart;
                m_solver.OnTestStart         += m_solver_OnTestStart;
                m_solver.OnTestingIteration  += m_solver_OnTestingIteration;
                m_solver.OnTrainingIteration += m_solver_OnTrainingIteration;

                if ((m_blobLabel = m_net.FindBlob("label")) == null)
                {
                    throw new Exception("Could not find the 'Input' layer top named 'label'!");
                }

                m_nSequenceLengthLabel = m_blobLabel.count(0, 2);
                m_rgLabelInput         = new T[m_nSequenceLengthLabel];
                m_mycaffe.Log.CHECK_EQ(m_rgLabelInput.Length, m_blobLabel.count(), "The label count must equal the label sequence length * batch size: " + m_rgLabelInput.Length.ToString());
                m_mycaffe.Log.CHECK(m_nSequenceLengthLabel == m_nSequenceLength * m_nBatchSize || m_nSequenceLengthLabel == 1, "The label sqeuence length must be 1 or equal the length of the sequence: " + m_nSequenceLength.ToString());
            }
        }
コード例 #24
0
        public Brain(MyCaffeControl <T> mycaffe, PropertySet properties, CryptoRandom random, IxTrainerCallbackRNN icallback, Phase phase, BucketCollection rgVocabulary, string strRunProperties = null)
        {
            string strOutputBlob = null;

            if (strRunProperties != null)
            {
                m_runProperties = new PropertySet(strRunProperties);
            }

            m_icallback    = icallback;
            m_mycaffe      = mycaffe;
            m_properties   = properties;
            m_random       = random;
            m_rgVocabulary = rgVocabulary;

            if (m_runProperties != null)
            {
                m_dfTemperature = m_runProperties.GetPropertyAsDouble("Temperature", 0);
                string strPhaseOnRun = m_runProperties.GetProperty("PhaseOnRun", false);
                switch (strPhaseOnRun)
                {
                case "RUN":
                    m_phaseOnRun = Phase.RUN;
                    break;

                case "TEST":
                    m_phaseOnRun = Phase.TEST;
                    break;

                case "TRAIN":
                    m_phaseOnRun = Phase.TRAIN;
                    break;
                }

                if (phase == Phase.RUN && m_phaseOnRun != Phase.NONE)
                {
                    if (m_phaseOnRun != Phase.RUN)
                    {
                        m_mycaffe.Log.WriteLine("Warning: Running on the '" + m_phaseOnRun.ToString() + "' network.");
                    }

                    strOutputBlob = m_runProperties.GetProperty("OutputBlob", false);
                    if (strOutputBlob == null)
                    {
                        throw new Exception("You must specify the 'OutputBlob' when Running with a phase other than RUN.");
                    }

                    strOutputBlob = Utility.Replace(strOutputBlob, '~', ';');

                    phase = m_phaseOnRun;
                }
            }

            m_net = mycaffe.GetInternalNet(phase);

            // Find the first LSTM layer to determine how to load the data.
            // NOTE: Only LSTM has a special loading order, other layers use the standard N, C, H, W ordering.
            LSTMLayer <T>       lstmLayer       = null;
            LSTMSimpleLayer <T> lstmSimpleLayer = null;

            foreach (Layer <T> layer1 in m_net.layers)
            {
                if (layer1.layer_param.type == LayerParameter.LayerType.LSTM)
                {
                    lstmLayer  = layer1 as LSTMLayer <T>;
                    m_lstmType = LayerParameter.LayerType.LSTM;
                    break;
                }
                else if (layer1.layer_param.type == LayerParameter.LayerType.LSTM_SIMPLE)
                {
                    lstmSimpleLayer = layer1 as LSTMSimpleLayer <T>;
                    m_lstmType      = LayerParameter.LayerType.LSTM_SIMPLE;
                    break;
                }
            }

            if (lstmLayer == null && lstmSimpleLayer == null)
            {
                throw new Exception("Could not find the required LSTM or LSTM_SIMPLE layer!");
            }

            if (m_phaseOnRun != Phase.NONE && m_phaseOnRun != Phase.RUN && strOutputBlob != null)
            {
                if ((m_blobOutput = m_net.FindBlob(strOutputBlob)) == null)
                {
                    throw new Exception("Could not find the 'Output' layer top named '" + strOutputBlob + "'!");
                }
            }

            if ((m_blobData = m_net.FindBlob("data")) == null)
            {
                throw new Exception("Could not find the 'Input' layer top named 'data'!");
            }

            if ((m_blobClip = m_net.FindBlob("clip")) == null)
            {
                throw new Exception("Could not find the 'Input' layer top named 'clip'!");
            }

            Layer <T> layer = m_net.FindLastLayer(LayerParameter.LayerType.INNERPRODUCT);

            m_mycaffe.Log.CHECK(layer != null, "Could not find an ending INNERPRODUCT layer!");

            m_nVocabSize = (int)layer.layer_param.inner_product_param.num_output;
            if (rgVocabulary != null)
            {
                m_mycaffe.Log.CHECK_EQ(m_nVocabSize, rgVocabulary.Count, "The vocabulary count and last inner product output count should match!");
            }

            if (m_lstmType == LayerParameter.LayerType.LSTM)
            {
                m_nSequenceLength = m_blobData.shape(0);
                m_nBatchSize      = m_blobData.shape(1);
            }
            else
            {
                m_nBatchSize      = (int)lstmSimpleLayer.layer_param.lstm_simple_param.batch_size;
                m_nSequenceLength = m_blobData.shape(0) / m_nBatchSize;

                if (phase == Phase.RUN)
                {
                    m_nBatchSize = 1;

                    List <int> rgNewShape = new List <int>()
                    {
                        m_nSequenceLength, 1
                    };
                    m_blobData.Reshape(rgNewShape);
                    m_blobClip.Reshape(rgNewShape);
                    m_net.Reshape();
                }
            }

            m_mycaffe.Log.CHECK_EQ(m_blobData.count(), m_blobClip.count(), "The data and clip blobs must have the same count!");

            m_rgDataInput = new T[m_nSequenceLength * m_nBatchSize];

            T[] rgClipInput = new T[m_nSequenceLength * m_nBatchSize];
            m_tZero = (T)Convert.ChangeType(0, typeof(T));
            m_tOne  = (T)Convert.ChangeType(1, typeof(T));

            for (int i = 0; i < rgClipInput.Length; i++)
            {
                if (m_lstmType == LayerParameter.LayerType.LSTM)
                {
                    rgClipInput[i] = (i < m_nBatchSize) ? m_tZero : m_tOne;
                }
                else
                {
                    rgClipInput[i] = (i % m_nSequenceLength == 0) ? m_tZero : m_tOne;
                }
            }

            m_blobClip.mutable_cpu_data = rgClipInput;

            if (phase != Phase.RUN)
            {
                m_solver                      = mycaffe.GetInternalSolver();
                m_solver.OnStart             += m_solver_OnStart;
                m_solver.OnTestStart         += m_solver_OnTestStart;
                m_solver.OnTestingIteration  += m_solver_OnTestingIteration;
                m_solver.OnTrainingIteration += m_solver_OnTrainingIteration;

                if ((m_blobLabel = m_net.FindBlob("label")) == null)
                {
                    throw new Exception("Could not find the 'Input' layer top named 'label'!");
                }

                m_rgLabelInput = new T[m_nSequenceLength * m_nBatchSize];
                m_mycaffe.Log.CHECK_EQ(m_blobData.count(), m_blobLabel.count(), "The data and label blobs must have the same count!");
            }
        }