Exemplo n.º 1
0
        public void ConvUtils_Batch_Im2Cols()
        {
            var batchSize = 5;

            var filterHeight = 2;
            var filterWidth  = 2;

            var stride  = 1;
            var padding = 0;

            var inputWidth  = 3;
            var inputHeight = 3;
            var inputDepth  = 3;

            var random = new Random(42);
            var input  = Matrix <float> .Build.Random(batchSize, inputWidth *inputHeight *inputDepth, 42);

            var filterGridWidth  = ConvUtils.GetFilterGridLength(inputWidth, filterWidth, stride, padding, BorderMode.Valid);
            var filterGridHeight = ConvUtils.GetFilterGridLength(inputHeight, filterHeight, stride, padding, BorderMode.Valid);

            var k = filterWidth * filterHeight * inputDepth;
            var n = batchSize * filterGridWidth * filterGridHeight;

            var actual = Matrix <float> .Build.Dense(k, n);

            ConvUtils.Batch_Im2Col(input, inputDepth, inputHeight, inputWidth, filterHeight, filterWidth,
                                   padding, padding, stride, stride, BorderMode.Valid, actual);

            Trace.WriteLine(actual.ToString());
            Trace.WriteLine(string.Join(",", actual.ToColumnMajorArray()));

            var expected = Matrix <float> .Build.Dense(k, n, new float[] { 0.408388f, -0.5256838f, -1.416015f, -0.3205518f, 0.8964508f, -0.7706847f, 0.1228476f, 1.401819f, 0.02538049f, 0.4443011f, 0.3597376f, -0.8992839f, -0.5256838f, -0.8472909f, -0.3205518f, 0.168334f, -0.7706847f, -0.2688324f, 1.401819f, 0.5753565f, 0.4443011f, -0.8027026f, -0.8992839f, -0.6576554f, -1.416015f, -0.3205518f, 0.1622419f, -0.8718526f, 0.1228476f, 1.401819f, -0.8105127f, -1.366049f, 0.3597376f, -0.8992839f, -0.09693441f, 0.1117831f, -0.3205518f, 0.168334f, -0.8718526f, 2.464335f, 1.401819f, 0.5753565f, -1.366049f, 0.7328596f, -0.8992839f, -0.6576554f, 0.1117831f, -2.00572f, -0.8723587f, 1.785321f, 0.02021696f, -1.087396f, -0.7902505f, -0.06449615f, -0.4799407f, 0.7755837f, -0.08005979f, -0.163763f, 1.463557f, -0.5891034f, 1.785321f, -0.7747191f, -1.087396f, 1.942754f, -0.06449615f, 0.08791012f, 0.7755837f, 1.559499f, -0.163763f, 1.144407f, -0.5891034f, 1.486937f, 0.02021696f, -1.087396f, 1.386084f, -0.742821f, -0.4799407f, 0.7755837f, -0.93938f, 0.4403726f, 1.463557f, -0.5891034f, 0.2961742f, -1.676224f, -1.087396f, 1.942754f, -0.742821f, 0.3750592f, 0.7755837f, 1.559499f, 0.4403726f, 1.018316f, -0.5891034f, 1.486937f, -1.676224f, 0.5095494f, -1.069885f, 0.1028096f, -0.5383296f, -0.5273784f, -1.362978f, -2.817736f, -0.3506753f, -2.379571f, -0.205604f, -0.8553149f, 1.364009f, 1.960906f, 0.1028096f, 0.06300805f, -0.5273784f, 0.1655738f, -2.817736f, -0.2654593f, -2.379571f, 0.3019102f, -0.8553149f, 0.380102f, 1.960906f, -1.644088f, -0.5383296f, -0.5273784f, 1.407161f, 0.8093351f, -0.3506753f, -2.379571f, -0.1132597f, 0.00849107f, 1.364009f, 1.960906f, -1.907569f, 1.585406f, -0.5273784f, 0.1655738f, 0.8093351f, -0.5961999f, -2.379571f, 0.3019102f, 0.00849107f, -0.9973568f, 1.960906f, -1.644088f, 1.585406f, 0.1513373f, 0.06503697f, -0.6606446f, 1.281655f, 0.2639574f, -0.3281617f, 0.6252633f, -0.9870397f, -0.2739736f, 0.5706424f, -0.6933832f, -0.9226705f, 1.837471f, -0.6606446f, -2.021355f, 0.2639574f, -1.713513f, 0.6252633f, -0.6887951f, -0.2739736f, -0.1102718f, -0.6933832f, -0.2514778f, 1.837471f, 1.012506f, 1.281655f, 0.2639574f, -0.6539868f, -1.332823f, -0.9870397f, -0.2739736f, -0.6845301f, 0.3220822f, -0.9226705f, 1.837471f, 2.257283f, -0.2592173f, 0.2639574f, -1.713513f, -1.332823f, -0.1056926f, -0.2739736f, -0.1102718f, 0.3220822f, 0.02583288f, 1.837471f, 1.012506f, -0.2592173f, 0.5775524f, -0.734176f, 0.5288628f, 0.314957f, 1.331584f, 0.1659867f, -0.0002207408f, -0.3023876f, 0.5506561f, -1.365916f, -0.314546f, -0.6079422f, 0.3696074f, 0.5288628f, -0.7030032f, 1.331584f, 0.7429405f, -0.0002207408f, -2.21279f, 0.5506561f, 0.5057944f, -0.314546f, -1.749763f, 0.3696074f, -0.1464183f, 0.314957f, 1.331584f, 0.2864983f, 0.9384909f, -0.3023876f, 0.5506561f, 1.133461f, 1.134041f, -0.6079422f, 0.3696074f, 0.2236174f, -0.9724815f, 1.331584f, 0.7429405f, 0.9384909f, 1.441582f, 0.5506561f, 0.5057944f, 1.134041f, 0.2430595f, 0.3696074f, -0.1464183f, -0.9724815f, 0.7229092f });

            MatrixAsserts.AreEqual(expected, actual);
        }
Exemplo n.º 2
0
        /// <summary>
        ///
        /// </summary>
        /// <param name="input"></param>
        /// <returns></returns>
        public Matrix <float> Forward(Matrix <float> input)
        {
            m_inputActivations = input;

            // Arrange input item for GEMM version of convolution.
            ConvUtils.Batch_Im2Col(m_inputActivations, InputDepth, InputHeight, InputWidth,
                                   FilterWidth, FilterHeight, m_padHeight, m_padWidth, m_stride, m_stride, BorderMode, Im2Cols);

            // matrix multiplication for convolution
            Weights.Multiply(Im2Cols, Conv);
            Conv.AddColumnWise(Bias, Conv);

            // Return the covolved data to row major and copy  data to output
            ConvUtils.ReshapeConvolutionsToRowMajor(Conv, InputDepth, InputHeight, InputWidth,
                                                    FilterWidth, FilterHeight, m_padHeight, m_padWidth, m_stride, m_stride, BorderMode, OutputActivations);

            return(OutputActivations);
        }