/// <summary> /// Creates a new copy of this instance of the parameter. /// </summary> /// <param name="bCloneLayers">When <i>true</i>, each layer is cloned as well.</param> /// <param name="nSolverCount">Optionally, specifies a solver count for the clone.</param> /// <param name="nSolverRank">Optionally, specifies a solver rank for the clone.</param> /// <returns>A new instance of this parameter is returned.</returns> public NetParameter Clone(bool bCloneLayers = true, int?nSolverCount = null, int?nSolverRank = null) { NetParameter p = new NetParameter(); p.m_nProjectID = m_nProjectID; p.m_strName = m_strName; p.m_rgstrInput = Utility.Clone <string>(m_rgstrInput); p.m_rgInputShape = Utility.Clone <BlobShape>(m_rgInputShape); p.m_rgInputDim = Utility.Clone <int>(m_rgInputDim); p.m_bForceBackward = m_bForceBackward; p.m_state = (m_state != null) ? m_state.Clone() : null; p.m_bDebugInfo = m_bDebugInfo; if (bCloneLayers) { p.m_rgLayers = Utility.Clone <LayerParameter>(m_rgLayers); } if (nSolverCount == null) { nSolverCount = m_nSolverCount; } if (nSolverRank == null) { nSolverRank = m_nSolverRank; } p.m_nSolverCount = nSolverCount.Value; p.m_nSolverRank = nSolverRank.Value; return(p); }
/// <summary> /// Load a new instance of the parameter from a binary reader. /// </summary> /// <param name="br">Specifies the binary reader.</param> /// <returns>The new instance is returned.</returns> public static NetParameter Load(BinaryReader br) { NetParameter p = new NetParameter(); p.m_strName = br.ReadString(); p.m_rgstrInput = Utility.Load <string>(br); p.m_rgInputShape = Utility.Load <BlobShape>(br); p.m_rgInputDim = Utility.Load <int>(br); p.m_bForceBackward = br.ReadBoolean(); p.m_state = NetState.Load(br); p.m_bDebugInfo = br.ReadBoolean(); p.m_rgLayers = Utility.Load <LayerParameter>(br); return(p); }
/// <summary> /// Parse a RawProto into a new instance of the parameter. /// </summary> /// <param name="rp">Specifies the RawProto to parse.</param> /// <returns>A new instance of the parameter is returned.</returns> public static NetParameter FromProto(RawProto rp) { string strVal; NetParameter p = new NetParameter(); if ((strVal = rp.FindValue("name")) != null) { p.name = strVal; } p.input = rp.FindArray <string>("input"); RawProtoCollection rgp = rp.FindChildren("input_shape"); foreach (RawProto rpChild in rgp) { p.input_shape.Add(BlobShape.FromProto(rpChild)); } p.input_dim = rp.FindArray <int>("input_dim"); if ((strVal = rp.FindValue("force_backward")) != null) { p.force_backward = bool.Parse(strVal); } RawProto rpState = rp.FindChild("state"); if (rpState != null) { p.state = NetState.FromProto(rpState); } if ((strVal = rp.FindValue("debug_info")) != null) { p.debug_info = bool.Parse(strVal); } rgp = rp.FindChildren("layer", "layers"); foreach (RawProto rpChild in rgp) { p.layer.Add(LayerParameter.FromProto(rpChild)); } return(p); }
/// <summary> /// Parses a new SolverParameter from a RawProto. /// </summary> /// <param name="rp">Specifies the RawProto representing the SolverParameter.</param> /// <returns>The new SolverParameter instance is returned.</returns> public static SolverParameter FromProto(RawProto rp) { string strVal; SolverParameter p = new SolverParameter(); RawProto rpNetParam = rp.FindChild("net_param"); if (rpNetParam != null) { p.net_param = NetParameter.FromProto(rpNetParam); } RawProto rpTrainNetParam = rp.FindChild("train_net_param"); if (rpTrainNetParam != null) { p.train_net_param = NetParameter.FromProto(rpTrainNetParam); } RawProtoCollection rgpTn = rp.FindChildren("test_net_param"); foreach (RawProto rpTest in rgpTn) { p.test_net_param.Add(NetParameter.FromProto(rpTest)); } RawProto rpTrainState = rp.FindChild("train_state"); if (rpTrainState != null) { p.train_state = NetState.FromProto(rpTrainState); } RawProtoCollection rgpNs = rp.FindChildren("test_state"); foreach (RawProto rpNs in rgpNs) { p.test_state.Add(NetState.FromProto(rpNs)); } p.test_iter = rp.FindArray <int>("test_iter"); if ((strVal = rp.FindValue("test_interval")) != null) { p.test_interval = int.Parse(strVal); } if ((strVal = rp.FindValue("test_compute_loss")) != null) { p.test_compute_loss = bool.Parse(strVal); } if ((strVal = rp.FindValue("test_initialization")) != null) { p.test_initialization = bool.Parse(strVal); } if ((strVal = rp.FindValue("base_lr")) != null) { p.base_lr = double.Parse(strVal); } if ((strVal = rp.FindValue("display")) != null) { p.display = int.Parse(strVal); } if ((strVal = rp.FindValue("average_loss")) != null) { p.average_loss = int.Parse(strVal); } if ((strVal = rp.FindValue("max_iter")) != null) { p.max_iter = int.Parse(strVal); } if ((strVal = rp.FindValue("iter_size")) != null) { p.iter_size = int.Parse(strVal); } if ((strVal = rp.FindValue("lr_policy")) != null) { p.lr_policy = strVal; } if ((strVal = rp.FindValue("gamma")) != null) { p.gamma = double.Parse(strVal); } if ((strVal = rp.FindValue("power")) != null) { p.power = double.Parse(strVal); } if ((strVal = rp.FindValue("momentum")) != null) { p.momentum = double.Parse(strVal); } if ((strVal = rp.FindValue("weight_decay")) != null) { p.weight_decay = double.Parse(strVal); } if ((strVal = rp.FindValue("regularization_type")) != null) { p.regularization_type = strVal; } if ((strVal = rp.FindValue("stepsize")) != null) { p.stepsize = int.Parse(strVal); } p.stepvalue = rp.FindArray <int>("stepvalue"); if ((strVal = rp.FindValue("clip_gradients")) != null) { p.clip_gradients = double.Parse(strVal); } if ((strVal = rp.FindValue("snapshot")) != null) { p.snapshot = int.Parse(strVal); } if ((strVal = rp.FindValue("snapshot_prefix")) != null) { p.snapshot_prefix = strVal; } if ((strVal = rp.FindValue("snapshot_diff")) != null) { p.snapshot_diff = bool.Parse(strVal); } if ((strVal = rp.FindValue("snapshot_format")) != null) { switch (strVal) { case "BINARYPROTO": p.snapshot_format = SnapshotFormat.BINARYPROTO; break; case "HDF5": p.snapshot_format = SnapshotFormat.BINARYPROTO; break; default: throw new Exception("Unknown 'snapshot_format' value: " + strVal); } } if ((strVal = rp.FindValue("device_id")) != null) { p.device_id = int.Parse(strVal); } if ((strVal = rp.FindValue("random_seed")) != null) { p.random_seed = long.Parse(strVal); } if ((strVal = rp.FindValue("type")) != null) { string strVal1 = strVal.ToLower(); switch (strVal1) { case "sgd": p.type = SolverType.SGD; break; case "nesterov": p.type = SolverType.NESTEROV; break; case "adagrad": p.type = SolverType.ADAGRAD; break; case "adadelta": p.type = SolverType.ADADELTA; break; case "adam": p.type = SolverType.ADAM; break; case "rmsprop": p.type = SolverType.RMSPROP; break; case "lbgfs": p.type = SolverType.LBFGS; break; default: throw new Exception("Unknown solver 'type' value: " + strVal); } } if ((strVal = rp.FindValue("delta")) != null) { p.delta = double.Parse(strVal); } if ((strVal = rp.FindValue("momentum2")) != null) { p.momentum2 = double.Parse(strVal); } if ((strVal = rp.FindValue("rms_decay")) != null) { p.rms_decay = double.Parse(strVal); } if ((strVal = rp.FindValue("debug_info")) != null) { p.debug_info = bool.Parse(strVal); } if ((strVal = rp.FindValue("lbgfs_corrections")) != null) { p.lbgfs_corrections = int.Parse(strVal); } if ((strVal = rp.FindValue("snapshot_after_train")) != null) { p.snapshot_after_train = bool.Parse(strVal); } if ((strVal = rp.FindValue("custom_trainer")) != null) { p.custom_trainer = strVal; } if ((strVal = rp.FindValue("custom_trainer_properties")) != null) { p.custom_trainer_properties = strVal; } if ((strVal = rp.FindValue("output_average_results")) != null) { p.output_average_results = bool.Parse(strVal); } if ((strVal = rp.FindValue("snapshot_include_weights")) != null) { p.snapshot_include_weights = bool.Parse(strVal); } if ((strVal = rp.FindValue("snapshot_include_state")) != null) { p.snapshot_include_state = bool.Parse(strVal); } return(p); }