Exemplo n.º 1
0
        /// <summary>
        /// Creates a new copy of the SolverParameter.
        /// </summary>
        /// <returns>A new instance of the SolverParameter is returned.</returns>
        public SolverParameter Clone()
        {
            SolverParameter p = SolverParameter.FromProto(ToProto("clone"));

            return(p);
        }
Exemplo n.º 2
0
        /// <summary>
        /// Parses a new SolverParameter from a RawProto.
        /// </summary>
        /// <param name="rp">Specifies the RawProto representing the SolverParameter.</param>
        /// <returns>The new SolverParameter instance is returned.</returns>
        public static SolverParameter FromProto(RawProto rp)
        {
            string          strVal;
            SolverParameter p = new SolverParameter();

            RawProto rpNetParam = rp.FindChild("net_param");

            if (rpNetParam != null)
            {
                p.net_param = NetParameter.FromProto(rpNetParam);
            }

            RawProto rpTrainNetParam = rp.FindChild("train_net_param");

            if (rpTrainNetParam != null)
            {
                p.train_net_param = NetParameter.FromProto(rpTrainNetParam);
            }

            RawProtoCollection rgpTn = rp.FindChildren("test_net_param");

            foreach (RawProto rpTest in rgpTn)
            {
                p.test_net_param.Add(NetParameter.FromProto(rpTest));
            }

            RawProto rpTrainState = rp.FindChild("train_state");

            if (rpTrainState != null)
            {
                p.train_state = NetState.FromProto(rpTrainState);
            }

            RawProtoCollection rgpNs = rp.FindChildren("test_state");

            foreach (RawProto rpNs in rgpNs)
            {
                p.test_state.Add(NetState.FromProto(rpNs));
            }

            p.test_iter = rp.FindArray <int>("test_iter");

            if ((strVal = rp.FindValue("test_interval")) != null)
            {
                p.test_interval = int.Parse(strVal);
            }

            if ((strVal = rp.FindValue("test_compute_loss")) != null)
            {
                p.test_compute_loss = bool.Parse(strVal);
            }

            if ((strVal = rp.FindValue("test_initialization")) != null)
            {
                p.test_initialization = bool.Parse(strVal);
            }

            if ((strVal = rp.FindValue("base_lr")) != null)
            {
                p.base_lr = double.Parse(strVal);
            }

            if ((strVal = rp.FindValue("display")) != null)
            {
                p.display = int.Parse(strVal);
            }

            if ((strVal = rp.FindValue("average_loss")) != null)
            {
                p.average_loss = int.Parse(strVal);
            }

            if ((strVal = rp.FindValue("max_iter")) != null)
            {
                p.max_iter = int.Parse(strVal);
            }

            if ((strVal = rp.FindValue("iter_size")) != null)
            {
                p.iter_size = int.Parse(strVal);
            }

            if ((strVal = rp.FindValue("lr_policy")) != null)
            {
                p.lr_policy = strVal;
            }

            if ((strVal = rp.FindValue("gamma")) != null)
            {
                p.gamma = double.Parse(strVal);
            }

            if ((strVal = rp.FindValue("power")) != null)
            {
                p.power = double.Parse(strVal);
            }

            if ((strVal = rp.FindValue("momentum")) != null)
            {
                p.momentum = double.Parse(strVal);
            }

            if ((strVal = rp.FindValue("weight_decay")) != null)
            {
                p.weight_decay = double.Parse(strVal);
            }

            if ((strVal = rp.FindValue("regularization_type")) != null)
            {
                p.regularization_type = strVal;
            }

            if ((strVal = rp.FindValue("stepsize")) != null)
            {
                p.stepsize = int.Parse(strVal);
            }

            p.stepvalue = rp.FindArray <int>("stepvalue");

            if ((strVal = rp.FindValue("clip_gradients")) != null)
            {
                p.clip_gradients = double.Parse(strVal);
            }

            if ((strVal = rp.FindValue("snapshot")) != null)
            {
                p.snapshot = int.Parse(strVal);
            }

            if ((strVal = rp.FindValue("snapshot_prefix")) != null)
            {
                p.snapshot_prefix = strVal;
            }

            if ((strVal = rp.FindValue("snapshot_diff")) != null)
            {
                p.snapshot_diff = bool.Parse(strVal);
            }

            if ((strVal = rp.FindValue("snapshot_format")) != null)
            {
                switch (strVal)
                {
                case "BINARYPROTO":
                    p.snapshot_format = SnapshotFormat.BINARYPROTO;
                    break;

                case "HDF5":
                    p.snapshot_format = SnapshotFormat.BINARYPROTO;
                    break;

                default:
                    throw new Exception("Unknown 'snapshot_format' value: " + strVal);
                }
            }

            if ((strVal = rp.FindValue("device_id")) != null)
            {
                p.device_id = int.Parse(strVal);
            }

            if ((strVal = rp.FindValue("random_seed")) != null)
            {
                p.random_seed = long.Parse(strVal);
            }

            if ((strVal = rp.FindValue("type")) != null)
            {
                string strVal1 = strVal.ToLower();

                switch (strVal1)
                {
                case "sgd":
                    p.type = SolverType.SGD;
                    break;

                case "nesterov":
                    p.type = SolverType.NESTEROV;
                    break;

                case "adagrad":
                    p.type = SolverType.ADAGRAD;
                    break;

                case "adadelta":
                    p.type = SolverType.ADADELTA;
                    break;

                case "adam":
                    p.type = SolverType.ADAM;
                    break;

                case "rmsprop":
                    p.type = SolverType.RMSPROP;
                    break;

                case "lbgfs":
                    p.type = SolverType.LBFGS;
                    break;

                default:
                    throw new Exception("Unknown solver 'type' value: " + strVal);
                }
            }

            if ((strVal = rp.FindValue("delta")) != null)
            {
                p.delta = double.Parse(strVal);
            }

            if ((strVal = rp.FindValue("momentum2")) != null)
            {
                p.momentum2 = double.Parse(strVal);
            }

            if ((strVal = rp.FindValue("rms_decay")) != null)
            {
                p.rms_decay = double.Parse(strVal);
            }

            if ((strVal = rp.FindValue("debug_info")) != null)
            {
                p.debug_info = bool.Parse(strVal);
            }

            if ((strVal = rp.FindValue("lbgfs_corrections")) != null)
            {
                p.lbgfs_corrections = int.Parse(strVal);
            }

            if ((strVal = rp.FindValue("snapshot_after_train")) != null)
            {
                p.snapshot_after_train = bool.Parse(strVal);
            }

            if ((strVal = rp.FindValue("custom_trainer")) != null)
            {
                p.custom_trainer = strVal;
            }

            if ((strVal = rp.FindValue("custom_trainer_properties")) != null)
            {
                p.custom_trainer_properties = strVal;
            }

            if ((strVal = rp.FindValue("output_average_results")) != null)
            {
                p.output_average_results = bool.Parse(strVal);
            }

            if ((strVal = rp.FindValue("snapshot_include_weights")) != null)
            {
                p.snapshot_include_weights = bool.Parse(strVal);
            }

            if ((strVal = rp.FindValue("snapshot_include_state")) != null)
            {
                p.snapshot_include_state = bool.Parse(strVal);
            }

            return(p);
        }