예제 #1
0
        public virtual NDArrayDict PackWeights(NDArrayDict args)
        {
            if (this.GateNames == null)
            {
                return(args);
            }
            foreach (var group_name in new List <string> {
                "i2h",
                "h2h"
            })
            {
                var weight = new List <NDArray>();
                var bias   = new List <NDArray>();
                foreach (var gate in this.GateNames)
                {
                    var wname = $"{_prefix}{group_name}{gate}_weight";
                    weight.Add(args[wname]);
                    var bname = $"{_prefix}{group_name}{gate}_bias";
                    bias.Add(args[bname]);
                }

                args[$"{_prefix}{group_name}_weight"] = nd.Concat(weight);
                args[$"{_prefix}{group_name}_bias"]   = nd.Concat(bias);
            }
            return(args);
        }
예제 #2
0
        public virtual void LoadParams(string fname)
        {
            var save_dict  = NDArray.Load(fname);
            var arg_params = new NDArrayDict();
            var aux_params = new NDArrayDict();

            foreach (var item in save_dict)
            {
                var arg_type = item.Key.Split(':')[0];
                var name     = item.Key.Split(':')[1];
                if (arg_type == "arg")
                {
                    arg_params[name] = item.Value;
                }
                else if (arg_type == "aux")
                {
                    aux_params[name] = item.Value;
                }
                else
                {
                    throw new Exception("Invalid param file: " + fname);
                }
            }

            SetParams(arg_params, aux_params);
        }
예제 #3
0
        public override void Update(int index, NDArray weight, NDArray grad, NDArrayDict state)
        {
            UpdateCount(index);
            var lr = GetLr(index);
            var wd = GetWd(index);

            grad = grad * RescaleGrad;
            if (ClipGradient.HasValue)
            {
                grad = nd.Clip(grad, -ClipGradient.Value, ClipGradient.Value);
            }

            if (state["momentum"] != null)
            {
                state["momentum"] *= Momentum;
                state["momentum"] += -lr * (grad + wd * weight + Lamda * grad * grad * (weight - state["prev_weight"]));
            }
            else
            {
                state["momentum"] += -lr * (grad + wd * weight + Lamda * grad * grad * (weight - state["prev_weight"]));
            }

            state["prev_weight"] = weight;
            weight += state["momentum"];
        }
예제 #4
0
        public static NDArrayDict InitData(NDArrayList data, bool allow_empty, string default_name)
        {
            var result = new NDArrayDict();

            if (data == null)
            {
                return(result);
            }

            if (!allow_empty && data.Length == 0)
            {
                throw new Exception("Data cannot be empty when allow_empty is false");
            }

            if (data.Length == 1)
            {
                result.Add(default_name, data[0]);
            }
            else
            {
                for (var i = 0; i < data.Length; i++)
                {
                    result.Add($"_{i}_{default_name}", data[i]);
                }
            }

            return(result);
        }
예제 #5
0
        public NDArray Predict(NDArray x, uint?batchSize = null)
        {
            NDArray      result   = new NDArray();
            List <float> preds    = new List <float>();
            NDArrayIter  dataIter = new NDArrayIter(new NDArray[] { x }, null);

            if (!batchSize.HasValue)
            {
                batchSize = x.Shape[0];
            }

            List <uint> inputShape  = new List <uint>();
            NDArrayDict predictArgs = new NDArrayDict();

            Model.InferArgsMap(mx.Device, predictArgs, args);
            predictArgs["X"]     = new NDArray(x.Shape);
            predictArgs["label"] = new NDArray(new Shape(batchSize.Value));
            using (var exec = Model.SimpleBind(mx.Device, predictArgs))
            {
                dataIter.BatchSize = batchSize.Value;
                dataIter.Reset();
                while (dataIter.IterNext())
                {
                    var batch = dataIter.Next();
                    batch.Data[0].CopyTo(predictArgs["X"]);
                    exec.Forward(false);
                    preds.AddRange(exec.Output.GetValues <float>());
                }
            }

            return(new NDArray(preds.ToArray()).Reshape((int)x.Shape[0], -1));
        }
