bool TransposeToBarracuda(Layer layer, ModelBuilder net)
        {
            string input0 = layer.inputs[0];

            Model.Input input0Info = net.model.inputs.First(i => i.name == layer.inputs[0]);

            var onnxTranspose = layer.pool;

            var rank = input0Info.rank;

            switch (rank)
            {
            case 2:
            {
                // onnx : 5,7 => 5,7,1,1 / 7,5
                layer.pool = new[] { layer.pool[0], layer.pool[1], 2, 3 };
                return(true);
            }

            case 3:
            {
                // onnx : 5,7,3 => 5,7,3,1 / 7,5,3,1 / 7,3,5,1 ...
                layer.pool = new[] { layer.pool[0], layer.pool[1], layer.pool[2], 3 };
                return(true);
            }

            case 4:
            {
                return(true);
            }

            default:
                throw new ArgumentException($"Unsupported transpose");
            }
        }
        bool AxisToBarracuda(Layer layer, ModelBuilder net)
        {
            string input0 = layer.inputs[0];

            Model.Input input0Info = net.model.inputs.First(i => i.name == layer.inputs[0]);

            var onnxRank = input0Info.rank;

            if (layer.axis < 0)
            {
                layer.axis += onnxRank;
            }

            switch (onnxRank)
            {
            case 6:
                layer.axis += 2;
                break;

            case 5:
                layer.axis = layer.axis + (layer.axis == 0 ? 2 : 3);
                break;

            default:
                layer.axis = layer.axis + (layer.axis == 0 ? 2 : 4);
                break;
            }

            return(true);
        }
Example #3
0
        public void GetInputFromFile()
        {
            string       workspaceDir = Directory.GetParent(Directory.GetParent(Directory.GetParent(Directory.GetCurrentDirectory().ToString()).ToString()).ToString()).ToString();
            string       testDataPath = Directory.GetDirectories(workspaceDir, "TestData")[0];
            string       inputPath    = Directory.GetDirectories(testDataPath, "Input")[0];
            StreamReader sr           = new StreamReader(Directory.GetFiles(inputPath, "B.csv")[0]);

            Model.Input input = Input.Parse(sr);
            Assert.Equal(18, input.Nodes.Count);
            Assert.Equal(17, input.ReactionInputs[3].Number);
            Assert.Equal(10, input.FrameElements[3].Ax);
            Assert.True(input.LoadCases[0].UniformLoads[1].Load.Y + 1.1 < 0.0001);
        }
        bool GatherToBarracuda(Layer layer, ModelBuilder net)
        {
            string input0 = layer.inputs[0];

            Model.Input input0Info = net.model.inputs.First(i => i.name == layer.inputs[0]);

            string input1 = layer.inputs[1];

            Model.Input input1Info = net.model.inputs.First(i => i.name == layer.inputs[1]);

            layer.pool = new[] { input0Info.rank, input1Info.rank };

            return(AxisToBarracuda(layer, net));
        }
        bool Transpose0UsingRank(Layer layer, ModelBuilder net)
        {
            string input0 = layer.inputs[0];

            Model.Input input0Info = net.model.inputs.First(i => i.name == layer.inputs[0]);

            Layer input0Transposed = net.Transpose($"Transpose_For_{input0}", input0, input0Info.rank == 3 ? k_FromNCHtoN1WC : k_ToNHWC);

            // Most of the layer stays intact
            string originalLayerName = layer.name;

            layer.name      = $"{layer.name}_NHWC";
            layer.inputs[0] = input0Transposed.name;
            net.model.layers.Add(layer);

            net.Transpose(originalLayerName, layer.name, input0Info.rank == 3 ? k_FromN1WCtoNCH : k_ToNCHW);

            return(false);
        }
Example #6
0
        void Rewrite(ref Model model)
        {
            IRShapeInferenceHelper.RankInference.ListTemporaryTensorRanks(model, out m_RanksByName);
            var inputShapes = new Dictionary <string, TensorShape>();

            foreach (var i in model.inputs)
            {
                if (!ModelAnalyzer.IsInputShapeAcceptablyKnowForShapeInference(i))
                {
                    continue;
                }
                inputShapes.Add(i.name, new TensorShape(i.shape));
            }

            IRShapeInferenceHelper.ShapeInference.ListTemporaryTensorShapesNCHW(model, inputShapes, ref m_RanksByName, out m_ShapesByName);

            var nhwc = model.ShallowCopy();

            nhwc.layers.Clear();
            nhwc.layout = "NHWC";

            // TF2ONNX transpose pattern -> part of the model are in NHWC and not NCHW
            // * identify those
            // * transpose inputs to NCHW
            // * remove layout transposes
            // * convert axis/constants accordingly
            LayoutTransposeRemovalHelper transposeRemoval = new LayoutTransposeRemovalHelper();

            m_isModelExportedFromNHWC = transposeRemoval.InferAllLayersChannelOrder(model, out m_layersChannelOrder);

            if (m_isModelExportedFromNHWC && !transposeRemoval.IsImporterLikelyNHWCLayout(model.ProducerName))
            {
                nhwc.Warnings.Add(new Model.ImporterWarning("model", "model detected as NCHW, but not natively in this layout, behavior might be erroneous"));
            }

            // remove layout change transposes
            if (m_isModelExportedFromNHWC)
            {
                transposeRemoval.RemoveAllChannelLayoutTransposes(ref model, m_layersChannelOrder);
            }

            var modelBuilder = new ModelBuilder(nhwc);

            for (int i = 0; i < nhwc.inputs.Count; i++)
            {
                Model.Input input = nhwc.inputs[i];

                int[] shape            = input.shape;
                var   tensorShape      = new TensorShape(shape);
                int[] rankPermutations = GetChannelsLastPermutationsFromRank(input.rank);
                int[] permutations     = tensorShape.Get8DPermutationsForNCHWPermutationsAndShape(rankPermutations);

                // Preserve symbolic shape by operating on int array instead of TensorShape, which would resolve unknown dimensions
                if (m_isModelExportedFromNHWC) // transpose input shape if importer preserved NHWC layout
                {
                    if (m_layersChannelOrder[input.name] == LayoutTransposeRemovalHelper.ChannelsOrder.NCHW)
                    {
                        input.shape = TensorExtensions.Permute(shape, permutations);
                    }
                    else
                    {
                        var onnxShape = new List <int> {
                            shape[2], shape[5], shape[6], shape[7]
                        };
                        onnxShape.RemoveRange(input.rank, 4 - input.rank);
                        input.shape = IRShapeInferenceHelper.ShapeInference.BarracudaLayoutToTensorShapeLayout(onnxShape.ToArray());
                    }
                }
                else
                {
                    input.shape = TensorExtensions.Permute(shape, permutations);
                }
                nhwc.inputs[i] = input;
            }

            // NCHW -> Barracuda NHWC rewriter (some layer need to insert aditional layers to be Barracuda compatible)
            var rewriters = InstantiateRewriterNCHWToNHWC();
            // NHWC -> Barracuda NHWC rewriter (axis and constant padding padding)
            var rewritersNHWC = InstantiateRewriterNHWCToNHWC();


            foreach (var l in model.layers)
            {
                // Some nodes output multiple layers (e.g. LSTM), so don't process or include those layers
                if (nhwc.layers.Exists(alreadyOutputLayer => alreadyOutputLayer.name == l.name))
                {
                    continue;
                }

                if (m_layersChannelOrder.TryGetValue(l.name, out LayoutTransposeRemovalHelper.ChannelsOrder layerChannelOrder))
                {
                    if (m_isModelExportedFromNHWC && (layerChannelOrder == LayoutTransposeRemovalHelper.ChannelsOrder.NHWC))
                    {
                        if (!rewritersNHWC.TryGetValue(l.type, out Func <Layer, ModelBuilder, bool> rwNCHW) || rwNCHW(l, modelBuilder))
                        {
                            nhwc.layers.Add(l);
                        }
                        continue;
                    }
                }

                if (!rewriters.TryGetValue(l.type, out Func <Layer, ModelBuilder, bool> rw) || rw(l, modelBuilder))
                {
                    // Either no re-write was needed or the layer was not replaced
                    nhwc.layers.Add(l);
                }
            }

            // We need to correct constants to have broadcast work correctly
            // ONNX: 1,64,32 + c:32
            // Barracuda: 1,_32,64 + c:_,_,32,64 and not c:32,_,_,_
            // X:5,7 + c: 6,9,5,7 = 6,9,5,7
            // X: 5,_,_,7 + c: 6,5,7,9 = ???
            CorrectConstantsForBroadCast(ref nhwc);
            CorrectDynamicInputsForBroadCast(ref nhwc);

            // for NHWC importers, perform slightly more aggressive output shape check
            // => add transposes to match onnx layout
            if (transposeRemoval.IsImporterLikelyNHWCLayout(model.ProducerName))
            {
                CorrectOutputLayoutToMatchNHWCLayout(ref nhwc);
            }

            model = nhwc;
        }
