// works on IRModel public bool InferAllLayersChannelOrder(Model model, out Dictionary <string, ChannelsOrder> layerChannelOrder) { layerChannelOrder = new Dictionary <string, ChannelsOrder>(); IDictionary <string, TensorShape?> shapesByName = new Dictionary <string, TensorShape?>(); IDictionary <string, int?> ranksByName = new Dictionary <string, int?>(); foreach (var i in model.inputs) { ranksByName[i.name] = i.rank; if (!ModelAnalyzer.IsInputShapeAcceptablyKnowForShapeInference(i)) { continue; } shapesByName[i.name] = new TensorShape(i.shape); } IRShapeInferenceAndConstantFusing shapeInferencePass = new IRShapeInferenceAndConstantFusing(); shapeInferencePass.InferAllShapes(model, ref shapesByName, ref ranksByName); // flood-fill approach: NCHW layout is propagated from NCHW ops // * onnx-nchw ops are flagged as being native nchw // * nchw layout is propagated to upstream and downstream nodes // foreach node: // take layout being propagated to // if T or T-1 flip layout depending on upstream/downstream direction // - stop if layout is the same as previously propagated // - native nchw layout has priority Queue <(string, ChannelsOrder, FlowDirection)> layersToInferLayout = new Queue <(string, ChannelsOrder, FlowDirection)>(); for (int l = 0; l < model.layers.Count; l++) { var layer = model.layers[l]; if (!IsLayerNecessarilyNCHWOnnx(layer)) { continue; } layersToInferLayout.Enqueue((layer.name, ChannelsOrder.NativeNCHW, FlowDirection.Seed)); } while (layersToInferLayout.Any()) { (string, ChannelsOrder, FlowDirection)layerData = layersToInferLayout.Dequeue(); string name = layerData.Item1; ChannelsOrder deducedChannelOrder = layerData.Item2; // 0: in-place native // 1: downstream // 2: upstream FlowDirection flowDirection = layerData.Item3; if (!layerChannelOrder.ContainsKey(name)) { layerChannelOrder[name] = deducedChannelOrder; } else if (deducedChannelOrder == layerChannelOrder[name]) { continue; } else if (layerChannelOrder[name] == ChannelsOrder.NativeNCHW) { continue; } // heuristic to stop ping-pong loop, prioritize NHWC over NCHW as it implies less transposes // if incoming is NativeNCHW always propagate that // TODO: count # of transpose swaps else if (layerChannelOrder[name] == ChannelsOrder.NHWC && deducedChannelOrder != ChannelsOrder.NativeNCHW) { continue; } Layer layer; bool found = ModelAnalyzer.FindLayerByName(model, name, out layer); if (IsLayerChangingLayoutToNHWC(layer, shapesByName, ranksByName)) { // NCHW -> T -> NHWC if (((deducedChannelOrder == ChannelsOrder.NCHW) || (deducedChannelOrder == ChannelsOrder.NativeNCHW)) && (flowDirection == FlowDirection.Downstream)) { deducedChannelOrder = ChannelsOrder.TransposeToNHWC; } // NCHW <- T <- NHWC else if ((deducedChannelOrder == ChannelsOrder.NHWC) && (flowDirection == FlowDirection.Upstream)) { deducedChannelOrder = ChannelsOrder.TransposeToNHWC; } } else if (IsLayerChangingLayoutToNCHW(layer, shapesByName, ranksByName)) { // NHWC -> T-1 -> NCHW if ((deducedChannelOrder == ChannelsOrder.NHWC) && (flowDirection == FlowDirection.Downstream)) { deducedChannelOrder = ChannelsOrder.TransposeToNCHW; } // NHWC <- T-1 <- NCHW else if (((deducedChannelOrder == ChannelsOrder.NCHW) || (deducedChannelOrder == ChannelsOrder.NativeNCHW)) && (flowDirection == FlowDirection.Upstream)) { deducedChannelOrder = ChannelsOrder.TransposeToNCHW; } } if ((deducedChannelOrder == ChannelsOrder.TransposeToNCHW || deducedChannelOrder == ChannelsOrder.TransposeToNHWC) && (deducedChannelOrder == layerChannelOrder[name])) { continue; } layerChannelOrder[name] = deducedChannelOrder; foreach (var input in layer.inputs) { if (deducedChannelOrder == ChannelsOrder.TransposeToNCHW) { layersToInferLayout.Enqueue((input, ChannelsOrder.NHWC, FlowDirection.Upstream)); } else if (deducedChannelOrder == ChannelsOrder.TransposeToNHWC) { layersToInferLayout.Enqueue((input, ChannelsOrder.NCHW, FlowDirection.Upstream)); } else { layersToInferLayout.Enqueue((input, deducedChannelOrder, FlowDirection.Upstream)); } } var outputs = ModelAnalyzer.FindLayerOutputs(model, layer.name); foreach (var output in outputs) { if (deducedChannelOrder == ChannelsOrder.TransposeToNCHW) { layersToInferLayout.Enqueue((output, ChannelsOrder.NCHW, FlowDirection.Downstream)); } else if (deducedChannelOrder == ChannelsOrder.TransposeToNHWC) { layersToInferLayout.Enqueue((output, ChannelsOrder.NHWC, FlowDirection.Downstream)); } else { layersToInferLayout.Enqueue((output, deducedChannelOrder, FlowDirection.Downstream)); } } } bool modelExportedASNHWC = false; foreach (string key in layerChannelOrder.Keys.ToList()) { var value = layerChannelOrder[key]; if (value == ChannelsOrder.NativeNCHW) { layerChannelOrder[key] = ChannelsOrder.NCHW; } if (value == ChannelsOrder.NHWC) { modelExportedASNHWC = true; } } return(modelExportedASNHWC); }