예제 #1
0
        public static (Symbol, NDArrayDict, NDArrayDict) LoadCheckpoint(string prefix, int epoch)
        {
            Symbol      sym        = Symbol.Load($"{prefix}-symbol.json");
            string      param_name = $"{prefix}-{epoch.ToString("D4")}.params";
            var         save_dict  = NDArray.Load(param_name);
            NDArrayDict arg_params = new NDArrayDict();
            NDArrayDict aux_params = new NDArrayDict();

            if (save_dict == null)
            {
                Logger.Warning($"Params file '{param_name}' is empty");
            }
            else
            {
                foreach (var item in save_dict)
                {
                    if (item.Key.StartsWith("arg:"))
                    {
                        arg_params.Add(item.Key.Replace("arg:", ""), item.Value);
                    }
                    else if (item.Key.StartsWith("aux:"))
                    {
                        aux_params.Add(item.Key.Replace("aux:", ""), item.Value);
                    }
                    else
                    {
                        Logger.Warning($"Params file '{param_name}' contains unknown param '{item.Key}'");
                    }
                }
            }

            return(sym, arg_params, aux_params);
        }
 public void SetParams(NDArrayDict arg_params, NDArrayDict aux_params)
 {
     foreach (var texec in execgrp.train_execs)
     {
         texec.CopyFromParams(arg_params, aux_params);
     }
 }
예제 #3
0
        public static void SaveCheckpoint(string prefix, int epoch, Symbol symbol, NDArrayDict arg_params,
                                          NDArrayDict aux_params, bool remove_amp_cast = true)
        {
            if (symbol != null)
            {
                symbol.Save($"{prefix}-symbol.json", remove_amp_cast);
            }

            NDArrayDict save_dict = new NDArrayDict();

            foreach (var item in arg_params)
            {
                save_dict.Add($"arg:{item.Key}", item.Value);
            }

            foreach (var item in aux_params)
            {
                save_dict.Add($"aux:{item.Key}", item.Value);
            }

            string param_name = $"{prefix}-{epoch.ToString("D4")}.params";

            NDArray.Save(param_name, save_dict);
            Logger.Info($"Saved checkpoint to \"{param_name}\"");
        }
예제 #4
0
        public void InferArgsMap(Context context,
                                 NDArrayDict argsMap,
                                 NDArrayDict knownArgs)
        {
            if (context == null)
            {
                throw new ArgumentNullException(nameof(context));
            }
            if (argsMap == null)
            {
                throw new ArgumentNullException(nameof(argsMap));
            }
            if (knownArgs == null)
            {
                throw new ArgumentNullException(nameof(knownArgs));
            }

            ThrowIfDisposed();

            var argShapes = new Dictionary <string, Shape>();

            var argNameList = ListArguments();

            foreach (var argName in argNameList)
            {
                if (knownArgs[argName] != null)
                {
                    argShapes[argName] = knownArgs[argName].Shape;
                }
            }

            var(inShapes, outShapes, auxShapes) = InferShape(argShapes);

            for (var i = 0; i < inShapes.Length; ++i)
            {
                var shape   = inShapes[i];
                var argName = argNameList[i];
                if (knownArgs[argName] != null)
                {
                    argsMap[argName] = knownArgs[argName];
                }
                else
                {
                    var array = new NDArray(shape, false);
                    argsMap[argName] = array;
                    //NDArray.SampleGaussian(0, 1, array);
                    nd.Random.Uniform(0, 1, array.Shape).CopyTo(array);
                }
            }
        }
예제 #5
0
 public void InferExecutorArrays(Context context,
                                 NDArrayList argArrays,
                                 NDArrayList gradArrays,
                                 IList <OpGradReq> gradReqs,
                                 NDArrayList auxArrays,
                                 NDArrayDict argsMap)
 {
     InferExecutorArrays(context,
                         argArrays,
                         gradArrays,
                         gradReqs,
                         auxArrays,
                         argsMap,
                         new NDArrayDict());
 }
        public void CopyTo(NDArrayDict arg_params, NDArrayDict aux_params)
        {
            //ToDo: Revisit code
            param_names.Zip(ParamArrays, (name, block) =>
            {
                var w = new NDArray(new[] { block.Sum() }, Context.Cpu());
                w.AsType(arg_params[name].DataType).CopyTo(arg_params[name]);
                return(true);
            });

            aux_names.Zip(AuxArrays, (name, block) =>
            {
                var w = new NDArray(new[] { block.Sum() }, Context.Cpu());
                w.AsType(aux_params[name].DataType).CopyTo(aux_params[name]);
                return(true);
            });
        }
