Пример #1
0
        private static (IntPtr[], IntPtr[]) ParseHead(NDArrayList heads, NDArrayList head_grads)
        {
            IntPtr[] headHandles     = null;
            IntPtr[] headGradHandles = null;

            headHandles = MxUtil.GetNDArrayHandles(heads);

            if (head_grads == null)
            {
                headGradHandles = new IntPtr[heads.Length];
                for (var i = 0; i < headHandles.Length; i++)
                {
                    headGradHandles[i] = IntPtr.Zero;
                }
            }
            else
            {
                if (heads.Length != head_grads.Length)
                {
                    throw new ArgumentException("heads and head_grads must be lists of the same length");
                }

                headGradHandles = MxUtil.GetNDArrayHandles(head_grads);
            }

            return(headHandles, headGradHandles);
        }
Пример #2
0
        public static void MarkVariables(NDArrayList variables, NDArrayList gradients,
                                         OpGradReq grad_reqs = OpGradReq.Write)
        {
            var gradReqs = new int[variables.Length];

            for (var i = 0; i < gradReqs.Length; i++)
            {
                gradReqs[i] = (int)OpGradReq.Write;
            }

            NativeMethods.MXAutogradMarkVariables(variables.Length, MxUtil.GetNDArrayHandles(variables), gradReqs,
                                                  MxUtil.GetNDArrayHandles(gradients));
        }
Пример #3
0
        public NDArrayList Call(NDArrayList args)
        {
            NativeMethods.MXInvokeCachedOpEx(handle, args.Length, MxUtil.GetNDArrayHandles(args), out var num_outputs,
                                             out var outputs, out var out_stypes);
            var result = new NDArrayList();

            for (var i = 0; i < num_outputs; i++)
            {
                result.Add(new NDArray(outputs[i]).ToSType((StorageStype)out_stypes[i]));
            }

            return(result.ToArray());
        }
Пример #4
0
        public static NDArrayList Grad(NDArrayList heads, NDArrayList variables, NDArrayList head_grads = null,
                                       bool retain_graph = false, bool create_graph = true, bool train_mode = true)
        {
            var(head_handles, head_grads_handles) = ParseHead(heads, head_grads);

            //var grad_handles = new IntPtr[head_handles.Length];
            //var grad_stypes = new int[head_handles.Length];

            NativeMethods.MXAutogradBackwardEx(head_handles.Length, head_handles, head_grads_handles, variables.Length,
                                               MxUtil.GetNDArrayHandles(variables), Convert.ToInt32(retain_graph),
                                               Convert.ToInt32(create_graph), Convert.ToInt32(train_mode), out var grad_handles, out var grad_stypes);

            var result = new NDArrayList();

            foreach (var item in grad_handles)
            {
                result.Add(new NDArray(item));
            }

            return(result.ToArray());
        }
