public void Run(ref Model model) { if (!model.layout.Contains("NCHW")) { return; } // This is a necessary pass for NCHW models that have the layout built into the model itself (e.g. SSD) // It's necessary to contract this into a single layer, so that the Gather pass doesn't get converted var shapeContractionPass = new ShapeContractionPass(); shapeContractionPass.Run(ref model); // Remove shape-gather-reshape pattern when they map a transpose to NHWC operation var shapeGatherReshapeToNHWCRemovePass = new ShapeGatherReshapeToNHWCRemovePass(); shapeGatherReshapeToNHWCRemovePass.Run(ref model); Rewrite(ref model); // Preserve any new layers that must be preserved (e.g. new LSTM outputs) // TODO: outputs are preserved, adjust optimization passes to properly merge outputs by renaming layers var preserveLayersPass = new PreserveLayersPass(); preserveLayersPass.Run(ref model); // cleanup var removeUnusedPass = new Cleanup.RemoveUnusedLayersPass(); removeUnusedPass.Run(ref model); var removeNoOpPass = new Cleanup.RemoveNoOpsPass(); removeNoOpPass.Run(ref model); }
public void Run(ref Model model) { var warnings = new List <Model.ImporterWarning>(); var shapeInferencePass = new IRShapeInferenceAndConstantFusing(); shapeInferencePass.Run(ref model, warnings); if (Optimize) { // Optimization var linearLayerFusingPass = new Optimization.FuseLinearLayersPass(); linearLayerFusingPass.Run(ref model); var activationFusingPass = new Optimization.FuseActivationPass(); activationFusingPass.Run(ref model); // Cleanup var removeUnusedPass = new Cleanup.RemoveUnusedLayersPass(); removeUnusedPass.Run(ref model); var removeNoOpPass = new Cleanup.RemoveNoOpsPass(); removeNoOpPass.Run(ref model); } // TODO, put asserts in ImporterWarning? var validateNCHWPass = new ValidateNCHWPass(); validateNCHWPass.Run(model, ref warnings); // to runnable NHWC var nhwcPass = new NCHWToNHWCPass(); nhwcPass.Run(ref model); // optimizations if (Optimize) { var contractToSimplerLayerPass = new Optimization.ContractToSimplerLayerPass(); contractToSimplerLayerPass.Run(ref model); var concatenateTransposesPass = new Optimization.ConcatenateTransposesPass(); concatenateTransposesPass.Run(ref model); var dense3FusingPass = new Optimization.FuseDense3Pass(); dense3FusingPass.Run(ref model); } var validateNHWCPass = new ValidateNHWCPass(); validateNHWCPass.Run(model, ref warnings); model.Warnings.AddRange(warnings); }