public void Run(ref Model model)
        {
            var warnings           = new List <Model.ImporterWarning>();
            var shapeInferencePass = new IRShapeInferenceAndConstantFusing();

            shapeInferencePass.Run(ref model, warnings);

            if (Optimize)
            {
                // Optimization
                var linearLayerFusingPass = new Optimization.FuseLinearLayersPass();
                linearLayerFusingPass.Run(ref model);
                var activationFusingPass = new Optimization.FuseActivationPass();
                activationFusingPass.Run(ref model);

                // Cleanup
                var removeUnusedPass = new Cleanup.RemoveUnusedLayersPass();
                removeUnusedPass.Run(ref model);
                var removeNoOpPass = new Cleanup.RemoveNoOpsPass();
                removeNoOpPass.Run(ref model);
            }

            // TODO, put asserts in ImporterWarning?
            var validateNCHWPass = new ValidateNCHWPass();

            validateNCHWPass.Run(model, ref warnings);

            // to runnable NHWC
            var nhwcPass = new NCHWToNHWCPass();

            nhwcPass.Run(ref model);

            // optimizations
            if (Optimize)
            {
                var contractToSimplerLayerPass = new Optimization.ContractToSimplerLayerPass();
                contractToSimplerLayerPass.Run(ref model);

                var concatenateTransposesPass = new Optimization.ConcatenateTransposesPass();
                concatenateTransposesPass.Run(ref model);

                var dense3FusingPass = new Optimization.FuseDense3Pass();
                dense3FusingPass.Run(ref model);
            }

            var validateNHWCPass = new ValidateNHWCPass();

            validateNHWCPass.Run(model, ref warnings);

            model.Warnings.AddRange(warnings);
        }
Exemple #2
0
        public void Run(Model model, ref List <Model.ImporterWarning> warnings)
        {
            var modelTemp = model.ShallowCopy();
            IDictionary <string, TensorShape> inputShapes = new Dictionary <string, TensorShape>();

            // force batch to 1
            for (int i = 0; i < modelTemp.inputs.Count; i++)
            {
                var input = modelTemp.inputs[i];
                var shape = input.shape.ToArray();
                if (shape[TensorShape.DataBatch] <= 0)
                {
                    shape[TensorShape.DataBatch] = 1;
                }
                input.shape         = shape;
                modelTemp.inputs[i] = input;

                if (!ModelAnalyzer.IsInputShapeAcceptablyKnowForShapeInference(input))
                {
                    continue;
                }

                inputShapes[input.name] = new TensorShape(input.shape);
            }

            ValidationHelper.AppendWarning(inputShapes.Count == modelTemp.inputs.Count, "model", "Input Shape: unkown non batch dimension", ref warnings);

            IRShapeInferenceAndConstantFusing shapeInferencePass = new IRShapeInferenceAndConstantFusing();

            shapeInferencePass.Run(ref modelTemp);

            IDictionary <string, int?> ranksByName;

            IRShapeInferenceHelper.RankInference.ListTemporaryTensorRanks(modelTemp, out ranksByName);
            IDictionary <string, TensorShape?> shapesByName;

            IRShapeInferenceHelper.ShapeInference.ListTemporaryTensorShapesNCHW(modelTemp, inputShapes, ranksByName, out shapesByName);

            int negativeRanks = ranksByName.Values.Count(x => x < 0);

            ValidationHelper.AppendWarning(negativeRanks == 0, "model", $"StaticRankInference: {negativeRanks} negative rank(s) found!", ref warnings, MessageType.Warning);

            int knowRanks  = ranksByName.Count(x => x.Value != null);
            int knowShapes = shapesByName.Count(x => x.Value != null);

            ValidationHelper.AppendWarning(knowRanks == knowShapes, "model", "StaticShape/RankInference: known ranks # != known shape #", ref warnings);

            foreach (var i in modelTemp.inputs)
            {
                var name = i.name;
                ValidationHelper.AppendWarning(ranksByName.ContainsKey(name), name, "StaticRankInference: did not find input", ref warnings);
                if (ranksByName.ContainsKey(name))
                {
                    ValidationHelper.AppendWarning(ranksByName[name] != null, name, "StaticRankInference: unknown input rank at compile time", ref warnings);
                }

                ValidationHelper.AppendWarning(shapesByName.ContainsKey(name), name, "StaticShapeInference: did not find input", ref warnings);
                if (shapesByName.ContainsKey(name))
                {
                    ValidationHelper.AppendWarning(shapesByName[name] != null, name, "StaticShapeInference: unknown input shape for at compile time", ref warnings);
                }
            }
            foreach (var l in modelTemp.layers)
            {
                var name = l.name;
                ValidationHelper.AppendWarning(ranksByName.ContainsKey(name), name, "StaticRankInference: did not find layer", ref warnings);
                if (ranksByName.ContainsKey(name))
                {
                    ValidationHelper.AppendWarning(ranksByName[name] != null, name, "StaticRankInference: unknown layer rank at compile time", ref warnings);
                }

                ValidationHelper.AppendWarning(shapesByName.ContainsKey(name), name, "StaticShapeInference: did not find layer", ref warnings);
                if (shapesByName.ContainsKey(name))
                {
                    ValidationHelper.AppendWarning(shapesByName[name] != null, name, "StaticShapeInference: unknown layer shape at compile time", ref warnings);
                }
            }
        }