Exemplo n.º 1
0
        internal static NamedOnnxValue LoadTensorFromFilePb(string filename, IReadOnlyDictionary <string, NodeMetadata> nodeMetaDict)
        {
            //Set buffer size to 4MB
            int readBufferSize = 4194304;

            Onnx.TensorProto tensor = null;
            using (var file = new FileStream(filename, FileMode.Open, FileAccess.Read, FileShare.Read, readBufferSize))
            {
                tensor = Onnx.TensorProto.Parser.ParseFrom(file);
            }

            return(LoadTensorPb(tensor, nodeMetaDict));
        }
Exemplo n.º 2
0
        internal static NamedOnnxValue LoadTensorFromEmbeddedResourcePb(string path, IReadOnlyDictionary <string, NodeMetadata> nodeMetaDict)
        {
            Onnx.TensorProto tensor = null;

            var assembly = typeof(TestDataLoader).Assembly;

            using (Stream stream = assembly.GetManifestResourceStream($"{assembly.GetName().Name}.TestData.{path}"))
            {
                tensor = Onnx.TensorProto.Parser.ParseFrom(stream);
            }

            return(LoadTensorPb(tensor, nodeMetaDict));
        }
Exemplo n.º 3
0
        static List <NamedOnnxValue> LoadTestDataFromProtobuf(string testDataPath, IReadOnlyDictionary <string, NodeMetadata> inputMeta)
        {
            var container = new List <NamedOnnxValue>();

            var filenames = from filename in Directory.EnumerateFiles(testDataPath, "input_*.pb") select filename;

            foreach (var filename in filenames)
            {
                Onnx.TensorProto tensorProto = null;
                using (var inputFile = File.OpenRead(filename))
                {
                    tensorProto = Onnx.TensorProto.Parser.ParseFrom(inputFile);
                }

                var namedOnnxValue = CreateNamedOnnxValueFromTensorProto(tensorProto, inputMeta);
                container.Add(namedOnnxValue);
            }

            return(container);
        }
Exemplo n.º 4
0
        static NamedOnnxValue LoadTensorPb(Onnx.TensorProto tensor, IReadOnlyDictionary <string, NodeMetadata> nodeMetaDict)
        {
            Type tensorElemType = null;
            int  width          = 0;

            GetTypeAndWidth((Tensors.TensorElementType)tensor.DataType, out tensorElemType, out width);
            var intDims = new int[tensor.Dims.Count];

            for (int i = 0; i < tensor.Dims.Count; i++)
            {
                intDims[i] = (int)tensor.Dims[i];
            }

            NodeMetadata nodeMeta = null;
            string       nodeName = string.Empty;

            if (nodeMetaDict.Count == 1)
            {
                nodeMeta = nodeMetaDict.Values.First();
                nodeName = nodeMetaDict.Keys.First(); // valid for single node input
            }
            else if (nodeMetaDict.Count > 1)
            {
                if (tensor.Name.Length > 0)
                {
                    nodeMeta = nodeMetaDict[tensor.Name];
                    nodeName = tensor.Name;
                }
                else
                {
                    bool matchfound = false;
                    // try to find from matching type and shape
                    foreach (var key in nodeMetaDict.Keys)
                    {
                        var meta = nodeMetaDict[key];
                        if (tensorElemType == meta.ElementType && tensor.Dims.Count == meta.Dimensions.Length)
                        {
                            int i = 0;
                            for (; i < meta.Dimensions.Length; i++)
                            {
                                if (meta.Dimensions[i] != -1 && meta.Dimensions[i] != intDims[i])
                                {
                                    break;
                                }
                            }
                            if (i >= meta.Dimensions.Length)
                            {
                                matchfound = true;
                                nodeMeta   = meta;
                                nodeName   = key;
                                break;
                            }
                        }
                    }
                    if (!matchfound)
                    {
                        // throw error
                        throw new Exception($"No Matching Tensor found in InputOutputMetadata corresponding to the serialized tensor specified");
                    }
                }
            }
            else
            {
                // throw error
                throw new Exception($"While reading the serliazed tensor specified, metaDataDict has 0 elements");
            }

            if (!nodeMeta.IsTensor)
            {
                throw new Exception("LoadTensorFromFile can load Tensor types only");
            }

            if (tensorElemType != nodeMeta.ElementType)
            {
                throw new Exception($"{nameof(tensorElemType)} is expected to be equal to {nameof(nodeMeta.ElementType)}");
            }

            if (nodeMeta.Dimensions.Length != tensor.Dims.Count)
            {
                throw new Exception($"{nameof(nodeMeta.Dimensions.Length)} is expected to be equal to {nameof(tensor.Dims.Count)}");
            }

            for (int i = 0; i < nodeMeta.Dimensions.Length; i++)
            {
                if ((nodeMeta.Dimensions[i] != -1) && (nodeMeta.Dimensions[i] != intDims[i]))
                {
                    throw new Exception($"{nameof(nodeMeta.Dimensions)}[{i}] is expected to either be -1 or {nameof(intDims)}[{i}]");
                }
            }

            if (nodeMeta.ElementType == typeof(float))
            {
                return(CreateNamedOnnxValueFromRawData <float>(nodeName, tensor.RawData.ToArray(), sizeof(float), intDims));
            }
            else if (nodeMeta.ElementType == typeof(double))
            {
                return(CreateNamedOnnxValueFromRawData <double>(nodeName, tensor.RawData.ToArray(), sizeof(double), intDims));
            }
            else if (nodeMeta.ElementType == typeof(int))
            {
                return(CreateNamedOnnxValueFromRawData <int>(nodeName, tensor.RawData.ToArray(), sizeof(int), intDims));
            }
            else if (nodeMeta.ElementType == typeof(uint))
            {
                return(CreateNamedOnnxValueFromRawData <uint>(nodeName, tensor.RawData.ToArray(), sizeof(uint), intDims));
            }
            else if (nodeMeta.ElementType == typeof(long))
            {
                return(CreateNamedOnnxValueFromRawData <long>(nodeName, tensor.RawData.ToArray(), sizeof(long), intDims));
            }
            else if (nodeMeta.ElementType == typeof(ulong))
            {
                return(CreateNamedOnnxValueFromRawData <ulong>(nodeName, tensor.RawData.ToArray(), sizeof(ulong), intDims));
            }
            else if (nodeMeta.ElementType == typeof(short))
            {
                return(CreateNamedOnnxValueFromRawData <short>(nodeName, tensor.RawData.ToArray(), sizeof(short), intDims));
            }
            else if (nodeMeta.ElementType == typeof(ushort))
            {
                return(CreateNamedOnnxValueFromRawData <ushort>(nodeName, tensor.RawData.ToArray(), sizeof(ushort), intDims));
            }
            else if (nodeMeta.ElementType == typeof(byte))
            {
                return(CreateNamedOnnxValueFromRawData <byte>(nodeName, tensor.RawData.ToArray(), sizeof(byte), intDims));
            }
            else if (nodeMeta.ElementType == typeof(bool))
            {
                return(CreateNamedOnnxValueFromRawData <bool>(nodeName, tensor.RawData.ToArray(), sizeof(bool), intDims));
            }
            else if (nodeMeta.ElementType == typeof(Float16))
            {
                return(CreateNamedOnnxValueFromRawData <Float16>(nodeName, tensor.RawData.ToArray(), sizeof(ushort), intDims));
            }
            else if (nodeMeta.ElementType == typeof(BFloat16))
            {
                return(CreateNamedOnnxValueFromRawData <BFloat16>(nodeName, tensor.RawData.ToArray(), sizeof(ushort), intDims));
            }
            else
            {
                //TODO: Add support for remaining types
                throw new Exception($"Tensors of type {nameof(nodeMeta.ElementType)} not currently supporte in the LoadTensorFromEmbeddedResource");
            }
        }
