public void Run(ref Model model) { var warnings = new List <Model.ImporterWarning>(); var shapeInferencePass = new IRShapeInferenceAndConstantFusing(); shapeInferencePass.Run(ref model, warnings); if (Optimize) { // Optimization var linearLayerFusingPass = new Optimization.FuseLinearLayersPass(); linearLayerFusingPass.Run(ref model); var activationFusingPass = new Optimization.FuseActivationPass(); activationFusingPass.Run(ref model); // Cleanup var removeUnusedPass = new Cleanup.RemoveUnusedLayersPass(); removeUnusedPass.Run(ref model); var removeNoOpPass = new Cleanup.RemoveNoOpsPass(); removeNoOpPass.Run(ref model); } // TODO, put asserts in ImporterWarning? var validateNCHWPass = new ValidateNCHWPass(); validateNCHWPass.Run(model, ref warnings); // to runnable NHWC var nhwcPass = new NCHWToNHWCPass(); nhwcPass.Run(ref model); // optimizations if (Optimize) { var contractToSimplerLayerPass = new Optimization.ContractToSimplerLayerPass(); contractToSimplerLayerPass.Run(ref model); var concatenateTransposesPass = new Optimization.ConcatenateTransposesPass(); concatenateTransposesPass.Run(ref model); var dense3FusingPass = new Optimization.FuseDense3Pass(); dense3FusingPass.Run(ref model); } var validateNHWCPass = new ValidateNHWCPass(); validateNHWCPass.Run(model, ref warnings); model.Warnings.AddRange(warnings); }
public void Run(Model model, ref List <Model.ImporterWarning> warnings) { var modelTemp = model.ShallowCopy(); IDictionary <string, TensorShape> inputShapes = new Dictionary <string, TensorShape>(); // force batch to 1 for (int i = 0; i < modelTemp.inputs.Count; i++) { var input = modelTemp.inputs[i]; var shape = input.shape.ToArray(); if (shape[TensorShape.DataBatch] <= 0) { shape[TensorShape.DataBatch] = 1; } input.shape = shape; modelTemp.inputs[i] = input; if (!ModelAnalyzer.IsInputShapeAcceptablyKnowForShapeInference(input)) { continue; } inputShapes[input.name] = new TensorShape(input.shape); } ValidationHelper.AppendWarning(inputShapes.Count == modelTemp.inputs.Count, "model", "Input Shape: unkown non batch dimension", ref warnings); IRShapeInferenceAndConstantFusing shapeInferencePass = new IRShapeInferenceAndConstantFusing(); shapeInferencePass.Run(ref modelTemp); IDictionary <string, int?> ranksByName; IRShapeInferenceHelper.RankInference.ListTemporaryTensorRanks(modelTemp, out ranksByName); IDictionary <string, TensorShape?> shapesByName; IRShapeInferenceHelper.ShapeInference.ListTemporaryTensorShapesNCHW(modelTemp, inputShapes, ranksByName, out shapesByName); int negativeRanks = ranksByName.Values.Count(x => x < 0); ValidationHelper.AppendWarning(negativeRanks == 0, "model", $"StaticRankInference: {negativeRanks} negative rank(s) found!", ref warnings, MessageType.Warning); int knowRanks = ranksByName.Count(x => x.Value != null); int knowShapes = shapesByName.Count(x => x.Value != null); ValidationHelper.AppendWarning(knowRanks == knowShapes, "model", "StaticShape/RankInference: known ranks # != known shape #", ref warnings); foreach (var i in modelTemp.inputs) { var name = i.name; ValidationHelper.AppendWarning(ranksByName.ContainsKey(name), name, "StaticRankInference: did not find input", ref warnings); if (ranksByName.ContainsKey(name)) { ValidationHelper.AppendWarning(ranksByName[name] != null, name, "StaticRankInference: unknown input rank at compile time", ref warnings); } ValidationHelper.AppendWarning(shapesByName.ContainsKey(name), name, "StaticShapeInference: did not find input", ref warnings); if (shapesByName.ContainsKey(name)) { ValidationHelper.AppendWarning(shapesByName[name] != null, name, "StaticShapeInference: unknown input shape for at compile time", ref warnings); } } foreach (var l in modelTemp.layers) { var name = l.name; ValidationHelper.AppendWarning(ranksByName.ContainsKey(name), name, "StaticRankInference: did not find layer", ref warnings); if (ranksByName.ContainsKey(name)) { ValidationHelper.AppendWarning(ranksByName[name] != null, name, "StaticRankInference: unknown layer rank at compile time", ref warnings); } ValidationHelper.AppendWarning(shapesByName.ContainsKey(name), name, "StaticShapeInference: did not find layer", ref warnings); if (shapesByName.ContainsKey(name)) { ValidationHelper.AppendWarning(shapesByName[name] != null, name, "StaticShapeInference: unknown layer shape at compile time", ref warnings); } } }