public static Tensor transpose(this Tensor x, int dim1, int dim2) { var shape = new int[x.__shape.Length]; for (int i = 0; i < shape.Length; i++) { shape[i] = x.__shape[i]; } var t = shape[dim2]; shape[dim2] = shape[dim1]; shape[dim1] = t; var y = new Tensor(shape, x.dtype, (!torch.autograd.grad_mode.no_grad.prev) && x.requires_grad); switch (x.dtype) { case torch.float16: { MKL.Transpose(x.__half, x.__shape, x.__strides, dim1, dim2, y.__half, y.__shape, y.__strides); if (y.requires_grad) { y.__backward_fn = () => { MKL.dTranspose(x.grad.__half, x.__shape, x.__strides, dim1, dim2, y.grad.__half, y.__shape, y.__strides); if (x.__backward_fn != null) { x.__backward_fn(); } }; } break; } case torch.float32: { MKL.Transpose(x.__float, x.__shape, x.__strides, dim1, dim2, y.__float, y.__shape, y.__strides); if (y.requires_grad) { y.__backward_fn = () => { MKL.dTranspose(x.grad.__float, x.__shape, x.__strides, dim1, dim2, y.grad.__float, y.__shape, y.__strides); if (x.__backward_fn != null) { x.__backward_fn(); } }; } break; } case torch.float64: { MKL.Transpose(x.__double, x.__shape, x.__strides, dim1, dim2, y.__double, y.__shape, y.__strides); if (y.requires_grad) { y.__backward_fn = () => { MKL.dTranspose(x.grad.__double, x.__shape, x.__strides, dim1, dim2, y.grad.__double, y.__shape, y.__strides); if (x.__backward_fn != null) { x.__backward_fn(); } }; } break; } case torch.int8: { MKL.Transpose(x.__int8, x.__shape, x.__strides, dim1, dim2, y.__int8, y.__shape, y.__strides); break; } case torch.uint8: { MKL.Transpose(x.__uint8, x.__shape, x.__strides, dim1, dim2, y.__uint8, y.__shape, y.__strides); break; } case torch.int16: { MKL.Transpose(x.__int16, x.__shape, x.__strides, dim1, dim2, y.__int16, y.__shape, y.__strides); break; } case torch.int32: { MKL.Transpose(x.__int32, x.__shape, x.__strides, dim1, dim2, y.__int32, y.__shape, y.__strides); break; } case torch.int64: { MKL.Transpose(x.__int64, x.__shape, x.__strides, dim1, dim2, y.__int64, y.__shape, y.__strides); break; } case torch.@bool: { MKL.Transpose(x.__bool, x.__shape, x.__strides, dim1, dim2, y.__bool, y.__shape, y.__strides); break; } } return(y); }