private ONNXTensor BakeNodeIntoConstant(Action <ModelBuilder, ONNXNodeWrapper> opImportAction, ONNXNodeWrapper node) { var model = new Model(); var net = new ModelBuilder(model); // add all inputs as constants Debug.Assert(node.AreAllInputsConst); for (var i = 0; i < node.InputCount; ++i) { var assumeOnnxLayout = i == 0 ? "NCHW" : "CONST"; var input = node.Inputs[i]; net.Const(input, constantTensors[input].ToBarracuda(assumeOnnxLayout)); } // add node that we are going to bake into the constant opImportAction(net, node); // bake var noInputs = new Dictionary <string, Tensor>(); var useCPUforBaking = WorkerFactory.Device.CPU; var worker = WorkerFactory.CreateWorker(model, useCPUforBaking); var result = worker.ExecuteAndWaitForCompletion(noInputs); // convert from Barracuda back into ONNX layout var onnxData = ONNXTensor.Permute(result, new int[] { 0, 3, 1, 2 }); // NHWC -> NCHW var onnxShape = onnxData.shape.ToArray().Select(x => (long)x).ToArray(); return(new ONNXTensor(onnxData, onnxShape).SqueezeAll()); }
// Transpose channels first to channels last data in MatMul/GEMM weight tensor internal static Tensor SwapSpatialDimensionsAndFeaturesInMatMulWeights(Tensor weights, int featureCount) { Debug.Assert(featureCount <= weights.flatHeight); if (featureCount != weights.flatHeight) { var shape = weights.shape; var implicitSpatialDimensionsInWeights = shape.flatHeight / featureCount; Debug.Assert(shape.flatHeight % featureCount == 0); // reshape: C__K -> CHWK weights = weights.Reshape( new TensorShape(featureCount, implicitSpatialDimensionsInWeights, 1, shape.channels)); // permute: CHWK -> HWCK weights = ONNXTensor.Permute(weights, new int[] { 1, 0, 2, 3 }); // @TODO: use Permute(, onnxLayout:CHWK) // reshape: HWCK -> C__K weights = weights.Reshape(shape); } return(weights); }
private void Output(string name, ONNXTensor onnxTensor) { m_ModelTensors.AddVariable(name, onnxTensor); }
private void Const(string name, ONNXTensor onnxTensor) { m_ModelTensors.AddConstant(name, onnxTensor); }
// Helpers to keep track of model tensors private void Const(ONNXNodeWrapper node, ONNXTensor onnxTensor) { m_ModelTensors.AddConstant(node.Name, onnxTensor); }
// ONNX parser declaration // Add operator handlers here public ONNXModelImporter() { // TODO: setup m_NodeImporters via initializer list // TODO: simplify code to avoid passing node.Name over and over again Add("Constant", (net, node) => { node.UnsupportedAttribute("sparse_value"); Const(node, node.ValueAsTensor); }); Add("Reshape", (net, node) => { long[] onnxShape; if (node.InputCount > 1) // Reshape-5 { onnxShape = node.Input1Constant(onnxLayout: "C", name: "shape").AsLongs(); } else // Reshape-1 { onnxShape = node.Shape; } if (node.IsInput0Const) { // reshape constant source tensor and store it as the new constant var reshapedTensor = constantTensors[node.Input0].Reshape(onnxShape); Const(node, reshapedTensor); } else { var symbolicShape = ONNXLayout.ConvertSymbolicShapeToBarracuda(onnxShape, "NCHW"); if (patchReshapeToSupportBatchSize) { symbolicShape[0] = 0; // force keep batch size } net.Reshape(node.Name, node.Input0, symbolicShape); Output(node, rank: symbolicShape.Length); } }); Add("Shape", (net, node) => { // @TODO: dynamic implementation that would return real shape during execution of the model if (!node.IsInput0Const) { throw new OnnxLayerImportException( $"Currently only constant inputs for node of type {node.OperatorType} are supported. Instead input of {node.Name} is pointing to non-constant node {node.Input0}."); } var shapeTensor = new ONNXTensor( data: new Tensor(4, 1, new[] { -1f, (float)node.Input0Rank, -1f, -1f }), onnxShape: new [] { 4L }); Const(node, shapeTensor); Output(node, rank: 1); }); Add("Unsqueeze", (net, node) => { if (node.IsInput0Const) { var unsqueezed = constantTensors[node.Input0].Unsqueeze(node.Axes); Const(node, unsqueezed); } else { // NOTE: axis=0 or 1 will require Transpose between channels and other spatial dimensions when converted to Barracuda layout. // As we have different layouts between ONNX and Barracuda, Unsqueeze might require actual Transpose not just Reshape! // ONNX pseudocode here: // a = Tensor [2, 10] # NC -> barracuda N11C // b = Unsqueeze(a, axis=0) // # b is now Tensor [1, 2, 10] # NCHW -> barrada NHWC // Because ONNX is NCHW, but generally hell knows what goes where and Barracuda is strict NHWC. We end up with: // `a` would be [2, 1, 1, 10], but `b` would have to be [1, 10, 1, 2]. Note the actual data swap in channels! bool mightNeedTranspose = node.Axes.Any(axis => axis <= 1); if (mightNeedTranspose) { Warn(net, node, $"Unsqeeze on axis next to batch dimension for non-constant tensors might lead to unexpected results."); } net.Identity(node.Name, node.Input0); } }); Add("Squeeze", (net, node) => { if (node.IsInput0Const) { var squeezed = constantTensors[node.Input0].Squeeze(node.Axes); Const(node, squeezed); } else { // See Unsqueeze above for explanation bool mightNeedTranspose = node.Axes.Any(axis => axis <= 1); if (mightNeedTranspose) { Warn(net, node, $"Sqeeze on any axis next to batch dimension for non-constant tensors might lead to unexpected results."); } net.Identity(node.Name, node.Input0); } }); Add("Flatten", (net, node) => { node.UnsupportedAttribute("axis", 1); net.Flatten(node.Name, node.Input0); Output(node, rank: 2); }); Add("Concat", (net, node) => { int axis = node.AxisOptional(0); // TODO: write dedicated ONNXTensor.Concat() so that output shape is exact to ONNX // if (node.AreAllInputsConst) Const(node, ONNXTensor.Concat(node.Inputs.Select(i => constantTensors[i]), axis)); axis = ONNXLayout.ConvertAxisToBarracuda(axis, onnxRank: node.Input0Rank, onnxLayout: "NCHW"); net.Concat(node.Name, node.Inputs, axis); bool lastAxis = (axis == -1 || axis == node.Input0Rank - 1); // last axis in Barracuda is feature axis if (lastAxis) { var featuresConcatenated = node.Inputs.Sum(i => variableTensors[i].features); Output(node, features: featuresConcatenated); } }); Add("Slice", (net, node) => { int[] starts, ends, axes, steps; if (node.InputCount > 1) // Slice-10 { var constStarts = node.Input1Constant(onnxLayout: "C", name: "starts"); var constEnds = node.Input2Constant(onnxLayout: "C", name: "ends"); var defaultAxes = new Tensor(constStarts.shape, Enumerable.Range(0, constStarts.length).Select(v => (float)v).ToArray()); var constAxes = node.Input3ConstantOptional(defaultAxes, onnxLayout: "C", name: "axes"); var constSteps = node.Input4ConstantOptional(constStarts.shape, 1.0f, onnxLayout: "C", name: "steps"); starts = constStarts.AsInts(); ends = constEnds.AsInts(); axes = constAxes.AsInts(); steps = constSteps.AsInts(); } else // Slice-1 { starts = node.Starts; ends = node.Ends; axes = node.AxesOptional(Enumerable.Range(0, starts.Length).ToArray()); steps = Enumerable.Repeat(1, starts.Length).ToArray(); } Debug.Assert(starts.Length == ends.Length); var onnxRank = node.Input0Rank; var onnxLast = (long)int.MaxValue; var onnxStarts = Enumerable.Repeat(0L, onnxRank).ToArray(); var onnxEnds = Enumerable.Repeat(onnxLast, onnxRank).ToArray(); // by default copy the whole axis till the end var onnxSteps = Enumerable.Repeat(1L, onnxRank).ToArray(); // NOTE: begin=0, end=0, stride=1 <= full range from existing axis // begin=0, end=inf,stride=1 <= full range from existing axis // begin=0, end=X, stride=1 <= full range from existing axis, if X==last element on this axis // begin=0, end=0, stride=0 <= new axis OR shrink axis to single 1st element // begin=N, end=N, stride=0 <= shrink axis to single Nth element // These notes are copied from TensorExtensions.ApplyStridedSlice(...) for (int i = 0; i < axes.Length; ++i) { var axis = axes[i]; if (axis < 0) { axis += onnxRank; } axis = Math.Min(Math.Max(axis, 0), onnxRank); onnxStarts[axis] = starts[i]; onnxEnds[axis] = ends[i]; onnxSteps[axis] = steps[i]; } if (node.IsInput0Const) { var slicedTensor = constantTensors[node.Input0].Slice( starts: onnxStarts.Select(x => (int)x).ToArray(), ends: onnxEnds.Select(x => (int)x).ToArray(), steps: onnxSteps.Select(x => (int)x).ToArray()); Const(node, slicedTensor); } else { net.StridedSlice(node.Name, node.Input0, starts: ONNXLayout.ConvertSymbolicShapeToBarracuda(onnxStarts, onnxLayout: "NCHW"), ends: ONNXLayout.ConvertSymbolicShapeToBarracuda(onnxEnds, onnxLayout: "NCHW"), strides: ONNXLayout.ConvertSymbolicShapeToBarracuda(onnxSteps, onnxLayout: "NCHW")); } }); Add("Gather", (net, node) => { int axis = node.AxisOptional(0); if (node.IsInput0Const) { var indices = node.Input1Constant(onnxLayout: "C", name: "indices").AsInts(); var gatheredTensor = constantTensors[node.Input0].Gather(axis, indices); Const(node, gatheredTensor); } else { axis = ONNXLayout.ConvertAxisToBarracuda(axis, onnxRank: node.Input0Rank, onnxLayout: "NCHW"); net.Gather(node.Name, node.Input0, node.Input1, axis); } }); Add("OneHot", (net, node) => { node.UnsupportedAttribute("axis", -1); var defaultOffOn = new Tensor(1, 1, 1, 2, new float[] { 0, 1 }); var depth = (int)node.Input1Constant(onnxLayout: "C", name: "depth")[0]; var offon = node.Input2ConstantOptional(defaultOffOn, onnxLayout: "C", name: "values"); net.OneHot(node.Name, node.Input0, depth, (int)offon[1], (int)offon[0]); }); // Activation ops Add("Relu", (net, node) => { net.Relu(node.Name, node.Input0); }); Add("Softmax", (net, node) => { net.Softmax(node.Name, node.Input0); node.UnsupportedAttribute("axis", 1); }); Add("Tanh", (net, node) => { net.Tanh(node.Name, node.Input0); }); Add("Sqrt", (net, node) => { net.Sqrt(node.Name, node.Input0); }); Add("Sigmoid", (net, node) => { net.Sigmoid(node.Name, node.Input0); }); Add("Elu", (net, node) => { net.Elu(node.Name, node.Input0, node.AlphaOptional(1f)); }); Add("LeakyRelu", (net, node) => { net.LeakyRelu(node.Name, node.Input0, node.AlphaOptional(0.01f)); }); Add("Selu", (net, node) => { net.Selu(node.Name, node.Input0, node.AlphaOptional(1.67326f), node.GammaOptional(1.0507f)); }); Add("Swish", (net, node) => { net.Swish(node.Name, node.Input0); }); Add("PRelu", (net, node) => { net.PRelu(node.Name, node.Input0, node.Input1); }); Add("LogSoftmax", (net, node) => { net.LogSoftmax(node.Name, node.Input0); node.UnsupportedAttribute("axis", 1); }); // TODO: Add("Hardmax", (net, node) => { net.Hardmax(node.Name, node.Input0); node.UnsupportedAttribute("axis", 1); }); // TODO: Add("Softplus", (net, node) => { net.Softplus(node.Name, node.Input0); }); // TODO: Add("Softsign", (net, node) => { net.Softsign(node.Name, node.Input0); }); // TODO: Add("HardSigmoid", (net, node) => { net.HardSigmoid(node.Name, node.Input0, node.AlphaOptional(0.2f), node.BetaOptional(0.5f)); }); Add("Exp", (net, node) => { net.Exp(node.Name, node.Input0); }); Add("Log", (net, node) => { net.Log(node.Name, node.Input0); }); Add("Reciprocal", (net, node) => { net.Reciprocal(node.Name, node.Input0); }); Add("Abs", (net, node) => { net.Abs(node.Name, node.Input0); }); Add("Neg", (net, node) => { net.Neg(node.Name, node.Input0); }); Add("Ceil", (net, node) => { net.Ceil(node.Name, node.Input0); }); Add("Floor", (net, node) => { net.Floor(node.Name, node.Input0); }); Add("Round", (net, node) => { net.Round(node.Name, node.Input0); }); Add("Clip", (net, node) => { net.Clip(node.Name, node.Input0, node.MinOptional(float.MinValue), node.MaxOptional(float.MaxValue)); }); // Broadcast ops Add("Add", (net, node) => { net.Add(node.Name, node.Inputs); }); Add("Sum", (net, node) => { net.Add(node.Name, node.Inputs); }); // Sum is implemented via Add Add("Sub", (net, node) => { net.Sub(node.Name, node.Inputs); }); Add("Mul", (net, node) => { net.Mul(node.Name, node.Inputs); }); Add("Div", (net, node) => { net.Div(node.Name, node.Inputs); }); Add("Pow", (net, node) => { net.Pow(node.Name, node.Inputs); }); Add("Min", (net, node) => { net.Min(node.Name, node.Inputs); }); Add("Max", (net, node) => { net.Max(node.Name, node.Inputs); }); Add("Mean", (net, node) => { net.Mean(node.Name, node.Inputs); }); // Logical ops Add("Greater", (net, node) => { net.Greater(node.Name, node.Input0, node.Input1); }); Add("Less", (net, node) => { net.Less(node.Name, node.Input0, node.Input1); }); Add("Equal", (net, node) => { net.Equal(node.Name, node.Input0, node.Input1); }); Add("Or", (net, node) => { net.LogicalOr(node.Name, node.Input0, node.Input1); }); Add("And", (net, node) => { net.LogicalAnd(node.Name, node.Input0, node.Input1); }); Add("Not", (net, node) => { net.LogicalNot(node.Name, node.Input0); }); Add("Xor", (net, node) => { net.LogicalXor(node.Name, node.Input0, node.Input1); }); // Padding ops Add("Pad", (net, node) => { var mode = node.GetOptionalString("mode", "constant"); switch (mode) { case "constant": net.Border2D(node.Name, node.Input0, node.Pads, node.GetOptionalFloat("value", 0.0f)); break; case "reflect": net.Pad2DReflect(node.Name, node.Input0, node.Pads); break; case "edge": net.Pad2DEdge(node.Name, node.Input0, node.Pads); break; } }); // Pooling ops Add("AveragePool", (net, node) => { node.UnsupportedAttribute("ceil_mode", 0); node.UnsupportedAttribute("count_include_pad", 0); net.AvgPool2D(node.Name, node.Input0, node.KernelShape, node.Strides, node.Pads); }); Add("MaxPool", (net, node) => { node.UnsupportedAttribute("ceil_mode", 0); node.UnsupportedAttribute("dilations", new[] { 1, 1 }); node.UnsupportedAttribute("storage_order", 0); net.MaxPool2D(node.Name, node.Input0, node.KernelShape, node.Strides, node.Pads); }); Add("GlobalAveragePool", (net, node) => { net.GlobalAvgPool2D(node.Name, node.Input0); }); Add("GlobalMaxPool", (net, node) => { net.GlobalMaxPool2D(node.Name, node.Input0); }); Add("Upsample", (net, node) => { node.UnsupportedAttribute("mode", "nearest"); float[] scales; if (node.InputCount > 1) { scales = node.Input1Constant(onnxLayout: "C", name: "scales").AsFloats(); } else { scales = node.Scales; } if (scales.Length < 2 || scales.Length > 5) { throw new OnnxLayerImportException( $"Input scales of unsupported length {scales.Length} in {node.Name} ot fype {node.OperatorType}."); } if ((scales[0] != 1) || (scales[1] != 1)) { Warn(net, node, $"Unsupported scaling, only H and W scaling are supported. Value will be ignored and defaulted to 1."); } // skip NC from onnx NCHW layout scales = scales.Skip(2).ToArray(); if (scales.Length == 1) // append default H, if 1D { scales = new[] { 1f, scales[0] } } ; Resize2D(net, node, scales); }); Add("Resize", (net, node) => { node.UnsupportedAttribute("mode", "nearest"); float[] scales; if (node.InputCount > 2) // Resize-11 { node.UnsupportedAttribute("coordinate_transformation_mode", "half_pixel"); node.UnsupportedAttribute("cubic_coeff_a", -0.75f); node.UnsupportedAttribute("exclude_outside", 0); node.UnsupportedAttribute("extrapolation_value", 0f); node.UnsupportedAttribute("nearest_mode", "round_prefer_floor"); // Inputs (3 - 4) // X : T1 // roi : T2, It only takes effect when coordinate_transformation_mode is "tf_crop_and_resize" // scales : tensor(float) // sizes (optional) : tensor(int64) // TODO: cropping via roi input // TODO: support sizes scales = node.Input2Constant(onnxLayout: "C", name: "scales").AsFloats(); } else // Resize-10 { scales = node.Input1Constant(onnxLayout: "C", name: "scales").AsFloats(); } if (scales.Length < 2 || scales.Length > 5) { throw new OnnxLayerImportException( $"Input scales of unsupported length {scales.Length} in {node.Name} ot fype {node.OperatorType}."); } if ((scales[0] != 1) || (scales[1] != 1)) { Warn(net, node, $"Unsupported scaling, only H and W scaling are supported. Value will be ignored and defaulted to 1."); } // skip NC from onnx NCHW layout scales = scales.Skip(2).ToArray(); Resize2D(net, node, scales); }); Add("Transpose", (net, node) => { if (!node.IsInput0Const) { throw new OnnxLayerImportException( $"Currently only constant inputs for node of type {node.OperatorType} are supported. Instead input of {node.Name} is pointing to non-constant node {node.Input0}."); } var permutations = node.GetOptionalIntArray("perm", new [] { 3, 2, 1, 0 }); var transposedTensor = constantTensors[node.Input0].Permute(permutations); Const(node, transposedTensor); }); // Tensor ops Add("Gemm", (net, node) => { node.UnsupportedAttribute("alpha", 1.0f); node.UnsupportedAttribute("beta", 1.0f); node.UnsupportedAttribute("transA", 0); var onnxLayout = node.TransBOptional() ? "KC" : "CK"; var weights = node.Input1Constant(onnxLayout, name: "B"); var biases = node.Input2ConstantOptional(Bias(weights.shape), 0.0f, onnxLayout: "C", name: "C"); // Change data layout from "channels first" to "channels last" weights = SwapSpatialDimensionsAndFeaturesInMatMulWeights(weights, node.Input0Features); net.Dense(node.Name, node.Input0, weights, biases); Output(node, features: weights.channels, rank: 2); // Gemm forces flatten of the input to rank 2 }); Add("MatMul", (net, node) => { var weights = node.Input1Constant(onnxLayout: "CK", name: "B"); var biases = node.DefaultTensor(Bias(weights.shape), 0.0f); // Change data layout from "channels first" to "channels last" weights = SwapSpatialDimensionsAndFeaturesInMatMulWeights(weights, node.Input0Features); net.Dense(node.Name, node.Input0, weights, biases); Output(node, features: weights.channels, rank: 2); // MatMul forces flatten of the input to rank 2 }); Add("Conv", (net, node) => { node.UnsupportedAttribute("dilations", new[] { 1, 1 }); node.IgnoredAttribute("kernel_shape", "Kernel shape is derived from K tensor weights instead"); var kernels = node.Input1Constant(onnxLayout: "KCHW", name: "W"); var biases = node.Input2ConstantOptional(Bias(kernels.shape), 0.0f, onnxLayout: "C", name: "B"); if (node.GroupOptional() > 1) { net.DepthwiseConv2D(node.Name, node.Input0, node.Strides, node.Pads, kernels, biases); } else { net.Conv2D(node.Name, node.Input0, node.Strides, node.Pads, kernels, biases); } Output(node, features: kernels.channels); }); Add("ConvTranspose", (net, node) => { node.UnsupportedAttribute("dilations", new[] { 1, 1 }); node.UnsupportedAttribute("group", 1); node.UnsupportedAttribute("output_shape", new int[0]); node.IgnoredAttribute("kernel_shape", "Kernel shape is derived from K tensor weights instead"); var kernels = node.Input1Constant(onnxLayout: "CKHW", name: "W"); var biases = node.Input2ConstantOptional(Bias(kernels.shape), 0.0f, onnxLayout: "C", name: "B"); net.Conv2DTrans(node.Name, node.Input0, node.Strides, node.Pads, node.OutputPadding, kernels, biases); Output(node, features: kernels.channels); }); Add("BatchNormalization", (net, node) => { var variance = node.Input4Constant(onnxLayout: "C", name: "var"); var scale = node.Input1ConstantOptional(variance.shape, 1.0f, onnxLayout: "C", name: "scale"); var bias = node.Input2ConstantOptional(variance.shape, 0.0f, onnxLayout: "C", name: "B"); var mean = node.Input3ConstantOptional(variance.shape, 0.0f, onnxLayout: "C", name: "mean"); var fusedData = FuseBatchNormWeights(scale, bias, mean, variance, node.EpsilonOptional()); net.ScaleBias(node.Name, node.Input0, fusedData.Item1, fusedData.Item2); }); Add("InstanceNormalization", (net, node) => { var scale = node.Input1Constant(onnxLayout: "C", name: "scale"); var bias = node.Input2ConstantOptional(scale.shape, 0.0f, onnxLayout: "C", name: "B"); net.Normalization(node.Name, node.Input0, scale, bias, node.EpsilonOptional()); }); // random ops Add("RandomNormal", (net, node) => { float mean = node.GetOptionalFloat("mean", 1.0f); float scale = node.GetOptionalFloat("scale", 0.0f); float seed = node.GetOptionalFloat("seed", 0.0f); int[] shape = node.GetRequiredIntArray("shape"); net.RandomNormal(node.Name, mean, scale, seed, shape); }); Add("RandomNormalLike", (net, node) => { float mean = node.GetOptionalFloat("mean", 1.0f); float scale = node.GetOptionalFloat("scale", 0.0f); float seed = node.GetOptionalFloat("seed", 0.0f); net.RandomNormal(node.Name, mean, scale, seed, node.Input0); }); Add("RandomUniform", (net, node) => { float high = node.GetOptionalFloat("high", 1.0f); float low = node.GetOptionalFloat("low", 0.0f); float seed = node.GetOptionalFloat("seed", 0.0f); int[] shape = node.GetRequiredIntArray("shape"); net.RandomUniform(node.Name, low, high, seed, shape); }); Add("RandomUniformLike", (net, node) => { float high = node.GetOptionalFloat("high", 1.0f); float low = node.GetOptionalFloat("low", 0.0f); float seed = node.GetOptionalFloat("seed", 0.0f); net.RandomUniform(node.Name, low, high, seed, node.Input0); }); Add("ReduceMax", (net, node) => { Reduce(net, node, Layer.Type.ReduceMax); }); Add("ReduceMean", (net, node) => { Reduce(net, node, Layer.Type.ReduceMean); }); Add("ReduceMin", (net, node) => { Reduce(net, node, Layer.Type.ReduceMin); }); Add("ReduceProd", (net, node) => { Reduce(net, node, Layer.Type.ReduceProd); }); Add("ReduceSum", (net, node) => { Reduce(net, node, Layer.Type.ReduceSum); }); // Ignore, noop during inference Add("Identity", (net, node) => { net.Identity(node.Name, node.Input0); }); Add("Cast", (net, node) => { net.Identity(node.Name, node.Input0); }); Add("Dropout", (net, node) => { net.Identity(node.Name, node.Input0); }); }