Exemplo n.º 1
0
        public static TorchTensor Stack(this TorchTensor[] tensors, long dimension)
        {
            var    parray     = new PinnedArray <IntPtr>();
            IntPtr tensorsRef = parray.CreateArray(tensors.Select(p => p.Handle).ToArray());

            return(new TorchTensor(THSTensor_stack(tensorsRef, parray.Array.Length, dimension)));
        }
Exemplo n.º 2
0
        public TorchTensor Forward(params TorchTensor[] tensors)
        {
            var    parray     = new PinnedArray <IntPtr>();
            IntPtr tensorRefs = parray.CreateArray(tensors.Select(p => p.Handle).ToArray());

            return(new TorchTensor(THSJIT_forward(handle, tensorRefs, parray.Array.Length)));
        }
Exemplo n.º 3
0
        public static double clip_grad_norm(this IList <TorchTensor> tensors, double max_norm, double norm_type = 2.0)
        {
            using (var parray = new PinnedArray <IntPtr>()) {
                IntPtr tensorsRef = parray.CreateArray(tensors.Select(p => p.Handle).ToArray());

                return(THSTensor_clip_grad_norm_(tensorsRef, parray.Array.Length, max_norm, norm_type));
            }
        }
Exemplo n.º 4
0
        public static TorchTensor dstack(this IList <TorchTensor> tensors)
        {
            using (var parray = new PinnedArray <IntPtr>()) {
                IntPtr tensorsRef = parray.CreateArray(tensors.Select(p => p.Handle).ToArray());

                var res = THSTensor_dstack(tensorsRef, parray.Array.Length);
                if (res == IntPtr.Zero)
                {
                    Torch.CheckForErrors();
                }
                return(new TorchTensor(res));
            }
        }
Exemplo n.º 5
0
        public static TorchTensor Cat(this TorchTensor[] tensors, long dimension)
        {
            if (tensors.Length == 0)
            {
                throw new ArgumentException(nameof(tensors));
            }
            if (tensors.Length == 1)
            {
                return(tensors[0]);
            }

            var    parray     = new PinnedArray <IntPtr>();
            IntPtr tensorsRef = parray.CreateArray(tensors.Select(p => p.Handle).ToArray());

            return(new TorchTensor(THSTensor_cat(tensorsRef, parray.Array.Length, dimension)));
        }
Exemplo n.º 6
0
        protected Module(params Parameter[] parameters)
        {
            var names     = parameters.Select(p => Marshal.StringToHGlobalAnsi(p.Name)).ToArray();
            var @params   = parameters.Select(p => p.Tensor.Handle).ToArray();
            var withGrads = parameters.Select(p => p.WithGrad).ToArray();

            var namesPinned  = new PinnedArray <IntPtr>();
            var paramsPinned = new PinnedArray <IntPtr>();
            var wGradPinned  = new PinnedArray <bool>();

            var nparray = namesPinned.CreateArray(names);
            var pparray = paramsPinned.CreateArray(@params);
            var gparray = wGradPinned.CreateArray(withGrads);

            handle = new HType(THSNN_new_module(nparray, pparray, gparray, names.Length), true);
        }