private ONNXTensor BakeNodeIntoConstant(Action <ModelBuilder, ONNXNodeWrapper> opImportAction, ONNXNodeWrapper node) { var model = new Model(); var net = new ModelBuilder(model); // add all inputs as constants Debug.Assert(node.AreAllInputsConst); for (var i = 0; i < node.InputCount; ++i) { var assumeOnnxLayout = i == 0 ? "NCHW" : "CONST"; var input = node.Inputs[i]; net.Const(input, constantTensors[input].ToBarracuda(assumeOnnxLayout)); } // add node that we are going to bake into the constant opImportAction(net, node); // bake var noInputs = new Dictionary <string, Tensor>(); var useCPUforBaking = WorkerFactory.Device.CPU; var worker = WorkerFactory.CreateWorker(model, useCPUforBaking); var result = worker.ExecuteAndWaitForCompletion(noInputs); // convert from Barracuda back into ONNX layout var onnxData = ONNXTensor.Permute(result, new int[] { 0, 3, 1, 2 }); // NHWC -> NCHW var onnxShape = onnxData.shape.ToArray().Select(x => (long)x).ToArray(); return(new ONNXTensor(onnxData, onnxShape).SqueezeAll()); }
// Transpose channels first to channels last data in MatMul/GEMM weight tensor internal static Tensor SwapSpatialDimensionsAndFeaturesInMatMulWeights(Tensor weights, int featureCount) { Debug.Assert(featureCount <= weights.flatHeight); if (featureCount != weights.flatHeight) { var shape = weights.shape; var implicitSpatialDimensionsInWeights = shape.flatHeight / featureCount; Debug.Assert(shape.flatHeight % featureCount == 0); // reshape: C__K -> CHWK weights = weights.Reshape( new TensorShape(featureCount, implicitSpatialDimensionsInWeights, 1, shape.channels)); // permute: CHWK -> HWCK weights = ONNXTensor.Permute(weights, new int[] { 1, 0, 2, 3 }); // @TODO: use Permute(, onnxLayout:CHWK) // reshape: HWCK -> C__K weights = weights.Reshape(shape); } return(weights); }