Example #7
0
 public InputViewModel()
 {
     NhapKhoList = new ObservableCollection <NhapKho>();
     Input       = new ObservableCollection <Model.Input>(DataProvider.Ins.DB.Inputs);
     Object      = new ObservableCollection <Model.Object>(DataProvider.Ins.DB.Objects);
     _DateInput  = DateTime.Now;
     _Count      = 0;
     _InputPrice = 0;
     #region Add
     AddCommand = new RelayCommand <object>((p) =>
     {
         foreach (var i in NhapKhoList)
         {
             if (i.Object.Id == SelectedObject.Id)
             {
                 return(false);
             }
         }
         if (SelectedObject == null || DateInput == null || Count < 0 || InputPrice < 0)
         {
             return(false);
         }
         return(true);
     }, (p) =>
     {
         var inputInfo = new Model.NhapKho()
         {
             Object = SelectedObject, Count = Count, InputPrice = InputPrice, Status = Status, Id = Guid.NewGuid().ToString()
         };
         NhapKhoList.Add(inputInfo);
     });
     #endregion
     #region Lưu
     SaveCommand = new RelayCommand <object>((p) =>
     {
         foreach (var i in NhapKhoList)
         {
             if (i.Count <= 0 || i.InputPrice <= 0)
             {
                 return(false);
             }
         }
         if (NhapKhoList.Count() == 0)
         {
             return(false);
         }
         return(true);
     }, (p) =>
     {
         if (MessageBox.Show("Bạn có thật sự muốn lưu?", "Thông báo", MessageBoxButton.OKCancel, MessageBoxImage.Question) != MessageBoxResult.OK)
         {
             return;
         }
         else
         {
             var Inputs = new Model.Input()
             {
                 Id = Guid.NewGuid().ToString(), DateInput = DateInput
             };
             DataProvider.Ins.DB.Inputs.Add(Inputs);
             foreach (var item in NhapKhoList)
             {
                 var InputInfo = new Model.InputInfo()
                 {
                     IdObject = item.Object.Id, Count = item.Count, IdInput = Inputs.Id, InputPrice = item.InputPrice, Status = item.Status, Id = item.Id
                 };
                 DataProvider.Ins.DB.InputInfoes.Add(InputInfo);
                 DataProvider.Ins.DB.SaveChanges();
             }
             MessageBox.Show("Lưu thành công!");
             NhapKhoList.Clear();
             Reset();
         }
     });
     #endregion
     #region Delete
     DeleteCommand = new RelayCommand <object>((p) =>
     {
         return(true);
     }, (p) =>
     {
         if (MessageBox.Show("Bạn có thật sự muốn tạo lại?", "Thông báo", MessageBoxButton.OKCancel, MessageBoxImage.Warning) != MessageBoxResult.OK)
         {
             return;
         }
         else
         {
             NhapKhoList.Clear();
             Reset();
         }
     });
     #endregion
     TKCommand = new RelayCommand <object>((p) => { return(true); }, (p) => { TKInputWindow wd = new TKInputWindow(); TKInputViewModel vm = new TKInputViewModel(); wd.DataContext = vm; wd.ShowDialog(); });
 }
        static int[] SqueezeAxisPermutation(int rank, int axis)
        {
            var identity = new[] { 0, 1, 2, 3 };

            if (rank == 5)
            {
                //            axis:    0        1        2        3        4
                // ONNX:      NCDHW    CDHW     NDHW     NCHW     NCDW     NCDH
                // { 0,1,2,3,4,5,6,7}
                //   _,_,N,_,C,D,H,W
                if (axis == 0)
                {
                    return new[] { 0, 1, 4, 3, 5, 6, 7, 2 }
                }
                ;
                if (axis == 1)
                {
                    return new[] { 0, 1, 2, 3, 5, 6, 7, 4 }
                }
                ;
                if (axis == 2)
                {
                    return new[] { 0, 1, 2, 3, 4, 6, 7, 5 }
                }
                ;
                if (axis == 3)
                {
                    return new[] { 0, 1, 2, 3, 4, 5, 7, 6 }
                }
                ;

                return(new[] { 0, 1, 2, 3, 4, 5, 6, 7 });
            }
            if (rank == 4)
            {
                //            axis:   0       1      2      3
                // ONNX:      NCHW    CHW     NHW    NCW    NCH
                if (axis == 0)
                {
                    return new[] { 1, 2, 3, 0 }
                }
                ;
                if (axis == 1)
                {
                    return new[] { 0, 2, 3, 1 }
                }
                ;
                if (axis == 2)
                {
                    return new[] { 0, 1, 3, 2 }
                }
                ;

                return(identity);
            }
            if (rank == 3)
            {
                //            axis:   0       1      2
                // ONNX:      NCH     CH      NH     NC
                if (axis == 0)
                {
                    return new[] { 1, 2, 0, 3 }
                }
                ;
                if (axis == 1)
                {
                    return new[] { 0, 2, 1, 3 }
                }
                ;

                return(identity);
            }
            if (rank == 2)
            {
                //            axis:   0       1
                // ONNX:      NC      C       N
                if (axis == 0)
                {
                    return new[] { 1, 0, 2, 3 }
                }
                ;

                return(identity);
            }
            if (rank == 1)
            {
                return(identity);
            }

            throw new InvalidOperationException($"Not supported Squeeze operation with rank {rank}");
        }

        bool SliceToBarracuda(Layer layer, ModelBuilder net)
        {
            string input0 = layer.inputs[0];

            Model.Input input0Info = net.model.inputs.First(i => i.name == layer.inputs[0]);
            int         rank       = input0Info.rank;

            var starts = layer.pad;
            var ends   = layer.pool;
            var steps  = layer.stride;
            var axes   = layer.axes;

            var onnxStarts = Enumerable.Repeat(0, rank).ToArray();
            var onnxEnds   = Enumerable.Repeat(int.MaxValue, rank).ToArray(); // by default copy the whole axis till the end
            var onnxSteps  = Enumerable.Repeat(1, rank).ToArray();

            // NOTE: begin=0, end=0, stride=1  <=  full range from existing axis
            //       begin=0, end=inf,stride=1 <=  full range from existing axis
            //       begin=0, end=X, stride=1  <=  full range from existing axis, if X==last element on this axis
            //       begin=0, end=0, stride=0  <=  new axis OR shrink axis to single 1st element
            //       begin=N, end=N, stride=0  <=              shrink axis to single Nth element
            // These notes are copied from TensorExtensions.ApplyStridedSlice(...)

            for (int i = 0; i < axes.Length; ++i)
            {
                var axis = axes[i];
                if (axis < 0)
                {
                    axis += rank;
                }
                axis = Math.Min(Math.Max(axis, 0), rank);

                onnxStarts[axis] = starts[i];
                onnxEnds[axis]   = ends[i];
                onnxSteps[axis]  = steps[i];
            }

            switch (rank)
            {
            case 1:
                layer.pad    = new[] { 0, 0, onnxStarts[0], 0, 0, 0, 0, 0 };
                layer.pool   = new[] { int.MaxValue, int.MaxValue, onnxEnds[0], int.MaxValue, int.MaxValue, int.MaxValue, int.MaxValue, int.MaxValue };
                layer.stride = new[] { 1, 1, onnxSteps[0], 1, 1, 1, 1, 1 };
                break;

            case 2:
                layer.pad    = new[] { 0, 0, onnxStarts[0], 0, 0, onnxStarts[1], 0, 0 };
                layer.pool   = new[] { int.MaxValue, int.MaxValue, onnxEnds[0], int.MaxValue, int.MaxValue, onnxEnds[1], int.MaxValue, int.MaxValue };
                layer.stride = new[] { 1, 1, onnxSteps[0], 1, 1, onnxSteps[1], 1, 1 };
                break;

            case 3:
                layer.pad    = new[] { 0, 0, onnxStarts[0], 0, 0, onnxStarts[1], onnxStarts[2], 0 };
                layer.pool   = new[] { int.MaxValue, int.MaxValue, onnxEnds[0], int.MaxValue, int.MaxValue, onnxEnds[1], onnxEnds[2], int.MaxValue };
                layer.stride = new[] { 1, 1, onnxSteps[0], 1, 1, onnxSteps[1], onnxSteps[2], 1 };
                break;

            case 4:
                layer.pad    = new[] { 0, 0, onnxStarts[0], 0, 0, onnxStarts[1], onnxStarts[2], onnxStarts[3] };
                layer.pool   = new[] { int.MaxValue, int.MaxValue, onnxEnds[0], int.MaxValue, int.MaxValue, onnxEnds[1], onnxEnds[2], onnxEnds[3] };
                layer.stride = new[] { 1, 1, onnxSteps[0], 1, 1, onnxSteps[1], onnxSteps[2], onnxSteps[3] };
                break;

            default:
                throw new ArgumentException($"Unsupported tensor rank {rank} for StridedSlice");
            }

            return(true);
        }

        bool Transpose0UsingRank(Layer layer, ModelBuilder net)
        {
            string input0 = layer.inputs[0];

            Model.Input input0Info = net.model.inputs.First(i => i.name == layer.inputs[0]);

            Layer input0Transposed = net.Transpose($"Transpose_For_{input0}", input0, input0Info.rank == 3 ? k_FromNCHtoN1WC : k_ToNHWC);

            // Most of the layer stays intact
            string originalLayerName = layer.name;

            layer.name      = $"{layer.name}_NHWC";
            layer.inputs[0] = input0Transposed.name;
            net.model.layers.Add(layer);

            net.Transpose(originalLayerName, layer.name, input0Info.rank == 3 ? k_FromN1WCtoNCH : k_ToNCHW);

            return(false);
        }

        bool TransposeInput01UsingRank(Layer layer, ModelBuilder net)
        {
            string input0 = layer.inputs[0];

            Model.Input input0Info = net.model.inputs.First(i => i.name == layer.inputs[0]);

            string input1 = layer.inputs[1];

            Model.Input input1Info = net.model.inputs.First(i => i.name == layer.inputs[1]);

            Layer input0Transposed = net.Transpose($"Transpose_For_{input0}", input0, input0Info.rank == 3 ? k_FromNCHtoN1WC : k_ToNHWC);
            Layer input1Transposed = net.Transpose($"Transpose_For_{input1}", input1, input1Info.rank == 3 ? k_FromNCHtoN1WC : k_ToNHWC);

            string originalLayerName = layer.name;

            layer.name      = $"{layer.name}_NHWC";
            layer.inputs[0] = input0Transposed.name;
            layer.inputs[1] = input1Transposed.name;
            net.model.layers.Add(layer);

            net.Transpose(originalLayerName, layer.name, input0Info.rank == 3 ? k_FromN1WCtoNCH : k_ToNCHW);

            return(false);
        }

        bool TransposeForBroadcast(Layer layer, ModelBuilder net)
        {
            int maxRankI = 0;

            for (int i = 0; i < layer.inputs.Length; i++)
            {
                Model.Input inputInfo = net.model.inputs.First(x => x.name == layer.inputs[i]);
                maxRankI = Math.Max(maxRankI, inputInfo.rank);
            }

            List <Layer> insertedTranspose = new List <Layer>();

            for (int i = 0; i < layer.inputs.Length; i++)
            {
                string      input     = layer.inputs[i];
                Model.Input inputInfo = net.model.inputs.First(x => x.name == layer.inputs[i]);
                int         inputRank = inputInfo.rank;

                var   transpose       = GetTransposeForBroadCast(inputRank, maxRankI);
                Layer inputTransposed = net.Transpose($"Transpose_For_{input}", input, transpose);
                insertedTranspose.Add(inputTransposed);
            }

            string originalLayerName = layer.name;

            layer.name = $"{layer.name}_NHWC";
            for (int i = 0; i < layer.inputs.Length; i++)
            {
                layer.inputs[i] = insertedTranspose[i].name;
            }
            net.model.layers.Add(layer);

            net.Transpose(originalLayerName, layer.name, new [] { 0, 1, 2, 3 });

            return(false);
        }

        int[] GetTransposeForBroadCast(int rank0, int rank1)
        {
            if (rank0 == rank1)
            {
                return new[] { 0, 1, 2, 3 }
            }
            ;

            if (rank1 == 0 || rank1 == 1)
            {
                return new[] { 0, 1, 2, 3 }
            }
            ;
            if (rank1 == 2)
            {
                // 3 + 53 => 1,3
                if (rank0 == 0 || rank0 == 1)
                {
                    return new[] { 1, 0, 2, 3 }
                }
                ;
                else
                {
                    throw new ArgumentException($"Unsupported rank permutation change {rank0} to {rank1}");
                }
            }
            else if (rank1 == 3)
            {
                // 3 + 753 => 1,1,3
                if (rank0 == 0 || rank0 == 1)
                {
                    return new[] { 1, 2, 0, 3 }
                }
                ;
                // 53 + 753 => 1,5,3
                else if (rank0 == 2)
                {
                    return new[] { 2, 0, 1, 3 }
                }
                ;
                else
                {
                    throw new ArgumentException($"Unsupported rank permutation change {rank0} to {rank1}");
                }
            }
            else if (rank1 == 4)
            {
                // 3 + 9753 => 1,1,1,3
                if (rank0 == 0 || rank0 == 1)
                {
                    return new[] { 1, 2, 3, 0 }
                }
                ;
                // 53 + 9753 => 1,1,5,3
                else if (rank0 == 2)
                {
                    return new[] { 2, 3, 0, 1 }
                }
                ;
                // 753 + 9753 => 1,1,5,3
                else if (rank0 == 3)
                {
                    return new[] { 3, 0, 1, 2 }
                }
                ;
                else
                {
                    throw new ArgumentException($"Unsupported rank permutation change {rank0} to {rank1}");
                }
            }
            else
            {
                throw new ArgumentException($"Unsupported rank permutation change {rank0} to {rank1}");
            }
        }

        bool TransposeInput01(Layer layer, ModelBuilder net)
        {
            string input0 = layer.inputs[0];
            string input1 = layer.inputs[1];

            Layer  input0Transposed  = net.Transpose($"Transpose_For_{input0}", input0, k_ToNHWC);
            Layer  input1Transposed  = net.Transpose($"Transpose_For_{input1}", input1, k_ToNHWC);
            string originalLayerName = layer.name;

            layer.name      = $"{layer.name}_NHWC";
            layer.inputs[0] = input0Transposed.name;
            layer.inputs[1] = input1Transposed.name;
            net.model.layers.Add(layer);

            net.Transpose(originalLayerName, layer.name, k_ToNCHW);

            return(false);
        }

        bool TransposeInput0(Layer layer, ModelBuilder net)
        {
            string input0 = layer.inputs[0];

            Layer  input0Transposed  = net.Transpose($"Transpose_For_{input0}", input0, k_ToNHWC);
            string originalLayerName = layer.name;

            layer.name      = $"{layer.name}_NHWC";
            layer.inputs[0] = input0Transposed.name;
            net.model.layers.Add(layer);

            net.Transpose(originalLayerName, layer.name, k_ToNCHW);

            return(false);
        }

        private static int[] RankChangePermutationBarracuda(int rank0, int rank1)
        public void Run(ref Model model)
        {
            if (model.layout != "iNCHW")
            {
                return;
            }

            IDictionary <string, int?>         ranksByName;
            IDictionary <string, TensorShape?> shapesByName;

            IRShapeInferenceHelper.RankInference.ListTemporaryTensorRanks(model, out ranksByName);
            var inputShapes = new Dictionary <string, TensorShape>();

            foreach (var i in model.inputs)
            {
                if (!ModelAnalyzer.IsInputShapeAcceptablyKnowForShapeInference(i))
                {
                    continue;
                }
                inputShapes[i.name] = new TensorShape(i.shape);
            }

            IRShapeInferenceHelper.ShapeInference.ListTemporaryTensorShapesNCHW(model, inputShapes, ranksByName, out shapesByName);

            var nchw = model.ShallowCopy();

            nchw.layers.Clear();
            nchw.layout = "NCHW";

            var modelBuilder = new ModelBuilder(nchw);

            var rewriters    = new Dictionary <Layer.Type, Func <Layer, ModelBuilder, bool> >();
            var layerRenames = new Dictionary <string, string>();
            var inputRemaps  = new Dictionary <string, string>();

            // return true if layer should be included in rewritten model, false if it was replaced
            rewriters.Add(Layer.Type.Unsqueeze, (layer, net) =>
            {
                if (layer.pool.Length > 1)
                {
                    // Multiple axes unsupported; leave layer as-is
                    return(true);
                }

                string input0 = layer.inputs[0];

                if (!shapesByName.TryGetValue(input0, out TensorShape? input0Shape) || !input0Shape.HasValue)
                {
                    throw new Exception($"Must have input shape for {input0} for Unsqueeze");
                }

                if (!ranksByName.TryGetValue(input0, out int?input0Rank) || !input0Rank.HasValue)
                {
                    throw new Exception($"Must have input rank for {input0} for Unsqueeze");
                }

                int rank = input0Rank.Value;

                if (rank >= 4)
                {
                    // Only 4D unsqueezes of rank 3 or less are supported
                    return(true);
                }

                int axis = layer.pool[0];
                if (axis < 0)
                {
                    axis = rank + axis;
                }

                int[] shape8D    = input0Shape.Value.ToArray(); // 8D
                List <int> shape = new List <int>();
                shape.Add(shape8D[TensorShape.DataBatch]);
                if (rank > 1)
                {
                    shape.Add(shape8D[TensorShape.H]); // C in NCHW
                }
                if (rank > 2)
                {
                    shape.Add(shape8D[TensorShape.W]); // H in NCHW
                }
                shape.Insert(axis, 1);
                shape.AddRange(Enumerable.Repeat(1, 4 - shape.Count));

                net.Reshape(layer.name, input0, shape.ToArray());

                return(false);
            });
            rewriters.Add(Layer.Type.Squeeze, (layer, net) =>
            {
                if (layer.pool.Length > 1)
                {
                    // Multiple axes unsupported; leave layer as-is
                    return(true);
                }

                string input0 = layer.inputs[0];

                // Replace w/ a Transpose since Barracuda tensors are full rank
                if (!ranksByName.TryGetValue(input0, out int?input0Rank) || !input0Rank.HasValue)
                {
                    throw new Exception($"Must have input rank for {input0} for Squeeze");
                }

                int rank = input0Rank.Value;
                int axis = layer.pool[0];
                if (axis < 0)
                {
                    axis = rank + axis;
                }

                var transpose = SqueezeAxisPermutation(rank, axis);
                net.Transpose(layer.name, input0, transpose);

                return(false);
            });
            rewriters.Add(Layer.Type.NonMaxSuppression, (layer, net) =>
            {
                string boxes  = layer.inputs[0];
                string scores = layer.inputs[1];

                Layer boxesTransposed  = net.Transpose($"Transpose_For_{boxes}", boxes, k_FromNCHtoN1WC);
                Layer scoresTransposed = net.Transpose($"Transpose_For_{scores}", scores, k_FromNCHtoN1WC);

                // Most of the layer stays intact
                string originalLayerName = layer.name;
                layer.name      = $"{layer.name}_NHWC";
                layer.inputs[0] = boxesTransposed.name;
                layer.inputs[1] = scoresTransposed.name;
                net.model.layers.Add(layer);

                net.Transpose(originalLayerName, layer.name, k_ToNCHW);

                return(false);
            });
            rewriters.Add(Layer.Type.Activation, (layer, net) =>
            {
                return(true);
            });
            // Pad
            rewriters.Add(Layer.Type.Border2D, TransposeInput0);
            rewriters.Add(Layer.Type.Pad2DReflect, TransposeInput0);
            rewriters.Add(Layer.Type.Pad2DEdge, TransposeInput0);

            rewriters.Add(Layer.Type.GlobalAvgPool2D, TransposeInput0);
            rewriters.Add(Layer.Type.GlobalMaxPool2D, TransposeInput0);

            // Upsample
            rewriters.Add(Layer.Type.Upsample2D, (layer, net) =>
            {
                if (layer.inputs.Length > 1)
                {
                    return(TransposeInput01(layer, net)); // Upsample usage
                }
                else
                {
                    return(TransposeInput0(layer, net)); // Resize usage
                }
            });
            rewriters.Add(Layer.Type.Upsample3D, TransposeInput01); // Upsample usage
            rewriters.Add(Layer.Type.AvgPool2D, TransposeInput0);   // ModelBuilder: Resize2D

            // Resize: could be Resample2D, AvgPool2D, or Upsample2D
            rewriters.Add(Layer.Type.Resample2D, TransposeInput0);

            // Gemm
            rewriters.Add(Layer.Type.Dense, TransposeInput0);
            rewriters.Add(Layer.Type.MatMul, TransposeInput01UsingRank);

            // Conv
            rewriters.Add(Layer.Type.DepthwiseConv2D, Transpose0UsingRank);
            rewriters.Add(Layer.Type.Conv2D, Transpose0UsingRank);
            rewriters.Add(Layer.Type.Conv3D, Transpose0UsingRank);
            rewriters.Add(Layer.Type.Conv2DTrans, Transpose0UsingRank);

            // BatchNormalization
            rewriters.Add(Layer.Type.ScaleBias, Transpose0UsingRank);

            // InstanceNormalization
            rewriters.Add(Layer.Type.Normalization, Transpose0UsingRank);

            // broadcastable ops
            rewriters.Add(Layer.Type.Add, TransposeForBroadcast);
            rewriters.Add(Layer.Type.Mul, TransposeForBroadcast);
            rewriters.Add(Layer.Type.Sub, TransposeForBroadcast);
            rewriters.Add(Layer.Type.Div, TransposeForBroadcast);


            rewriters.Add(Layer.Type.StridedSlice, SliceToBarracuda);
            rewriters.Add(Layer.Type.Gather, GatherToBarracuda);
            rewriters.Add(Layer.Type.Concat, AxisToBarracuda);
            rewriters.Add(Layer.Type.Tile, ShapeToBarracuda);
            rewriters.Add(Layer.Type.Reshape, ShapeToBarracuda);
            rewriters.Add(Layer.Type.Transpose, TransposeToBarracuda);
            rewriters.Add(Layer.Type.Expand, (layer, net) =>
            {
                string input0          = layer.inputs[0];
                Model.Input input0Info = net.model.inputs.First(i => i.name == layer.inputs[0]);

                var rank0 = input0Info.rank;
                var size  = layer.pool.ToList();

                if (rank0 >= size.Count)
                {
                    for (int i = 0; i < rank0 - size.Count; i++)
                    {
                        size.Insert(0, 1);
                    }
                    layer.pool = size.ToArray();
                    return(ShapeToBarracuda(layer, net));
                }

                // inputShape needs to be unsqueezed
                var transpose       = RankChangePermutationBarracuda(rank0, size.Count);
                Layer nchwTranspose = net.Transpose($"Transpose_{input0}_For_{layer.name}", input0, transpose);

                ShapeToBarracuda(layer, net);

                net.Expand(layer.name, nchwTranspose, layer.pool);

                return(false);
            });
            rewriters.Add(Layer.Type.OneHot, (layer, net) =>
            {
                string input0          = layer.inputs[0];
                Model.Input input0Info = net.model.inputs.First(i => i.name == layer.inputs[0]);

                Layer input0Transposed = net.Transpose($"Transpose_For_{input0}", input0, k_ToNHWC);

                // Most of the layer stays intact
                string originalLayerName = layer.name;
                layer.name      = $"{layer.name}_NHWC";
                layer.inputs[0] = input0Transposed.name;
                net.model.layers.Add(layer);

                // OneHot outputRank = inputRank + 1
                net.Transpose(originalLayerName, layer.name, input0Info.rank == 2 ? k_FromN1WCtoNCH : k_ToNCHW);

                return(false);
            });

            // Reduce
            rewriters.Add(Layer.Type.ReduceL1, AxisToBarracuda);
            rewriters.Add(Layer.Type.ReduceL2, AxisToBarracuda);
            rewriters.Add(Layer.Type.ReduceMax, AxisToBarracuda);
            rewriters.Add(Layer.Type.ReduceMean, AxisToBarracuda);
            rewriters.Add(Layer.Type.ReduceMin, AxisToBarracuda);
            rewriters.Add(Layer.Type.ReduceProd, AxisToBarracuda);
            rewriters.Add(Layer.Type.ReduceSum, AxisToBarracuda);
            rewriters.Add(Layer.Type.ReduceLogSum, AxisToBarracuda);
            rewriters.Add(Layer.Type.ReduceSumSquare, AxisToBarracuda);
            rewriters.Add(Layer.Type.ReduceLogSumExp, AxisToBarracuda);

            foreach (var l in model.layers)
            {
                if (!rewriters.TryGetValue(l.type, out Func <Layer, ModelBuilder, bool> rw) || rw(l, modelBuilder))
                {
                    nchw.layers.Add(l);
                }
            }

            model = nchw;
        }
