/// <summary> /// Creates a new copy of the SolverParameter. /// </summary> /// <returns>A new instance of the SolverParameter is returned.</returns> public SolverParameter Clone() { SolverParameter p = SolverParameter.FromProto(ToProto("clone")); return(p); }
/// <summary> /// Parses a new SolverParameter from a RawProto. /// </summary> /// <param name="rp">Specifies the RawProto representing the SolverParameter.</param> /// <returns>The new SolverParameter instance is returned.</returns> public static SolverParameter FromProto(RawProto rp) { string strVal; SolverParameter p = new SolverParameter(); RawProto rpNetParam = rp.FindChild("net_param"); if (rpNetParam != null) { p.net_param = NetParameter.FromProto(rpNetParam); } RawProto rpTrainNetParam = rp.FindChild("train_net_param"); if (rpTrainNetParam != null) { p.train_net_param = NetParameter.FromProto(rpTrainNetParam); } RawProtoCollection rgpTn = rp.FindChildren("test_net_param"); foreach (RawProto rpTest in rgpTn) { p.test_net_param.Add(NetParameter.FromProto(rpTest)); } RawProto rpTrainState = rp.FindChild("train_state"); if (rpTrainState != null) { p.train_state = NetState.FromProto(rpTrainState); } RawProtoCollection rgpNs = rp.FindChildren("test_state"); foreach (RawProto rpNs in rgpNs) { p.test_state.Add(NetState.FromProto(rpNs)); } p.test_iter = rp.FindArray <int>("test_iter"); if ((strVal = rp.FindValue("test_interval")) != null) { p.test_interval = int.Parse(strVal); } if ((strVal = rp.FindValue("test_compute_loss")) != null) { p.test_compute_loss = bool.Parse(strVal); } if ((strVal = rp.FindValue("test_initialization")) != null) { p.test_initialization = bool.Parse(strVal); } if ((strVal = rp.FindValue("base_lr")) != null) { p.base_lr = double.Parse(strVal); } if ((strVal = rp.FindValue("display")) != null) { p.display = int.Parse(strVal); } if ((strVal = rp.FindValue("average_loss")) != null) { p.average_loss = int.Parse(strVal); } if ((strVal = rp.FindValue("max_iter")) != null) { p.max_iter = int.Parse(strVal); } if ((strVal = rp.FindValue("iter_size")) != null) { p.iter_size = int.Parse(strVal); } if ((strVal = rp.FindValue("lr_policy")) != null) { p.lr_policy = strVal; } if ((strVal = rp.FindValue("gamma")) != null) { p.gamma = double.Parse(strVal); } if ((strVal = rp.FindValue("power")) != null) { p.power = double.Parse(strVal); } if ((strVal = rp.FindValue("momentum")) != null) { p.momentum = double.Parse(strVal); } if ((strVal = rp.FindValue("weight_decay")) != null) { p.weight_decay = double.Parse(strVal); } if ((strVal = rp.FindValue("regularization_type")) != null) { p.regularization_type = strVal; } if ((strVal = rp.FindValue("stepsize")) != null) { p.stepsize = int.Parse(strVal); } p.stepvalue = rp.FindArray <int>("stepvalue"); if ((strVal = rp.FindValue("clip_gradients")) != null) { p.clip_gradients = double.Parse(strVal); } if ((strVal = rp.FindValue("snapshot")) != null) { p.snapshot = int.Parse(strVal); } if ((strVal = rp.FindValue("snapshot_prefix")) != null) { p.snapshot_prefix = strVal; } if ((strVal = rp.FindValue("snapshot_diff")) != null) { p.snapshot_diff = bool.Parse(strVal); } if ((strVal = rp.FindValue("snapshot_format")) != null) { switch (strVal) { case "BINARYPROTO": p.snapshot_format = SnapshotFormat.BINARYPROTO; break; case "HDF5": p.snapshot_format = SnapshotFormat.BINARYPROTO; break; default: throw new Exception("Unknown 'snapshot_format' value: " + strVal); } } if ((strVal = rp.FindValue("device_id")) != null) { p.device_id = int.Parse(strVal); } if ((strVal = rp.FindValue("random_seed")) != null) { p.random_seed = long.Parse(strVal); } if ((strVal = rp.FindValue("type")) != null) { string strVal1 = strVal.ToLower(); switch (strVal1) { case "sgd": p.type = SolverType.SGD; break; case "nesterov": p.type = SolverType.NESTEROV; break; case "adagrad": p.type = SolverType.ADAGRAD; break; case "adadelta": p.type = SolverType.ADADELTA; break; case "adam": p.type = SolverType.ADAM; break; case "rmsprop": p.type = SolverType.RMSPROP; break; case "lbgfs": p.type = SolverType.LBFGS; break; default: throw new Exception("Unknown solver 'type' value: " + strVal); } } if ((strVal = rp.FindValue("delta")) != null) { p.delta = double.Parse(strVal); } if ((strVal = rp.FindValue("momentum2")) != null) { p.momentum2 = double.Parse(strVal); } if ((strVal = rp.FindValue("rms_decay")) != null) { p.rms_decay = double.Parse(strVal); } if ((strVal = rp.FindValue("debug_info")) != null) { p.debug_info = bool.Parse(strVal); } if ((strVal = rp.FindValue("lbgfs_corrections")) != null) { p.lbgfs_corrections = int.Parse(strVal); } if ((strVal = rp.FindValue("snapshot_after_train")) != null) { p.snapshot_after_train = bool.Parse(strVal); } if ((strVal = rp.FindValue("custom_trainer")) != null) { p.custom_trainer = strVal; } if ((strVal = rp.FindValue("custom_trainer_properties")) != null) { p.custom_trainer_properties = strVal; } if ((strVal = rp.FindValue("output_average_results")) != null) { p.output_average_results = bool.Parse(strVal); } if ((strVal = rp.FindValue("snapshot_include_weights")) != null) { p.snapshot_include_weights = bool.Parse(strVal); } if ((strVal = rp.FindValue("snapshot_include_state")) != null) { p.snapshot_include_state = bool.Parse(strVal); } return(p); }