Exemplo n.º 1
0
        private static void _update_params(
            List <List <NDArray> > param_arrays,
            List <List <NDArray> > grad_arrays,
            Action <int, NDArray, NDArray> updater,
            int num_device, KVStore kvstore = null)
        {
            for (int index = 0; index < param_arrays.Count; index++)
            {
                var arg_list  = param_arrays[index];
                var grad_list = grad_arrays[index];
                if (grad_list[0] == null)
                {
                    continue;
                }
                if (kvstore != null)
                {
                    //push gradient, priority is negative index
                    kvstore.Push(index, grad_list, priority: -index);
                    //pull back the weights
                    kvstore.Pull(index, arg_list, priority: -index);
                }

                for (int k = 0; k < arg_list.Count; k++)
                {
                    var w = arg_list[k];
                    var g = grad_list[k];


                    updater(index * num_device + k, g, w);
                }
            }
        }
Exemplo n.º 2
0
        private static void _initialize_kvstore(KVStore kvstore,
                                                List <List <NDArray> > param_arrays,
                                                Dictionary <string, NDArray> arg_params,
                                                List <string> param_names,
                                                bool update_on_kvstore)
        {
            for (int idx = 0; idx < param_arrays.Count; idx++)
            {
                var param_on_devs = param_arrays[idx];
                kvstore.Init(idx, arg_params[param_names[idx]]);

                if (update_on_kvstore)
                {
                    kvstore.Pull(idx, param_on_devs, priority: -idx);
                }
            }
        }
Exemplo n.º 3
0
 private static void _update_params_on_kvstore(
     List <List <NDArray> > param_arrays,
     List <List <NDArray> > grad_arrays,
     KVStore kvstore)
 {
     for (int index = 0; index < param_arrays.Count; index++)
     {
         var arg_list  = param_arrays[index];
         var grad_list = grad_arrays[index];
         if (grad_list[0] == null)
         {
             continue;
         }
         //push gradient, priority is negative index
         kvstore.Push(index, grad_list, priority: -index);
         //pull back the weights
         kvstore.Pull(index, arg_list, priority: -index);
     }
 }