Exemplo n.º 1
0
 public static TorchTensor From(byte[] rawArray, long[] dimensions)
 {
     unsafe
     {
         fixed(byte *parray = rawArray)
         {
             return(ByteTensor.From((IntPtr)parray, dimensions));
         }
     }
 }
Exemplo n.º 2
0
 public static TorchTensor From(byte[] rawArray, long[] dimensions, bool requiresGrad = false)
 {
     unsafe
     {
         fixed(byte *parray = rawArray)
         {
             return(ByteTensor.From((IntPtr)parray, dimensions, requiresGrad));
         }
     }
 }
Exemplo n.º 3
0
        public static TorchTensor ToTorchTensor <T>(this T[] rawArray, long[] dimensions, bool doCopy = false, bool requiresGrad = false)
        {
            var array = doCopy ? (T[])rawArray.Clone() : rawArray;

            switch (true)
            {
            case bool _ when typeof(T) == typeof(byte): {
                return(ByteTensor.from(array as byte[], dimensions, requiresGrad));;
            }

            case bool _ when typeof(T) == typeof(sbyte): {
                return(Int8Tensor.from(array as sbyte[], dimensions, requiresGrad));;
            }

            case bool _ when typeof(T) == typeof(short): {
                return(Int16Tensor.from(array as short[], dimensions, requiresGrad));;
            }

            case bool _ when typeof(T) == typeof(int): {
                return(Int32Tensor.from(array as int[], dimensions, requiresGrad));
            }

            case bool _ when typeof(T) == typeof(long): {
                return(Int64Tensor.from(array as long[], dimensions, requiresGrad));
            }

            case bool _ when typeof(T) == typeof(double): {
                return(Float64Tensor.from(array as double[], dimensions, requiresGrad));
            }

            case bool _ when typeof(T) == typeof(float): {
                return(Float32Tensor.from(array as float[], dimensions, requiresGrad));
            }

            case bool _ when typeof(T) == typeof(bool): {
                return(BoolTensor.from(array as bool[], dimensions, requiresGrad));
            }

            //case bool _ when typeof(T) == typeof(System.Numerics.Complex):
            //    {
            //        return ComplexFloat64Tensor.from(array as System.Numerics.Complex[], dimensions, requiresGrad);
            //    }
            default: throw new NotImplementedException($"Creating tensor of type {typeof(T)} is not supported.");
            }
        }
Exemplo n.º 4
0
        public static TorchTensor ToTorchTensor <T>(this T scalar, bool requiresGrad = false)
        {
            if (requiresGrad && typeof(T) != typeof(float) && typeof(T) != typeof(double))
            {
                throw new ArgumentException(nameof(requiresGrad), "Only floating point types support gradients.");
            }

            switch (true)
            {
            case bool _ when typeof(T) == typeof(byte):
            {
                return(ByteTensor.From((byte)(object)scalar, requiresGrad));
            }

            case bool _ when typeof(T) == typeof(short):
            {
                return(ShortTensor.From((short)(object)scalar, requiresGrad));
            }

            case bool _ when typeof(T) == typeof(int):
            {
                return(IntTensor.From((int)(object)scalar, requiresGrad));
            }

            case bool _ when typeof(T) == typeof(long):
            {
                return(LongTensor.From((long)(object)scalar, requiresGrad));
            }

            case bool _ when typeof(T) == typeof(double):
            {
                return(DoubleTensor.From((double)(object)scalar, requiresGrad));
            }

            case bool _ when typeof(T) == typeof(float):
            {
                return(FloatTensor.From((float)(object)scalar, requiresGrad));
            }

            default: throw new NotImplementedException($"Creating tensor of type {typeof(T)} is not supported.");
            }
        }
Exemplo n.º 5
0
        public static TorchTensor ToTorchTensor <T>(this T scalar, Device?device = null, bool requiresGrad = false) where T : struct
        {
            if (requiresGrad && typeof(T) != typeof(float) && typeof(T) != typeof(double))
            {
                throw new ArgumentException(nameof(requiresGrad), "Only floating point types support gradients.");
            }

            if (typeof(T) == typeof(byte))
            {
                return(ByteTensor.from((byte)(object)scalar, device, requiresGrad));
            }
            if (typeof(T) == typeof(sbyte))
            {
                return(Int8Tensor.from((sbyte)(object)scalar, device, requiresGrad));
            }
            if (typeof(T) == typeof(short))
            {
                return(Int16Tensor.from((short)(object)scalar, device, requiresGrad));
            }
            if (typeof(T) == typeof(int))
            {
                return(Int32Tensor.from((int)(object)scalar, device, requiresGrad));
            }
            if (typeof(T) == typeof(long))
            {
                return(Int64Tensor.from((long)(object)scalar, device, requiresGrad));
            }
            if (typeof(T) == typeof(double))
            {
                return(Float64Tensor.from((double)(object)scalar, device, requiresGrad));
            }
            if (typeof(T) == typeof(float))
            {
                return(Float32Tensor.from((float)(object)scalar, device, requiresGrad));
            }
            throw new NotImplementedException($"Creating tensor of type {typeof(T)} is not supported.");
        }