Пример #5
0
        internal static Executor BindExec(Symbol sym, Context ctx, Dictionary <string, Shape> input_shapes,
                                          string[] param_names, bool need_grad = false,
                                          Executor base_exec = null, NDArrayDict shared_data_arrays    = null,
                                          Dictionary <string, DType> input_types = null, Logger logger = null)
        {
            var(arg_shape, _, aux_shape) = sym.InferShape(input_shapes);
            if (arg_shape == null)
            {
                throw new ArgumentNullException("arg_shape");
            }

            if (input_types == null)
            {
                input_types = new Dictionary <string, DType>();
                foreach (var item in input_shapes.Keys)
                {
                    input_types.Add(item, DType.Float32);
                }
            }

            var(arg_types, _, aux_types) = sym.InferType(input_types);

            if (arg_types == null)
            {
                throw new ArgumentNullException("arg_types");
            }

            var arg_arrays  = new NDArrayList();
            var aux_arrays  = new NDArrayList();
            var grad_arrays = need_grad ? new NDArrayDict() : null;

            var arg_names   = sym.ListArguments();
            var needGradSet = new List <string>();

            if (!need_grad)
            {
                needGradSet = new List <string>();
            }
            else
            {
                foreach (var item in arg_names)
                {
                    if (!input_shapes.ContainsKey(item))
                    {
                        needGradSet.Add(item);
                    }
                }

                needGradSet = MxUtil.Set(needGradSet);
            }

            var grad_req = new Dictionary <string, OpGradReq>();

            foreach (var item in arg_names)
            {
                if (needGradSet.Contains(item))
                {
                    grad_req.Add(item, OpGradReq.Write);
                }
            }

            for (var i = 0; i < arg_names.Count; i++)
            {
                var     name     = arg_names[i];
                NDArray arg_arr  = null;
                NDArray grad_arr = null;
                if (!param_names.Contains(name))
                {
                    if (shared_data_arrays != null && shared_data_arrays.Contains(name))
                    {
                        arg_arr = shared_data_arrays[name];
                        if (np.prod(arg_arr.Shape.Data) >= np.prod(arg_shape[i].Data))
                        {
                            if (arg_types[i].Name != arg_arr.DataType.Name)
                            {
                                throw new ArgumentException("arg_type and arg_arr datatype mismatch");
                            }

                            arg_arr = arg_arr.Reshape(arg_shape[i]);
                        }
                        else
                        {
                            var logmsg = new StringBuilder();
                            logmsg.AppendFormat("bucketing: data \"{0}\" has a shape {1}", name, arg_shape[i]);
                            logmsg.AppendFormat(", which is larger than already allocated ");
                            logmsg.AppendFormat("shape {0}", arg_arr.Shape);
                            logmsg.AppendFormat(". Need to re-allocate. Consider putting default_bucket_key " +
                                                "to be the bucket taking the largest input for better memory sharing.");

                            Logger.Warning(logmsg.ToString());

                            arg_arr = nd.Zeros(arg_shape[i], ctx, arg_types[i]);
                            shared_data_arrays[name] = arg_arr;
                        }
                    }
                    else
                    {
                        arg_arr = nd.Zeros(arg_shape[i], ctx, arg_types[i]);
                        if (shared_data_arrays != null)
                        {
                            shared_data_arrays[name] = arg_arr;
                        }
                    }

                    arg_arrays.Add(arg_arr);
                }
                else
                {
                    if (base_exec == null)
                    {
                        arg_arr = nd.Zeros(arg_shape[i], ctx, arg_types[i]);
                        if (needGradSet.Contains(name))
                        {
                            grad_arr          = nd.Zeros(arg_shape[i], ctx, arg_types[i]);
                            grad_arrays[name] = grad_arr;
                        }
                        else
                        {
                            arg_arr = base_exec.ArgmentDictionary()[name];
                            if (arg_arr.Shape != arg_shape[i])
                            {
                                throw new ArgumentException("arg_arr.Shape != arg_shape[i]");
                            }

                            if (arg_arr.DataType != arg_types[i])
                            {
                                throw new ArgumentException("arg_arr.DataType != arg_types[i]");
                            }

                            if (needGradSet.Contains(name))
                            {
                                grad_arrays[name] = base_exec.GradientDictionary()[name];
                            }
                        }

                        arg_arrays.Add(arg_arr);
                    }
                }
            }

            if (base_exec != null)
            {
                for (var i = 0; i < aux_shape.Length; i++)
                {
                    var s = aux_shape[i];
                    var t = aux_types[i];
                    aux_arrays.Add(nd.Zeros(s, ctx, t));
                }
            }
            else
            {
                foreach (var item in base_exec.AuxiliaryDictionary())
                {
                    aux_arrays.Add(item.Value);
                }
            }

            var executor = sym.Bind(ctx, arg_arrays, grad_arrays.Values.ToList(), grad_req.Values.ToList(), aux_arrays);

            return(executor);
        }
Пример #6
0
 public CachedOp(Symbol sym, NDArrayDict flags)
 {
     handle = IntPtr.Zero;
     NativeMethods.MXCreateCachedOpEx(sym.GetHandle(), flags.Count, flags.Keys.ToArray(),
                                      MxUtil.GetNDArrayHandles(flags.Values.ToArray()), out handle);
 }