public static Tensor cat(Tensor[] x, int dim) { var dtype = x[0].dtype; var data = new object[x.Length]; var shapes = new int[x.Length][]; var strides = new int[x.Length][]; var y_shape = new int[x[0].__shape.Length]; for (int i = 0; i < x[0].__shape.Length; i++) { y_shape[i] = x[0].__shape[i]; } var requires_grad = false; for (int i = 0; i < x.Length; i++) { var t = x[i]; if (t.requires_grad) { requires_grad = true; } shapes[i] = t.__shape; strides[i] = t.__strides; if (i > 0) { if (t.__shape.Length != y_shape.Length) { throw new TorchException("TorchException: Invalid tensor sizes for torch.cat: the number of dimensions must be the same."); } for (int j = 0; j < t.__shape.Length; j++) { if (j == dim) { y_shape[j] += t.__shape[j]; continue; } if (t.__shape[j] != y_shape[j]) { throw new TorchException(string.Format("TorchException: Invalid tensor sizes for torch.cat: impossible to match shapes {0} and {1} of dimention {2} of tensors {3} and {4}.", t.__shape[j], y_shape[j], j, i - 1, i)); } } } switch (t.dtype) { case torch.uint8: { if (dtype == torch.@bool) { throw new TorchException("TorchException: unable to cat bool tensor and non-bool tensor."); } data[i] = t.__uint8; break; } case torch.int8: { if (dtype == torch.@bool) { throw new TorchException("TorchException: unable to cat bool tensor and non-bool tensor."); } data[i] = t.__int8; if (dtype == torch.uint8) { dtype = torch.int8; } break; } case torch.int16: { if (dtype == torch.@bool) { throw new TorchException("TorchException: unable to cat bool tensor and non-bool tensor."); } data[i] = t.__int16; if (@sizeof(dtype) == 1) { dtype = torch.int16; } break; } case torch.int32: { if (dtype == torch.@bool) { throw new TorchException("TorchException: unable to cat bool tensor and non-bool tensor."); } data[i] = t.__int32; if (@sizeof(dtype) < 4) { dtype = torch.int32; } break; } case torch.int64: { if (dtype == torch.@bool) { throw new TorchException("TorchException: unable to cat bool tensor and non-bool tensor."); } data[i] = t.__int64; if (@sizeof(dtype) == 8) { dtype = torch.int64; } break; } case torch.half: { if (dtype == torch.@bool) { throw new TorchException("TorchException: unable to cat bool tensor and non-bool tensor."); } data[i] = t.__half; if (!dtype.is_floating_point()) { switch (@sizeof(dtype)) { case 1: case 2: { dtype = torch.float16; break; } case 4: { dtype = torch.float32; break; } case 8: { dtype = torch.float64; break; } } } break; } case torch.float32: { if (dtype == torch.@bool) { throw new TorchException("TorchException: unable to cat bool tensor and non-bool tensor."); } data[i] = t.__float; if (!dtype.is_floating_point()) { switch (@sizeof(dtype)) { case 1: case 2: case 4: { dtype = torch.float32; break; } case 8: { dtype = torch.float64; break; } } } else { if (dtype == torch.float16) { dtype = torch.float32; break; } if (dtype == torch.float64) { dtype = torch.float64; break; } } break; } case torch.float64: { if (dtype == torch.@bool) { throw new TorchException("TorchException: unable to cat bool tensor and non-bool tensor."); } data[i] = t.__double; dtype = torch.float64; break; } case torch.@bool: { if (dtype != torch.@bool) { throw new TorchException("TorchException: unable to cat bool tensor and non-bool tensor."); } data[i] = t.__bool; break; } } } var y = new Tensor(y_shape, dtype, (!torch.autograd.grad_mode.no_grad.prev) && requires_grad); switch (dtype) { case torch.float16: { MKL.Cat(data, shapes, strides, dim, y.__half, y.__shape, y.__strides); break; } case torch.float32: { MKL.Cat(data, shapes, strides, dim, y.__float, y.__shape, y.__strides); break; } case torch.float64: { MKL.Cat(data, shapes, strides, dim, y.__double, y.__shape, y.__strides); break; } case torch.int8: { MKL.Cat(data, shapes, strides, dim, y.__int8, y.__shape, y.__strides); break; } case torch.uint8: { MKL.Cat(data, shapes, strides, dim, y.__uint8, y.__shape, y.__strides); break; } case torch.int16: { MKL.Cat(data, shapes, strides, dim, y.__int16, y.__shape, y.__strides); break; } case torch.int32: { MKL.Cat(data, shapes, strides, dim, y.__int32, y.__shape, y.__strides); break; } case torch.int64: { MKL.Cat(data, shapes, strides, dim, y.__int64, y.__shape, y.__strides); break; } case torch.@bool: { MKL.Cat(data, shapes, strides, dim, y.__bool, y.__shape, y.__strides); break; } } return(y); }