Esempio n. 1
0
        public static Conv2DInfo computePool2DInfo(int[] inShape, int[] filterSize,
                                                   int[] strides, PadType pad, roundingMode roundingMode,
                                                   ConvDataFormat dataFormat = ConvDataFormat.channelsLast, Nullable <int> padValue = null)
        {
            var filterHeight = filterSize[0];
            var filterWidth  = filterSize[1];

            int[] filterShape;
            if (dataFormat == ConvDataFormat.channelsLast)
            {
                filterShape = new int[] { filterHeight, filterWidth, inShape[3], inShape[3] };
            }
            else if (dataFormat == ConvDataFormat.channelsFirst)
            {
                filterShape = new int[] { filterHeight, filterWidth, inShape[1], inShape[1] };
            }
            else
            {
                throw new Exception("Unknown dataFormat");
            }
            var dilations = 1;

            return(computeConv2DInfo(
                       inShape, filterShape, strides, new int[] { dilations }, pad, roundingMode, false,
                       dataFormat, padValue));
        }
Esempio n. 2
0
        public static Conv2DInfo computeConv2DInfo(int[] inShape, int[] filterShape,
                                                   int[] strides, int[] dilations, PadType pad,
                                                   roundingMode roundingMode = roundingMode.none, bool depthwise = false,
                                                   ConvDataFormat dataFormat = ConvDataFormat.channelsLast, Nullable <int> padValue = null)
        {
            var batchSize  = -1;
            var inHeight   = -1;
            var inWidth    = -1;
            var inChannels = -1;

            if (dataFormat == ConvDataFormat.channelsLast)
            {
                batchSize  = inShape[0];
                inHeight   = inShape[1];
                inWidth    = inShape[2];
                inChannels = inShape[3];
            }
            else
            {
                batchSize  = inShape[0];
                inChannels = inShape[1];
                inHeight   = inShape[2];
                inWidth    = inShape[3];
            }

            var filterHeight   = filterShape[0];
            var filterWidth    = filterShape[1];
            var filterChannels = filterShape[3];

            var strideHeight = strides[0];
            var strideWidth  = strides[1];

            var dilationHeight = dilations[0];
            int dilationWidth  = 0;

            if (dilations.Length > 1)
            {
                dilationWidth = dilations[1];
            }
            else
            {
                dilationWidth = dilations[0];
            }

            var effectiveFilterHeight = getEffectiveFilterSize(filterHeight, dilationHeight);
            var effectiveFilterWidth  =
                getEffectiveFilterSize(filterWidth, dilationWidth);



            var d = getPadAndOutInfo(
                pad, inHeight, inWidth, strideHeight, strideWidth, effectiveFilterHeight,
                effectiveFilterWidth, roundingMode, padValue);
            var padInfo     = d.Item1;
            var outHeight   = d.Item2;
            var outWidth    = d.Item3;
            var outChannels = depthwise ? filterChannels * inChannels : filterChannels;

            int[] outShape = null;
            if (dataFormat == ConvDataFormat.channelsFirst)
            {
                outShape = new int[] { batchSize, outChannels, outHeight, outWidth };
            }
            else if (dataFormat == ConvDataFormat.channelsLast)
            {
                outShape = new int[] { batchSize, outHeight, outWidth, outChannels };
            }


            return(new Conv2DInfo()
            {
                batchSize = batchSize,
                dataFormat = dataFormat,
                inHeight = inHeight,
                inWidth = inWidth,
                inChannels = inChannels,
                outHeight = outHeight,
                outWidth = outWidth,
                outChannels = outChannels,
                padInfo = padInfo,
                strideHeight = strideHeight,
                strideWidth = strideWidth,
                filterHeight = filterHeight,
                filterWidth = filterWidth,
                dilationHeight = dilationHeight,
                dilationWidth = dilationWidth,
                inShape = inShape,
                outShape = outShape,
                filterShape = filterShape
            });
        }
Esempio n. 3
0
        /// <summary>
        /// Add a bias to a tensor.
        /// </summary>
        /// <param name="x">The tensor to add the bias to.</param>
        /// <param name="bias">The bias to add to `x`. Must be 1D or the same rank as `x`.</param>
        /// <param name="dataFormat"></param>
        /// <returns> Result of the bias adding.</returns>
        public static Tensor biasAdd(this Tensor x, Tensor bias, ConvDataFormat dataFormat = ConvDataFormat.channelsLast)
        {
            if (bias.Rank != 1 && bias.Rank != x.Rank)
            {
                throw new Exception(
                          "Unexpected bias dimensions: " + bias.Rank +
                          "; expected it to be 1 or " + x.Rank);
            }

            var    biasShape = bias.Shape;
            Tensor y         = null;

            if (x.Rank == 5)
            {
                if (dataFormat == ConvDataFormat.channelsFirst)
                {
                    if (biasShape.Length == 1)
                    {
                        y = x.add(bias.reshape(shape(1, biasShape[0], 1, 1, 1)));
                    }
                    else
                    {
                        y = x.add(bias.reshape(
                                      shape(1, biasShape[3], biasShape[0], biasShape[1], biasShape[2])));
                    }
                }
                else if (dataFormat == ConvDataFormat.channelsLast)
                {
                    if (biasShape.Length == 1)
                    {
                        y = x.add(bias.reshape(shape(1, 1, 1, 1, biasShape[0])));
                    }
                    else
                    {
                        var nb = new List <int>();
                        nb.Add(1);
                        nb.AddRange(biasShape);
                        y = x.add(bias.reshape(nb.ToArray()));
                    }
                }
            }
            else if (x.Rank == 4)
            {
                if (dataFormat == ConvDataFormat.channelsFirst)
                {
                    if (biasShape.Length == 1)
                    {
                        y = x.add(bias.reshape(shape(1, biasShape[0], 1, 1)));
                    }
                    else
                    {
                        y = x.add(
                            bias.reshape(shape(1, biasShape[2], biasShape[0], biasShape[1])));
                    }
                }
                else if (dataFormat == ConvDataFormat.channelsLast)
                {
                    if (biasShape.Length == 1)
                    {
                        y = x.add(bias.reshape(shape(1, 1, 1, biasShape[0])));
                    }
                    else
                    {
                        var nb = new List <int>();
                        nb.Add(1);
                        nb.AddRange(biasShape);
                        y = x.add(bias.reshape(nb.ToArray()));
                    }
                }
            }
            else if (x.Rank == 3)
            {
                if (dataFormat == ConvDataFormat.channelsFirst)
                {
                    if (biasShape.Length == 1)
                    {
                        y = x.add(bias.reshape(shape(1, biasShape[0], 1)));
                    }
                    else
                    {
                        y = x.add(bias.reshape(shape(1, biasShape[1], biasShape[0])));
                    }
                }
                else if (dataFormat == ConvDataFormat.channelsLast)
                {
                    if (biasShape.Length == 1)
                    {
                        y = x.add(bias.reshape(shape(1, 1, biasShape[0])));
                    }
                    else
                    {
                        var nb = new List <int>();
                        nb.Add(1);
                        nb.AddRange(biasShape);
                        y = x.add(bias.reshape(nb.ToArray()));
                    }
                }
            }
            else if (x.Rank < 3)
            {
                y = x.add(bias);
            }
            else
            {
                throw new Exception("Unsupported input Rank by biasAdd: " + x.Rank.ToString());
            }
            return(y);
        }