internal static void InitializeKVStore(KVStore kvstore, List <NDArrayList> param_arrays, NDArrayDict arg_params, string[] param_names, bool update_on_kvstore) { for (int i = 0; i < param_arrays.Count; i++) { if (param_arrays[i].Length == 0) { continue; } if (param_arrays[i][0] == null) { continue; } var name = param_names[i]; var param_on_devs = param_arrays[i]; kvstore.Init(name, arg_params[name]); if (update_on_kvstore) { kvstore.Pull(name, param_on_devs, -i); } } }
internal static void UpdateParams(List <NDArrayList> param_arrays, List <NDArrayList> grad_arrays, Updater updater, int num_device, KVStore kvstore, string[] param_names) { Dictionary <int, List <(int, NDArray, NDArray)> > updates = new Dictionary <int, List <(int, NDArray, NDArray)> >(); for (int i = 0; i < num_device; i++) { updates.Add(i, new List <(int, NDArray, NDArray)>()); } for (int i = 0; i < param_arrays.Count; i++) { var arg_list = param_arrays[i]; var grad_list = grad_arrays[i]; if (grad_list.Length == 0) { continue; } if (grad_list[0] == null) { continue; } int index = i; if (kvstore != null) { string name = param_names[index]; kvstore.Push(name, grad_list, -index); kvstore.Pull(name, arg_list, -index); } for (int j = 0; j < arg_list.Length; j++) { var w = arg_list[j]; var g = grad_list[j]; updates[i].Add((index * num_device + j, w, g)); } foreach (var dev_updates in updates.Values) { foreach (var item in dev_updates) { var(idx, w, g) = item; updater.Call(idx, w, g); } } } }
internal static void UpdateParamsOnKVStoreNCCL(List <NDArrayList> param_arrays, List <NDArrayList> grad_arrays, KVStore kvstore, string[] param_names) { List <int> valid_indices = new List <int>(); int i = 0; grad_arrays.ForEach((x) => { valid_indices.Add(i); i++; }); var valid_grad_arrays = valid_indices.Select(x => (grad_arrays[x])).ToArray(); var valid_param_arrays = valid_indices.Select(x => (param_arrays[x])).ToArray(); var valid_param_names = valid_indices.Select(x => (param_names[x])).ToArray(); int size = valid_grad_arrays.Length; int start = 0; int batch = 16; if (!string.IsNullOrWhiteSpace(Environment.GetEnvironmentVariable("MXNET_UPDATE_AGGREGATION_SIZE"))) { batch = Convert.ToInt32(Environment.GetEnvironmentVariable("MXNET_UPDATE_AGGREGATION_SIZE")); } while (start < size) { int end = start + batch < size ? start + batch : size; var name_batch_list = valid_param_names.Skip(start).Take(end - start).ToArray(); var grad_batch_list = valid_grad_arrays.Skip(start).Take(end - start).ToArray(); var param_batch_list = valid_grad_arrays.Skip(start).Take(end - start).ToArray(); for (int kvi = 0; kvi < name_batch_list.Length; kvi++) { kvstore.Push(valid_param_names[kvi], valid_grad_arrays[kvi], -start); kvstore.Pull(valid_param_names[kvi], param_batch_list[kvi], -start); } start = end; } }
internal static void UpdateParamsOnKVStore(List <NDArrayList> param_arrays, List <NDArrayList> grad_arrays, KVStore kvstore, string[] param_names) { for (int index = 0; index < param_arrays.Count; index++) { var arg_list = param_arrays[index]; var grad_list = grad_arrays[index]; if (grad_list.Length == 0) { continue; } if (grad_list[0] == null) { continue; } string name = param_names[index]; kvstore.Push(name, grad_list, -index); kvstore.Pull(name, arg_list, -index); } }