예제 #7
0
 public void InferExecutorArrays(Context context,
                                 NDArrayList argArrays,
                                 NDArrayList gradArrays,
                                 IList <OpGradReq> gradReqs,
                                 NDArrayList auxArrays,
                                 NDArrayDict argsMap,
                                 NDArrayDict argGradStore)
 {
     InferExecutorArrays(context,
                         argArrays,
                         gradArrays,
                         gradReqs,
                         auxArrays,
                         argsMap,
                         argGradStore,
                         new Dictionary <string, OpGradReq>());
 }
예제 #8
0
        public static NDArrayDict GetMNIST()
        {
            var path = "http://data.mxnet.io/data/mnist/";

            var(train_lbl, train_img) = read_data(path + "train-labels-idx1-ubyte.gz",
                                                  path + "train-images-idx3-ubyte.gz", 60000);
            var(test_lbl, test_img) =
                read_data(path + "t10k-labels-idx1-ubyte.gz", path + "t10k-images-idx3-ubyte.gz", 10000);

            var dataset = new NDArrayDict();

            dataset.Add("train_data", train_img);
            dataset.Add("train_label", train_lbl);
            dataset.Add("test_data", test_img);
            dataset.Add("test_label", test_lbl);

            return(dataset);
        }
예제 #9
0
        private static NDArrayDict GetDictionary(IList <string> names, NDArrayList arrays)
        {
            var ret = new NDArrayDict();

            var set = new HashSet <string>();

            foreach (var s in names)
            {
                Logging.CHECK(!set.Contains(s), $"Duplicate names detected, {s}");
                set.Add(s);
            }

            Logging.CHECK_EQ(set.Count, arrays.Length, "names size not equal to arrays size");
            for (var i = 0; i < names.Count; ++i)
            {
                ret[names[i]] = arrays[i];
            }

            return(ret);
        }
예제 #10
0
        public void CopyFromParams(NDArrayDict arg_params, NDArrayDict aux_params = null,
                                   bool allow_extra_params = false)
        {
            var arg_dict = ArgmentDictionary();
            var aux_dict = AuxiliaryDictionary();

            foreach (var item in arg_params)
            {
                if (arg_dict.Contains(item.Key))
                {
                    var dst = arg_dict[item.Key];
                    item.Value.AsType(dst.DataType).CopyTo(dst);
                }
                else if (!allow_extra_params)
                {
                    throw new Exception($"Find name \"{item.Key}\" that is not in the arguments");
                }
            }

            if (aux_params == null)
            {
                return;
            }

            foreach (var item in aux_params)
            {
                if (aux_dict.Contains(item.Key))
                {
                    var dst = aux_dict[item.Key];
                    item.Value.AsType(dst.DataType).CopyTo(dst);
                }
                else if (!allow_extra_params)
                {
                    throw new Exception($"Find name \"{item.Key}\" that is not in the auxiliary states");
                }
            }
        }
예제 #11
0
        internal static (KVStore, bool) CreateKVStore(KVStore kvstore, int num_device, NDArrayDict arg_params)
        {
            var update_on_kvstore = true;

            if (!string.IsNullOrWhiteSpace(Environment.GetEnvironmentVariable("MXNET_UPDATE_ON_KVSTORE")))
            {
                update_on_kvstore = Convert.ToBoolean(Environment.GetEnvironmentVariable("MXNET_UPDATE_ON_KVSTORE"));
            }

            if (kvstore == null)
            {
                update_on_kvstore = false;
            }

            return(kvstore, update_on_kvstore);
        }
