private static void _update_params( List <List <NDArray> > param_arrays, List <List <NDArray> > grad_arrays, Action <int, NDArray, NDArray> updater, int num_device, KVStore kvstore = null) { for (int index = 0; index < param_arrays.Count; index++) { var arg_list = param_arrays[index]; var grad_list = grad_arrays[index]; if (grad_list[0] == null) { continue; } if (kvstore != null) { //push gradient, priority is negative index kvstore.Push(index, grad_list, priority: -index); //pull back the weights kvstore.Pull(index, arg_list, priority: -index); } for (int k = 0; k < arg_list.Count; k++) { var w = arg_list[k]; var g = grad_list[k]; updater(index * num_device + k, g, w); } } }
private static void _initialize_kvstore(KVStore kvstore, List <List <NDArray> > param_arrays, Dictionary <string, NDArray> arg_params, List <string> param_names, bool update_on_kvstore) { for (int idx = 0; idx < param_arrays.Count; idx++) { var param_on_devs = param_arrays[idx]; kvstore.Init(idx, arg_params[param_names[idx]]); if (update_on_kvstore) { kvstore.Pull(idx, param_on_devs, priority: -idx); } } }
private static void _update_params_on_kvstore( List <List <NDArray> > param_arrays, List <List <NDArray> > grad_arrays, KVStore kvstore) { for (int index = 0; index < param_arrays.Count; index++) { var arg_list = param_arrays[index]; var grad_list = grad_arrays[index]; if (grad_list[0] == null) { continue; } //push gradient, priority is negative index kvstore.Push(index, grad_list, priority: -index); //pull back the weights kvstore.Pull(index, arg_list, priority: -index); } }