예제 #6
0
        public override void Update(int index, NDArray weight, NDArray grad, NDArrayDict state)
        {
            UpdateCount(index);
            var lr        = GetLr(index);
            var wd        = GetWd(index);
            var is_sparse = grad.SType == StorageStype.RowSparse;
            var history   = state["history"];

            if (is_sparse)
            {
                nd.SparseAdagradUpdate(weight, grad, history, lr, Epsilon, wd, RescaleGrad,
                                       ClipGradient.HasValue ? ClipGradient.Value : -1);
            }
            else
            {
                grad = grad * RescaleGrad;
                if (ClipGradient.HasValue)
                {
                    grad = nd.Clip(grad, -ClipGradient.Value, ClipGradient.Value);
                }

                history += nd.Square(grad);
                var div = grad / nd.Sqrt(history + Epsilon);
                weight += (div + weight * wd) * -lr;
            }
        }
예제 #7
0
파일: NAG.cs 프로젝트: AvenSun/MxNet.Sharp
        private void _update_impl(int index, NDArray weight, NDArray grad, NDArrayDict state,
                                  bool multi_precision = false)
        {
            UpdateCount(index);
            var lr = GetLr(index);
            var wd = GetWd(index);

            if (!multi_precision)
            {
                if (state["momentum"] != null)
                {
                    weight = nd.NAGMomUpdate(weight, grad, state["momentum"], lr, Momentum, wd, RescaleGrad,
                                             ClipGradient.HasValue ? ClipGradient.Value : -1);
                }
                else
                {
                    weight = nd.SgdUpdate(weight, grad, lr, wd, RescaleGrad,
                                          ClipGradient.HasValue ? ClipGradient.Value : -1);
                }
            }
            else
            {
                if (state["momentum"] != null)
                {
                    weight = nd.MPNAGMomUpdate(weight, grad, state["momentum"], state["weight32"], lr, Momentum, wd,
                                               RescaleGrad, ClipGradient.HasValue ? ClipGradient.Value : -1);
                }
                else
                {
                    weight = nd.MpSgdUpdate(weight, grad, state["weight32"], lr, wd, RescaleGrad,
                                            ClipGradient.HasValue ? ClipGradient.Value : -1);
                }
            }
        }
예제 #8
0
        public virtual NDArrayDict UnpackWeights(NDArrayDict args)
        {
            if (GateNames == null)
            {
                return(args);
            }

            var h = _num_hidden;

            foreach (var group_name in new string[] { "i2h", "h2h" })
            {
                var weight = args[$"{_prefix}{group_name}_weight"];
                var bias   = args[$"{_prefix}{group_name}_bias"];
                for (int j = 0; j < GateNames.Length; j++)
                {
                    var    gate  = GateNames[j];
                    string wname = $"{_prefix}{group_name}{gate}_weight";
                    args[wname] = weight[$"{j * h}:{(j + 1) * h}"].Copy();
                    string bname = $"{_prefix}{group_name}{gate}_bias";
                    args[bname] = weight[$"{j * h}:{(j + 1) * h}"].Copy();
                }
            }

            return(args);
        }
예제 #9
0
        public override NDArrayDict CreateState(int index, NDArray weight)
        {
            var state = new NDArrayDict();

            state["weight_master_copy"] = null;
            state["momentum"]           = null;
            if (MultiPrecision && weight.DataType.Name == DType.Float16.Name)
            {
                state["weight_master_copy"] = weight.AsType(DType.Float32);
                if (Momentum != 0)
                {
                    state["momentum"] = nd.Zeros(weight.Shape, weight.Context, weight.DataType).ToSType(weight.SType);
                }

                return(state);
            }

            if (!MultiPrecision && weight.DataType.Name == DType.Float16.Name)
            {
                Logger.Warning("Accumulating with float16 in optimizer can lead to " +
                               "poor accuracy or slow convergence. " +
                               "Consider using multi_precision=True option of the " +
                               "SGD optimizer");
            }

            if (Momentum != 0)
            {
                state["momentum"] = nd.Zeros(weight.Shape, weight.Context, weight.DataType).ToSType(weight.SType);
            }

            return(state);
        }