예제 #12
0
        internal static void TrainMultiDevice(Symbol symbol, Context[] ctx, string[] arg_names, string[] param_names, string[] aux_names, NDArrayDict arg_params, NDArrayDict aux_params, int begin_epoch, int end_epoch, int?epoch_size, Optimizer optimizer, KVStore kvstore, bool update_on_kvstore, DataIter train_data, DataIter eval_data = null, EvalMetric eval_metric = null, IEpochEndCallback epoch_end_callback = null, IBatchEndCallback batch_end_callback = null, int[] work_load_list = null, Monitor monitor = null, IEvalEndCallback eval_end_callback = null, IEvalBatchEndCallback eval_batch_end_callback = null, Func <int, Symbol> sym_gen = null)
        {
            var executor_manager = new DataParallelExecutorManager(symbol: symbol,
                                                                   ctx: ctx,
                                                                   train_data: train_data,
                                                                   arg_names: arg_names,
                                                                   param_names: param_names,
                                                                   aux_names: aux_names,
                                                                   work_load_list: work_load_list,
                                                                   sym_gen: sym_gen);

            if (monitor != null)
            {
                executor_manager.InstallMonitor(monitor);
            }

            executor_manager.SetParams(arg_params, aux_params);
            Updater updater = null;

            if (!update_on_kvstore)
            {
                updater = Optimizer.GetUpdater(optimizer);
            }
            else
            {
                kvstore.SetOptimizer(optimizer);
            }

            if (kvstore != null)
            {
                InitializeKVStore(kvstore: kvstore,
                                  param_arrays: new List <NDArrayList>()
                {
                    executor_manager.ParamArrays
                },
                                  arg_params: arg_params,
                                  param_names: executor_manager.param_names,
                                  update_on_kvstore: update_on_kvstore);
            }

            train_data.Reset();
            for (int epoch = begin_epoch; epoch < end_epoch; epoch++)
            {
                var tic = DateTime.Now;
                eval_metric.Reset();
                int nbatch = 0;
                while (true)
                {
                    bool do_reset = true;
                    while (!train_data.End())
                    {
                        var data_batch = train_data.Next();
                        executor_manager.LoadDataBatch(data_batch);
                        if (monitor != null)
                        {
                            monitor.Tic();
                        }

                        executor_manager.Forward(true);
                        executor_manager.Backward();
                        if (update_on_kvstore)
                        {
                            if (kvstore.Type.Contains("nccl"))
                            {
                                UpdateParamsOnKVStoreNCCL(new List <NDArrayList>()
                                {
                                    executor_manager.ParamArrays
                                }, new List <NDArrayList>()
                                {
                                    executor_manager.GradArrays
                                }, kvstore, executor_manager.param_names);
                            }
                            else
                            {
                                UpdateParamsOnKVStore(new List <NDArrayList>()
                                {
                                    executor_manager.ParamArrays
                                }, new List <NDArrayList>()
                                {
                                    executor_manager.GradArrays
                                }, kvstore, executor_manager.param_names);
                            }
                        }
                        else
                        {
                            UpdateParams(new List <NDArrayList>()
                            {
                                executor_manager.ParamArrays
                            }, new List <NDArrayList>()
                            {
                                executor_manager.GradArrays
                            }, updater, ctx.Length, kvstore, executor_manager.param_names);
                        }

                        if (monitor != null)
                        {
                            monitor.TocPrint();
                        }

                        executor_manager.UpdateMetric(eval_metric, data_batch.Label);
                        nbatch++;
                        if (batch_end_callback != null)
                        {
                            MultipleCallbacks(batch_end_callback, epoch, nbatch, eval_metric);
                        }

                        if (epoch_size.HasValue && nbatch >= epoch_size.Value)
                        {
                            do_reset = false;
                            break;
                        }
                    }

                    if (do_reset)
                    {
                        Logger.Info($"Epoch[{epoch}] Resetting Data Iterator");
                        train_data.Reset();
                    }

                    if (epoch_size.HasValue)
                    {
                        if (nbatch >= epoch_size.Value)
                        {
                            break;
                        }
                        else
                        {
                            break;
                        }
                    }
                }

                var toc = DateTime.Now;
                Logger.Info($"Epoch[{epoch}] Time cost={(toc - tic).TotalSeconds}");

                if (epoch_end_callback != null || epoch + 1 == end_epoch)
                {
                    executor_manager.CopyTo(arg_params, aux_params);
                }

                MultipleCallbacks(epoch_end_callback, epoch, symbol, arg_params, aux_params);

                if (eval_data != null)
                {
                    eval_metric.Reset();
                    eval_data.Reset();
                    int total_num_batch = 0;
                    int i = 0;
                    while (!eval_data.End())
                    {
                        var eval_batch = eval_data.Next();
                        executor_manager.LoadDataBatch(eval_batch);
                        executor_manager.Forward();
                        executor_manager.UpdateMetric(eval_metric, eval_batch.Label);
                        if (eval_batch_end_callback != null)
                        {
                            MultipleCallbacks(eval_batch_end_callback, epoch, i, eval_metric);
                        }

                        total_num_batch++;
                    }

                    if (eval_end_callback != null)
                    {
                        MultipleCallbacks(eval_end_callback, epoch, eval_metric);
                    }
                }
            }
        }
