Пример #1
0
                    public static Tensor conv2d(Tensor x, Tensor weight, Tensor bias = null, Union <int, Tuple <int, int> > stride = null, Union <int, Tuple <int, int> > padding = null, Union <int, Tuple <int, int> > dilation = null, int groups = 1)
                    {
                        if (x.dtype == torch.@bool)
                        {
                            throw new TorchException("TorchException: nn.functional.conv2d is not implemented for bool tensors.");
                        }
                        if (x.__shape.Length != 4)
                        {
                            throw new TorchException("TorchException: nn.functional.conv2d requires 4D input, but " + x.__shape.Length.ToString() + "D given.");
                        }
                        Tuple <int, int> stride__;

                        if (stride == null)
                        {
                            stride__ = new Tuple <int, int>(1, 1);
                        }
                        else
                        {
                            if ((Tuple <int, int>)stride != null)
                            {
                                stride__ = (Tuple <int, int>)stride;
                            }
                            else
                            {
                                stride__ = new Tuple <int, int>((int)stride, (int)stride);
                            }
                        }
                        Tuple <int, int> padding__;

                        if (padding == null)
                        {
                            padding__ = new Tuple <int, int>(0, 0);
                        }
                        else
                        {
                            if ((Tuple <int, int>)padding != null)
                            {
                                padding__ = (Tuple <int, int>)padding;
                            }
                            else
                            {
                                padding__ = new Tuple <int, int>((int)padding, (int)padding);
                            }
                        }
                        Tuple <int, int> dilation__;

                        if (dilation == null)
                        {
                            dilation__ = new Tuple <int, int>(1, 1);
                        }
                        else
                        {
                            if ((Tuple <int, int>)dilation != null)
                            {
                                dilation__ = (Tuple <int, int>)dilation;
                            }
                            else
                            {
                                dilation__ = new Tuple <int, int>((int)dilation, (int)dilation);
                            }
                        }
                        if ((stride__.Item1 < 1) || (stride__.Item2 < 1))
                        {
                            throw new TorchException("TorchException: stride should be >= 1.");
                        }
                        if ((padding__.Item1 < 0) || (padding__.Item2 < 0))
                        {
                            throw new TorchException("TorchException: padding should be >= 0.");
                        }
                        if ((dilation__.Item1 < 1) || (dilation__.Item2 < 1))
                        {
                            throw new TorchException("TorchException: dilation should be >= 1.");
                        }
                        int srcB      = x.__shape[0];
                        int srcC      = x.__shape[1];
                        int srcH      = x.__shape[2];
                        int srcW      = x.__shape[3];
                        int padX      = padding__.Item2;
                        int padY      = padding__.Item1;
                        int padH      = padding__.Item1;
                        int padW      = padding__.Item2;
                        int dilationX = dilation__.Item2;
                        int dilationY = dilation__.Item1;
                        int strideX   = stride__.Item2;
                        int strideY   = stride__.Item1;
                        int kernelX   = weight.__shape[3];
                        int kernelY   = weight.__shape[2];
                        int dstC      = weight.__shape[0];
                        int dstH      = (srcH + padY + padH - (dilationY * (kernelY - 1) + 1)) / strideY + 1;
                        int dstW      = (srcW + padX + padW - (dilationX * (kernelX - 1) + 1)) / strideX + 1;

                        switch (x.dtype)
                        {
                        case torch.float32:
                        {
                            switch (weight.dtype)
                            {
                            case torch.float32:
                            {
                                var y = new Tensor(new int[] { srcB, dstC, dstH, dstW }, torch.float32, (!torch.autograd.grad_mode.no_grad.prev) && (x.requires_grad || weight.requires_grad));
                                MKL.Conv2d(x.__float, srcB, srcC, srcH, srcW, kernelY, kernelX, dilationY, dilationX, strideY, strideX, padY, padX, padH, padW, groups, weight.__float, y.__float, dstC, dstH, dstW);
                                if ((object)bias != null)
                                {
                                    y = y + torch.unsqueeze(torch.unsqueeze(torch.unsqueeze(bias, 0), 2), 3);
                                }
                                if (y.requires_grad)
                                {
                                    y.__backward_fn = () =>
                                    {
                                        if (x.requires_grad && (!weight.requires_grad))
                                        {
                                            MKL.dConv2d(x.__float,
                                                        x.grad.__float,
                                                        srcB,
                                                        srcC,
                                                        srcH,
                                                        srcW,
                                                        kernelY,
                                                        kernelX,
                                                        dilationY,
                                                        dilationX,
                                                        strideY,
                                                        strideX,
                                                        padY,
                                                        padX,
                                                        padH,
                                                        padW,
                                                        groups,
                                                        weight.__float,
                                                        null,
                                                        y.grad.__float,
                                                        dstC,
                                                        dstH,
                                                        dstW);
                                            if (x.__backward_fn != null)
                                            {
                                                x.__backward_fn();
                                            }
                                            return;
                                        }
                                        if (x.requires_grad && weight.requires_grad)
                                        {
                                            MKL.dConv2d(x.__float,
                                                        x.grad.__float,
                                                        srcB,
                                                        srcC,
                                                        srcH,
                                                        srcW,
                                                        kernelY,
                                                        kernelX,
                                                        dilationY,
                                                        dilationX,
                                                        strideY,
                                                        strideX,
                                                        padY,
                                                        padX,
                                                        padH,
                                                        padW,
                                                        groups,
                                                        weight.__float,
                                                        weight.grad.__float,
                                                        y.grad.__float,
                                                        dstC,
                                                        dstH,
                                                        dstW);
                                            if (x.__backward_fn != null)
                                            {
                                                x.__backward_fn();
                                            }
                                            if (weight.__backward_fn != null)
                                            {
                                                weight.__backward_fn();
                                            }
                                            return;
                                        }
                                        MKL.dConv2d(x.__float,
                                                    null,
                                                    srcB,
                                                    srcC,
                                                    srcH,
                                                    srcW,
                                                    kernelY,
                                                    kernelX,
                                                    dilationY,
                                                    dilationX,
                                                    strideY,
                                                    strideX,
                                                    padY,
                                                    padX,
                                                    padH,
                                                    padW,
                                                    groups,
                                                    weight.__float,
                                                    weight.grad.__float,
                                                    y.grad.__float,
                                                    dstC,
                                                    dstH,
                                                    dstW);
                                        if (x.__backward_fn != null)
                                        {
                                            x.__backward_fn();
                                        }
                                        if (weight.__backward_fn != null)
                                        {
                                            weight.__backward_fn();
                                        }
                                    };
                                }
                                return(y);
                            }
                            }
                            break;
                        }
                        }
                        return(null);
                    }