public virtual NDArrayDict PackWeights(NDArrayDict args) { if (this.GateNames == null) { return(args); } foreach (var group_name in new List <string> { "i2h", "h2h" }) { var weight = new List <NDArray>(); var bias = new List <NDArray>(); foreach (var gate in this.GateNames) { var wname = $"{_prefix}{group_name}{gate}_weight"; weight.Add(args[wname]); var bname = $"{_prefix}{group_name}{gate}_bias"; bias.Add(args[bname]); } args[$"{_prefix}{group_name}_weight"] = nd.Concat(weight); args[$"{_prefix}{group_name}_bias"] = nd.Concat(bias); } return(args); }
public virtual void LoadParams(string fname) { var save_dict = NDArray.Load(fname); var arg_params = new NDArrayDict(); var aux_params = new NDArrayDict(); foreach (var item in save_dict) { var arg_type = item.Key.Split(':')[0]; var name = item.Key.Split(':')[1]; if (arg_type == "arg") { arg_params[name] = item.Value; } else if (arg_type == "aux") { aux_params[name] = item.Value; } else { throw new Exception("Invalid param file: " + fname); } } SetParams(arg_params, aux_params); }
public override void Update(int index, NDArray weight, NDArray grad, NDArrayDict state) { UpdateCount(index); var lr = GetLr(index); var wd = GetWd(index); grad = grad * RescaleGrad; if (ClipGradient.HasValue) { grad = nd.Clip(grad, -ClipGradient.Value, ClipGradient.Value); } if (state["momentum"] != null) { state["momentum"] *= Momentum; state["momentum"] += -lr * (grad + wd * weight + Lamda * grad * grad * (weight - state["prev_weight"])); } else { state["momentum"] += -lr * (grad + wd * weight + Lamda * grad * grad * (weight - state["prev_weight"])); } state["prev_weight"] = weight; weight += state["momentum"]; }
public static NDArrayDict InitData(NDArrayList data, bool allow_empty, string default_name) { var result = new NDArrayDict(); if (data == null) { return(result); } if (!allow_empty && data.Length == 0) { throw new Exception("Data cannot be empty when allow_empty is false"); } if (data.Length == 1) { result.Add(default_name, data[0]); } else { for (var i = 0; i < data.Length; i++) { result.Add($"_{i}_{default_name}", data[i]); } } return(result); }
public NDArray Predict(NDArray x, uint?batchSize = null) { NDArray result = new NDArray(); List <float> preds = new List <float>(); NDArrayIter dataIter = new NDArrayIter(new NDArray[] { x }, null); if (!batchSize.HasValue) { batchSize = x.Shape[0]; } List <uint> inputShape = new List <uint>(); NDArrayDict predictArgs = new NDArrayDict(); Model.InferArgsMap(mx.Device, predictArgs, args); predictArgs["X"] = new NDArray(x.Shape); predictArgs["label"] = new NDArray(new Shape(batchSize.Value)); using (var exec = Model.SimpleBind(mx.Device, predictArgs)) { dataIter.BatchSize = batchSize.Value; dataIter.Reset(); while (dataIter.IterNext()) { var batch = dataIter.Next(); batch.Data[0].CopyTo(predictArgs["X"]); exec.Forward(false); preds.AddRange(exec.Output.GetValues <float>()); } } return(new NDArray(preds.ToArray()).Reshape((int)x.Shape[0], -1)); }
public override void Update(int index, NDArray weight, NDArray grad, NDArrayDict state) { UpdateCount(index); var lr = GetLr(index); var wd = GetWd(index); var is_sparse = grad.SType == StorageStype.RowSparse; var history = state["history"]; if (is_sparse) { nd.SparseAdagradUpdate(weight, grad, history, lr, Epsilon, wd, RescaleGrad, ClipGradient.HasValue ? ClipGradient.Value : -1); } else { grad = grad * RescaleGrad; if (ClipGradient.HasValue) { grad = nd.Clip(grad, -ClipGradient.Value, ClipGradient.Value); } history += nd.Square(grad); var div = grad / nd.Sqrt(history + Epsilon); weight += (div + weight * wd) * -lr; } }
private void _update_impl(int index, NDArray weight, NDArray grad, NDArrayDict state, bool multi_precision = false) { UpdateCount(index); var lr = GetLr(index); var wd = GetWd(index); if (!multi_precision) { if (state["momentum"] != null) { weight = nd.NAGMomUpdate(weight, grad, state["momentum"], lr, Momentum, wd, RescaleGrad, ClipGradient.HasValue ? ClipGradient.Value : -1); } else { weight = nd.SgdUpdate(weight, grad, lr, wd, RescaleGrad, ClipGradient.HasValue ? ClipGradient.Value : -1); } } else { if (state["momentum"] != null) { weight = nd.MPNAGMomUpdate(weight, grad, state["momentum"], state["weight32"], lr, Momentum, wd, RescaleGrad, ClipGradient.HasValue ? ClipGradient.Value : -1); } else { weight = nd.MpSgdUpdate(weight, grad, state["weight32"], lr, wd, RescaleGrad, ClipGradient.HasValue ? ClipGradient.Value : -1); } } }
public virtual NDArrayDict UnpackWeights(NDArrayDict args) { if (GateNames == null) { return(args); } var h = _num_hidden; foreach (var group_name in new string[] { "i2h", "h2h" }) { var weight = args[$"{_prefix}{group_name}_weight"]; var bias = args[$"{_prefix}{group_name}_bias"]; for (int j = 0; j < GateNames.Length; j++) { var gate = GateNames[j]; string wname = $"{_prefix}{group_name}{gate}_weight"; args[wname] = weight[$"{j * h}:{(j + 1) * h}"].Copy(); string bname = $"{_prefix}{group_name}{gate}_bias"; args[bname] = weight[$"{j * h}:{(j + 1) * h}"].Copy(); } } return(args); }
public override NDArrayDict CreateState(int index, NDArray weight) { var state = new NDArrayDict(); state["weight_master_copy"] = null; state["momentum"] = null; if (MultiPrecision && weight.DataType.Name == DType.Float16.Name) { state["weight_master_copy"] = weight.AsType(DType.Float32); if (Momentum != 0) { state["momentum"] = nd.Zeros(weight.Shape, weight.Context, weight.DataType).ToSType(weight.SType); } return(state); } if (!MultiPrecision && weight.DataType.Name == DType.Float16.Name) { Logger.Warning("Accumulating with float16 in optimizer can lead to " + "poor accuracy or slow convergence. " + "Consider using multi_precision=True option of the " + "SGD optimizer"); } if (Momentum != 0) { state["momentum"] = nd.Zeros(weight.Shape, weight.Context, weight.DataType).ToSType(weight.SType); } return(state); }
public override void Update(int index, NDArray weight, NDArray grad, NDArrayDict state) { UpdateCount(index); var lr = GetLr(index); var wd = GetWd(index); var t = index_update_count[index]; grad = grad * RescaleGrad + wd * weight; if (ClipGradient.HasValue) { grad = nd.Clip(grad, -ClipGradient.Value, ClipGradient.Value); } var momentum_t = Beta1 * (1 - 0.5f * (float)Math.Pow(0.96, t * ScheduleDecay)); var momentum_t_1 = Beta1 * (1 - 0.5f * (float)Math.Pow(0.96, (t + 1) * ScheduleDecay)); MSchedule = MSchedule * momentum_t; var m_schedule_next = MSchedule * momentum_t_1; var m_t = state["mean"]; var v_t = state["variance"]; m_t *= Beta1; m_t += (1 - Beta1) * grad; v_t *= Beta2; v_t += (1 - Beta2) * grad * grad; var grad_prime = grad / (1 - MSchedule); var m_t_prime = m_t / (1 - m_schedule_next); var v_t_prime = v_t / (1 - (float)Math.Pow(Beta2, t)); var m_t_bar = (1 - momentum_t) * grad_prime + momentum_t_1 * m_t_prime; weight -= lr * m_t_bar / (nd.Sqrt(v_t_prime) + Epsilon); }
public void Invoke(int epoch, Symbol symbol, NDArrayDict arg_params, NDArrayDict aux_params) { if ((epoch + 1) % _period == 0) { RNN.SaveRNNCheckPoint(_cells, _prefix, epoch, symbol, arg_params, aux_params); } }
public override NDArrayDict CreateState(int index, NDArray weight) { var state = new NDArrayDict("history"); state["history"] = nd.Zeros(weight.Shape, weight.Context, weight.DataType).ToSType(weight.SType); return(state); }
public void Invoke(int epoch, Symbol symbol, NDArrayDict arg_params, NDArrayDict aux_params) { _period = Math.Max(1, _period); if ((epoch + 1) % _period == 0) { Model.SaveCheckpoint(_prefix, epoch + 1, symbol, arg_params, aux_params); } }
public override void Update(int index, NDArray weight, NDArray grad, NDArrayDict state) { var lr = GetLr(index); var wd = GetWd(index); UpdateCount(index); var cgrad = _cumulate_gradient(grad, index); float lbmult = 0; if (cgrad.Nums % BatchScale == 0) { grad = cgrad.Grad / BatchScale; if (WarmupStrategy == "lars") { lbmult = _get_lars(weight, grad, wd); } else { lbmult = _get_lbmult(cgrad.Nums); } lr = lr * lbmult; var use_multi_precision = state["weight_master_copy"] != null; if (!use_multi_precision) { if (state["momentum"] != null) { weight = nd.SgdMomUpdate(weight, grad, state["momentum"], lr, Momentum, wd, RescaleGrad, ClipGradient.HasValue ? ClipGradient.Value : -1); } else { weight = nd.SgdUpdate(weight, grad, lr, wd, RescaleGrad, ClipGradient.HasValue ? ClipGradient.Value : -1); } } else { if (state["momentum"] != null) { weight = nd.MpSgdMomUpdate(weight, grad, state["momentum"], state["weight_master_copy"], lr, Momentum, wd, RescaleGrad, ClipGradient.HasValue ? ClipGradient.Value : -1); } else { weight = nd.MpSgdUpdate(weight, grad, state["weight_master_copy"], lr, wd, RescaleGrad, ClipGradient.HasValue ? ClipGradient.Value : -1); } } } else { lr = 0; weight = nd.SgdUpdate(weight, grad, lr, wd, RescaleGrad); } }
public override NDArrayDict CreateState(int index, NDArray weight) { var stype = LazyUpdate ? weight.SType : StorageStype.Default; var state = new NDArrayDict("mean", "variance"); state["mean"] = nd.Zeros(weight.Shape, weight.Context, weight.DataType).ToSType(stype); state["variance"] = nd.Zeros(weight.Shape, weight.Context, weight.DataType).ToSType(stype); return(state); }
public override NDArrayDict CreateState(int index, NDArray weight) { var state = new NDArrayDict(); state["prev_d"] = nd.Zeros(weight.Shape, weight.Context, weight.DataType); state["prev_v"] = nd.Zeros(weight.Shape, weight.Context, weight.DataType); state["prev_z"] = nd.Zeros(weight.Shape, weight.Context, weight.DataType); return(state); }
public override void Update(int index, NDArray weight, NDArray grad, NDArrayDict state) { UpdateCount(index); var lr = GetLr(index); var wd = GetWd(index); var t = index_update_count[index]; weight = nd.FtmlUpdate(weight, grad, state["prev_d"], state["prev_v"], state["prev_z"], lr, t, Beta1, Beta2, Epsilon, wd, RescaleGrad, ClipGradient.HasValue ? ClipGradient.Value : -1); }
public override void Update(int index, NDArray weight, NDArray grad, NDArrayDict state) { UpdateCount(index); var lr = GetLr(index); var wd = GetWd(index); var t = index_update_count[index]; weight = nd.FtrlUpdate(weight, grad, state["z"], state["n"], lr, Lamda1, Beta, wd, RescaleGrad, ClipGradient.HasValue ? ClipGradient.Value : -1); }
public static NDArrayDict GetDataByIdx(NDArrayDict data) { var shuffle_data = new NDArrayDict(); foreach (var item in data) { shuffle_data.Add(item.Key, nd.Shuffle(item.Value)); } return(shuffle_data); }
public static NDArrayDict CellsPackWeights(BaseRNNCell[] cells, NDArrayDict args) { SymbolList ret = new SymbolList(); foreach (var item in cells) { args = item.PackWeights(args); } return(args); }
public override NDArrayDict CreateState(int index, NDArray weight) { var state = new NDArrayDict("momentum"); if (Momentum != 0) { state["momentum"] = nd.Zeros(weight.Shape, weight.Context, weight.DataType); } return(state); }
public static bool HasInstance(NDArrayDict data, DType dtype) { foreach (var item in data) { if (item.Value.DataType.Name == dtype.Name) { return(true); } } return(false); }
public void SaveParameters(string filename) { var arg_dict = new NDArrayDict(); var collected_params = CollectParamsWithPrefix(); foreach (var item in collected_params.Items()) { arg_dict[item.Key] = item.Value.Reduce(); } NDArray.Save(filename, arg_dict); }
public void Fit(DataIter train_data, DataIter eval_data = null, string eval_metric = "acc", IEpochEndCallback[] epoch_end_callback = null, IBatchEndCallback[] batch_end_callback = null, string kvstore = "local", string optimizer = "sgd", Dictionary <string, object> optimizer_params = null, IBatchEndCallback[] eval_end_callback = null, IBatchEndCallback[] eval_batch_end_callback = null, Initializer initializer = null, NDArrayDict arg_params = null, NDArrayDict aux_params = null, bool allow_missing = false, bool force_rebind = false, bool force_init = false, int begin_epoch = 0, int?num_epoch = null, EvalMetric validation_metric = null, Monitor monitor = null, Func <DataBatch, NDArrayDict> sparse_row_id_fn = null) { throw new NotImplementedException(); }
private NDArrayList _batchify(NDArrayDict data_source) { if (Cursor > num_data) { throw new Exception("DataIter need reset"); } if (last_batch_handle == "roll_over" && -BatchSize < Cursor && Cursor < 0) { if (_cache_data == null && _cache_label == null) { throw new Exception("Next epoch should have cached data"); } var cache_data = _cache_data != null ? _cache_data : _cache_label; var second_data = _getdata(data_source, end: Cursor + BatchSize); if (_cache_data != null) { _cache_data = null; } else { _cache_label = null; } return(_concat(cache_data, second_data)); } if (last_batch_handle == "pad" && Cursor + BatchSize > num_data) { var pad = BatchSize - num_data + Cursor; var first_data = _getdata(data_source, Cursor); var second_data = _getdata(data_source, end: pad); return(_concat(first_data, second_data)); } var end_idx = 0; if (Cursor + BatchSize < num_data) { end_idx = Cursor + BatchSize; } else { end_idx = num_data; } return(_getdata(data_source, Cursor, end_idx)); }
public override void Update(int index, NDArray weight, NDArray grad, NDArrayDict state) { UpdateCount(index); var lr = GetLr(index); var wd = GetWd(index); var t = index_update_count[index]; var coef1 = 1 - (float)Math.Pow(Beta1, t); var coef2 = 1 - (float)Math.Pow(Beta2, t); lr *= (float)Math.Sqrt(coef2) / coef1; weight = nd.AdamUpdate(weight, grad, state["mean"], state["variance"], lr, Beta1, Beta2, Epsilon, wd, RescaleGrad, ClipGradient.HasValue ? ClipGradient.Value : -1, LazyUpdate); }
public override void Update(int index, NDArray weight, NDArray grad, NDArrayDict state) { UpdateCount(index); var lr = GetLr(index); var wd = GetWd(index); grad = grad * RescaleGrad; if (ClipGradient.HasValue) { grad = nd.Clip(grad, -ClipGradient.Value, ClipGradient.Value); } weight += -lr / 2 * (grad + wd * weight); weight += nd.Random.Normal(0, (float)Math.Sqrt(lr), weight.Shape, dtype: weight.DataType, ctx: weight.Context); }
public override NDArrayDict UnpackWeights(NDArrayDict args) { var arr = args[this._parameter.Name]; var b = this._directions.Count; var m = this.NumGates; var h = this._num_hidden; var num_input = arr.Size / b / h / m - (this._num_layers - 1) * (h + b * h + 2) - h - 2; var nargs = this.SliceWeight(arr, num_input, this._num_hidden); var newargs = nargs.ToDictionary(_tup_1 => _tup_1.Key, _tup_1 => _tup_1.Value.Copy()); foreach (var item in newargs) { args[item.Key] = item.Value; } return(args); }
public NDArrayIter(NDArrayList data, NDArrayList label = null, int batch_size = 1, bool shuffle = false, string last_batch_handle = "pad", string data_name = "data", string label_name = "softmax_label") { this.data = IOUtils.InitData(data, false, data_name); this.label = IOUtils.InitData(label, false, label_name); BatchSize = batch_size; Cursor = batch_size; num_data = data[0].Shape[0]; this.last_batch_handle = last_batch_handle; this.shuffle = shuffle; Reset(); data_list.Add(data); data_list.Add(label); _cache_data = null; _cache_label = null; }
public override NDArrayDict CreateState(int index, NDArray weight) { var state = new NDArrayDict("momentum", "prev_weight"); if (Momentum == 0) { state["momentum"] = null; state["prev_weight"] = weight.Copy(); } else { state["momentum"] = nd.Zeros(weight.Shape, weight.Context, weight.DataType); state["prev_weight"] = weight.Copy(); } return(state); }