public static (string, ParameterDict) Create(string prefix, ParameterDict @params, string hint) { var current = _current.IsValueCreated ? _current.Value : null; if (current == null) { if (prefix == null) { if (!NameManager.current.IsValueCreated) { NameManager.current.Value = new NameManager(); } prefix = NameManager.current.Value.Get(null, hint) + "_"; } if (@params == null) { @params = new ParameterDict(prefix); } else { @params = new ParameterDict(@params.Prefix, @params); } return(prefix, @params); } if (string.IsNullOrWhiteSpace(prefix)) { var count = current._counter.ContainsKey(hint) ? _current.Value._counter[hint] : 0; prefix = hint + count; current._counter[hint] = count + 1; } if (@params == null) { var parent = current._block.Params; @params = new ParameterDict(parent.Prefix + prefix, parent.Shared); } else { @params = new ParameterDict(@params.Prefix + prefix, @params); } return(current._block.Prefix + prefix, @params); }
public CTCLoss(string layout = "NTC", string label_layout = "NT", float?weight = null, int?batch_axis = 0, string prefix = null, ParameterDict @params = null) : base(weight, batch_axis, prefix, @params) { if (layout != "NTC" && layout != "TNC") { throw new ArgumentException($"Only 'NTC' and 'TNC' layouts for pred are supported. Got: {layout}"); } if (label_layout != "NT" && label_layout != "TN") { throw new ArgumentException( $"Only 'NTC' and 'TNC' layouts for label are supported. Got: {label_layout}"); } Layout = layout; LabelLayout = label_layout; BatchAxis = label_layout.ToCharArray().ToList().IndexOf('N'); }
public virtual ParameterDict CollectParamsWithPrefix(string prefix = "") { var ret = new ParameterDict(); if (!string.IsNullOrWhiteSpace(prefix)) { prefix += "."; } foreach (var item in _reg_params) { ret[prefix + item.Key] = item.Value; } foreach (var item in _childrens) { ret.Update(item.Value.CollectParamsWithPrefix(prefix + item.Key)); } return(ret); }
private ParameterDict CollectParamsWithPrefix(string prefix = "") { if (!string.IsNullOrWhiteSpace(prefix)) { prefix += "."; } var ret = new ParameterDict(); foreach (var item in Params.Items()) { ret[prefix + item.Key] = item.Value; } foreach (var item in _childrens.Values) { ret.Update(item.CollectParamsWithPrefix(prefix + item.Name)); } return(ret); }
public Trainer(ParameterDict @params, Optimizer optimizer, string kvstore = "device", Dictionary <string, object> compression_params = null, bool?update_on_kvstore = null) { var paramValues = @params.Values(); _params = new List <Parameter>(); var keys = @params.Keys(); for (var i = 0; i < keys.Length; i++) { var param = @params[keys[i]]; _param2idx[keys[i]] = i; _params.Add(param); param.SetTrainer(this); if (param.Stype != StorageStype.Default) { _contains_sparse_weight = true; } if (param.Grad_Stype != StorageStype.Default) { _contains_sparse_grad = true; } } _compression_params = compression_params; _contexts = CheckContexts(); InitOptimizer(optimizer); _scale = optimizer.RescaleGrad; _kvstore_params = new Dictionary <string, object>(); _kvstore_params.Add("kvstore", kvstore); _kvstore_params.Add("update_on_kvstore", update_on_kvstore); _kvstore = null; _update_on_kvstore = null; _params_to_init = new List <Parameter>(); ResetKVstore(); }
public SymbolBlock(SymbolList outputs, SymbolList inputs, ParameterDict @params = null) : base("", new ParameterDict("", @params)) { Construct(outputs, inputs, @params); }
public ParameterDict(string prefix = "", ParameterDict shared = null) { Prefix = prefix; Shared = shared; _params = new Dictionary <string, Parameter>(); }
public SoftmaxCELoss(int axis = -1, bool sparse_label = true, bool from_logits = false, float?weight = null, int?batch_axis = 0, string prefix = null, ParameterDict @params = null) : base( axis, sparse_label, from_logits, weight, batch_axis, prefix, @params) { }
public SoftmaxCrossEntropyLoss(int axis = -1, bool sparse_label = true, bool from_logits = false, float?weight = null, int?batch_axis = 0, string prefix = null, ParameterDict @params = null) : base( weight, batch_axis, prefix, @params) { _axis = axis; _sparse_label = sparse_label; _from_logits = from_logits; }