예제 #13
0
        public Executor SimpleBind(Context ctx, Dictionary <string, OpGradReq> grad_req = null, Dictionary <string, DType> type_dict = null, Dictionary <string, StorageStype> stype_dict = null, Dictionary <string, Context> group2ctx = null, string[] shared_arg_names = null, Executor shared_exec = null, NDArrayDict shared_buffer = null, DataDesc[] kwargs = null)
        {
            int num_provided_arg_types = 0;

            string[] provided_arg_type_names = null;
            int[]    provided_arg_type_data  = null;
            if (type_dict != null)
            {
                provided_arg_type_names = type_dict.Keys.ToArray();
                provided_arg_type_data  = type_dict.Values.Select(x => x.Index).ToArray();
                num_provided_arg_types  = type_dict.Count;
            }

            int num_provided_arg_stypes = 0;

            string[] provided_arg_stype_names = null;
            int[]    provided_arg_stype_data  = null;
            if (stype_dict != null)
            {
                provided_arg_stype_names = stype_dict.Keys.ToArray();
                provided_arg_stype_data  = stype_dict.Values.Select(x => (int)x).ToArray();
                num_provided_arg_stypes  = stype_dict.Count;
            }

            List <int> provided_arg_shape_data = new List <int>();
            List <int> provided_arg_shape_idx  = new List <int>()
            {
                0
            };
            List <string> provided_arg_shape_names = new List <string>();

            foreach (var desc in kwargs)
            {
                provided_arg_shape_names.Add(desc.Name);
                provided_arg_shape_data.AddRange(desc.Shape.Data.Where(x => x > 0).ToList());
                provided_arg_shape_idx.Add(provided_arg_shape_data.Count);
            }

            int provided_req_type_list_len = 0;

            string[] provided_grad_req_names = new string[0];
            string[] provided_grad_req_types = new string[0];
            if (grad_req != null)
            {
                provided_grad_req_names    = grad_req.Keys.ToArray();
                provided_grad_req_types    = grad_req.Values.Select(x => Enum.GetName(x.GetType(), x).ToLower()).ToArray();
                provided_req_type_list_len = grad_req.Count;
            }

            int num_ctx_map_keys = 0;

            string[] ctx_map_keys      = new string[0];
            int[]    ctx_map_dev_types = new int[0];
            int[]    ctx_map_dev_ids   = new int[0];
            if (group2ctx != null)
            {
                ctx_map_keys      = group2ctx.Keys.ToArray();
                ctx_map_dev_types = group2ctx.Values.Select(x => (int)x.GetDeviceType()).ToArray();
                ctx_map_dev_ids   = group2ctx.Values.Select(x => x.GetDeviceId()).ToArray();
                num_ctx_map_keys  = group2ctx.Count;
            }

            string[] shared_arg_name_list = new string[0];
            if (shared_arg_names != null)
            {
                shared_arg_name_list = shared_arg_names;
            }

            unsafe
            {
                int      shared_start          = -1;
                int      shared_buffer_len     = shared_start;
                string[] shared_buffer_names   = null;
                IntPtr[] shared_buffer_handles = null;
                if (shared_buffer.Count > 0)
                {
                    shared_buffer_len     = shared_buffer.Count;
                    shared_buffer_names   = shared_buffer.Keys;
                    shared_buffer_handles = shared_buffer.Values.Handles;
                }

                var shared_exec_handle = shared_exec != null ? shared_exec.Handle : new ExecutorHandle();

                char **        updated_shared_buffer_names;
                SymbolHandle * updated_shared_buffer_handles;
                int            num_in_args;
                SymbolHandle * in_arg_handles;
                SymbolHandle * arg_grad_handles;
                SymbolHandle * aux_state_handles;
                ExecutorHandle exe_handle;
                int            num_aux_states;

                NativeMethods.MXExecutorSimpleBindEx(NativePtr, (int)ctx.GetDeviceType(), ctx.GetDeviceId(), num_ctx_map_keys, ctx_map_keys,
                                                     ctx_map_dev_types, ctx_map_dev_ids, provided_req_type_list_len, provided_grad_req_names, provided_grad_req_types,
                                                     provided_arg_shape_names.Count, provided_arg_shape_names.ToArray(), provided_arg_shape_data.ToArray(),
                                                     provided_arg_shape_idx.ToArray(), num_provided_arg_types, provided_arg_type_names, provided_arg_type_data,
                                                     num_provided_arg_stypes, provided_arg_stype_names, provided_arg_stype_data, shared_arg_name_list.Length,
                                                     shared_arg_name_list, &shared_buffer_len, shared_buffer_names, shared_buffer_handles, &updated_shared_buffer_names,
                                                     &updated_shared_buffer_handles, &num_in_args, &in_arg_handles, &arg_grad_handles,
                                                     &num_aux_states, &aux_state_handles, shared_exec_handle, &exe_handle);

                if (shared_buffer.Count > 0)
                {
                    int l = shared_buffer_len;
                    for (int i = 0; i < l; i++)
                    {
                        string  k = new string(updated_shared_buffer_names[i]);
                        NDArray v = new NDArray(updated_shared_buffer_handles[i]);
                        shared_buffer[k] = v;
                    }
                }

                NDArrayList arg_arrays = new NDArrayList();
                for (int i = 0; i < num_in_args; i++)
                {
                    arg_arrays.Add(new NDArray(in_arg_handles[i]));
                }

                NDArrayList grad_arrays = new NDArrayList();
                for (int i = 0; i < num_in_args; i++)
                {
                    if (arg_grad_handles[i] == IntPtr.Zero)
                    {
                        continue;
                    }
                    grad_arrays.Add(new NDArray(arg_grad_handles[i]));
                }

                NDArrayList aux_arrays = new NDArrayList();
                for (int i = 0; i < num_aux_states; i++)
                {
                    aux_arrays.Add(new NDArray(aux_state_handles[i]));
                }

                var executor = new Executor(exe_handle, ctx, grad_req.Values.ToList(), group2ctx);
                executor.ArgmentArrays   = arg_arrays;
                executor.GradientArrays  = grad_arrays;
                executor.AuxiliaryArrays = aux_arrays;
                executor._Symbol         = this;
                return(executor);
            }
        }