Exemplo n.º 5
0
        static NamedOnnxValue CreateNamedOnnxValueFromTensorProto(Onnx.TensorProto tensorProto, IReadOnlyDictionary <string, NodeMetadata> inputMeta)
        {
            Type tensorElemType = null;
            int  elemWidth      = 0;

            GetElementTypeAndWidth((TensorElementType)tensorProto.DataType, out tensorElemType, out elemWidth);
            var dims = tensorProto.Dims.ToList().ConvertAll(x => (int)x);

            NodeMetadata nodeMeta = null;

            if (!inputMeta.TryGetValue(tensorProto.Name, out nodeMeta) ||
                nodeMeta.ElementType != tensorElemType)
            {
                throw new Exception("No Matching Tensor found from serialized tensor");
            }

            if (nodeMeta.ElementType == typeof(float))
            {
                return(CreateNamedOnnxValueFromRawData <float>(tensorProto.Name, tensorProto.RawData.ToArray(), sizeof(float), dims));
            }
            else if (nodeMeta.ElementType == typeof(double))
            {
                return(CreateNamedOnnxValueFromRawData <double>(tensorProto.Name, tensorProto.RawData.ToArray(), sizeof(double), dims));
            }
            else if (nodeMeta.ElementType == typeof(int))
            {
                return(CreateNamedOnnxValueFromRawData <int>(tensorProto.Name, tensorProto.RawData.ToArray(), sizeof(int), dims));
            }
            else if (nodeMeta.ElementType == typeof(uint))
            {
                return(CreateNamedOnnxValueFromRawData <uint>(tensorProto.Name, tensorProto.RawData.ToArray(), sizeof(uint), dims));
            }
            else if (nodeMeta.ElementType == typeof(long))
            {
                return(CreateNamedOnnxValueFromRawData <long>(tensorProto.Name, tensorProto.RawData.ToArray(), sizeof(long), dims));
            }
            else if (nodeMeta.ElementType == typeof(ulong))
            {
                return(CreateNamedOnnxValueFromRawData <ulong>(tensorProto.Name, tensorProto.RawData.ToArray(), sizeof(ulong), dims));
            }
            else if (nodeMeta.ElementType == typeof(short))
            {
                return(CreateNamedOnnxValueFromRawData <short>(tensorProto.Name, tensorProto.RawData.ToArray(), sizeof(short), dims));
            }
            else if (nodeMeta.ElementType == typeof(ushort))
            {
                return(CreateNamedOnnxValueFromRawData <ushort>(tensorProto.Name, tensorProto.RawData.ToArray(), sizeof(ushort), dims));
            }
            else if (nodeMeta.ElementType == typeof(byte))
            {
                return(CreateNamedOnnxValueFromRawData <byte>(tensorProto.Name, tensorProto.RawData.ToArray(), sizeof(byte), dims));
            }
            else if (nodeMeta.ElementType == typeof(bool))
            {
                return(CreateNamedOnnxValueFromRawData <bool>(tensorProto.Name, tensorProto.RawData.ToArray(), sizeof(bool), dims));
            }
            else
            {
                throw new Exception("Tensors of type " + nameof(nodeMeta.ElementType) + " not currently supported in this tool");
            }
        }