Example #1
0
 public void InferExecutorArrays(Context context,
                                 NDArrayList argArrays,
                                 NDArrayList gradArrays,
                                 IList <OpGradReq> gradReqs,
                                 NDArrayList auxArrays,
                                 NDArrayDict argsMap,
                                 NDArrayDict argGradStore)
 {
     InferExecutorArrays(context,
                         argArrays,
                         gradArrays,
                         gradReqs,
                         auxArrays,
                         argsMap,
                         argGradStore,
                         new Dictionary <string, OpGradReq>());
 }
Example #2
0
        public void Backward(NDArrayList headGrads)
        {
            if (headGrads == null)
            {
                throw new ArgumentNullException(nameof(headGrads));
            }

            var tmp = headGrads.Select(d => d.GetHandle()).ToArray();

            if (tmp.Length > 0)
            {
                NativeMethods.MXExecutorBackward(Handle, (uint)tmp.Length, tmp);
            }
            else
            {
                NativeMethods.MXExecutorBackward(Handle, 0, null);
            }
        }
Example #3
0
        public Executor(ExecutorHandle h, Context context,
                        List <OpGradReq> gradReqs,
                        Dictionary <string, Context> groupToCtx)
        {
            if (h == IntPtr.Zero)
            {
                throw new ArgumentException("Can not pass IntPtr.Zero", nameof(h));
            }
            Outputs = new NDArrayList();
            Handle  = h;
            Logging.CHECK_EQ(NativeMethods.MXExecutorOutputs(Handle, out var outSize, out var outArray), 0);
            var outArrayArray = InteropHelper.ToPointerArray(outArray, outSize);

            for (uint i = 0; i < outSize; ++i)
            {
                Outputs.Add(new NDArray(outArrayArray[i]));
            }
        }
Example #4
0
        public Executor Bind(Context context,
                             NDArrayList argArrays,
                             NDArrayList gradArrays,
                             IList <OpGradReq> gradReqs,
                             NDArrayList auxArrays,
                             IDictionary <string, Context> groupToCtx,
                             Executor sharedExec)
        {
            if (context == null)
            {
                throw new ArgumentNullException(nameof(context));
            }
            if (argArrays == null)
            {
                throw new ArgumentNullException(nameof(argArrays));
            }
            if (gradArrays == null)
            {
                throw new ArgumentNullException(nameof(gradArrays));
            }
            if (gradReqs == null)
            {
                throw new ArgumentNullException(nameof(gradReqs));
            }
            if (auxArrays == null)
            {
                throw new ArgumentNullException(nameof(auxArrays));
            }
            if (groupToCtx == null)
            {
                throw new ArgumentNullException(nameof(groupToCtx));
            }

            return(new Executor(this,
                                context,
                                argArrays,
                                gradArrays,
                                gradReqs,
                                auxArrays,
                                groupToCtx,
                                sharedExec));
        }
Example #5
0
        private static NDArrayDict GetDictionary(IList <string> names, NDArrayList arrays)
        {
            var ret = new NDArrayDict();

            var set = new HashSet <string>();

            foreach (var s in names)
            {
                Logging.CHECK(!set.Contains(s), $"Duplicate names detected, {s}");
                set.Add(s);
            }

            Logging.CHECK_EQ(set.Count, arrays.Length, "names size not equal to arrays size");
            for (var i = 0; i < names.Count; ++i)
            {
                ret[names[i]] = arrays[i];
            }

            return(ret);
        }
