Example #1
0
        public void Run(ref Model model)
        {
            if (!model.layout.Contains("NCHW"))
            {
                return;
            }

            // This is a necessary pass for NCHW models that have the layout built into the model itself (e.g. SSD)
            // It's necessary to contract this into a single layer, so that the Gather pass doesn't get converted
            var shapeContractionPass = new ShapeContractionPass();

            shapeContractionPass.Run(ref model);

            // Remove shape-gather-reshape pattern when they map a transpose to NHWC operation
            var shapeGatherReshapeToNHWCRemovePass = new ShapeGatherReshapeToNHWCRemovePass();

            shapeGatherReshapeToNHWCRemovePass.Run(ref model);

            Rewrite(ref model);

            // Preserve any new layers that must be preserved (e.g. new LSTM outputs)
            // TODO: outputs are preserved, adjust optimization passes to properly merge outputs by renaming layers
            var preserveLayersPass = new PreserveLayersPass();

            preserveLayersPass.Run(ref model);

            // cleanup
            var removeUnusedPass = new Cleanup.RemoveUnusedLayersPass();

            removeUnusedPass.Run(ref model);
            var removeNoOpPass = new Cleanup.RemoveNoOpsPass();

            removeNoOpPass.Run(ref model);
        }
        public void Run(ref Model model)
        {
            var warnings           = new List <Model.ImporterWarning>();
            var shapeInferencePass = new IRShapeInferenceAndConstantFusing();

            shapeInferencePass.Run(ref model, warnings);

            if (Optimize)
            {
                // Optimization
                var linearLayerFusingPass = new Optimization.FuseLinearLayersPass();
                linearLayerFusingPass.Run(ref model);
                var activationFusingPass = new Optimization.FuseActivationPass();
                activationFusingPass.Run(ref model);

                // Cleanup
                var removeUnusedPass = new Cleanup.RemoveUnusedLayersPass();
                removeUnusedPass.Run(ref model);
                var removeNoOpPass = new Cleanup.RemoveNoOpsPass();
                removeNoOpPass.Run(ref model);
            }

            // TODO, put asserts in ImporterWarning?
            var validateNCHWPass = new ValidateNCHWPass();

            validateNCHWPass.Run(model, ref warnings);

            // to runnable NHWC
            var nhwcPass = new NCHWToNHWCPass();

            nhwcPass.Run(ref model);

            // optimizations
            if (Optimize)
            {
                var contractToSimplerLayerPass = new Optimization.ContractToSimplerLayerPass();
                contractToSimplerLayerPass.Run(ref model);

                var concatenateTransposesPass = new Optimization.ConcatenateTransposesPass();
                concatenateTransposesPass.Run(ref model);

                var dense3FusingPass = new Optimization.FuseDense3Pass();
                dense3FusingPass.Run(ref model);
            }

            var validateNHWCPass = new ValidateNHWCPass();

            validateNHWCPass.Run(model, ref warnings);

            model.Warnings.AddRange(warnings);
        }