public override void Forward(SuperArray x)
        {
            base.Forward(x);
            var(n, c, d, h, w) = x.GetConv3DShape();

            int pad = 0;

            if (Padding == PaddingType.Same)
            {
                pad = 1;
            }
            else if (Padding == PaddingType.Full)
            {
                pad = 2;
            }

            var d_out = (d - PoolSize.Item1) / Strides + 1;
            var h_out = (h - PoolSize.Item2) / Strides + 1;
            var w_out = (w - PoolSize.Item3) / Strides + 1;

            var x_reshaped = x.Reshape(n * c, 1, d, h, w);

            //xCols = ImUtil.Im2Col(x_reshaped, PoolSize, pad, Strides);
            Output = Ops.ArgMax(xCols);
            Output = Output.Reshape(d_out, h_out, w_out, n, c).Transpose(2, 3, 4, 0, 1);
        }
示例#2
0
        public override void Forward(SuperArray x)
        {
            base.Forward(x);
            var(n, c, d, h, w) = x.GetConv3DShape();

            Parameter weight = BuildParam("w", new Shape(Filters, c, KernalSize.Item1, KernalSize.Item2, KernalSize.Item2), KernalInitializer, KernalConstraint, KernalRegularizer);
            Parameter bias   = null;

            if (UseBias)
            {
                bias = BuildParam("b", new Shape(Filters, 1), BiasInitializer, BiasConstraint, BiasRegularizer);
            }

            int pad = 0;

            if (Padding == PaddingType.Same)
            {
                pad = 1;
            }
            else if (Padding == PaddingType.Full)
            {
                pad = 2;
            }

            var d_out = (d - KernalSize.Item1 + 2 * pad) / Strides + 1;
            var h_out = (h - KernalSize.Item2 + 2 * pad) / Strides + 1;
            var w_out = (w - KernalSize.Item3 + 2 * pad) / Strides + 1;

            //xCols = ImUtil.Im2Col(x, KernalSize, pad, Strides);
            var wRows = weight.Data.Reshape(Filters, -1);

            Output = Ops.Dot(wRows, xCols);
            if (UseBias)
            {
                Output = Output + bias.Data;
            }

            Output = Output.Reshape(Filters, d_out, h_out, w_out, n).Transpose(4, 0, 1, 2, 3);
        }