Пример #1
0
        public static TorchTensor ToTorchTensor <T>(this T scalar, bool requiresGrad = false)
        {
            if (requiresGrad && typeof(T) != typeof(float) && typeof(T) != typeof(double))
            {
                throw new ArgumentException(nameof(requiresGrad), "Only floating point types support gradients.");
            }

            switch (true)
            {
            case bool _ when typeof(T) == typeof(byte):
            {
                return(ByteTensor.From((byte)(object)scalar, requiresGrad));
            }

            case bool _ when typeof(T) == typeof(short):
            {
                return(ShortTensor.From((short)(object)scalar, requiresGrad));
            }

            case bool _ when typeof(T) == typeof(int):
            {
                return(IntTensor.From((int)(object)scalar, requiresGrad));
            }

            case bool _ when typeof(T) == typeof(long):
            {
                return(LongTensor.From((long)(object)scalar, requiresGrad));
            }

            case bool _ when typeof(T) == typeof(double):
            {
                return(DoubleTensor.From((double)(object)scalar, requiresGrad));
            }

            case bool _ when typeof(T) == typeof(float):
            {
                return(FloatTensor.From((float)(object)scalar, requiresGrad));
            }

            default: throw new NotImplementedException($"Creating tensor of type {typeof(T)} is not supported.");
            }
        }
Пример #2
0
        public static TorchTensor ToTorchTensor <T>(this T scalar)
        {
            switch (true)
            {
            case bool _ when typeof(T) == typeof(byte):
            {
                return(ByteTensor.From((byte)(object)scalar));
            }

            case bool _ when typeof(T) == typeof(short):
            {
                return(ShortTensor.From((short)(object)scalar));
            }

            case bool _ when typeof(T) == typeof(int):
            {
                return(IntTensor.From((int)(object)scalar));
            }

            case bool _ when typeof(T) == typeof(long):
            {
                return(LongTensor.From((long)(object)scalar));
            }

            case bool _ when typeof(T) == typeof(double):
            {
                return(DoubleTensor.From((double)(object)scalar));
            }

            case bool _ when typeof(T) == typeof(float):
            {
                return(FloatTensor.From((float)(object)scalar));
            }

            default: throw new NotImplementedException($"Creating tensor of type {typeof(T)} is not supported.");
            }
        }
Пример #3
0
        public static TorchTensor ToTorchTensor <T>(this T[] rawArray, long[] dimensions)
        {
            switch (true)
            {
            case bool _ when typeof(T) == typeof(byte):
            {
                return(ByteTensor.From(rawArray as byte[], dimensions));
            }

            case bool _ when typeof(T) == typeof(short):
            {
                return(ShortTensor.From(rawArray as short[], dimensions));
            }

            case bool _ when typeof(T) == typeof(int):
            {
                return(IntTensor.From(rawArray as int[], dimensions));
            }

            case bool _ when typeof(T) == typeof(long):
            {
                return(LongTensor.From(rawArray as long[], dimensions));
            }

            case bool _ when typeof(T) == typeof(double):
            {
                return(DoubleTensor.From(rawArray as double[], dimensions));
            }

            case bool _ when typeof(T) == typeof(float):
            {
                return(FloatTensor.From(rawArray as float[], dimensions));
            }

            default: throw new NotImplementedException($"Creating tensor of type {typeof(T)} is not supported.");
            }
        }
Пример #4
0
        public static TorchTensor ToTorchTensor <T>(this T[] rawArray, long[] dimensions, bool doCopy = false, bool requiresGrad = false)
        {
            switch (true)
            {
            case bool _ when typeof(T) == typeof(byte):
            {
                var result = ByteTensor.From(rawArray as byte[], dimensions, requiresGrad);

                if (doCopy)
                {
                    return(result.Clone());
                }
                return(result);
            }

            case bool _ when typeof(T) == typeof(short):
            {
                var result = ShortTensor.From(rawArray as short[], dimensions, requiresGrad);

                if (doCopy)
                {
                    return(result.Clone());
                }
                return(result);
            }

            case bool _ when typeof(T) == typeof(int):
            {
                var result = IntTensor.From(rawArray as int[], dimensions, requiresGrad);

                if (doCopy)
                {
                    return(result.Clone());
                }
                return(result);
            }

            case bool _ when typeof(T) == typeof(long):
            {
                var result = LongTensor.From(rawArray as long[], dimensions, requiresGrad);

                if (doCopy)
                {
                    return(result.Clone());
                }
                return(result);
            }

            case bool _ when typeof(T) == typeof(double):
            {
                var result = DoubleTensor.From(rawArray as double[], dimensions, requiresGrad);

                if (doCopy)
                {
                    return(result.Clone());
                }
                return(result);
            }

            case bool _ when typeof(T) == typeof(float):
            {
                var result = FloatTensor.From(rawArray as float[], dimensions, requiresGrad);

                if (doCopy)
                {
                    return(result.Clone());
                }
                return(result);
            }

            default: throw new NotImplementedException($"Creating tensor of type {typeof(T)} is not supported.");
            }
        }