예제 #14
0
        public void InferExecutorArrays(Context context,
                                        NDArrayList argArrays,
                                        NDArrayList gradArrays,
                                        IList <OpGradReq> gradReqs,
                                        NDArrayList auxArrays,
                                        NDArrayDict argsMap,
                                        NDArrayDict argGradStore,
                                        IDictionary <string, OpGradReq> gradReqType,
                                        NDArrayDict auxMap)
        {
            if (context == null)
            {
                throw new ArgumentNullException(nameof(context));
            }
            if (argArrays == null)
            {
                throw new ArgumentNullException(nameof(argArrays));
            }
            if (gradArrays == null)
            {
                throw new ArgumentNullException(nameof(gradArrays));
            }
            if (gradReqs == null)
            {
                throw new ArgumentNullException(nameof(gradReqs));
            }
            if (auxArrays == null)
            {
                throw new ArgumentNullException(nameof(auxArrays));
            }
            if (argsMap == null)
            {
                throw new ArgumentNullException(nameof(argsMap));
            }
            if (argGradStore == null)
            {
                throw new ArgumentNullException(nameof(argGradStore));
            }
            if (gradReqType == null)
            {
                throw new ArgumentNullException(nameof(gradReqType));
            }
            if (auxMap == null)
            {
                throw new ArgumentNullException(nameof(auxMap));
            }

            ThrowIfDisposed();

            var argNameList = ListArguments();
            var argShapes   = new Dictionary <string, Shape>();

            foreach (var argName in argNameList)
            {
                if (argsMap[argName] != null)
                {
                    argShapes[argName] = argsMap[argName].Shape;
                }
            }

            var(inShapes, auxShapes, outShapes) = InferShape(argShapes);

            for (var i = 0; i < inShapes.Length; ++i)
            {
                var shape   = inShapes[i];
                var argName = argNameList[i];
                if (argsMap[argName] != null)
                {
                    argArrays.Add(argsMap[argName]);
                }
                else
                {
                    argArrays.Add(new NDArray(shape, false));
                    //NDArray.SampleGaussian(0, 1, argArrays.Last());
                    var argArr = argArrays.Last();
                    nd.Random.Uniform(0, 1, argArr.Shape).CopyTo(argArr);
                }

                if (argGradStore[argName] != null)
                {
                    gradArrays.Add(argGradStore[argName]);
                }
                else
                {
                    gradArrays.Add(new NDArray(shape, false));
                }

                if (gradReqType.TryGetValue(argName, out var value3))
                {
                    gradReqs.Add(value3);
                }
                else if (argName.LastIndexOf("data", StringComparison.InvariantCulture) == argName.Length - 4 ||
                         argName.LastIndexOf("label", StringComparison.InvariantCulture) == argName.Length - 5)
                {
                    gradReqs.Add(OpGradReq.Null);
                }
                else
                {
                    gradReqs.Add(OpGradReq.Write);
                }
            }

            var auxNameList = ListAuxiliaryStates();

            for (var i = 0; i < auxShapes.Length; ++i)
            {
                var shape   = auxShapes[i];
                var auxName = auxNameList[i];
                if (auxMap[auxName] != null)
                {
                    auxArrays.Add(auxMap[auxName]);
                }
                else
                {
                    auxArrays.Add(new NDArray(shape, false));
                    var aux = auxArrays.Last();
                    //NDArray.SampleGaussian(0, 1, auxArrays.Last());
                    nd.Random.Uniform(0, 1, aux.Shape).CopyTo(aux);
                }
            }
        }