예제 #10
0
        public override void Update(int index, NDArray weight, NDArray grad, NDArrayDict state)
        {
            UpdateCount(index);
            var lr = GetLr(index);
            var wd = GetWd(index);

            var t = index_update_count[index];

            grad = grad * RescaleGrad + wd * weight;
            if (ClipGradient.HasValue)
            {
                grad = nd.Clip(grad, -ClipGradient.Value, ClipGradient.Value);
            }

            var momentum_t   = Beta1 * (1 - 0.5f * (float)Math.Pow(0.96, t * ScheduleDecay));
            var momentum_t_1 = Beta1 * (1 - 0.5f * (float)Math.Pow(0.96, (t + 1) * ScheduleDecay));

            MSchedule = MSchedule * momentum_t;
            var m_schedule_next = MSchedule * momentum_t_1;
            var m_t             = state["mean"];
            var v_t             = state["variance"];

            m_t *= Beta1;
            m_t += (1 - Beta1) * grad;
            v_t *= Beta2;
            v_t += (1 - Beta2) * grad * grad;

            var grad_prime = grad / (1 - MSchedule);
            var m_t_prime  = m_t / (1 - m_schedule_next);
            var v_t_prime  = v_t / (1 - (float)Math.Pow(Beta2, t));
            var m_t_bar    = (1 - momentum_t) * grad_prime + momentum_t_1 * m_t_prime;

            weight -= lr * m_t_bar / (nd.Sqrt(v_t_prime) + Epsilon);
        }
예제 #11
0
 public void Invoke(int epoch, Symbol symbol, NDArrayDict arg_params, NDArrayDict aux_params)
 {
     if ((epoch + 1) % _period == 0)
     {
         RNN.SaveRNNCheckPoint(_cells, _prefix, epoch, symbol, arg_params, aux_params);
     }
 }
예제 #12
0
        public override NDArrayDict CreateState(int index, NDArray weight)
        {
            var state = new NDArrayDict("history");

            state["history"] = nd.Zeros(weight.Shape, weight.Context, weight.DataType).ToSType(weight.SType);
            return(state);
        }
예제 #13
0
 public void Invoke(int epoch, Symbol symbol, NDArrayDict arg_params, NDArrayDict aux_params)
 {
     _period = Math.Max(1, _period);
     if ((epoch + 1) % _period == 0)
     {
         Model.SaveCheckpoint(_prefix, epoch + 1, symbol, arg_params, aux_params);
     }
 }
예제 #14
0
        public override void Update(int index, NDArray weight, NDArray grad, NDArrayDict state)
        {
            var lr = GetLr(index);
            var wd = GetWd(index);

            UpdateCount(index);

            var   cgrad  = _cumulate_gradient(grad, index);
            float lbmult = 0;

            if (cgrad.Nums % BatchScale == 0)
            {
                grad = cgrad.Grad / BatchScale;
                if (WarmupStrategy == "lars")
                {
                    lbmult = _get_lars(weight, grad, wd);
                }
                else
                {
                    lbmult = _get_lbmult(cgrad.Nums);
                }

                lr = lr * lbmult;

                var use_multi_precision = state["weight_master_copy"] != null;
                if (!use_multi_precision)
                {
                    if (state["momentum"] != null)
                    {
                        weight = nd.SgdMomUpdate(weight, grad, state["momentum"], lr, Momentum, wd, RescaleGrad,
                                                 ClipGradient.HasValue ? ClipGradient.Value : -1);
                    }
                    else
                    {
                        weight = nd.SgdUpdate(weight, grad, lr, wd, RescaleGrad,
                                              ClipGradient.HasValue ? ClipGradient.Value : -1);
                    }
                }
                else
                {
                    if (state["momentum"] != null)
                    {
                        weight = nd.MpSgdMomUpdate(weight, grad, state["momentum"], state["weight_master_copy"], lr,
                                                   Momentum, wd, RescaleGrad, ClipGradient.HasValue ? ClipGradient.Value : -1);
                    }
                    else
                    {
                        weight = nd.MpSgdUpdate(weight, grad, state["weight_master_copy"], lr, wd, RescaleGrad,
                                                ClipGradient.HasValue ? ClipGradient.Value : -1);
                    }
                }
            }
            else
            {
                lr     = 0;
                weight = nd.SgdUpdate(weight, grad, lr, wd, RescaleGrad);
            }
        }
예제 #15
0
        public override NDArrayDict CreateState(int index, NDArray weight)
        {
            var stype = LazyUpdate ? weight.SType : StorageStype.Default;
            var state = new NDArrayDict("mean", "variance");

            state["mean"]     = nd.Zeros(weight.Shape, weight.Context, weight.DataType).ToSType(stype);
            state["variance"] = nd.Zeros(weight.Shape, weight.Context, weight.DataType).ToSType(stype);
            return(state);
        }
