Ejemplo n.º 1
0
        /// <inheritdoc/>
        public override Tensor Concat(Tensor[] tensors, int axis)
        {
            if (tensors.Any(x => !x.shape.Is4D()))
            {
                return(base.Concat(tensors, axis));
            }

            var Oshape = TensorExtensions.Concat(tensors, axis);

            axis = Oshape.Axis(axis);
            var     axisNCHW = TensorExtensions.Convert8DAxisTo4D(axis);
            Vector4 offsets  = Vector4.zero;

            Material material = new Material(PixelShaderSingleton.Instance.FindShader("Barracuda/Copy"));

            var O     = NewTensor(Oshape, AllocScope.LayerOutput, "O");
            var Opred = NewTensor(Oshape, AllocScope.LayerOutput, "O");

            bool pingPong    = true;
            bool isFirstPass = true;

            foreach (var inputTensor in tensors)
            {
                Assert.IsTrue(inputTensor.shape.Is4D());

                SetTensor(material, "X", inputTensor);
                SetTensor(material, "OPred", pingPong ? O : Opred);

                material.SetVector("_Pad", offsets);

                material.SetInt("_IsFirstPass", isFirstPass ? 1 : 0);

                var pinO = pingPong ? Pin(Opred) : Pin(O);
                material.SetVector("OdeclShape", new Vector4(O.batch, O.height, O.width, O.channels));

                Graphics.Blit(null, pinO.bufferAsTexture, material);

                offsets[axisNCHW] += inputTensor.shape[axis];

                isFirstPass = false;
                pingPong    = !pingPong;
            }

            return(pingPong ? O : Opred);
        }