예제 #15
0
        internal static Executor BindExec(Symbol sym, Context ctx, Dictionary <string, Shape> input_shapes,
                                          string[] param_names, bool need_grad = false,
                                          Executor base_exec = null, NDArrayDict shared_data_arrays    = null,
                                          Dictionary <string, DType> input_types = null, Logger logger = null)
        {
            var(arg_shape, _, aux_shape) = sym.InferShape(input_shapes);
            if (arg_shape == null)
            {
                throw new ArgumentNullException("arg_shape");
            }

            if (input_types == null)
            {
                input_types = new Dictionary <string, DType>();
                foreach (var item in input_shapes.Keys)
                {
                    input_types.Add(item, DType.Float32);
                }
            }

            var(arg_types, _, aux_types) = sym.InferType(input_types);

            if (arg_types == null)
            {
                throw new ArgumentNullException("arg_types");
            }

            var arg_arrays  = new NDArrayList();
            var aux_arrays  = new NDArrayList();
            var grad_arrays = need_grad ? new NDArrayDict() : null;

            var arg_names   = sym.ListArguments();
            var needGradSet = new List <string>();

            if (!need_grad)
            {
                needGradSet = new List <string>();
            }
            else
            {
                foreach (var item in arg_names)
                {
                    if (!input_shapes.ContainsKey(item))
                    {
                        needGradSet.Add(item);
                    }
                }

                needGradSet = MxUtil.Set(needGradSet);
            }

            var grad_req = new Dictionary <string, OpGradReq>();

            foreach (var item in arg_names)
            {
                if (needGradSet.Contains(item))
                {
                    grad_req.Add(item, OpGradReq.Write);
                }
            }

            for (var i = 0; i < arg_names.Count; i++)
            {
                var     name     = arg_names[i];
                NDArray arg_arr  = null;
                NDArray grad_arr = null;
                if (!param_names.Contains(name))
                {
                    if (shared_data_arrays != null && shared_data_arrays.Contains(name))
                    {
                        arg_arr = shared_data_arrays[name];
                        if (np.prod(arg_arr.Shape.Data) >= np.prod(arg_shape[i].Data))
                        {
                            if (arg_types[i].Name != arg_arr.DataType.Name)
                            {
                                throw new ArgumentException("arg_type and arg_arr datatype mismatch");
                            }

                            arg_arr = arg_arr.Reshape(arg_shape[i]);
                        }
                        else
                        {
                            var logmsg = new StringBuilder();
                            logmsg.AppendFormat("bucketing: data \"{0}\" has a shape {1}", name, arg_shape[i]);
                            logmsg.AppendFormat(", which is larger than already allocated ");
                            logmsg.AppendFormat("shape {0}", arg_arr.Shape);
                            logmsg.AppendFormat(". Need to re-allocate. Consider putting default_bucket_key " +
                                                "to be the bucket taking the largest input for better memory sharing.");

                            Logger.Warning(logmsg.ToString());

                            arg_arr = nd.Zeros(arg_shape[i], ctx, arg_types[i]);
                            shared_data_arrays[name] = arg_arr;
                        }
                    }
                    else
                    {
                        arg_arr = nd.Zeros(arg_shape[i], ctx, arg_types[i]);
                        if (shared_data_arrays != null)
                        {
                            shared_data_arrays[name] = arg_arr;
                        }
                    }

                    arg_arrays.Add(arg_arr);
                }
                else
                {
                    if (base_exec == null)
                    {
                        arg_arr = nd.Zeros(arg_shape[i], ctx, arg_types[i]);
                        if (needGradSet.Contains(name))
                        {
                            grad_arr          = nd.Zeros(arg_shape[i], ctx, arg_types[i]);
                            grad_arrays[name] = grad_arr;
                        }
                        else
                        {
                            arg_arr = base_exec.ArgmentDictionary()[name];
                            if (arg_arr.Shape != arg_shape[i])
                            {
                                throw new ArgumentException("arg_arr.Shape != arg_shape[i]");
                            }

                            if (arg_arr.DataType != arg_types[i])
                            {
                                throw new ArgumentException("arg_arr.DataType != arg_types[i]");
                            }

                            if (needGradSet.Contains(name))
                            {
                                grad_arrays[name] = base_exec.GradientDictionary()[name];
                            }
                        }

                        arg_arrays.Add(arg_arr);
                    }
                }
            }

            if (base_exec != null)
            {
                for (var i = 0; i < aux_shape.Length; i++)
                {
                    var s = aux_shape[i];
                    var t = aux_types[i];
                    aux_arrays.Add(nd.Zeros(s, ctx, t));
                }
            }
            else
            {
                foreach (var item in base_exec.AuxiliaryDictionary())
                {
                    aux_arrays.Add(item.Value);
                }
            }

            var executor = sym.Bind(ctx, arg_arrays, grad_arrays.Values.ToList(), grad_req.Values.ToList(), aux_arrays);

            return(executor);
        }
