public static TorchTensor ToTorchTensor <T>(this T[] rawArray, long[] dimensions, bool doCopy = false, bool requiresGrad = false) { var array = doCopy ? (T[])rawArray.Clone() : rawArray; switch (true) { case bool _ when typeof(T) == typeof(byte): { return(ByteTensor.from(array as byte[], dimensions, requiresGrad));; } case bool _ when typeof(T) == typeof(sbyte): { return(Int8Tensor.from(array as sbyte[], dimensions, requiresGrad));; } case bool _ when typeof(T) == typeof(short): { return(Int16Tensor.from(array as short[], dimensions, requiresGrad));; } case bool _ when typeof(T) == typeof(int): { return(Int32Tensor.from(array as int[], dimensions, requiresGrad)); } case bool _ when typeof(T) == typeof(long): { return(Int64Tensor.from(array as long[], dimensions, requiresGrad)); } case bool _ when typeof(T) == typeof(double): { return(Float64Tensor.from(array as double[], dimensions, requiresGrad)); } case bool _ when typeof(T) == typeof(float): { return(Float32Tensor.from(array as float[], dimensions, requiresGrad)); } case bool _ when typeof(T) == typeof(bool): { return(BoolTensor.from(array as bool[], dimensions, requiresGrad)); } //case bool _ when typeof(T) == typeof(System.Numerics.Complex): // { // return ComplexFloat64Tensor.from(array as System.Numerics.Complex[], dimensions, requiresGrad); // } default: throw new NotImplementedException($"Creating tensor of type {typeof(T)} is not supported."); } }
public static TorchTensor ToTorchTensor <T>(this T scalar, Device?device = null, bool requiresGrad = false) where T : struct { if (requiresGrad && typeof(T) != typeof(float) && typeof(T) != typeof(double)) { throw new ArgumentException(nameof(requiresGrad), "Only floating point types support gradients."); } if (typeof(T) == typeof(byte)) { return(ByteTensor.from((byte)(object)scalar, device, requiresGrad)); } if (typeof(T) == typeof(sbyte)) { return(Int8Tensor.from((sbyte)(object)scalar, device, requiresGrad)); } if (typeof(T) == typeof(short)) { return(Int16Tensor.from((short)(object)scalar, device, requiresGrad)); } if (typeof(T) == typeof(int)) { return(Int32Tensor.from((int)(object)scalar, device, requiresGrad)); } if (typeof(T) == typeof(long)) { return(Int64Tensor.from((long)(object)scalar, device, requiresGrad)); } if (typeof(T) == typeof(double)) { return(Float64Tensor.from((double)(object)scalar, device, requiresGrad)); } if (typeof(T) == typeof(float)) { return(Float32Tensor.from((float)(object)scalar, device, requiresGrad)); } throw new NotImplementedException($"Creating tensor of type {typeof(T)} is not supported."); }