Exemplo n.º 1
0
        public static TorchTensor ToTorchTensor <T>(this T[] rawArray, long[] dimensions, bool doCopy = false, bool requiresGrad = false)
        {
            var array = doCopy ? (T[])rawArray.Clone() : rawArray;

            switch (true)
            {
            case bool _ when typeof(T) == typeof(byte): {
                return(ByteTensor.from(array as byte[], dimensions, requiresGrad));;
            }

            case bool _ when typeof(T) == typeof(sbyte): {
                return(Int8Tensor.from(array as sbyte[], dimensions, requiresGrad));;
            }

            case bool _ when typeof(T) == typeof(short): {
                return(Int16Tensor.from(array as short[], dimensions, requiresGrad));;
            }

            case bool _ when typeof(T) == typeof(int): {
                return(Int32Tensor.from(array as int[], dimensions, requiresGrad));
            }

            case bool _ when typeof(T) == typeof(long): {
                return(Int64Tensor.from(array as long[], dimensions, requiresGrad));
            }

            case bool _ when typeof(T) == typeof(double): {
                return(Float64Tensor.from(array as double[], dimensions, requiresGrad));
            }

            case bool _ when typeof(T) == typeof(float): {
                return(Float32Tensor.from(array as float[], dimensions, requiresGrad));
            }

            case bool _ when typeof(T) == typeof(bool): {
                return(BoolTensor.from(array as bool[], dimensions, requiresGrad));
            }

            //case bool _ when typeof(T) == typeof(System.Numerics.Complex):
            //    {
            //        return ComplexFloat64Tensor.from(array as System.Numerics.Complex[], dimensions, requiresGrad);
            //    }
            default: throw new NotImplementedException($"Creating tensor of type {typeof(T)} is not supported.");
            }
        }
Exemplo n.º 2
0
        public static TorchTensor ToTorchTensor <T>(this T scalar, Device?device = null, bool requiresGrad = false) where T : struct
        {
            if (requiresGrad && typeof(T) != typeof(float) && typeof(T) != typeof(double))
            {
                throw new ArgumentException(nameof(requiresGrad), "Only floating point types support gradients.");
            }

            if (typeof(T) == typeof(byte))
            {
                return(ByteTensor.from((byte)(object)scalar, device, requiresGrad));
            }
            if (typeof(T) == typeof(sbyte))
            {
                return(Int8Tensor.from((sbyte)(object)scalar, device, requiresGrad));
            }
            if (typeof(T) == typeof(short))
            {
                return(Int16Tensor.from((short)(object)scalar, device, requiresGrad));
            }
            if (typeof(T) == typeof(int))
            {
                return(Int32Tensor.from((int)(object)scalar, device, requiresGrad));
            }
            if (typeof(T) == typeof(long))
            {
                return(Int64Tensor.from((long)(object)scalar, device, requiresGrad));
            }
            if (typeof(T) == typeof(double))
            {
                return(Float64Tensor.from((double)(object)scalar, device, requiresGrad));
            }
            if (typeof(T) == typeof(float))
            {
                return(Float32Tensor.from((float)(object)scalar, device, requiresGrad));
            }
            throw new NotImplementedException($"Creating tensor of type {typeof(T)} is not supported.");
        }