예제 #1
0
        /// <summary>
        /// Creates a new copy of this instance of the parameter.
        /// </summary>
        /// <param name="bCloneLayers">When <i>true</i>, each layer is cloned as well.</param>
        /// <param name="nSolverCount">Optionally, specifies a solver count for the clone.</param>
        /// <param name="nSolverRank">Optionally, specifies a solver rank for the clone.</param>
        /// <returns>A new instance of this parameter is returned.</returns>
        public NetParameter Clone(bool bCloneLayers = true, int?nSolverCount = null, int?nSolverRank = null)
        {
            NetParameter p = new NetParameter();

            p.m_nProjectID     = m_nProjectID;
            p.m_strName        = m_strName;
            p.m_rgstrInput     = Utility.Clone <string>(m_rgstrInput);
            p.m_rgInputShape   = Utility.Clone <BlobShape>(m_rgInputShape);
            p.m_rgInputDim     = Utility.Clone <int>(m_rgInputDim);
            p.m_bForceBackward = m_bForceBackward;
            p.m_state          = (m_state != null) ? m_state.Clone() : null;
            p.m_bDebugInfo     = m_bDebugInfo;

            if (bCloneLayers)
            {
                p.m_rgLayers = Utility.Clone <LayerParameter>(m_rgLayers);
            }

            if (nSolverCount == null)
            {
                nSolverCount = m_nSolverCount;
            }

            if (nSolverRank == null)
            {
                nSolverRank = m_nSolverRank;
            }

            p.m_nSolverCount = nSolverCount.Value;
            p.m_nSolverRank  = nSolverRank.Value;

            return(p);
        }
예제 #2
0
        /// <summary>
        /// Load a new instance of the parameter from a binary reader.
        /// </summary>
        /// <param name="br">Specifies the binary reader.</param>
        /// <returns>The new instance is returned.</returns>
        public static NetParameter Load(BinaryReader br)
        {
            NetParameter p = new NetParameter();

            p.m_strName        = br.ReadString();
            p.m_rgstrInput     = Utility.Load <string>(br);
            p.m_rgInputShape   = Utility.Load <BlobShape>(br);
            p.m_rgInputDim     = Utility.Load <int>(br);
            p.m_bForceBackward = br.ReadBoolean();
            p.m_state          = NetState.Load(br);
            p.m_bDebugInfo     = br.ReadBoolean();
            p.m_rgLayers       = Utility.Load <LayerParameter>(br);

            return(p);
        }
예제 #3
0
        /// <summary>
        /// Parse a RawProto into a new instance of the parameter.
        /// </summary>
        /// <param name="rp">Specifies the RawProto to parse.</param>
        /// <returns>A new instance of the parameter is returned.</returns>
        public static NetParameter FromProto(RawProto rp)
        {
            string       strVal;
            NetParameter p = new NetParameter();

            if ((strVal = rp.FindValue("name")) != null)
            {
                p.name = strVal;
            }

            p.input = rp.FindArray <string>("input");

            RawProtoCollection rgp = rp.FindChildren("input_shape");

            foreach (RawProto rpChild in rgp)
            {
                p.input_shape.Add(BlobShape.FromProto(rpChild));
            }

            p.input_dim = rp.FindArray <int>("input_dim");

            if ((strVal = rp.FindValue("force_backward")) != null)
            {
                p.force_backward = bool.Parse(strVal);
            }

            RawProto rpState = rp.FindChild("state");

            if (rpState != null)
            {
                p.state = NetState.FromProto(rpState);
            }

            if ((strVal = rp.FindValue("debug_info")) != null)
            {
                p.debug_info = bool.Parse(strVal);
            }

            rgp = rp.FindChildren("layer", "layers");
            foreach (RawProto rpChild in rgp)
            {
                p.layer.Add(LayerParameter.FromProto(rpChild));
            }

            return(p);
        }