Exemplo n.º 6
0
        public static TorchTensor ToTorchTensor <T>(this T scalar)
        {
            switch (true)
            {
            case bool _ when typeof(T) == typeof(byte):
            {
                return(ByteTensor.From((byte)(object)scalar));
            }

            case bool _ when typeof(T) == typeof(short):
            {
                return(ShortTensor.From((short)(object)scalar));
            }

            case bool _ when typeof(T) == typeof(int):
            {
                return(IntTensor.From((int)(object)scalar));
            }

            case bool _ when typeof(T) == typeof(long):
            {
                return(LongTensor.From((long)(object)scalar));
            }

            case bool _ when typeof(T) == typeof(double):
            {
                return(DoubleTensor.From((double)(object)scalar));
            }

            case bool _ when typeof(T) == typeof(float):
            {
                return(FloatTensor.From((float)(object)scalar));
            }

            default: throw new NotImplementedException($"Creating tensor of type {typeof(T)} is not supported.");
            }
        }
Exemplo n.º 7
0
        public static TorchTensor ToTorchTensor <T>(this T[] rawArray, long[] dimensions)
        {
            switch (true)
            {
            case bool _ when typeof(T) == typeof(byte):
            {
                return(ByteTensor.From(rawArray as byte[], dimensions));
            }

            case bool _ when typeof(T) == typeof(short):
            {
                return(ShortTensor.From(rawArray as short[], dimensions));
            }

            case bool _ when typeof(T) == typeof(int):
            {
                return(IntTensor.From(rawArray as int[], dimensions));
            }

            case bool _ when typeof(T) == typeof(long):
            {
                return(LongTensor.From(rawArray as long[], dimensions));
            }

            case bool _ when typeof(T) == typeof(double):
            {
                return(DoubleTensor.From(rawArray as double[], dimensions));
            }

            case bool _ when typeof(T) == typeof(float):
            {
                return(FloatTensor.From(rawArray as float[], dimensions));
            }

            default: throw new NotImplementedException($"Creating tensor of type {typeof(T)} is not supported.");
            }
        }
Exemplo n.º 8
0
        public static TorchTensor ToTorchTensor <T>(this T[] rawArray, long[] dimensions, bool doCopy = false, bool requiresGrad = false)
        {
            switch (true)
            {
            case bool _ when typeof(T) == typeof(byte):
            {
                var result = ByteTensor.From(rawArray as byte[], dimensions, requiresGrad);

                if (doCopy)
                {
                    return(result.Clone());
                }
                return(result);
            }

            case bool _ when typeof(T) == typeof(short):
            {
                var result = ShortTensor.From(rawArray as short[], dimensions, requiresGrad);

                if (doCopy)
                {
                    return(result.Clone());
                }
                return(result);
            }

            case bool _ when typeof(T) == typeof(int):
            {
                var result = IntTensor.From(rawArray as int[], dimensions, requiresGrad);

                if (doCopy)
                {
                    return(result.Clone());
                }
                return(result);
            }

            case bool _ when typeof(T) == typeof(long):
            {
                var result = LongTensor.From(rawArray as long[], dimensions, requiresGrad);

                if (doCopy)
                {
                    return(result.Clone());
                }
                return(result);
            }

            case bool _ when typeof(T) == typeof(double):
            {
                var result = DoubleTensor.From(rawArray as double[], dimensions, requiresGrad);

                if (doCopy)
                {
                    return(result.Clone());
                }
                return(result);
            }

            case bool _ when typeof(T) == typeof(float):
            {
                var result = FloatTensor.From(rawArray as float[], dimensions, requiresGrad);

                if (doCopy)
                {
                    return(result.Clone());
                }
                return(result);
            }

            default: throw new NotImplementedException($"Creating tensor of type {typeof(T)} is not supported.");
            }
        }