Exemplo n.º 1
0
        public void ConvUtils_ReshapeRowMajorToConvolutionLayout()
        {
            var batchSize = 5;

            var filterHeight = 2;
            var filterWidth  = 2;
            var filterDepth  = 2;

            var stride  = 1;
            var padding = 0;

            var inputWidth  = 3;
            var inputHeight = 3;
            var inputDepth  = 3;

            var filterGridWidth  = ConvUtils.GetFilterGridLength(inputWidth, filterWidth, stride, padding, BorderMode.Valid);
            var filterGridHeight = ConvUtils.GetFilterGridLength(inputHeight, filterHeight, stride, padding, BorderMode.Valid);

            var k   = filterDepth;
            var crs = inputDepth * filterWidth * filterHeight;
            var npq = batchSize * filterGridWidth * filterGridHeight;

            var rowMajor = Matrix <float> .Build.Dense(batchSize, k *filterGridWidth *filterGridHeight, new float[] { -6.260461f, -6.260461f, -6.260461f, -6.260461f, -6.260461f, -7.173417f, -7.173417f, -7.173417f, -7.173417f, -7.173417f, -8.999331f, -8.999331f, -8.999331f, -8.999331f, -8.999331f, -9.912288f, -9.912288f, -9.912288f, -9.912288f, -9.912288f, 87.38299f, 87.38299f, 87.38299f, 87.38299f, 87.38299f, 94.47046f, 94.47046f, 94.47046f, 94.47046f, 94.47046f, 108.6454f, 108.6454f, 108.6454f, 108.6454f, 108.6454f, 115.7329f, 115.7329f, 115.7329f, 115.7329f, 115.7329f });

            var actual = Matrix <float> .Build.Dense(k, npq);

            ConvUtils.ReshapeRowMajorToConvolutionLayout(rowMajor, inputDepth, inputHeight, inputWidth, filterHeight, filterWidth,
                                                         padding, padding, stride, stride, BorderMode.Valid, actual);

            var expected = Matrix <float> .Build.Dense(k, npq, new float[] { -6.260461f, 87.38299f, -7.173417f, 94.47046f, -8.999331f, 108.6454f, -9.912288f, 115.7329f, -6.260461f, 87.38299f, -7.173417f, 94.47046f, -8.999331f, 108.6454f, -9.912288f, 115.7329f, -6.260461f, 87.38299f, -7.173417f, 94.47046f, -8.999331f, 108.6454f, -9.912288f, 115.7329f, -6.260461f, 87.38299f, -7.173417f, 94.47046f, -8.999331f, 108.6454f, -9.912288f, 115.7329f, -6.260461f, 87.38299f, -7.173417f, 94.47046f, -8.999331f, 108.6454f, -9.912288f, 115.7329f });

            MatrixAsserts.AreEqual(expected, actual);
        }