Ejemplo n.º 2
0
        public static TensorShape?[] ListTemporaryTensorShapes(Model model, IDictionary <string, TensorShape> inputShapes,
                                                               out IDictionary <string, TensorShape?> shapesByName)
        {
            Profiler.BeginSample("Barracuda.ListTemporaryTensorShapes");
            var shapes = new List <TensorShape?>();

            shapesByName = new Dictionary <string, TensorShape?>();
            foreach (var entry in inputShapes)
            {
                shapesByName.Add(entry.Key, entry.Value);
            }

            TensorShape?Xn;

            shapesByName.TryGetValue(GetDefaultInputName(model), out Xn); // default input
            TensorShape?O = Xn;

            foreach (var l in model.layers)
            {
                if (l.inputs.Length > 0 && shapesByName.ContainsKey(l.inputs[0]))
                {
                    Xn = shapesByName[l.inputs[0]];
                }
                else
                {
                    Xn = O; // previous output is used, if-and-only-if layer has no explicit inputs
                }
                if (Xn == null)
                {
                    shapes.Add(Xn);
                    shapesByName.Add(l.name, Xn);
                    continue;
                }

                TensorShape X = Xn.Value;

                if (l.type == Layer.Type.Dense)
                {
                    Assert.IsNotNull(l.datasets);
                    var W = l.datasets[0].shape;
                    O = new TensorShape(X.flatHeight, W.flatWidth);
                }
                else if (
                    l.type == Layer.Type.Conv2D ||
                    l.type == Layer.Type.DepthwiseConv2D)
                {
                    var K = l.datasets[0].shape;

                    Assert.IsNotNull(l.stride);
                    Assert.IsNotNull(l.pad);
                    var pad = X.AdjustPadToKernel(K, l.stride, l.pad);

                    O = X.ApplyKernel(K, l.stride, pad);
                }
                else if (
                    l.type == Layer.Type.Conv2DTrans)
                {
                    var K = l.datasets[0].shape;
                    Assert.IsNotNull(l.stride);
                    Assert.IsNotNull(l.pad);
                    // pool size is treated as output_adjustment aka output_padding here
                    var outputAdjustment = l.pool;
                    var pad = X.AdjustPadToKernel(K, l.stride, l.pad);
                    O = X.ApplyKernelInverse(K, l.stride, pad, outputAdjustment);
                }
                else if (
                    l.type == Layer.Type.Upsample2D)
                {
                    if (inputShapes.Count > 1)
                    {
                        O = null;
                    }
                    else
                    {
                        // pool size is treated as upsample coefficient here
                        Assert.IsNotNull(l.pool);
                        Assert.AreEqual(l.pool.Length, 2);
                        O = new TensorShape(X.batch, X.height * l.pool[1], X.width * l.pool[0], X.channels);
                    }
                }
                else if (
                    l.type == Layer.Type.Resample2D)
                {
                    if (inputShapes.Count > 1)
                    {
                        O = null;
                    }
                    else
                    {
                        // pool is treated as resample size here
                        var size = l.pool;
                        Assert.IsNotNull(size);
                        Assert.AreEqual(size.Length, 2);
                        O = new TensorShape(X.batch, size[1], size[0], X.channels);
                    }
                }
                else if (
                    l.type == Layer.Type.DepthToSpace)
                {
                    // pool size is treated as blocksize here
                    Assert.IsNotNull(l.pool);
                    Assert.AreEqual(l.pool.Length, 2);
                    Assert.AreEqual(X.channels % (l.pool[0] * l.pool[1]), 0);
                    O = new TensorShape(X.batch, X.height * l.pool[1], X.width * l.pool[0], X.channels / (l.pool[0] * l.pool[1]));
                }
                else if (
                    l.type == Layer.Type.SpaceToDepth)
                {
                    // pool size is treated as blocksize here
                    Assert.IsNotNull(l.pool);
                    Assert.AreEqual(l.pool.Length, 2);
                    O = new TensorShape(X.batch, X.height / l.pool[1], X.width / l.pool[0], X.channels * (l.pool[0] * l.pool[1]));
                }
                else if (
                    l.type == Layer.Type.MaxPool2D ||
                    l.type == Layer.Type.AvgPool2D)
                {
                    Assert.IsNotNull(l.pool);
                    Assert.IsNotNull(l.stride);
                    Assert.IsNotNull(l.pad);
                    var pad = X.AdjustPadToPool(l.pool, l.stride, l.pad);
                    O = X.ApplyPool(l.pool, l.stride, pad);
                }
                else if (
                    l.type == Layer.Type.GlobalMaxPool2D ||
                    l.type == Layer.Type.GlobalAvgPool2D)
                {
                    O = new TensorShape(X.batch, 1, 1, X.channels);
                }
                else if (
                    l.type == Layer.Type.Border2D ||
                    l.type == Layer.Type.Pad2DReflect ||
                    l.type == Layer.Type.Pad2DSymmetric ||
                    l.type == Layer.Type.Pad2DEdge)
                {
                    Assert.IsNotNull(l.pad);
                    O = X.ApplyBorder(l.pad);
                }
                else if (
                    l.type == Layer.Type.Conv3D ||
                    l.type == Layer.Type.Conv3DTrans ||
                    l.type == Layer.Type.Upsample3D ||
                    l.type == Layer.Type.MaxPool3D ||
                    l.type == Layer.Type.AvgPool3D ||
                    l.type == Layer.Type.GlobalMaxPool3D ||
                    l.type == Layer.Type.GlobalAvgPool3D ||
                    l.type == Layer.Type.Border3D)
                {
                    throw new NotImplementedException();
                }
                else if (
                    l.type == Layer.Type.RandomNormal ||
                    l.type == Layer.Type.RandomUniform)
                {
                    Assert.IsNotNull(l.pool);
                    // pool size is treated as shape constant, if not empty
                    // otherwise shape of the previous tensor is used
                    if (l.pool.Length > 0)
                    {
                        O = new TensorShape(l.pool);
                    }
                    else
                    {
                        O = X;
                    }
                }
                else if (
                    l.type == Layer.Type.Multinomial)
                {
                    Assert.IsNotNull(l.pool);
                    Assert.AreEqual(l.pool.Length, 1);
                    O = new TensorShape(X.batch, l.pool[0]);
                }
                else if (
                    l.type == Layer.Type.OneHot)
                {
                    Assert.IsNotNull(l.pool);
                    Assert.AreEqual(l.pool.Length, 1);
                    int features = X.flatWidth;
                    int depth    = l.pool[0];
                    O = new TensorShape(X.batch, 1, features, depth);
                }
                else if (
                    l.type == Layer.Type.Add ||
                    l.type == Layer.Type.Sub ||
                    l.type == Layer.Type.Mul ||
                    l.type == Layer.Type.Div ||
                    l.type == Layer.Type.Pow ||
                    l.type == Layer.Type.Min ||
                    l.type == Layer.Type.Max ||
                    l.type == Layer.Type.Mean ||
                    l.type == Layer.Type.Greater ||
                    l.type == Layer.Type.GreaterEqual ||
                    l.type == Layer.Type.Less ||
                    l.type == Layer.Type.LessEqual ||
                    l.type == Layer.Type.Equal ||
                    l.type == Layer.Type.LogicalOr ||
                    l.type == Layer.Type.LogicalAnd ||
                    l.type == Layer.Type.LogicalXor)
                {
                    // gather shapes by names
                    var  list           = new List <TensorShape>(l.inputs.Length);
                    bool allShapesKnown = true;
                    foreach (var i in l.inputs)
                    {
                        if (!shapesByName.ContainsKey(i))
                        {
                            continue;
                        }

                        TensorShape?shape = shapesByName[i];
                        if (shape == null)
                        {
                            allShapesKnown = false;
                            continue;
                        }

                        list.Add(shapesByName[i].Value);
                    }

                    O = allShapesKnown ? TensorExtensions.Max(list.ToArray()) : default(TensorShape?);
                }
                else if (
                    l.type == Layer.Type.ReduceL1 ||
                    l.type == Layer.Type.ReduceL2 ||
                    l.type == Layer.Type.ReduceLogSum ||
                    l.type == Layer.Type.ReduceLogSumExp ||
                    l.type == Layer.Type.ReduceMax ||
                    l.type == Layer.Type.ReduceMean ||
                    l.type == Layer.Type.ReduceMin ||
                    l.type == Layer.Type.ReduceProd ||
                    l.type == Layer.Type.ReduceSum ||
                    l.type == Layer.Type.ReduceSumSquare)
                {
                    O = X.Reduce(l.axis);
                }
                else if (
                    l.type == Layer.Type.Flatten)
                {
                    O = X.Flatten();
                }
                else if (
                    l.type == Layer.Type.Reshape)
                {
                    // pool size is treated as reshape coefficient, if not empty
                    // otherwise shape of the 2nd input tensor is used
                    var size = l.pool;

                    Assert.IsNotNull(size);


                    if (size.Length == 0 && l.inputs.Length > 1)
                    {
                        if (shapesByName[l.inputs[1]] == null)
                        {
                            O = null;
                            break;
                        }
                        size = shapesByName[l.inputs[1]].Value.ToArray();
                    }

                    Assert.AreEqual(size.Length, 4);
                    O = X.Reshape(size);
                }
                else if (
                    l.type == Layer.Type.Expand)
                {
                    // pool size is treated as new shape
                    var newShape = l.pool;

                    Assert.IsNotNull(newShape);
                    Assert.AreEqual(newShape.Length, 4);

                    O = new TensorShape(newShape);
                }
                else if (
                    l.type == Layer.Type.Transpose)
                {
                    O = new TensorShape(X.flatWidth, X.flatHeight);
                }
                else if (
                    l.type == Layer.Type.Gather)
                {
                    if (shapesByName[l.inputs[0]] == null || shapesByName[l.inputs[1]] == null)
                    {
                        O = null;
                        break;
                    }
                    int[] shape = shapesByName[l.inputs[0]].Value.ToArray();
                    shape[l.axis] = shapesByName[l.inputs[1]].Value.flatWidth;

                    O = new TensorShape(shape);
                }
                else if (
                    l.type == Layer.Type.Squeeze ||
                    l.type == Layer.Type.Unsqueeze)
                {
                    throw new NotImplementedException();
                }
                else if (
                    l.type == Layer.Type.Concat)
                {
                    // gather shapes by names
                    var  list           = new List <TensorShape>(l.inputs.Length);
                    bool allShapesKnown = true;
                    foreach (var i in l.inputs)
                    {
                        if (!shapesByName.ContainsKey(i))
                        {
                            continue;
                        }
                        if (shapesByName[i] == null)
                        {
                            allShapesKnown = false;
                            continue;
                        }
                        list.Add(shapesByName[i].Value);
                    }

                    O = allShapesKnown ? TensorExtensions.Concat(list.ToArray(), l.axis) : default(TensorShape?);
                }
                else if (
                    l.type == Layer.Type.StridedSlice)
                {
                    Assert.IsNotNull(l.pad);
                    Assert.IsNotNull(l.pool);
                    Assert.IsNotNull(l.stride);
                    O = X.ApplyStridedSlice(l.pad, l.pool, l.stride);
                }
                else if (
                    l.type == Layer.Type.Tile)
                {
                    // pool size is treated as tiling coefficient here
                    Assert.IsNotNull(l.pool);
                    Assert.AreEqual(l.pool.Length, 4);
                    var scale = l.pool;
                    O = X.Scale(scale);
                }
                else if (
                    l.type == Layer.Type.Load)
                {
                    O = l.datasets[0].shape;
                }
                else if (// elementwise operations
                    l.type == Layer.Type.Nop ||
                    l.type == Layer.Type.Activation ||
                    l.type == Layer.Type.ScaleBias ||
                    l.type == Layer.Type.Normalization ||
                    l.type == Layer.Type.LRN ||
                    l.type == Layer.Type.Dropout ||
                    l.type == Layer.Type.LogicalNot ||
                    l.activation == Layer.Activation.PRelu)
                {
                    // works in place, keeps the same shape size
                    O = X;
                }
                else if (
                    l.type == Layer.Type.Conv3D ||
                    l.type == Layer.Type.Conv3DTrans ||
                    l.type == Layer.Type.Upsample3D ||
                    l.type == Layer.Type.MaxPool3D ||
                    l.type == Layer.Type.AvgPool3D ||
                    l.type == Layer.Type.GlobalMaxPool3D ||
                    l.type == Layer.Type.GlobalAvgPool3D ||
                    l.type == Layer.Type.Border3D)
                {
                    throw new NotImplementedException("3D operations are not implemented yet!");
                }
                else
                {
                    Assert.AreEqual(l.activation, Layer.Activation.None);
                    O = X;
                }

                shapes.Add(O);
                shapesByName.Add(l.name, O);
            }

            Profiler.EndSample();
            return(shapes.ToArray());
        }