예제 #16
0
        public override NDArrayDict CreateState(int index, NDArray weight)
        {
            var state = new NDArrayDict();

            state["prev_d"] = nd.Zeros(weight.Shape, weight.Context, weight.DataType);
            state["prev_v"] = nd.Zeros(weight.Shape, weight.Context, weight.DataType);
            state["prev_z"] = nd.Zeros(weight.Shape, weight.Context, weight.DataType);
            return(state);
        }
예제 #17
0
        public override void Update(int index, NDArray weight, NDArray grad, NDArrayDict state)
        {
            UpdateCount(index);
            var lr = GetLr(index);
            var wd = GetWd(index);
            var t  = index_update_count[index];

            weight = nd.FtmlUpdate(weight, grad, state["prev_d"], state["prev_v"], state["prev_z"], lr, t, Beta1, Beta2,
                                   Epsilon, wd, RescaleGrad, ClipGradient.HasValue ? ClipGradient.Value : -1);
        }
예제 #18
0
파일: Ftlr.cs 프로젝트: AvenSun/MxNet.Sharp
        public override void Update(int index, NDArray weight, NDArray grad, NDArrayDict state)
        {
            UpdateCount(index);
            var lr = GetLr(index);
            var wd = GetWd(index);
            var t  = index_update_count[index];

            weight = nd.FtrlUpdate(weight, grad, state["z"], state["n"], lr, Lamda1, Beta, wd, RescaleGrad,
                                   ClipGradient.HasValue ? ClipGradient.Value : -1);
        }
예제 #19
0
        public static NDArrayDict GetDataByIdx(NDArrayDict data)
        {
            var shuffle_data = new NDArrayDict();

            foreach (var item in data)
            {
                shuffle_data.Add(item.Key, nd.Shuffle(item.Value));
            }

            return(shuffle_data);
        }
예제 #20
0
        public static NDArrayDict CellsPackWeights(BaseRNNCell[] cells, NDArrayDict args)
        {
            SymbolList ret = new SymbolList();

            foreach (var item in cells)
            {
                args = item.PackWeights(args);
            }

            return(args);
        }
예제 #21
0
파일: NAG.cs 프로젝트: AvenSun/MxNet.Sharp
        public override NDArrayDict CreateState(int index, NDArray weight)
        {
            var state = new NDArrayDict("momentum");

            if (Momentum != 0)
            {
                state["momentum"] = nd.Zeros(weight.Shape, weight.Context, weight.DataType);
            }

            return(state);
        }
예제 #22
0
        public static bool HasInstance(NDArrayDict data, DType dtype)
        {
            foreach (var item in data)
            {
                if (item.Value.DataType.Name == dtype.Name)
                {
                    return(true);
                }
            }

            return(false);
        }
예제 #23
0
        public void SaveParameters(string filename)
        {
            var arg_dict         = new NDArrayDict();
            var collected_params = CollectParamsWithPrefix();

            foreach (var item in collected_params.Items())
            {
                arg_dict[item.Key] = item.Value.Reduce();
            }

            NDArray.Save(filename, arg_dict);
        }
예제 #24
0
 public void Fit(DataIter train_data, DataIter eval_data = null, string eval_metric = "acc",
                 IEpochEndCallback[] epoch_end_callback  = null, IBatchEndCallback[] batch_end_callback = null,
                 string kvstore   = "local",
                 string optimizer = "sgd", Dictionary <string, object> optimizer_params = null,
                 IBatchEndCallback[] eval_end_callback       = null,
                 IBatchEndCallback[] eval_batch_end_callback = null, Initializer initializer = null,
                 NDArrayDict arg_params = null,
                 NDArrayDict aux_params = null, bool allow_missing = false, bool force_rebind = false,
                 bool force_init        = false, int begin_epoch   = 0, int?num_epoch = null, EvalMetric validation_metric = null,
                 Monitor monitor        = null, Func <DataBatch, NDArrayDict> sparse_row_id_fn = null)
 {
     throw new NotImplementedException();
 }