Example #6
0
        public static NDArrayList Grad(NDArrayList heads, NDArrayList variables, NDArrayList head_grads = null,
                                       bool retain_graph = false, bool create_graph = true, bool train_mode = true)
        {
            var(head_handles, head_grads_handles) = ParseHead(heads, head_grads);

            //var grad_handles = new IntPtr[head_handles.Length];
            //var grad_stypes = new int[head_handles.Length];

            NativeMethods.MXAutogradBackwardEx(head_handles.Length, head_handles, head_grads_handles, variables.Length,
                                               MxUtil.GetNDArrayHandles(variables), Convert.ToInt32(retain_graph),
                                               Convert.ToInt32(create_graph), Convert.ToInt32(train_mode), out var grad_handles, out var grad_stypes);

            var result = new NDArrayList();

            foreach (var item in grad_handles)
            {
                result.Add(new NDArray(item));
            }

            return(result.ToArray());
        }
        public void UpdateMetric(EvalMetric metric, NDArrayList labels, bool pre_sliced = false)
        {
            var labels_slice = new NDArrayList();
            var i            = 0;

            train_execs.Zip(slices, (e, s) =>
            {
                if (!pre_sliced)
                {
                    foreach (var label in labels)
                    {
                        labels_slice.Add(label.Slice(s.Begin, s.End.Value));
                    }
                }
                else
                {
                    labels_slice.Add(labels[i]);
                }

                metric.Update(labels_slice.ToArray(), e.Outputs.ToArray());
                i++;
                return(true);
            });
        }
Example #8
0
 public abstract void Backward(OpGradReq[] req, NDArrayList out_grad, NDArrayList in_data, NDArrayList out_data, NDArrayList in_grad, NDArrayList aux);