예제 #4
0
        /// <summary>
        /// Parses a new SolverParameter from a RawProto.
        /// </summary>
        /// <param name="rp">Specifies the RawProto representing the SolverParameter.</param>
        /// <returns>The new SolverParameter instance is returned.</returns>
        public static SolverParameter FromProto(RawProto rp)
        {
            string          strVal;
            SolverParameter p = new SolverParameter();

            RawProto rpNetParam = rp.FindChild("net_param");

            if (rpNetParam != null)
            {
                p.net_param = NetParameter.FromProto(rpNetParam);
            }

            RawProto rpTrainNetParam = rp.FindChild("train_net_param");

            if (rpTrainNetParam != null)
            {
                p.train_net_param = NetParameter.FromProto(rpTrainNetParam);
            }

            RawProtoCollection rgpTn = rp.FindChildren("test_net_param");

            foreach (RawProto rpTest in rgpTn)
            {
                p.test_net_param.Add(NetParameter.FromProto(rpTest));
            }

            RawProto rpTrainState = rp.FindChild("train_state");

            if (rpTrainState != null)
            {
                p.train_state = NetState.FromProto(rpTrainState);
            }

            RawProtoCollection rgpNs = rp.FindChildren("test_state");

            foreach (RawProto rpNs in rgpNs)
            {
                p.test_state.Add(NetState.FromProto(rpNs));
            }

            p.test_iter = rp.FindArray <int>("test_iter");

            if ((strVal = rp.FindValue("test_interval")) != null)
            {
                p.test_interval = int.Parse(strVal);
            }

            if ((strVal = rp.FindValue("test_compute_loss")) != null)
            {
                p.test_compute_loss = bool.Parse(strVal);
            }

            if ((strVal = rp.FindValue("test_initialization")) != null)
            {
                p.test_initialization = bool.Parse(strVal);
            }

            if ((strVal = rp.FindValue("base_lr")) != null)
            {
                p.base_lr = double.Parse(strVal);
            }

            if ((strVal = rp.FindValue("display")) != null)
            {
                p.display = int.Parse(strVal);
            }

            if ((strVal = rp.FindValue("average_loss")) != null)
            {
                p.average_loss = int.Parse(strVal);
            }

            if ((strVal = rp.FindValue("max_iter")) != null)
            {
                p.max_iter = int.Parse(strVal);
            }

            if ((strVal = rp.FindValue("iter_size")) != null)
            {
                p.iter_size = int.Parse(strVal);
            }

            if ((strVal = rp.FindValue("lr_policy")) != null)
            {
                p.lr_policy = strVal;
            }

            if ((strVal = rp.FindValue("gamma")) != null)
            {
                p.gamma = double.Parse(strVal);
            }

            if ((strVal = rp.FindValue("power")) != null)
            {
                p.power = double.Parse(strVal);
            }

            if ((strVal = rp.FindValue("momentum")) != null)
            {
                p.momentum = double.Parse(strVal);
            }

            if ((strVal = rp.FindValue("weight_decay")) != null)
            {
                p.weight_decay = double.Parse(strVal);
            }

            if ((strVal = rp.FindValue("regularization_type")) != null)
            {
                p.regularization_type = strVal;
            }

            if ((strVal = rp.FindValue("stepsize")) != null)
            {
                p.stepsize = int.Parse(strVal);
            }

            p.stepvalue = rp.FindArray <int>("stepvalue");

            if ((strVal = rp.FindValue("clip_gradients")) != null)
            {
                p.clip_gradients = double.Parse(strVal);
            }

            if ((strVal = rp.FindValue("snapshot")) != null)
            {
                p.snapshot = int.Parse(strVal);
            }

            if ((strVal = rp.FindValue("snapshot_prefix")) != null)
            {
                p.snapshot_prefix = strVal;
            }

            if ((strVal = rp.FindValue("snapshot_diff")) != null)
            {
                p.snapshot_diff = bool.Parse(strVal);
            }

            if ((strVal = rp.FindValue("snapshot_format")) != null)
            {
                switch (strVal)
                {
                case "BINARYPROTO":
                    p.snapshot_format = SnapshotFormat.BINARYPROTO;
                    break;

                case "HDF5":
                    p.snapshot_format = SnapshotFormat.BINARYPROTO;
                    break;

                default:
                    throw new Exception("Unknown 'snapshot_format' value: " + strVal);
                }
            }

            if ((strVal = rp.FindValue("device_id")) != null)
            {
                p.device_id = int.Parse(strVal);
            }

            if ((strVal = rp.FindValue("random_seed")) != null)
            {
                p.random_seed = long.Parse(strVal);
            }

            if ((strVal = rp.FindValue("type")) != null)
            {
                string strVal1 = strVal.ToLower();

                switch (strVal1)
                {
                case "sgd":
                    p.type = SolverType.SGD;
                    break;

                case "nesterov":
                    p.type = SolverType.NESTEROV;
                    break;

                case "adagrad":
                    p.type = SolverType.ADAGRAD;
                    break;

                case "adadelta":
                    p.type = SolverType.ADADELTA;
                    break;

                case "adam":
                    p.type = SolverType.ADAM;
                    break;

                case "rmsprop":
                    p.type = SolverType.RMSPROP;
                    break;

                case "lbgfs":
                    p.type = SolverType.LBFGS;
                    break;

                default:
                    throw new Exception("Unknown solver 'type' value: " + strVal);
                }
            }

            if ((strVal = rp.FindValue("delta")) != null)
            {
                p.delta = double.Parse(strVal);
            }

            if ((strVal = rp.FindValue("momentum2")) != null)
            {
                p.momentum2 = double.Parse(strVal);
            }

            if ((strVal = rp.FindValue("rms_decay")) != null)
            {
                p.rms_decay = double.Parse(strVal);
            }

            if ((strVal = rp.FindValue("debug_info")) != null)
            {
                p.debug_info = bool.Parse(strVal);
            }

            if ((strVal = rp.FindValue("lbgfs_corrections")) != null)
            {
                p.lbgfs_corrections = int.Parse(strVal);
            }

            if ((strVal = rp.FindValue("snapshot_after_train")) != null)
            {
                p.snapshot_after_train = bool.Parse(strVal);
            }

            if ((strVal = rp.FindValue("custom_trainer")) != null)
            {
                p.custom_trainer = strVal;
            }

            if ((strVal = rp.FindValue("custom_trainer_properties")) != null)
            {
                p.custom_trainer_properties = strVal;
            }

            if ((strVal = rp.FindValue("output_average_results")) != null)
            {
                p.output_average_results = bool.Parse(strVal);
            }

            if ((strVal = rp.FindValue("snapshot_include_weights")) != null)
            {
                p.snapshot_include_weights = bool.Parse(strVal);
            }

            if ((strVal = rp.FindValue("snapshot_include_state")) != null)
            {
                p.snapshot_include_state = bool.Parse(strVal);
            }

            return(p);
        }