예제 #16
0
        internal static (KVStore, bool) CreateKVStore(string kvstore, int num_device, NDArrayDict arg_params)
        {
            KVStore kV = null;
            var     update_on_kvstore = true;

            if (num_device == 1 && !kvstore.Contains("dist"))
            {
                kV = null;
            }
            else
            {
                kV = KVStoreBase.Create(kvstore);
                if (kvstore == "local")
                {
                    var max_size = arg_params.Values.Select(x => x.Shape.Size).ToList().Max();
                    if (max_size > 1024 * 1024 * 16)
                    {
                        update_on_kvstore = false;
                    }
                }
            }

            return(kV, update_on_kvstore);
        }
예제 #17
0
        internal static void InitializeKVStore(KVStore kvstore, List <NDArrayList> param_arrays, NDArrayDict arg_params,
                                               string[] param_names, bool update_on_kvstore)
        {
            for (int i = 0; i < param_arrays.Count; i++)
            {
                if (param_arrays[i].Length == 0)
                {
                    continue;
                }

                if (param_arrays[i][0] == null)
                {
                    continue;
                }

                var name          = param_names[i];
                var param_on_devs = param_arrays[i];
                kvstore.Init(name, arg_params[name]);

                if (update_on_kvstore)
                {
                    kvstore.Pull(name, param_on_devs, -i);
                }
            }
        }
예제 #18
0
 public CachedOp(Symbol sym, NDArrayDict flags)
 {
     handle = IntPtr.Zero;
     NativeMethods.MXCreateCachedOpEx(sym.GetHandle(), flags.Count, flags.Keys.ToArray(),
                                      MxUtil.GetNDArrayHandles(flags.Values.ToArray()), out handle);
 }