예제 #25
0
        private NDArrayList _batchify(NDArrayDict data_source)
        {
            if (Cursor > num_data)
            {
                throw new Exception("DataIter need reset");
            }

            if (last_batch_handle == "roll_over" && -BatchSize < Cursor && Cursor < 0)
            {
                if (_cache_data == null && _cache_label == null)
                {
                    throw new Exception("Next epoch should have cached data");
                }

                var cache_data  = _cache_data != null ? _cache_data : _cache_label;
                var second_data = _getdata(data_source, end: Cursor + BatchSize);
                if (_cache_data != null)
                {
                    _cache_data = null;
                }
                else
                {
                    _cache_label = null;
                }

                return(_concat(cache_data, second_data));
            }

            if (last_batch_handle == "pad" && Cursor + BatchSize > num_data)
            {
                var pad         = BatchSize - num_data + Cursor;
                var first_data  = _getdata(data_source, Cursor);
                var second_data = _getdata(data_source, end: pad);
                return(_concat(first_data, second_data));
            }

            var end_idx = 0;

            if (Cursor + BatchSize < num_data)
            {
                end_idx = Cursor + BatchSize;
            }
            else
            {
                end_idx = num_data;
            }

            return(_getdata(data_source, Cursor, end_idx));
        }
예제 #26
0
        public override void Update(int index, NDArray weight, NDArray grad, NDArrayDict state)
        {
            UpdateCount(index);
            var lr = GetLr(index);
            var wd = GetWd(index);

            var t = index_update_count[index];

            var coef1 = 1 - (float)Math.Pow(Beta1, t);
            var coef2 = 1 - (float)Math.Pow(Beta2, t);

            lr    *= (float)Math.Sqrt(coef2) / coef1;
            weight = nd.AdamUpdate(weight, grad, state["mean"], state["variance"], lr, Beta1, Beta2, Epsilon, wd,
                                   RescaleGrad, ClipGradient.HasValue ? ClipGradient.Value : -1, LazyUpdate);
        }
예제 #27
0
파일: SGLD.cs 프로젝트: AvenSun/MxNet.Sharp
        public override void Update(int index, NDArray weight, NDArray grad, NDArrayDict state)
        {
            UpdateCount(index);
            var lr = GetLr(index);
            var wd = GetWd(index);

            grad = grad * RescaleGrad;
            if (ClipGradient.HasValue)
            {
                grad = nd.Clip(grad, -ClipGradient.Value, ClipGradient.Value);
            }

            weight += -lr / 2 * (grad + wd * weight);
            weight += nd.Random.Normal(0, (float)Math.Sqrt(lr), weight.Shape, dtype: weight.DataType,
                                       ctx: weight.Context);
        }
예제 #28
0
        public override NDArrayDict UnpackWeights(NDArrayDict args)
        {
            var arr       = args[this._parameter.Name];
            var b         = this._directions.Count;
            var m         = this.NumGates;
            var h         = this._num_hidden;
            var num_input = arr.Size / b / h / m - (this._num_layers - 1) * (h + b * h + 2) - h - 2;
            var nargs     = this.SliceWeight(arr, num_input, this._num_hidden);
            var newargs   = nargs.ToDictionary(_tup_1 => _tup_1.Key, _tup_1 => _tup_1.Value.Copy());

            foreach (var item in newargs)
            {
                args[item.Key] = item.Value;
            }

            return(args);
        }
예제 #29
0
        public NDArrayIter(NDArrayList data, NDArrayList label = null, int batch_size    = 1, bool shuffle           = false,
                           string last_batch_handle            = "pad", string data_name = "data", string label_name = "softmax_label")
        {
            this.data              = IOUtils.InitData(data, false, data_name);
            this.label             = IOUtils.InitData(label, false, label_name);
            BatchSize              = batch_size;
            Cursor                 = batch_size;
            num_data               = data[0].Shape[0];
            this.last_batch_handle = last_batch_handle;
            this.shuffle           = shuffle;

            Reset();
            data_list.Add(data);
            data_list.Add(label);
            _cache_data  = null;
            _cache_label = null;
        }
예제 #30
0
        public override NDArrayDict CreateState(int index, NDArray weight)
        {
            var state = new NDArrayDict("momentum", "prev_weight");

            if (Momentum == 0)
            {
                state["momentum"]    = null;
                state["prev_weight"] = weight.Copy();
            }
            else
            {
                state["momentum"]    = nd.Zeros(weight.Shape, weight.Context, weight.DataType);
                state["prev_weight"] = weight.Copy();
            }

            return(state);
        }