Example #9
0
        public Executor SimpleBind(Context ctx, Dictionary <string, OpGradReq> grad_req = null, Dictionary <string, DType> type_dict = null, Dictionary <string, StorageStype> stype_dict = null, Dictionary <string, Context> group2ctx = null, string[] shared_arg_names = null, Executor shared_exec = null, NDArrayDict shared_buffer = null, DataDesc[] kwargs = null)
        {
            int num_provided_arg_types = 0;

            string[] provided_arg_type_names = null;
            int[]    provided_arg_type_data  = null;
            if (type_dict != null)
            {
                provided_arg_type_names = type_dict.Keys.ToArray();
                provided_arg_type_data  = type_dict.Values.Select(x => x.Index).ToArray();
                num_provided_arg_types  = type_dict.Count;
            }

            int num_provided_arg_stypes = 0;

            string[] provided_arg_stype_names = null;
            int[]    provided_arg_stype_data  = null;
            if (stype_dict != null)
            {
                provided_arg_stype_names = stype_dict.Keys.ToArray();
                provided_arg_stype_data  = stype_dict.Values.Select(x => (int)x).ToArray();
                num_provided_arg_stypes  = stype_dict.Count;
            }

            List <int> provided_arg_shape_data = new List <int>();
            List <int> provided_arg_shape_idx  = new List <int>()
            {
                0
            };
            List <string> provided_arg_shape_names = new List <string>();

            foreach (var desc in kwargs)
            {
                provided_arg_shape_names.Add(desc.Name);
                provided_arg_shape_data.AddRange(desc.Shape.Data.Where(x => x > 0).ToList());
                provided_arg_shape_idx.Add(provided_arg_shape_data.Count);
            }

            int provided_req_type_list_len = 0;

            string[] provided_grad_req_names = new string[0];
            string[] provided_grad_req_types = new string[0];
            if (grad_req != null)
            {
                provided_grad_req_names    = grad_req.Keys.ToArray();
                provided_grad_req_types    = grad_req.Values.Select(x => Enum.GetName(x.GetType(), x).ToLower()).ToArray();
                provided_req_type_list_len = grad_req.Count;
            }

            int num_ctx_map_keys = 0;

            string[] ctx_map_keys      = new string[0];
            int[]    ctx_map_dev_types = new int[0];
            int[]    ctx_map_dev_ids   = new int[0];
            if (group2ctx != null)
            {
                ctx_map_keys      = group2ctx.Keys.ToArray();
                ctx_map_dev_types = group2ctx.Values.Select(x => (int)x.GetDeviceType()).ToArray();
                ctx_map_dev_ids   = group2ctx.Values.Select(x => x.GetDeviceId()).ToArray();
                num_ctx_map_keys  = group2ctx.Count;
            }

            string[] shared_arg_name_list = new string[0];
            if (shared_arg_names != null)
            {
                shared_arg_name_list = shared_arg_names;
            }

            unsafe
            {
                int      shared_start          = -1;
                int      shared_buffer_len     = shared_start;
                string[] shared_buffer_names   = null;
                IntPtr[] shared_buffer_handles = null;
                if (shared_buffer.Count > 0)
                {
                    shared_buffer_len     = shared_buffer.Count;
                    shared_buffer_names   = shared_buffer.Keys;
                    shared_buffer_handles = shared_buffer.Values.Handles;
                }

                var shared_exec_handle = shared_exec != null ? shared_exec.Handle : new ExecutorHandle();

                char **        updated_shared_buffer_names;
                SymbolHandle * updated_shared_buffer_handles;
                int            num_in_args;
                SymbolHandle * in_arg_handles;
                SymbolHandle * arg_grad_handles;
                SymbolHandle * aux_state_handles;
                ExecutorHandle exe_handle;
                int            num_aux_states;

                NativeMethods.MXExecutorSimpleBindEx(NativePtr, (int)ctx.GetDeviceType(), ctx.GetDeviceId(), num_ctx_map_keys, ctx_map_keys,
                                                     ctx_map_dev_types, ctx_map_dev_ids, provided_req_type_list_len, provided_grad_req_names, provided_grad_req_types,
                                                     provided_arg_shape_names.Count, provided_arg_shape_names.ToArray(), provided_arg_shape_data.ToArray(),
                                                     provided_arg_shape_idx.ToArray(), num_provided_arg_types, provided_arg_type_names, provided_arg_type_data,
                                                     num_provided_arg_stypes, provided_arg_stype_names, provided_arg_stype_data, shared_arg_name_list.Length,
                                                     shared_arg_name_list, &shared_buffer_len, shared_buffer_names, shared_buffer_handles, &updated_shared_buffer_names,
                                                     &updated_shared_buffer_handles, &num_in_args, &in_arg_handles, &arg_grad_handles,
                                                     &num_aux_states, &aux_state_handles, shared_exec_handle, &exe_handle);

                if (shared_buffer.Count > 0)
                {
                    int l = shared_buffer_len;
                    for (int i = 0; i < l; i++)
                    {
                        string  k = new string(updated_shared_buffer_names[i]);
                        NDArray v = new NDArray(updated_shared_buffer_handles[i]);
                        shared_buffer[k] = v;
                    }
                }

                NDArrayList arg_arrays = new NDArrayList();
                for (int i = 0; i < num_in_args; i++)
                {
                    arg_arrays.Add(new NDArray(in_arg_handles[i]));
                }

                NDArrayList grad_arrays = new NDArrayList();
                for (int i = 0; i < num_in_args; i++)
                {
                    if (arg_grad_handles[i] == IntPtr.Zero)
                    {
                        continue;
                    }
                    grad_arrays.Add(new NDArray(arg_grad_handles[i]));
                }

                NDArrayList aux_arrays = new NDArrayList();
                for (int i = 0; i < num_aux_states; i++)
                {
                    aux_arrays.Add(new NDArray(aux_state_handles[i]));
                }

                var executor = new Executor(exe_handle, ctx, grad_req.Values.ToList(), group2ctx);
                executor.ArgmentArrays   = arg_arrays;
                executor.GradientArrays  = grad_arrays;
                executor.AuxiliaryArrays = aux_arrays;
                executor._Symbol         = this;
                return(executor);
            }
        }
