Пример #1
0
            public static Tensor cat(Tensor[] x, int dim)
            {
                var dtype   = x[0].dtype;
                var data    = new object[x.Length];
                var shapes  = new int[x.Length][];
                var strides = new int[x.Length][];
                var y_shape = new int[x[0].__shape.Length];

                for (int i = 0; i < x[0].__shape.Length; i++)
                {
                    y_shape[i] = x[0].__shape[i];
                }
                var requires_grad = false;

                for (int i = 0; i < x.Length; i++)
                {
                    var t = x[i];
                    if (t.requires_grad)
                    {
                        requires_grad = true;
                    }
                    shapes[i]  = t.__shape;
                    strides[i] = t.__strides;
                    if (i > 0)
                    {
                        if (t.__shape.Length != y_shape.Length)
                        {
                            throw new TorchException("TorchException: Invalid tensor sizes for torch.cat: the number of dimensions must be the same.");
                        }
                        for (int j = 0; j < t.__shape.Length; j++)
                        {
                            if (j == dim)
                            {
                                y_shape[j] += t.__shape[j];
                                continue;
                            }
                            if (t.__shape[j] != y_shape[j])
                            {
                                throw new TorchException(string.Format("TorchException: Invalid tensor sizes for torch.cat: impossible to match shapes {0} and {1} of dimention {2} of tensors {3} and {4}.", t.__shape[j], y_shape[j], j, i - 1, i));
                            }
                        }
                    }
                    switch (t.dtype)
                    {
                    case torch.uint8:
                    {
                        if (dtype == torch.@bool)
                        {
                            throw new TorchException("TorchException: unable to cat bool tensor and non-bool tensor.");
                        }
                        data[i] = t.__uint8;
                        break;
                    }

                    case torch.int8:
                    {
                        if (dtype == torch.@bool)
                        {
                            throw new TorchException("TorchException: unable to cat bool tensor and non-bool tensor.");
                        }
                        data[i] = t.__int8;
                        if (dtype == torch.uint8)
                        {
                            dtype = torch.int8;
                        }
                        break;
                    }

                    case torch.int16:
                    {
                        if (dtype == torch.@bool)
                        {
                            throw new TorchException("TorchException: unable to cat bool tensor and non-bool tensor.");
                        }
                        data[i] = t.__int16;
                        if (@sizeof(dtype) == 1)
                        {
                            dtype = torch.int16;
                        }
                        break;
                    }

                    case torch.int32:
                    {
                        if (dtype == torch.@bool)
                        {
                            throw new TorchException("TorchException: unable to cat bool tensor and non-bool tensor.");
                        }
                        data[i] = t.__int32;
                        if (@sizeof(dtype) < 4)
                        {
                            dtype = torch.int32;
                        }
                        break;
                    }

                    case torch.int64:
                    {
                        if (dtype == torch.@bool)
                        {
                            throw new TorchException("TorchException: unable to cat bool tensor and non-bool tensor.");
                        }
                        data[i] = t.__int64;
                        if (@sizeof(dtype) == 8)
                        {
                            dtype = torch.int64;
                        }
                        break;
                    }

                    case torch.half:
                    {
                        if (dtype == torch.@bool)
                        {
                            throw new TorchException("TorchException: unable to cat bool tensor and non-bool tensor.");
                        }
                        data[i] = t.__half;
                        if (!dtype.is_floating_point())
                        {
                            switch (@sizeof(dtype))
                            {
                            case 1:
                            case 2:
                            {
                                dtype = torch.float16;
                                break;
                            }

                            case 4:
                            {
                                dtype = torch.float32;
                                break;
                            }

                            case 8:
                            {
                                dtype = torch.float64;
                                break;
                            }
                            }
                        }
                        break;
                    }

                    case torch.float32:
                    {
                        if (dtype == torch.@bool)
                        {
                            throw new TorchException("TorchException: unable to cat bool tensor and non-bool tensor.");
                        }
                        data[i] = t.__float;
                        if (!dtype.is_floating_point())
                        {
                            switch (@sizeof(dtype))
                            {
                            case 1:
                            case 2:
                            case 4:
                            {
                                dtype = torch.float32;
                                break;
                            }

                            case 8:
                            {
                                dtype = torch.float64;
                                break;
                            }
                            }
                        }
                        else
                        {
                            if (dtype == torch.float16)
                            {
                                dtype = torch.float32;
                                break;
                            }
                            if (dtype == torch.float64)
                            {
                                dtype = torch.float64;
                                break;
                            }
                        }
                        break;
                    }

                    case torch.float64:
                    {
                        if (dtype == torch.@bool)
                        {
                            throw new TorchException("TorchException: unable to cat bool tensor and non-bool tensor.");
                        }
                        data[i] = t.__double;
                        dtype   = torch.float64;
                        break;
                    }

                    case torch.@bool:
                    {
                        if (dtype != torch.@bool)
                        {
                            throw new TorchException("TorchException: unable to cat bool tensor and non-bool tensor.");
                        }
                        data[i] = t.__bool;
                        break;
                    }
                    }
                }
                var y = new Tensor(y_shape, dtype, (!torch.autograd.grad_mode.no_grad.prev) && requires_grad);

                switch (dtype)
                {
                case torch.float16:
                {
                    MKL.Cat(data, shapes, strides, dim, y.__half, y.__shape, y.__strides);
                    break;
                }

                case torch.float32:
                {
                    MKL.Cat(data, shapes, strides, dim, y.__float, y.__shape, y.__strides);
                    break;
                }

                case torch.float64:
                {
                    MKL.Cat(data, shapes, strides, dim, y.__double, y.__shape, y.__strides);
                    break;
                }

                case torch.int8:
                {
                    MKL.Cat(data, shapes, strides, dim, y.__int8, y.__shape, y.__strides);
                    break;
                }

                case torch.uint8:
                {
                    MKL.Cat(data, shapes, strides, dim, y.__uint8, y.__shape, y.__strides);
                    break;
                }

                case torch.int16:
                {
                    MKL.Cat(data, shapes, strides, dim, y.__int16, y.__shape, y.__strides);
                    break;
                }

                case torch.int32:
                {
                    MKL.Cat(data, shapes, strides, dim, y.__int32, y.__shape, y.__strides);
                    break;
                }

                case torch.int64:
                {
                    MKL.Cat(data, shapes, strides, dim, y.__int64, y.__shape, y.__strides);
                    break;
                }

                case torch.@bool:
                {
                    MKL.Cat(data, shapes, strides, dim, y.__bool, y.__shape, y.__strides);
                    break;
                }
                }
                return(y);
            }