Exemple #1
0
        private static (IntPtr[], IntPtr[]) ParseHead(NDArrayList heads, NDArrayList head_grads)
        {
            IntPtr[] headHandles     = null;
            IntPtr[] headGradHandles = null;

            headHandles = MxUtil.GetNDArrayHandles(heads);

            if (head_grads == null)
            {
                headGradHandles = new IntPtr[heads.Length];
                for (var i = 0; i < headHandles.Length; i++)
                {
                    headGradHandles[i] = IntPtr.Zero;
                }
            }
            else
            {
                if (heads.Length != head_grads.Length)
                {
                    throw new ArgumentException("heads and head_grads must be lists of the same length");
                }

                headGradHandles = MxUtil.GetNDArrayHandles(head_grads);
            }

            return(headHandles, headGradHandles);
        }
Exemple #2
0
        public static void MarkVariables(NDArrayList variables, NDArrayList gradients,
                                         OpGradReq grad_reqs = OpGradReq.Write)
        {
            var gradReqs = new int[variables.Length];

            for (var i = 0; i < gradReqs.Length; i++)
            {
                gradReqs[i] = (int)OpGradReq.Write;
            }

            NativeMethods.MXAutogradMarkVariables(variables.Length, MxUtil.GetNDArrayHandles(variables), gradReqs,
                                                  MxUtil.GetNDArrayHandles(gradients));
        }
Exemple #3
0
        public NDArrayList Call(NDArrayList args)
        {
            NativeMethods.MXInvokeCachedOpEx(handle, args.Length, MxUtil.GetNDArrayHandles(args), out var num_outputs,
                                             out var outputs, out var out_stypes);
            var result = new NDArrayList();

            for (var i = 0; i < num_outputs; i++)
            {
                result.Add(new NDArray(outputs[i]).ToSType((StorageStype)out_stypes[i]));
            }

            return(result.ToArray());
        }
Exemple #4
0
        public static NDArrayList Grad(NDArrayList heads, NDArrayList variables, NDArrayList head_grads = null,
                                       bool retain_graph = false, bool create_graph = true, bool train_mode = true)
        {
            var(head_handles, head_grads_handles) = ParseHead(heads, head_grads);

            //var grad_handles = new IntPtr[head_handles.Length];
            //var grad_stypes = new int[head_handles.Length];

            NativeMethods.MXAutogradBackwardEx(head_handles.Length, head_handles, head_grads_handles, variables.Length,
                                               MxUtil.GetNDArrayHandles(variables), Convert.ToInt32(retain_graph),
                                               Convert.ToInt32(create_graph), Convert.ToInt32(train_mode), out var grad_handles, out var grad_stypes);

            var result = new NDArrayList();

            foreach (var item in grad_handles)
            {
                result.Add(new NDArray(item));
            }

            return(result.ToArray());
        }
Exemple #5
0
 public CachedOp(Symbol sym, NDArrayDict flags)
 {
     handle = IntPtr.Zero;
     NativeMethods.MXCreateCachedOpEx(sym.GetHandle(), flags.Count, flags.Keys.ToArray(),
                                      MxUtil.GetNDArrayHandles(flags.Values.ToArray()), out handle);
 }