Example #10
0
        public void InferExecutorArrays(Context context,
                                        NDArrayList argArrays,
                                        NDArrayList gradArrays,
                                        IList <OpGradReq> gradReqs,
                                        NDArrayList auxArrays,
                                        NDArrayDict argsMap,
                                        NDArrayDict argGradStore,
                                        IDictionary <string, OpGradReq> gradReqType,
                                        NDArrayDict auxMap)
        {
            if (context == null)
            {
                throw new ArgumentNullException(nameof(context));
            }
            if (argArrays == null)
            {
                throw new ArgumentNullException(nameof(argArrays));
            }
            if (gradArrays == null)
            {
                throw new ArgumentNullException(nameof(gradArrays));
            }
            if (gradReqs == null)
            {
                throw new ArgumentNullException(nameof(gradReqs));
            }
            if (auxArrays == null)
            {
                throw new ArgumentNullException(nameof(auxArrays));
            }
            if (argsMap == null)
            {
                throw new ArgumentNullException(nameof(argsMap));
            }
            if (argGradStore == null)
            {
                throw new ArgumentNullException(nameof(argGradStore));
            }
            if (gradReqType == null)
            {
                throw new ArgumentNullException(nameof(gradReqType));
            }
            if (auxMap == null)
            {
                throw new ArgumentNullException(nameof(auxMap));
            }

            ThrowIfDisposed();

            var argNameList = ListArguments();
            var argShapes   = new Dictionary <string, Shape>();

            foreach (var argName in argNameList)
            {
                if (argsMap[argName] != null)
                {
                    argShapes[argName] = argsMap[argName].Shape;
                }
            }

            var(inShapes, auxShapes, outShapes) = InferShape(argShapes);

            for (var i = 0; i < inShapes.Length; ++i)
            {
                var shape   = inShapes[i];
                var argName = argNameList[i];
                if (argsMap[argName] != null)
                {
                    argArrays.Add(argsMap[argName]);
                }
                else
                {
                    argArrays.Add(new NDArray(shape, false));
                    //NDArray.SampleGaussian(0, 1, argArrays.Last());
                    var argArr = argArrays.Last();
                    nd.Random.Uniform(0, 1, argArr.Shape).CopyTo(argArr);
                }

                if (argGradStore[argName] != null)
                {
                    gradArrays.Add(argGradStore[argName]);
                }
                else
                {
                    gradArrays.Add(new NDArray(shape, false));
                }

                if (gradReqType.TryGetValue(argName, out var value3))
                {
                    gradReqs.Add(value3);
                }
                else if (argName.LastIndexOf("data", StringComparison.InvariantCulture) == argName.Length - 4 ||
                         argName.LastIndexOf("label", StringComparison.InvariantCulture) == argName.Length - 5)
                {
                    gradReqs.Add(OpGradReq.Null);
                }
                else
                {
                    gradReqs.Add(OpGradReq.Write);
                }
            }

            var auxNameList = ListAuxiliaryStates();

            for (var i = 0; i < auxShapes.Length; ++i)
            {
                var shape   = auxShapes[i];
                var auxName = auxNameList[i];
                if (auxMap[auxName] != null)
                {
                    auxArrays.Add(auxMap[auxName]);
                }
                else
                {
                    auxArrays.Add(new NDArray(shape, false));
                    var aux = auxArrays.Last();
                    //NDArray.SampleGaussian(0, 1, auxArrays.Last());
                    nd.Random.Uniform(0, 1, aux.Shape).CopyTo(aux);
                }
            }
        }
Example #11
0
 public static NDArrayOrSymbol[] ToNDArrayOrSymbols(this NDArrayList source)
 {
     return(source.Select(x => new NDArrayOrSymbol(x)).ToArray());
 }
Example #12
0
 internal static void LoadData(DataBatch batch, NDArrayList targets)
 {
     LoadGeneral(batch.Data, targets);
 }
Example #13
0
 public NDArrayOrSymbol(params NDArray[] x)
 {
     IsNDArray = true;
     IsSymbol  = false;
     ndx       = x;
 }
 public void UpdateMetric(EvalMetric eval_metric, NDArrayList labels, bool pre_sliced = false)
 {
     curr_execgrp.UpdateMetric(eval_metric, labels, pre_sliced);
 }
