Esempio n. 1
0
        public static (string, ParameterDict) Create(string prefix, ParameterDict @params, string hint)
        {
            var current = _current.IsValueCreated ? _current.Value : null;

            if (current == null)
            {
                if (prefix == null)
                {
                    if (!NameManager.current.IsValueCreated)
                    {
                        NameManager.current.Value = new NameManager();
                    }

                    prefix = NameManager.current.Value.Get(null, hint) + "_";
                }

                if (@params == null)
                {
                    @params = new ParameterDict(prefix);
                }
                else
                {
                    @params = new ParameterDict(@params.Prefix, @params);
                }

                return(prefix, @params);
            }

            if (string.IsNullOrWhiteSpace(prefix))
            {
                var count = current._counter.ContainsKey(hint) ? _current.Value._counter[hint] : 0;
                prefix = hint + count;
                current._counter[hint] = count + 1;
            }

            if (@params == null)
            {
                var parent = current._block.Params;
                @params = new ParameterDict(parent.Prefix + prefix, parent.Shared);
            }
            else
            {
                @params = new ParameterDict(@params.Prefix + prefix, @params);
            }

            return(current._block.Prefix + prefix, @params);
        }
Esempio n. 2
0
        public CTCLoss(string layout = "NTC", string label_layout  = "NT", float?weight = null, int?batch_axis = 0,
                       string prefix = null, ParameterDict @params = null) : base(weight, batch_axis, prefix, @params)
        {
            if (layout != "NTC" && layout != "TNC")
            {
                throw new ArgumentException($"Only 'NTC' and 'TNC' layouts for pred are supported. Got: {layout}");
            }

            if (label_layout != "NT" && label_layout != "TN")
            {
                throw new ArgumentException(
                          $"Only 'NTC' and 'TNC' layouts for label are supported. Got: {label_layout}");
            }

            Layout      = layout;
            LabelLayout = label_layout;
            BatchAxis   = label_layout.ToCharArray().ToList().IndexOf('N');
        }
Esempio n. 3
0
        public virtual ParameterDict CollectParamsWithPrefix(string prefix = "")
        {
            var ret = new ParameterDict();

            if (!string.IsNullOrWhiteSpace(prefix))
            {
                prefix += ".";
            }

            foreach (var item in _reg_params)
            {
                ret[prefix + item.Key] = item.Value;
            }

            foreach (var item in _childrens)
            {
                ret.Update(item.Value.CollectParamsWithPrefix(prefix + item.Key));
            }

            return(ret);
        }
Esempio n. 4
0
        private ParameterDict CollectParamsWithPrefix(string prefix = "")
        {
            if (!string.IsNullOrWhiteSpace(prefix))
            {
                prefix += ".";
            }

            var ret = new ParameterDict();

            foreach (var item in Params.Items())
            {
                ret[prefix + item.Key] = item.Value;
            }

            foreach (var item in _childrens.Values)
            {
                ret.Update(item.CollectParamsWithPrefix(prefix + item.Name));
            }

            return(ret);
        }
Esempio n. 5
0
        public Trainer(ParameterDict @params, Optimizer optimizer, string kvstore = "device",
                       Dictionary <string, object> compression_params             = null, bool?update_on_kvstore = null)
        {
            var paramValues = @params.Values();

            _params = new List <Parameter>();
            var keys = @params.Keys();

            for (var i = 0; i < keys.Length; i++)
            {
                var param = @params[keys[i]];
                _param2idx[keys[i]] = i;
                _params.Add(param);
                param.SetTrainer(this);
                if (param.Stype != StorageStype.Default)
                {
                    _contains_sparse_weight = true;
                }

                if (param.Grad_Stype != StorageStype.Default)
                {
                    _contains_sparse_grad = true;
                }
            }

            _compression_params = compression_params;
            _contexts           = CheckContexts();
            InitOptimizer(optimizer);
            _scale          = optimizer.RescaleGrad;
            _kvstore_params = new Dictionary <string, object>();
            _kvstore_params.Add("kvstore", kvstore);
            _kvstore_params.Add("update_on_kvstore", update_on_kvstore);
            _kvstore           = null;
            _update_on_kvstore = null;
            _params_to_init    = new List <Parameter>();
            ResetKVstore();
        }
Esempio n. 6
0
 public SymbolBlock(SymbolList outputs, SymbolList inputs, ParameterDict @params = null)
     : base("", new ParameterDict("", @params))
 {
     Construct(outputs, inputs, @params);
 }
Esempio n. 7
0
 public ParameterDict(string prefix = "", ParameterDict shared = null)
 {
     Prefix  = prefix;
     Shared  = shared;
     _params = new Dictionary <string, Parameter>();
 }
Esempio n. 8
0
 public SoftmaxCELoss(int axis     = -1, bool sparse_label = true, bool from_logits = false,
                      float?weight = null, int?batch_axis  = 0, string prefix       = null, ParameterDict @params = null) : base(
         axis, sparse_label, from_logits, weight, batch_axis, prefix, @params)
 {
 }
Esempio n. 9
0
 public SoftmaxCrossEntropyLoss(int axis     = -1, bool sparse_label = true, bool from_logits = false,
                                float?weight = null, int?batch_axis  = 0, string prefix       = null, ParameterDict @params = null) : base(
         weight, batch_axis, prefix, @params)
 {
     _axis         = axis;
     _sparse_label = sparse_label;
     _from_logits  = from_logits;
 }