Example #1
0
        internal static void UpdateParams(List <NDArrayList> param_arrays, List <NDArrayList> grad_arrays,
                                          Updater updater, int num_device, KVStore kvstore, string[] param_names)
        {
            Dictionary <int, List <(int, NDArray, NDArray)> > updates = new Dictionary <int, List <(int, NDArray, NDArray)> >();

            for (int i = 0; i < num_device; i++)
            {
                updates.Add(i, new List <(int, NDArray, NDArray)>());
            }

            for (int i = 0; i < param_arrays.Count; i++)
            {
                var arg_list  = param_arrays[i];
                var grad_list = grad_arrays[i];

                if (grad_list.Length == 0)
                {
                    continue;
                }

                if (grad_list[0] == null)
                {
                    continue;
                }

                int index = i;
                if (kvstore != null)
                {
                    string name = param_names[index];
                    kvstore.Push(name, grad_list, -index);
                    kvstore.Pull(name, arg_list, -index);
                }

                for (int j = 0; j < arg_list.Length; j++)
                {
                    var w = arg_list[j];
                    var g = grad_list[j];
                    updates[i].Add((index * num_device + j, w, g));
                }

                foreach (var dev_updates in updates.Values)
                {
                    foreach (var item in dev_updates)
                    {
                        var(idx, w, g) = item;
                        updater.Call(idx, w, g);
                    }
                }
            }
        }
Example #2
0
        internal static void UpdateParamsOnKVStoreNCCL(List <NDArrayList> param_arrays, List <NDArrayList> grad_arrays,
                                                       KVStore kvstore, string[] param_names)
        {
            List <int> valid_indices = new List <int>();
            int        i             = 0;

            grad_arrays.ForEach((x) => {
                valid_indices.Add(i);
                i++;
            });

            var valid_grad_arrays  = valid_indices.Select(x => (grad_arrays[x])).ToArray();
            var valid_param_arrays = valid_indices.Select(x => (param_arrays[x])).ToArray();
            var valid_param_names  = valid_indices.Select(x => (param_names[x])).ToArray();

            int size  = valid_grad_arrays.Length;
            int start = 0;
            int batch = 16;

            if (!string.IsNullOrWhiteSpace(Environment.GetEnvironmentVariable("MXNET_UPDATE_AGGREGATION_SIZE")))
            {
                batch = Convert.ToInt32(Environment.GetEnvironmentVariable("MXNET_UPDATE_AGGREGATION_SIZE"));
            }

            while (start < size)
            {
                int end              = start + batch < size ? start + batch : size;
                var name_batch_list  = valid_param_names.Skip(start).Take(end - start).ToArray();
                var grad_batch_list  = valid_grad_arrays.Skip(start).Take(end - start).ToArray();
                var param_batch_list = valid_grad_arrays.Skip(start).Take(end - start).ToArray();

                for (int kvi = 0; kvi < name_batch_list.Length; kvi++)
                {
                    kvstore.Push(valid_param_names[kvi], valid_grad_arrays[kvi], -start);
                    kvstore.Pull(valid_param_names[kvi], param_batch_list[kvi], -start);
                }

                start = end;
            }
        }
Example #3
0
        internal static void UpdateParamsOnKVStore(List <NDArrayList> param_arrays, List <NDArrayList> grad_arrays,
                                                   KVStore kvstore, string[] param_names)
        {
            for (int index = 0; index < param_arrays.Count; index++)
            {
                var arg_list  = param_arrays[index];
                var grad_list = grad_arrays[index];

                if (grad_list.Length == 0)
                {
                    continue;
                }

                if (grad_list[0] == null)
                {
                    continue;
                }

                string name = param_names[index];
                kvstore.Push(name, grad_list, -index);
                kvstore.Pull(name, arg_list, -index);
            }
        }