Ejemplo n.º 3
0
        public static TensorShape?[] ListTemporaryTensorShapes(Model model, IDictionary <string, TensorShape> inputShapes,
                                                               out IDictionary <string, TensorShape?> shapesByName)
        {
            Profiler.BeginSample("Barracuda.ListTemporaryTensorShapes");
            var shapes = new List <TensorShape?>();

            shapesByName = new Dictionary <string, TensorShape?>();
            foreach (var entry in inputShapes)
            {
                shapesByName.Add(entry.Key, entry.Value);
            }

            TensorShape?Xn;

            shapesByName.TryGetValue(GetDefaultInputName(model), out Xn); // default input
            TensorShape?O = Xn;

            foreach (var l in model.layers)
            {
                if (l.inputs.Length > 0 && shapesByName.TryGetValue(l.inputs[0], out TensorShape? xShape))
                {
                    Xn = xShape;
                }
                else
                {
                    Xn = O; // previous output is used, if-and-only-if layer has no explicit inputs
                }
                if (Xn == null)
                {
                    shapes.Add(Xn);
                    shapesByName.Add(l.name, Xn);
                    continue;
                }

                TensorShape X = Xn.Value;

                if (l.type == Layer.Type.Dense)
                {
                    Assert.IsNotNull(l.datasets);
                    var W = l.datasets[0].shape;
                    O = new TensorShape(X.flatHeight, W.flatWidth);
                }
                else if (l.type == Layer.Type.Dense3)
                {
                    Assert.IsNotNull(l.datasets);
                    var W = l.datasets[0].shape;
                    O = new TensorShape(X.batch, 1, W.channels, X.channels);
                }
                else if (l.type == Layer.Type.MatMul)
                {
                    if (!shapesByName.ContainsKey(l.inputs[1]) || shapesByName[l.inputs[1]] == null)
                    {
                        O = null;
                        break;
                    }

                    var Y = shapesByName[l.inputs[1]].Value;

                    int        rankX;
                    int        rankY;
                    List <int> onnxXshape;
                    List <int> onnxYshape;

                    if (l.pool == null || l.pool.Length == 0)
                    {
                        LegacyGetXYRanks(X, Y, out rankX, out rankY);
                    }
                    else
                    {
                        rankX = l.pool[0];
                        rankY = l.pool[1];
                    }

                    onnxXshape = Compiler.IRShapeInferenceHelper.ShapeInference.BarracudaShapeToOnnxLayout(X, rankX);
                    onnxYshape = Compiler.IRShapeInferenceHelper.ShapeInference.BarracudaShapeToOnnxLayout(Y, rankY);

                    int rankO = Math.Max(rankX, rankY);

                    // pad 1 on front of shape to both be rankO shape
                    for (int i = 0; i < (rankX - rankY); i++)
                    {
                        onnxYshape.Insert(0, 1);
                    }

                    for (int i = 0; i < (rankY - rankX); i++)
                    {
                        onnxXshape.Insert(0, 1);
                    }

                    if (rankO == 2)
                    {
                        O = new TensorShape(onnxXshape[0], 1, 1, onnxYshape[1]);
                    }
                    else if (rankO == 3)
                    {
                        O = new TensorShape(Math.Max(onnxXshape[0], onnxYshape[0]), 1, onnxYshape[2], onnxXshape[1]);
                    }
                    else
                    {
                        O = new TensorShape(Math.Max(onnxXshape[0], onnxYshape[0]), onnxXshape[2], onnxYshape[3], Math.Max(onnxXshape[1], onnxYshape[1]));
                    }
                }
                else if (
                    l.type == Layer.Type.Conv2D ||
                    l.type == Layer.Type.Conv3D ||
                    l.type == Layer.Type.DepthwiseConv2D)
                {
                    var K = l.datasets[0].shape;

                    Assert.IsNotNull(l.stride);
                    Assert.IsNotNull(l.pad);
                    var pad = X.AdjustPadToKernel(K, l.stride, l.pad);

                    O = X.ApplyKernel(K, l.stride, pad);
                }
                else if (
                    l.type == Layer.Type.Conv2DTrans)
                {
                    var K = l.datasets[0].shape;
                    Assert.IsNotNull(l.stride);
                    Assert.IsNotNull(l.pad);
                    // pool size is treated as output_adjustment aka output_padding here
                    var outputAdjustment = l.pool;
                    var pad = X.AdjustPadToKernel(K, l.stride, l.pad);
                    O = X.ApplyKernelInverse(K, l.stride, pad, outputAdjustment);
                }
                else if (
                    l.type == Layer.Type.Upsample2D)
                {
                    if (inputShapes.Count > 1)
                    {
                        O = null;
                    }
                    else
                    {
                        // pool size is treated as upsample coefficient here
                        Assert.IsNotNull(l.pool);
                        Assert.AreEqual(l.pool.Length, 2);
                        O = new TensorShape(X.batch, X.height * l.pool[1], X.width * l.pool[0], X.channels);
                    }
                }
                else if (
                    l.type == Layer.Type.Upsample3D)
                {
                    if (inputShapes.Count > 1)
                    {
                        O = null;
                    }
                    else
                    {
                        // pool size is treated as upsample coefficient here
                        Assert.IsNotNull(l.pool);
                        Assert.AreEqual(l.pool.Length, 3);
                        O = new TensorShape(1, 1, X.batch, 1, X.depth * l.pool[2], X.height * l.pool[1], X.width * l.pool[0], X.channels);
                    }
                }
                else if (
                    l.type == Layer.Type.Resample2D)
                {
                    if (inputShapes.Count > 1)
                    {
                        O = null;
                    }
                    else
                    {
                        // pool is treated as resample size here
                        var size = l.pool;
                        Assert.IsNotNull(size);
                        Assert.AreEqual(size.Length, 2);
                        O = new TensorShape(X.batch, size[1], size[0], X.channels);
                    }
                }
                else if (
                    l.type == Layer.Type.DepthToSpace)
                {
                    // pool size is treated as blocksize here
                    Assert.IsNotNull(l.pool);
                    Assert.AreEqual(l.pool.Length, 2);
                    Assert.AreEqual(X.channels % (l.pool[0] * l.pool[1]), 0);
                    O = new TensorShape(X.batch, X.height * l.pool[1], X.width * l.pool[0], X.channels / (l.pool[0] * l.pool[1]));
                }
                else if (
                    l.type == Layer.Type.SpaceToDepth)
                {
                    // pool size is treated as blocksize here
                    Assert.IsNotNull(l.pool);
                    Assert.AreEqual(l.pool.Length, 2);
                    O = new TensorShape(X.batch, X.height / l.pool[1], X.width / l.pool[0], X.channels * (l.pool[0] * l.pool[1]));
                }
                else if (
                    l.type == Layer.Type.MaxPool2D ||
                    l.type == Layer.Type.AvgPool2D)
                {
                    Assert.IsNotNull(l.pool);
                    Assert.IsNotNull(l.stride);
                    Assert.IsNotNull(l.pad);
                    var pad = X.AdjustPadToPool(l.pool, l.stride, l.pad);
                    O = X.ApplyPool(l.pool, l.stride, pad);
                }
                else if (
                    l.type == Layer.Type.GlobalMaxPool2D ||
                    l.type == Layer.Type.GlobalAvgPool2D)
                {
                    O = new TensorShape(X.batch, 1, 1, X.channels);
                }
                else if (
                    l.type == Layer.Type.Border2D ||
                    l.type == Layer.Type.Border3D ||
                    l.type == Layer.Type.Pad2DReflect ||
                    l.type == Layer.Type.Pad2DSymmetric ||
                    l.type == Layer.Type.Pad2DEdge)
                {
                    Assert.IsNotNull(l.pad);
                    O = X.ApplyBorder(l.pad);
                }
                else if (
                    l.type == Layer.Type.Conv3D ||
                    l.type == Layer.Type.Conv3DTrans ||
                    l.type == Layer.Type.Upsample3D ||
                    l.type == Layer.Type.MaxPool3D ||
                    l.type == Layer.Type.AvgPool3D ||
                    l.type == Layer.Type.GlobalMaxPool3D ||
                    l.type == Layer.Type.GlobalAvgPool3D ||
                    l.type == Layer.Type.Border3D)
                {
                    throw new NotImplementedException();
                }
                else if (
                    l.type == Layer.Type.RandomNormal ||
                    l.type == Layer.Type.RandomUniform)
                {
                    Assert.IsNotNull(l.pool);
                    // pool size is treated as shape constant, if not empty
                    // otherwise shape of the previous tensor is used
                    if (l.pool.Length > 0)
                    {
                        O = new TensorShape(l.pool);
                    }
                    else
                    {
                        O = X;
                    }
                }
                else if (l.type == Layer.Type.ConstantOfShape)
                {
                    if (l.axis != 1)
                    {
                        O = null;
                    }
                    else
                    {
                        O = X;
                    }
                }
                else if (
                    l.type == Layer.Type.Multinomial)
                {
                    Assert.IsNotNull(l.pool);
                    Assert.AreEqual(l.pool.Length, 1);
                    O = new TensorShape(X.batch, l.pool[0]);
                }
                else if (
                    l.type == Layer.Type.OneHot)
                {
                    Assert.IsNotNull(l.pool);
                    Assert.AreEqual(l.pool.Length, 1);
                    int features = X.flatWidth;
                    int depth    = l.pool[0];

                    if (X.flatWidth == 1) // 1D input
                    {
                        O = new TensorShape(X.batch, depth);
                    }
                    else
                    {
                        O = new TensorShape(X.batch, 1, depth, features);
                    }
                }
                else if (
                    l.type == Layer.Type.Add ||
                    l.type == Layer.Type.Sub ||
                    l.type == Layer.Type.Mul ||
                    l.type == Layer.Type.Div ||
                    l.type == Layer.Type.Pow ||
                    l.type == Layer.Type.Min ||
                    l.type == Layer.Type.Max ||
                    l.type == Layer.Type.Mean ||
                    l.type == Layer.Type.Greater ||
                    l.type == Layer.Type.GreaterEqual ||
                    l.type == Layer.Type.Less ||
                    l.type == Layer.Type.LessEqual ||
                    l.type == Layer.Type.Equal ||
                    l.type == Layer.Type.LogicalOr ||
                    l.type == Layer.Type.LogicalAnd ||
                    l.type == Layer.Type.LogicalXor)
                {
                    // gather shapes by names
                    var  list           = new List <TensorShape>(l.inputs.Length);
                    bool allShapesKnown = true;
                    foreach (var i in l.inputs)
                    {
                        if (shapesByName.TryGetValue(i, out TensorShape? shape) && shape != null)
                        {
                            list.Add(shape.Value);
                        }
                        else
                        {
                            allShapesKnown = false;
                        }
                    }

                    O = allShapesKnown ? TensorExtensions.Max(list.ToArray()) : default(TensorShape?);
                }
                else if (
                    l.type == Layer.Type.ReduceL1 ||
                    l.type == Layer.Type.ReduceL2 ||
                    l.type == Layer.Type.ReduceLogSum ||
                    l.type == Layer.Type.ReduceLogSumExp ||
                    l.type == Layer.Type.ReduceMax ||
                    l.type == Layer.Type.ReduceMean ||
                    l.type == Layer.Type.ReduceMin ||
                    l.type == Layer.Type.ReduceProd ||
                    l.type == Layer.Type.ReduceSum ||
                    l.type == Layer.Type.ReduceSumSquare ||
                    l.type == Layer.Type.ArgMax ||
                    l.type == Layer.Type.ArgMin)
                {
                    O = X.Reduce(l.axis);
                }
                else if (
                    l.type == Layer.Type.Flatten)
                {
                    O = X.Flatten();
                }
                else if (
                    l.type == Layer.Type.Reshape)
                {
                    // pool size is treated as the shape, if not empty
                    var size = l.pool;

                    Assert.IsNotNull(size);

                    if (size.Length == 0 && l.inputs.Length > 1)
                    {
                        switch (l.axis)
                        {
                        // Legacy - use the shape of the input tensor as the shape
                        case -1:
                            if (shapesByName.TryGetValue(l.inputs[1], out TensorShape? shape))
                            {
                                size = shape.Value.ToArray();
                            }
                            break;

                        // Use the tensor values as the shape; Calculated at runtime
                        case 1:
                            O = null;
                            break;
                        }

                        if (O == null)
                        {
                            break;
                        }
                    }

                    Assert.IsTrue((size.Length == 4) || (size.Length == 8));
                    O = X.Reshape(size);
                }
                else if (
                    l.type == Layer.Type.Expand)
                {
                    // pool size is treated as new shape
                    var newShape = l.pool;

                    Assert.IsNotNull(newShape);
                    Assert.IsTrue(newShape.Length == 8 || newShape.Length == 4);

                    O = new TensorShape(newShape);
                }
                else if (
                    l.type == Layer.Type.Transpose)
                {
                    var permutations = l.pool;
                    if (permutations == null)
                    {
                        O = new TensorShape(X.flatWidth, X.flatHeight);
                    }
                    else
                    {
                        Assert.IsTrue(permutations.Length == 8 || permutations.Length == 4);
                        O = X.Permute(permutations);
                    }
                }
                else if (
                    l.type == Layer.Type.Gather)
                {
                    if (!shapesByName.TryGetValue(l.inputs[0], out TensorShape? input0Shape) || input0Shape == null ||
                        !shapesByName.TryGetValue(l.inputs[1], out TensorShape? input1Shape) || input1Shape == null)
                    {
                        O = null;
                        break;
                    }

                    int[] shape = input0Shape.Value.ToArray();
                    shape[l.axis] = input1Shape.Value.length;

                    O = new TensorShape(shape);
                }
                else if (
                    l.type == Layer.Type.Squeeze ||
                    l.type == Layer.Type.Unsqueeze)
                {
                    O = X;
                }
                else if (
                    l.type == Layer.Type.Concat)
                {
                    // gather shapes by names
                    var  list           = new List <TensorShape>(l.inputs.Length);
                    bool allShapesKnown = true;
                    foreach (var i in l.inputs)
                    {
                        if (!shapesByName.TryGetValue(i, out var shape) || shape == null)
                        {
                            allShapesKnown = false;
                            continue;
                        }
                        list.Add(shape.Value);
                    }

                    O = allShapesKnown ? TensorExtensions.Concat(list.ToArray(), l.axis) : default(TensorShape?);
                }
                else if (
                    l.type == Layer.Type.StridedSlice)
                {
                    Assert.IsNotNull(l.pad);
                    Assert.IsNotNull(l.pool);
                    Assert.IsNotNull(l.stride);
                    O = X.ApplyStridedSlice(l.pad, l.pool, l.stride);
                }
                else if (
                    l.type == Layer.Type.Tile)
                {
                    // pool size is treated as tiling coefficient here
                    Assert.IsNotNull(l.pool);
                    var scale = l.pool;
                    O = X.Scale(scale);
                }
                else if (
                    l.type == Layer.Type.Load)
                {
                    O = l.datasets[0].shape;
                }
                else if (// elementwise operations
                    l.type == Layer.Type.Nop ||
                    l.type == Layer.Type.Activation ||
                    l.type == Layer.Type.ScaleBias ||
                    l.type == Layer.Type.Normalization ||
                    l.type == Layer.Type.LRN ||
                    l.type == Layer.Type.Dropout ||
                    l.type == Layer.Type.LogicalNot ||
                    l.type == Layer.Type.Sign ||
                    l.type == Layer.Type.Where)
                {
                    // works in place, keeps the same shape size
                    O = X;
                }
                else if (
                    l.type == Layer.Type.TopKIndices ||
                    l.type == Layer.Type.TopKValues ||
                    l.type == Layer.Type.NonMaxSuppression ||
                    l.type == Layer.Type.LSTM ||
                    l.type == Layer.Type.NonZero)
                {
                    // Calculated at runtime
                    O = null;
                }
                else if (l.type == Layer.Type.Shape)
                {
                    int shapeRank = l.axis > 0 ? 1 : X.length;
                    O = new TensorShape(shapeRank, 1, 1, 1);
                }
                else if (
                    l.type == Layer.Type.Conv3D ||
                    l.type == Layer.Type.Conv3DTrans ||
                    l.type == Layer.Type.Upsample3D ||
                    l.type == Layer.Type.MaxPool3D ||
                    l.type == Layer.Type.AvgPool3D ||
                    l.type == Layer.Type.GlobalMaxPool3D ||
                    l.type == Layer.Type.GlobalAvgPool3D ||
                    l.type == Layer.Type.Border3D)
                {
                    throw new NotImplementedException("3D operations are not implemented yet!");
                }
                else
                {
                    throw new NotImplementedException($"Layer type {l.type} needs to be explicitly handled");
                }

                shapes.Add(O);
                shapesByName.Add(l.name, O);
            }

            Profiler.EndSample();
            return(shapes.ToArray());
        }