コード例 #1
0
        public AveragePool2dLayerArgument DeserializeBin(int offset, K210BinDeserializeContext context)
        {
            var sr       = context.GetReaderAt(offset);
            var argument = new AveragePool2dLayerArgument
            {
                Flags = sr.Read <K210LayerFlags>(),
                MainMemoryInputAddress  = sr.Read <uint>(),
                MainMemoryOutputAddress = sr.Read <uint>(),
                InputWidth     = sr.Read <uint>(),
                InputHeight    = sr.Read <uint>(),
                InputChannels  = sr.Read <uint>(),
                OutputWidth    = sr.Read <uint>(),
                OutputHeight   = sr.Read <uint>(),
                OutputChannels = sr.Read <uint>(),
                KernelWidth    = sr.Read <uint>(),
                KernelHeight   = sr.Read <uint>(),
                StrideWidth    = sr.Read <uint>(),
                StrideHeight   = sr.Read <uint>(),
                PaddingWidth   = sr.Read <uint>(),
                PaddingHeight  = sr.Read <uint>(),
                Activation     = sr.Read <ActivationFunctionType>()
            };

            return(argument);
        }
コード例 #2
0
        public void Infer(AveragePool2d layer, AveragePool2dLayerArgument argument, InferenceContext context)
        {
            var inputAlloc  = context.MainMemoryMap[layer.Input.Connection.From];
            var outputAlloc = context.MainMemoryMap[layer.Output];

            argument.Flags = K210LayerFlags.MainMemoryOutput;
            argument.MainMemoryInputAddress  = inputAlloc.GetAddress();
            argument.MainMemoryOutputAddress = outputAlloc.GetAddress();
        }
コード例 #3
0
        public void Forward(AveragePool2dLayerArgument argument, ForwardContext context)
        {
            var src  = MemoryMarshal.Cast <byte, float>(context.GetMainRamAt((int)argument.MainMemoryInputAddress));
            var dest = MemoryMarshal.Cast <byte, float>(context.GetMainRamAt((int)argument.MainMemoryOutputAddress));

            int outIdx = 0;

            for (int oc = 0; oc < argument.OutputChannels; oc++)
            {
                var channelSrc = src.Slice((int)(argument.InputWidth * argument.InputHeight * oc));
                for (int oy = 0; oy < argument.OutputHeight; oy++)
                {
                    for (int ox = 0; ox < argument.OutputWidth; ox++)
                    {
                        int   inXOrigin    = (int)(ox * argument.StrideWidth) - (int)argument.PaddingWidth;
                        int   inYOrigin    = (int)(oy * argument.StrideHeight) - (int)argument.PaddingHeight;
                        int   kernelXStart = Math.Max(0, -inXOrigin);
                        int   kernelXEnd   = Math.Min((int)argument.KernelWidth, (int)argument.InputWidth - inXOrigin);
                        int   kernelYStart = Math.Max(0, -inYOrigin);
                        int   kernelYEnd   = Math.Min((int)argument.KernelHeight, (int)argument.InputHeight - inYOrigin);
                        float value        = 0;
                        float kernelCount  = 0;

                        for (int ky = kernelYStart; ky < kernelYEnd; ky++)
                        {
                            for (int kx = kernelXStart; kx < kernelXEnd; kx++)
                            {
                                int inX = inXOrigin + kx;
                                int inY = inYOrigin + ky;
                                value += channelSrc[inY * (int)argument.InputWidth + inX];
                                kernelCount++;
                            }
                        }

                        dest[outIdx++] = value / kernelCount;
                    }
                }
            }
        }