예제 #1
0
        protected override Tensor ApplyPadding(Tensor X, int[] pad, string kernelName, float constant = 0.0f)
        {
            Assert.AreEqual(pad.Length, 4);

            var O  = NewTensor(X.shape.ApplyBorder(pad));
            var fn = BestKernel(ComputeKernelLibrary.Padding(X.shape, O.shape, kernelName));

            fn.SetTensor("X", X.shape, Pin(X).buffer);
            fn.SetTensor("O", O.shape, Pin(O).buffer);

            fn.shader.SetInts("_Pad", pad);

            if (kernelName == "Border2D")
            {
                // NOTE: negative "pad" variable will crop X tensor
                int croppedWidth  = X.width - Math.Max(0, -pad[2]);
                int croppedHeight = X.height - Math.Max(0, -pad[3]);
                var croppedSize   = new int[] { 0, 0 };
                croppedSize[0] = croppedWidth;
                croppedSize[1] = croppedHeight;

                fn.shader.SetInts("_Pool", croppedSize);
                fn.shader.SetFloat("_Beta", constant);
            }

            fn.Dispatch();
            return(O);
        }