Example #15
0
        public Executor(Symbol symbol,
                        Context context,
                        NDArrayList argmentArrays,
                        NDArrayList gradientArrays,
                        IList <OpGradReq> gradReqs,
                        NDArrayList auxiliaryArrays,
                        IDictionary <string, Context> groupToCtx,
                        Executor sharedExec)
        {
            if (symbol == null)
            {
                throw new ArgumentNullException(nameof(symbol));
            }
            if (context == null)
            {
                throw new ArgumentNullException(nameof(context));
            }
            if (argmentArrays == null)
            {
                throw new ArgumentNullException(nameof(argmentArrays));
            }
            if (gradientArrays == null)
            {
                throw new ArgumentNullException(nameof(gradientArrays));
            }
            if (gradReqs == null)
            {
                throw new ArgumentNullException(nameof(gradReqs));
            }
            if (auxiliaryArrays == null)
            {
                throw new ArgumentNullException(nameof(auxiliaryArrays));
            }
            if (groupToCtx == null)
            {
                throw new ArgumentNullException(nameof(groupToCtx));
            }

            ArgmentArrays   = argmentArrays;
            GradientArrays  = gradientArrays;
            AuxiliaryArrays = auxiliaryArrays;
            _Symbol         = symbol;

            var argHandles   = argmentArrays.Select(array => array.GetHandle()).ToArray();
            var gradHandles  = gradientArrays.Select(array => array.GetHandle()).ToArray();
            var auxHandles   = auxiliaryArrays.Select(array => array.GetHandle()).ToArray();
            var gradReqsUint = gradReqs.Select(s => (uint)s).ToArray();

            var mapKeys  = new string[groupToCtx.Count];
            var devTypes = new int[groupToCtx.Count];
            var devIds   = new int[groupToCtx.Count];
            var keys     = groupToCtx.Keys.ToArray();

            for (var index = 0; index < keys.Length; index++)
            {
                var key = keys[index];
                mapKeys[index] = key;
                var value = groupToCtx[key];
                devTypes[index] = (int)value.GetDeviceType();
                devIds[index]   = value.GetDeviceId();
            }

            var sharedExecHandle = sharedExec?.Handle ?? IntPtr.Zero;

            Logging.CHECK_EQ(NativeMethods.MXExecutorBindEX(symbol.GetHandle(),
                                                            (int)context.GetDeviceType(),
                                                            context.GetDeviceId(),
                                                            (uint)groupToCtx.Count,
                                                            mapKeys,
                                                            devTypes,
                                                            devIds,
                                                            (uint)argHandles.Length,
                                                            argHandles,
                                                            gradHandles,
                                                            gradReqsUint,
                                                            (uint)auxHandles.Length,
                                                            auxHandles,
                                                            sharedExecHandle,
                                                            out var handle), NativeMethods.OK);
            Handle = handle;

            Outputs = new NDArrayList();
            Logging.CHECK_EQ(NativeMethods.MXExecutorOutputs(Handle, out var outSize, out var outArray), 0);
            var outArrayArray = InteropHelper.ToPointerArray(outArray, outSize);

            for (uint i = 0; i < outSize; ++i)
            {
                Outputs.Add(new NDArray(outArrayArray[i]));
            }
        }
Example #16
0
 public abstract NDArray Forward(bool is_train, OpGradReq[] req, NDArrayList in_data, NDArrayList out_data, NDArrayList aux);
Example #17
0
 internal static void LoadLabel(DataBatch batch, NDArrayList targets)
 {
     LoadGeneral(batch.Label, targets);
 }
