Exemplo n.º 1
0
        private ONNXTensor BakeNodeIntoConstant(Action <ModelBuilder, ONNXNodeWrapper> opImportAction, ONNXNodeWrapper node)
        {
            var model = new Model();
            var net   = new ModelBuilder(model);

            // add all inputs as constants
            Debug.Assert(node.AreAllInputsConst);
            for (var i = 0; i < node.InputCount; ++i)
            {
                var assumeOnnxLayout = i == 0 ? "NCHW" : "CONST";
                var input            = node.Inputs[i];
                net.Const(input,
                          constantTensors[input].ToBarracuda(assumeOnnxLayout));
            }

            // add node that we are going to bake into the constant
            opImportAction(net, node);

            // bake
            var noInputs = new Dictionary <string, Tensor>();

            var useCPUforBaking = WorkerFactory.Device.CPU;
            var worker          = WorkerFactory.CreateWorker(model, useCPUforBaking);
            var result          = worker.ExecuteAndWaitForCompletion(noInputs);

            // convert from Barracuda back into ONNX layout
            var onnxData  = ONNXTensor.Permute(result, new int[] { 0, 3, 1, 2 }); // NHWC -> NCHW
            var onnxShape = onnxData.shape.ToArray().Select(x => (long)x).ToArray();

            return(new ONNXTensor(onnxData, onnxShape).SqueezeAll());
        }
Exemplo n.º 2
0
 // Transpose channels first to channels last data in MatMul/GEMM weight tensor
 internal static Tensor SwapSpatialDimensionsAndFeaturesInMatMulWeights(Tensor weights, int featureCount)
 {
     Debug.Assert(featureCount <= weights.flatHeight);
     if (featureCount != weights.flatHeight)
     {
         var shape = weights.shape;
         var implicitSpatialDimensionsInWeights = shape.flatHeight / featureCount;
         Debug.Assert(shape.flatHeight % featureCount == 0);
         // reshape: C__K -> CHWK
         weights = weights.Reshape(
             new TensorShape(featureCount, implicitSpatialDimensionsInWeights, 1, shape.channels));
         // permute: CHWK -> HWCK
         weights = ONNXTensor.Permute(weights, new int[] { 1, 0, 2, 3 }); // @TODO: use Permute(, onnxLayout:CHWK)
         // reshape: HWCK -> C__K
         weights = weights.Reshape(shape);
     }
     return(weights);
 }