Example #10
0
        private void CreateBarracudaWorker()
        {
            if (_styleTexture == null || _predictionAlphasBetasData == null)
            {
                return;
            }
            int savedAlphaBetasIndex = 0;

            _layerNameToPatch = new List <string>();
            List <Layer> layerList = new List <Layer>(_model.layers);

            // Pre-process Network for run-time use
            Layer lastConv = null;

            for (int i = 0; i < layerList.Count; i++)
            {
                Layer layer = layerList[i];

                // Remove Style_Prediction_Network: constant with style, executed once in Setup()
                if (layer.name.Contains("Style_Prediction_Network/"))
                {
                    layerList.RemoveAt(i);
                    i--;
                    continue;
                }

                // Fix Upsample2D size parameters
                if (layer.type == Layer.Type.Upsample2D)
                {
                    layer.pool = new[] { 2, 2 };
                    //ref model is supposed to be nearest sampling but bilinear scale better when network is applied at lower resoltions
                    bool useBilinearUpsampling = forceBilinearUpsample2DInModel.value || (modelType.value != ModelType.Reference);
                    layer.axis = useBilinearUpsampling ? 1 : -1;
                }

                // Remove Mirror padding layers (not supported, TODO)
                if (layer.name.Contains("reflect_padding"))
                {
                    layerList[i + 1].inputs = layer.inputs;
                    layerList[i + 1].pad    = layer.pad.ToArray();
                    layerList.RemoveAt(i);
                    i--;
                }
                else if (layer.type == Layer.Type.Conv2D || layer.type == Layer.Type.Conv2DTrans)
                {
                    lastConv = layer;
                }
                else if (layer.type == Layer.Type.Normalization)
                {
                    // Manually set alpha/betas from Style_Prediction_Network as scale/bias tensors for InstanceNormalization
                    if (layerList[i - 1].type == Layer.Type.StridedSlice)
                    {
                        int channels = _predictionAlphasBetasData[savedAlphaBetasIndex].Length;
                        layer.datasets = new Layer.DataSet[2];

                        layer.datasets[0].shape  = new TensorShape(1, 1, 1, channels);
                        layer.datasets[0].offset = 0;
                        layer.datasets[0].length = channels;

                        layer.datasets[1].shape  = new TensorShape(1, 1, 1, channels);
                        layer.datasets[1].offset = channels;
                        layer.datasets[1].length = channels;

                        _layerNameToPatch.Add(layer.name);

                        float[] data = new float[channels * 2];
                        for (int j = 0; j < data.Length / 2; j++)
                        {
                            data[j] = _predictionAlphasBetasData[savedAlphaBetasIndex][j];
                        }
                        for (int j = data.Length / 2; j < data.Length; j++)
                        {
                            data[j] = _predictionAlphasBetasData[savedAlphaBetasIndex + 1][j - data.Length / 2];
                        }

                        layer.weights = new BarracudaArrayFromManagedArray(data);

                        savedAlphaBetasIndex += 2;
                    }
                    // Else initialize scale/bias tensors of InstanceNormalization to default 1/0
                    else
                    {
                        int channels = lastConv.datasets[1].shape.channels;
                        layer.datasets = new Layer.DataSet[2];

                        layer.datasets[0].shape  = new TensorShape(1, 1, 1, channels);
                        layer.datasets[0].offset = 0;
                        layer.datasets[0].length = channels;

                        layer.datasets[1].shape  = new TensorShape(1, 1, 1, channels);
                        layer.datasets[1].offset = channels;
                        layer.datasets[1].length = channels;

                        float[] data = new float[channels * 2];
                        for (int j = 0; j < data.Length / 2; j++)
                        {
                            data[j] = 1.0f;
                        }
                        for (int j = data.Length / 2; j < data.Length; j++)
                        {
                            data[j] = 0.0f;
                        }
                        layer.weights = new BarracudaArrayFromManagedArray(data);
                    }
                }
            }

            // Remove Slice layers originally used to get alpha/beta tensors into Style_Network
            for (int i = 0; i < layerList.Count; i++)
            {
                Layer layer = layerList[i];
                if (layer.type == Layer.Type.StridedSlice)
                {
                    layerList.RemoveAt(i);
                    i--;
                }
            }

            // Fold Relu into instance normalisation
            Dictionary <string, string> reluToInstNorm = new Dictionary <string, string>();

            for (int i = 0; i < layerList.Count; i++)
            {
                Layer layer = layerList[i];
                if (layer.type == Layer.Type.Activation && layer.activation == Layer.Activation.Relu)
                {
                    if (layerList[i - 1].type == Layer.Type.Normalization)
                    {
                        layerList[i - 1].activation = layer.activation;
                        reluToInstNorm[layer.name]  = layerList[i - 1].name;
                        layerList.RemoveAt(i);
                        i--;
                    }
                }
            }
            for (int i = 0; i < layerList.Count; i++)
            {
                Layer layer = layerList[i];
                for (int j = 0; j < layer.inputs.Length; j++)
                {
                    if (reluToInstNorm.ContainsKey(layer.inputs[j]))
                    {
                        layer.inputs[j] = reluToInstNorm[layer.inputs[j]];
                    }
                }
            }

            // Feed first convolution directly with input (no need for normalisation from the model)
            string firstConvName = "StyleNetwork/conv1/convolution_conv1/convolution";
            int    firstConv     = FindLayerIndexByName(layerList, firstConvName);

            layerList[firstConv].inputs = new[] { _model.inputs[1].name };

            if (modelType.value == ModelType.Reference)
            {
                layerList.RemoveAt(FindLayerIndexByName(layerList, "StyleNetwork/normalisation/add"));
                layerList.RemoveAt(FindLayerIndexByName(layerList, "StyleNetwork/normalisation/add/y"));
                layerList.RemoveAt(FindLayerIndexByName(layerList, "StyleNetwork/normalisation/normalized_contentFrames"));
                layerList.RemoveAt(FindLayerIndexByName(layerList, "StyleNetwork/normalisation/normalized_contentFrames/y"));
                layerList.RemoveAt(FindLayerIndexByName(layerList, "StyleNetwork/normalisation/sub"));
                layerList.RemoveAt(FindLayerIndexByName(layerList, "StyleNetwork/normalisation/sub/y"));
            }
            if (modelType.value == ModelType.RefBut32Channels)
            {
                layerList.RemoveAt(FindLayerIndexByName(layerList, "StyleNetwork/normalized_contentFrames"));
                layerList.RemoveAt(FindLayerIndexByName(layerList, "StyleNetwork/normalized_contentFrames/y"));
            }

            // Remove final model post processing, post process happen in tensor to texture instead
            int postAdd = FindLayerIndexByName(layerList, "StyleNetwork/clamp_0_255/add");

            layerList.RemoveRange(postAdd, 5);

            // Correct wrong output layer list
            _model.outputs = new List <string>()
            {
                layerList[postAdd - 1].name
            };

            _model.layers = layerList;
            Model.Input input = _model.inputs[1];
            input.shape[0] = 0;
            input.shape[1] = _rtHandle.rt.height;
            input.shape[2] = _rtHandle.rt.width;
            input.shape[3] = 3;
            _model.inputs  = new List <Model.Input> {
                _model.inputs[1]
            };
            //Create worker and execute it once at target resolution to prime all memory allocation (however in editor resolution can still change at runtime)
            _worker = WorkerFactory.CreateWorker(WorkerFactory.ValidateType(workerType.value), _model, debugModelLoading.value);
            Dictionary <string, Tensor> temp = new Dictionary <string, Tensor>();
            var inputTensor = new Tensor(input.shape, input.name);

            temp.Add("frame", inputTensor);
            _worker.Execute(temp);
            inputTensor.Dispose();

            Debug.Log("Style Transfer Model: \n" + _model.ToString());
        }
        // works on IRModel
        public bool InferAllLayersChannelOrder(Model model, out Dictionary <string, ChannelsOrder> layerChannelOrder)
        {
            // TF2Onnx : pattern T (.* Conv .*) T-1
            // * being transpose commutative layer
            layerChannelOrder = new Dictionary <string, ChannelsOrder>();

            IDictionary <string, TensorShape?> shapesByName = new Dictionary <string, TensorShape?>();
            IDictionary <string, int?>         ranksByName  = new Dictionary <string, int?>();

            foreach (var i in model.inputs)
            {
                ranksByName[i.name] = i.rank;
                if (!ModelAnalyzer.IsInputShapeAcceptablyKnowForShapeInference(i))
                {
                    continue;
                }
                shapesByName[i.name] = new TensorShape(i.shape);
            }

            IRShapeInferenceAndConstantFusing shapeInferencePass = new IRShapeInferenceAndConstantFusing();

            shapeInferencePass.InferAllShapes(model, ref shapesByName, ref ranksByName);

            bool inputsNHWC = false;
            bool inputsNHWCExportedInputsAsNCHW = false;

            bool patternMatchStart = false;
            bool patternMatchConv  = false;
            // tf to onnx does not swizzle axis, need to match * Conv * T-1 ...
            bool patternMatchStartInputsAsNCHWConv = false;

            for (int l = 0; l < model.layers.Count; l++)
            {
                var layer = model.layers[l];
                if (!patternMatchStart &&
                    IsLayerTranpose(layer) && Enumerable.SequenceEqual(layer.pool, new[] { 0, 3, 1, 2 }) ||
                    IsLayerReshape(layer) && (shapesByName[layer.inputs[0]] != null) && IsReshapeTransposeToNCHW(layer, shapesByName[layer.inputs[0]].Value))
                {
                    patternMatchStart = true;
                }
                else if (patternMatchStart && patternMatchConv &&
                         ((IsLayerTranpose(layer) && Enumerable.SequenceEqual(layer.pool, new[] { 0, 2, 3, 1 })) ||
                          (IsLayerReshape(layer) && (shapesByName[layer.inputs[0]] != null) && IsReshapeTransposeToNHWC(layer, shapesByName[layer.inputs[0]].Value)) ||
                          (IsLayerSqueeze(layer) && (ranksByName[layer.inputs[0]] != null) && IsSqueezeTransposeToNHWC(layer, ranksByName[layer.inputs[0]].Value)) ||
                          (IsLayerFlatten(layer) && (ranksByName[layer.inputs[0]] != null) && IsFlattenTransposeToNHWC(layer, ranksByName[layer.inputs[0]].Value))))
                {
                    inputsNHWC = true;
                }
                else if (patternMatchStart && IsLayerConv(layer))
                {
                    patternMatchConv = true;
                }

                if (!inputsNHWCExportedInputsAsNCHW && patternMatchStartInputsAsNCHWConv &&
                    ((IsLayerTranpose(layer) && Enumerable.SequenceEqual(layer.pool, new[] { 0, 2, 3, 1 })) ||
                     (IsLayerReshape(layer) && (shapesByName[layer.inputs[0]] != null) && IsReshapeTransposeToNHWC(layer, shapesByName[layer.inputs[0]].Value))))
                {
                    inputsNHWCExportedInputsAsNCHW = true;
                }
                else if (!patternMatchStartInputsAsNCHWConv && !patternMatchStart && IsLayerConv(layer))
                {
                    patternMatchStartInputsAsNCHWConv = true;
                }
            }

            // flag each layer as being NHWC or NCHW
            for (int i = 0; i < model.inputs.Count; i++)
            {
                Model.Input input = model.inputs[i];
                if (!inputsNHWCExportedInputsAsNCHW)
                {
                    layerChannelOrder[input.name] = inputsNHWC ? ChannelsOrder.NHWC : ChannelsOrder.NCHW;
                }
                else
                {
                    layerChannelOrder[input.name] = ChannelsOrder.NCHW;
                }
            }

            for (int l = 0; l < model.layers.Count; l++)
            {
                var layer = model.layers[l];

                if (IsLayerTranpose(layer) && Enumerable.SequenceEqual(layer.pool, new[] { 0, 3, 1, 2 }) ||
                    IsLayerReshape(layer) && (shapesByName[layer.inputs[0]] != null) && IsReshapeTransposeToNCHW(layer, shapesByName[layer.inputs[0]].Value) &&
                    layerChannelOrder[layer.inputs[0]] == ChannelsOrder.NHWC)
                {
                    layerChannelOrder[layer.name] = ChannelsOrder.TransposeToNCHW;
                }
                else if (IsLayerTranpose(layer) && Enumerable.SequenceEqual(layer.pool, new[] { 0, 2, 3, 1 }) ||
                         IsLayerReshape(layer) && (shapesByName[layer.inputs[0]] != null) && IsReshapeTransposeToNHWC(layer, shapesByName[layer.inputs[0]].Value) &&
                         layerChannelOrder[layer.inputs[0]] == ChannelsOrder.NCHW)
                {
                    layerChannelOrder[layer.name] = ChannelsOrder.TransposeToNHWC;
                }
                else
                {
                    string inputWithKnownOrder = null;
                    for (int i = 0; i < layer.inputs.Length; i++)
                    {
                        var input = layer.inputs[i];
                        if (layerChannelOrder.ContainsKey(input))
                        {
                            inputWithKnownOrder = input;
                            break;
                        }
                    }

                    if (inputWithKnownOrder == null)
                    {
                        continue;
                    }
                    Assert.IsNotNull(inputWithKnownOrder);
                    ChannelsOrder inputOrder = layerChannelOrder[inputWithKnownOrder];

                    if (inputOrder == ChannelsOrder.TransposeToNCHW)
                    {
                        inputOrder = ChannelsOrder.NCHW;
                    }
                    else if (inputOrder == ChannelsOrder.TransposeToNHWC)
                    {
                        inputOrder = ChannelsOrder.NHWC;
                    }

                    // all layers with unknown layout are const
                    for (int i = 0; i < layer.inputs.Length; i++)
                    {
                        var input = layer.inputs[i];
                        if (!layerChannelOrder.ContainsKey(input))
                        {
                            layerChannelOrder[input] = inputOrder;
                        }
                    }

                    layerChannelOrder[layer.name] = inputOrder;
                }
            }

            // TODO Assert that all layers have a channel order
            // Assert that all layers are NHWC if inputsNHWC
            return(inputsNHWC || inputsNHWCExportedInputsAsNCHW);
        }
        bool StridedSlice(Layer layer, ModelBuilder net)
        {
            string input0 = layer.inputs[0];

            Model.Input input0Info = net.model.inputs.First(i => i.name == layer.inputs[0]);

            var starts = layer.pad;
            var ends   = layer.pool;
            var steps  = layer.stride;
            var axes   = layer.axes;

            var onnxRank   = input0Info.rank;
            var onnxStarts = Enumerable.Repeat(0, onnxRank).ToArray();
            var onnxEnds   = Enumerable.Repeat(int.MaxValue, onnxRank).ToArray(); // by default copy the whole axis till the end
            var onnxSteps  = Enumerable.Repeat(1, onnxRank).ToArray();

            // NOTE: begin=0, end=0, stride=1  <=  full range from existing axis
            //       begin=0, end=inf,stride=1 <=  full range from existing axis
            //       begin=0, end=X, stride=1  <=  full range from existing axis, if X==last element on this axis
            //       begin=0, end=0, stride=0  <=  new axis OR shrink axis to single 1st element
            //       begin=N, end=N, stride=0  <=              shrink axis to single Nth element
            // These notes are copied from TensorExtensions.ApplyStridedSlice(...)

            for (int i = 0; i < axes.Length; ++i)
            {
                var axis = axes[i];
                if (axis < 0)
                {
                    axis += onnxRank;
                }
                axis = Math.Min(Math.Max(axis, 0), onnxRank);

                onnxStarts[axis] = starts[i];
                onnxEnds[axis]   = ends[i];
                onnxSteps[axis]  = steps[i];
            }

            layer.pad    = new[] { 0, 0, 0, 0, 0, 0, 0, 0 };                                                                                         // ONNXLayout.PermuteToBarracuda(onnxStarts, onnxLayout: "NCHW", 0);
            layer.pool   = new[] { int.MaxValue, int.MaxValue, int.MaxValue, int.MaxValue, int.MaxValue, int.MaxValue, int.MaxValue, int.MaxValue }; // ONNXLayout.PermuteToBarracuda(onnxEnds, onnxLayout: "NCHW", int.MaxValue);
            layer.stride = new[] { 1, 1, 1, 1, 1, 1, 1, 1 };                                                                                         // ONNXLayout.PermuteToBarracuda(onnxSteps, onnxLayout: "NCHW", 1);

            for (int i = 0; i < onnxRank; i++)
            {
                switch (onnxRank)
                {
                case 6:
                    layer.pad[i + 2]    = onnxStarts[i];
                    layer.pool[i + 2]   = onnxEnds[i];
                    layer.stride[i + 2] = onnxSteps[i];
                    break;

                case 5:
                    layer.pad[i + (i == 0 ? 2 : 3)]    = onnxStarts[i];
                    layer.pool[i + (i == 0 ? 2 : 3)]   = onnxEnds[i];
                    layer.stride[i + (i == 0 ? 2 : 3)] = onnxSteps[i];
                    break;

                default:
                    layer.pad[i + (i == 0 ? 2 : 4)]    = onnxStarts[i];
                    layer.pool[i + (i == 0 ? 2 : 4)]   = onnxEnds[i];
                    layer.stride[i + (i == 0 ? 2 : 4)] = onnxSteps[i];
                    break;
                }
            }

            return(true);
        }