Exemplo n.º 1
0
            public static Tensor mean(this Tensor x, Union <int, int[]> dim, bool keepdim = false)
            {
                if (x.dtype == torch.@bool)
                {
                    throw new TorchException("TorchException: torch.mean is not implemented for bool tensors.");
                }
                int[] dims;
                if ((int[])dim != null)
                {
                    dims = (int[])dim;
                }
                else
                {
                    dims = new int[] { (int)dim };
                }
                for (int i = 0; i < dims.Length; i++)
                {
                    dims[i] = (dims[i] >= 0) ? dims[i] : (x.__shape.Length + dims[i]);
                }
                var shape_permutation = new int[x.__shape.Length];
                int p1          = 0;
                int p2          = 0;
                int p3          = 0;
                var y_shape     = new int[keepdim ? x.__shape.Length : (x.__shape.Length - dims.Length)];
                var reduce_size = 1;

                for (int i = 0; i < x.__shape.Length; i++)
                {
                    if (dims.Contains(i))
                    {
                        reduce_size *= x.__shape[i];
                        if (keepdim)
                        {
                            y_shape[p3++] = 1;
                        }
                        shape_permutation[p2++ + (x.__shape.Length - dims.Length)] = i;
                    }
                    else
                    {
                        shape_permutation[p1++] = i;
                        y_shape[p3++]           = x.__shape[i];
                    }
                }
                var xp = x.permute(shape_permutation);
                var y  = new Tensor(y_shape, x.dtype, (!torch.autograd.grad_mode.no_grad.prev) && x.requires_grad);

                switch (x.dtype)
                {
                case torch.float16:
                {
                    MKL.Mean(xp.__half, reduce_size, y.__half);
                    if (y.requires_grad)
                    {
                        y.__backward_fn = () =>
                        {
                            MKL.dMean(x.grad.__half, reduce_size, y.grad.__half);
                            if (x.__backward_fn != null)
                            {
                                x.__backward_fn();
                            }
                        };
                    }
                    break;
                }

                case torch.float32:
                {
                    MKL.Mean(xp.__float, reduce_size, y.__float);
                    if (y.requires_grad)
                    {
                        y.__backward_fn = () =>
                        {
                            MKL.dMean(x.grad.__float, reduce_size, y.grad.__float);
                            if (x.__backward_fn != null)
                            {
                                x.__backward_fn();
                            }
                        };
                    }
                    break;
                }

                case torch.float64:
                {
                    MKL.Mean(xp.__double, reduce_size, y.__double);
                    if (y.requires_grad)
                    {
                        y.__backward_fn = () =>
                        {
                            MKL.dMean(x.grad.__double, reduce_size, y.grad.__double);
                            if (x.__backward_fn != null)
                            {
                                x.__backward_fn();
                            }
                        };
                    }
                    break;
                }

                case torch.int8:
                case torch.uint8:
                case torch.int16:
                case torch.int32:
                case torch.int64:
                {
                    throw new TorchException("TorchException: torch.mean is not implemented for integer tensors.");
                }
                }
                return(y);
            }