Esempio n. 1
0
        public void CreateBoolTensorOnes()
        {
            var shape = new long[] { 2, 2 };

            TorchTensor t = BoolTensor.Ones(shape);

            Assert.Equal(shape, t.Shape);
            Assert.Equal((object)true, t[0, 0].DataItem <bool>());
            Assert.Equal((object)true, t[1, 1].DataItem <bool>());
        }
Esempio n. 2
0
        public static TorchTensor ToTorchTensor <T>(this T[] rawArray, long[] dimensions, bool doCopy = false, bool requiresGrad = false)
        {
            var array = doCopy ? (T[])rawArray.Clone() : rawArray;

            switch (true)
            {
            case bool _ when typeof(T) == typeof(byte): {
                return(ByteTensor.from(array as byte[], dimensions, requiresGrad));;
            }

            case bool _ when typeof(T) == typeof(sbyte): {
                return(Int8Tensor.from(array as sbyte[], dimensions, requiresGrad));;
            }

            case bool _ when typeof(T) == typeof(short): {
                return(Int16Tensor.from(array as short[], dimensions, requiresGrad));;
            }

            case bool _ when typeof(T) == typeof(int): {
                return(Int32Tensor.from(array as int[], dimensions, requiresGrad));
            }

            case bool _ when typeof(T) == typeof(long): {
                return(Int64Tensor.from(array as long[], dimensions, requiresGrad));
            }

            case bool _ when typeof(T) == typeof(double): {
                return(Float64Tensor.from(array as double[], dimensions, requiresGrad));
            }

            case bool _ when typeof(T) == typeof(float): {
                return(Float32Tensor.from(array as float[], dimensions, requiresGrad));
            }

            case bool _ when typeof(T) == typeof(bool): {
                return(BoolTensor.from(array as bool[], dimensions, requiresGrad));
            }

            //case bool _ when typeof(T) == typeof(System.Numerics.Complex):
            //    {
            //        return ComplexFloat64Tensor.from(array as System.Numerics.Complex[], dimensions, requiresGrad);
            //    }
            default: throw new NotImplementedException($"Creating tensor of type {typeof(T)} is not supported.");
            }
        }