Example #18
0
        internal static Executor BindExec(Symbol sym, Context ctx, Dictionary <string, Shape> input_shapes,
                                          string[] param_names, bool need_grad = false,
                                          Executor base_exec = null, NDArrayDict shared_data_arrays    = null,
                                          Dictionary <string, DType> input_types = null, Logger logger = null)
        {
            var(arg_shape, _, aux_shape) = sym.InferShape(input_shapes);
            if (arg_shape == null)
            {
                throw new ArgumentNullException("arg_shape");
            }

            if (input_types == null)
            {
                input_types = new Dictionary <string, DType>();
                foreach (var item in input_shapes.Keys)
                {
                    input_types.Add(item, DType.Float32);
                }
            }

            var(arg_types, _, aux_types) = sym.InferType(input_types);

            if (arg_types == null)
            {
                throw new ArgumentNullException("arg_types");
            }

            var arg_arrays  = new NDArrayList();
            var aux_arrays  = new NDArrayList();
            var grad_arrays = need_grad ? new NDArrayDict() : null;

            var arg_names   = sym.ListArguments();
            var needGradSet = new List <string>();

            if (!need_grad)
            {
                needGradSet = new List <string>();
            }
            else
            {
                foreach (var item in arg_names)
                {
                    if (!input_shapes.ContainsKey(item))
                    {
                        needGradSet.Add(item);
                    }
                }

                needGradSet = MxUtil.Set(needGradSet);
            }

            var grad_req = new Dictionary <string, OpGradReq>();

            foreach (var item in arg_names)
            {
                if (needGradSet.Contains(item))
                {
                    grad_req.Add(item, OpGradReq.Write);
                }
            }

            for (var i = 0; i < arg_names.Count; i++)
            {
                var     name     = arg_names[i];
                NDArray arg_arr  = null;
                NDArray grad_arr = null;
                if (!param_names.Contains(name))
                {
                    if (shared_data_arrays != null && shared_data_arrays.Contains(name))
                    {
                        arg_arr = shared_data_arrays[name];
                        if (np.prod(arg_arr.Shape.Data) >= np.prod(arg_shape[i].Data))
                        {
                            if (arg_types[i].Name != arg_arr.DataType.Name)
                            {
                                throw new ArgumentException("arg_type and arg_arr datatype mismatch");
                            }

                            arg_arr = arg_arr.Reshape(arg_shape[i]);
                        }
                        else
                        {
                            var logmsg = new StringBuilder();
                            logmsg.AppendFormat("bucketing: data \"{0}\" has a shape {1}", name, arg_shape[i]);
                            logmsg.AppendFormat(", which is larger than already allocated ");
                            logmsg.AppendFormat("shape {0}", arg_arr.Shape);
                            logmsg.AppendFormat(". Need to re-allocate. Consider putting default_bucket_key " +
                                                "to be the bucket taking the largest input for better memory sharing.");

                            Logger.Warning(logmsg.ToString());

                            arg_arr = nd.Zeros(arg_shape[i], ctx, arg_types[i]);
                            shared_data_arrays[name] = arg_arr;
                        }
                    }
                    else
                    {
                        arg_arr = nd.Zeros(arg_shape[i], ctx, arg_types[i]);
                        if (shared_data_arrays != null)
                        {
                            shared_data_arrays[name] = arg_arr;
                        }
                    }

                    arg_arrays.Add(arg_arr);
                }
                else
                {
                    if (base_exec == null)
                    {
                        arg_arr = nd.Zeros(arg_shape[i], ctx, arg_types[i]);
                        if (needGradSet.Contains(name))
                        {
                            grad_arr          = nd.Zeros(arg_shape[i], ctx, arg_types[i]);
                            grad_arrays[name] = grad_arr;
                        }
                        else
                        {
                            arg_arr = base_exec.ArgmentDictionary()[name];
                            if (arg_arr.Shape != arg_shape[i])
                            {
                                throw new ArgumentException("arg_arr.Shape != arg_shape[i]");
                            }

                            if (arg_arr.DataType != arg_types[i])
                            {
                                throw new ArgumentException("arg_arr.DataType != arg_types[i]");
                            }

                            if (needGradSet.Contains(name))
                            {
                                grad_arrays[name] = base_exec.GradientDictionary()[name];
                            }
                        }

                        arg_arrays.Add(arg_arr);
                    }
                }
            }

            if (base_exec != null)
            {
                for (var i = 0; i < aux_shape.Length; i++)
                {
                    var s = aux_shape[i];
                    var t = aux_types[i];
                    aux_arrays.Add(nd.Zeros(s, ctx, t));
                }
            }
            else
            {
                foreach (var item in base_exec.AuxiliaryDictionary())
                {
                    aux_arrays.Add(item.Value);
                }
            }

            var executor = sym.Bind(ctx, arg_arrays, grad_arrays.Values.ToList(), grad_req.Values.ToList(), aux_arrays);

            return(executor);
        }
Example #19
0
 public static IntPtr[] GetNDArrayHandles(NDArrayList list)
 {
     return(list.Select(x => x.GetHandle()).ToArray());
 }