internal static NamedOnnxValue LoadTensorFromFilePb(string filename, IReadOnlyDictionary <string, NodeMetadata> nodeMetaDict) { //Set buffer size to 4MB int readBufferSize = 4194304; Onnx.TensorProto tensor = null; using (var file = new FileStream(filename, FileMode.Open, FileAccess.Read, FileShare.Read, readBufferSize)) { tensor = Onnx.TensorProto.Parser.ParseFrom(file); } return(LoadTensorPb(tensor, nodeMetaDict)); }
internal static NamedOnnxValue LoadTensorFromEmbeddedResourcePb(string path, IReadOnlyDictionary <string, NodeMetadata> nodeMetaDict) { Onnx.TensorProto tensor = null; var assembly = typeof(TestDataLoader).Assembly; using (Stream stream = assembly.GetManifestResourceStream($"{assembly.GetName().Name}.TestData.{path}")) { tensor = Onnx.TensorProto.Parser.ParseFrom(stream); } return(LoadTensorPb(tensor, nodeMetaDict)); }
static List <NamedOnnxValue> LoadTestDataFromProtobuf(string testDataPath, IReadOnlyDictionary <string, NodeMetadata> inputMeta) { var container = new List <NamedOnnxValue>(); var filenames = from filename in Directory.EnumerateFiles(testDataPath, "input_*.pb") select filename; foreach (var filename in filenames) { Onnx.TensorProto tensorProto = null; using (var inputFile = File.OpenRead(filename)) { tensorProto = Onnx.TensorProto.Parser.ParseFrom(inputFile); } var namedOnnxValue = CreateNamedOnnxValueFromTensorProto(tensorProto, inputMeta); container.Add(namedOnnxValue); } return(container); }
static NamedOnnxValue LoadTensorPb(Onnx.TensorProto tensor, IReadOnlyDictionary <string, NodeMetadata> nodeMetaDict) { Type tensorElemType = null; int width = 0; GetTypeAndWidth((Tensors.TensorElementType)tensor.DataType, out tensorElemType, out width); var intDims = new int[tensor.Dims.Count]; for (int i = 0; i < tensor.Dims.Count; i++) { intDims[i] = (int)tensor.Dims[i]; } NodeMetadata nodeMeta = null; string nodeName = string.Empty; if (nodeMetaDict.Count == 1) { nodeMeta = nodeMetaDict.Values.First(); nodeName = nodeMetaDict.Keys.First(); // valid for single node input } else if (nodeMetaDict.Count > 1) { if (tensor.Name.Length > 0) { nodeMeta = nodeMetaDict[tensor.Name]; nodeName = tensor.Name; } else { bool matchfound = false; // try to find from matching type and shape foreach (var key in nodeMetaDict.Keys) { var meta = nodeMetaDict[key]; if (tensorElemType == meta.ElementType && tensor.Dims.Count == meta.Dimensions.Length) { int i = 0; for (; i < meta.Dimensions.Length; i++) { if (meta.Dimensions[i] != -1 && meta.Dimensions[i] != intDims[i]) { break; } } if (i >= meta.Dimensions.Length) { matchfound = true; nodeMeta = meta; nodeName = key; break; } } } if (!matchfound) { // throw error throw new Exception($"No Matching Tensor found in InputOutputMetadata corresponding to the serialized tensor specified"); } } } else { // throw error throw new Exception($"While reading the serliazed tensor specified, metaDataDict has 0 elements"); } if (!nodeMeta.IsTensor) { throw new Exception("LoadTensorFromFile can load Tensor types only"); } if (tensorElemType != nodeMeta.ElementType) { throw new Exception($"{nameof(tensorElemType)} is expected to be equal to {nameof(nodeMeta.ElementType)}"); } if (nodeMeta.Dimensions.Length != tensor.Dims.Count) { throw new Exception($"{nameof(nodeMeta.Dimensions.Length)} is expected to be equal to {nameof(tensor.Dims.Count)}"); } for (int i = 0; i < nodeMeta.Dimensions.Length; i++) { if ((nodeMeta.Dimensions[i] != -1) && (nodeMeta.Dimensions[i] != intDims[i])) { throw new Exception($"{nameof(nodeMeta.Dimensions)}[{i}] is expected to either be -1 or {nameof(intDims)}[{i}]"); } } if (nodeMeta.ElementType == typeof(float)) { return(CreateNamedOnnxValueFromRawData <float>(nodeName, tensor.RawData.ToArray(), sizeof(float), intDims)); } else if (nodeMeta.ElementType == typeof(double)) { return(CreateNamedOnnxValueFromRawData <double>(nodeName, tensor.RawData.ToArray(), sizeof(double), intDims)); } else if (nodeMeta.ElementType == typeof(int)) { return(CreateNamedOnnxValueFromRawData <int>(nodeName, tensor.RawData.ToArray(), sizeof(int), intDims)); } else if (nodeMeta.ElementType == typeof(uint)) { return(CreateNamedOnnxValueFromRawData <uint>(nodeName, tensor.RawData.ToArray(), sizeof(uint), intDims)); } else if (nodeMeta.ElementType == typeof(long)) { return(CreateNamedOnnxValueFromRawData <long>(nodeName, tensor.RawData.ToArray(), sizeof(long), intDims)); } else if (nodeMeta.ElementType == typeof(ulong)) { return(CreateNamedOnnxValueFromRawData <ulong>(nodeName, tensor.RawData.ToArray(), sizeof(ulong), intDims)); } else if (nodeMeta.ElementType == typeof(short)) { return(CreateNamedOnnxValueFromRawData <short>(nodeName, tensor.RawData.ToArray(), sizeof(short), intDims)); } else if (nodeMeta.ElementType == typeof(ushort)) { return(CreateNamedOnnxValueFromRawData <ushort>(nodeName, tensor.RawData.ToArray(), sizeof(ushort), intDims)); } else if (nodeMeta.ElementType == typeof(byte)) { return(CreateNamedOnnxValueFromRawData <byte>(nodeName, tensor.RawData.ToArray(), sizeof(byte), intDims)); } else if (nodeMeta.ElementType == typeof(bool)) { return(CreateNamedOnnxValueFromRawData <bool>(nodeName, tensor.RawData.ToArray(), sizeof(bool), intDims)); } else if (nodeMeta.ElementType == typeof(Float16)) { return(CreateNamedOnnxValueFromRawData <Float16>(nodeName, tensor.RawData.ToArray(), sizeof(ushort), intDims)); } else if (nodeMeta.ElementType == typeof(BFloat16)) { return(CreateNamedOnnxValueFromRawData <BFloat16>(nodeName, tensor.RawData.ToArray(), sizeof(ushort), intDims)); } else { //TODO: Add support for remaining types throw new Exception($"Tensors of type {nameof(nodeMeta.ElementType)} not currently supporte in the LoadTensorFromEmbeddedResource"); } }
static NamedOnnxValue CreateNamedOnnxValueFromTensorProto(Onnx.TensorProto tensorProto, IReadOnlyDictionary <string, NodeMetadata> inputMeta) { Type tensorElemType = null; int elemWidth = 0; GetElementTypeAndWidth((TensorElementType)tensorProto.DataType, out tensorElemType, out elemWidth); var dims = tensorProto.Dims.ToList().ConvertAll(x => (int)x); NodeMetadata nodeMeta = null; if (!inputMeta.TryGetValue(tensorProto.Name, out nodeMeta) || nodeMeta.ElementType != tensorElemType) { throw new Exception("No Matching Tensor found from serialized tensor"); } if (nodeMeta.ElementType == typeof(float)) { return(CreateNamedOnnxValueFromRawData <float>(tensorProto.Name, tensorProto.RawData.ToArray(), sizeof(float), dims)); } else if (nodeMeta.ElementType == typeof(double)) { return(CreateNamedOnnxValueFromRawData <double>(tensorProto.Name, tensorProto.RawData.ToArray(), sizeof(double), dims)); } else if (nodeMeta.ElementType == typeof(int)) { return(CreateNamedOnnxValueFromRawData <int>(tensorProto.Name, tensorProto.RawData.ToArray(), sizeof(int), dims)); } else if (nodeMeta.ElementType == typeof(uint)) { return(CreateNamedOnnxValueFromRawData <uint>(tensorProto.Name, tensorProto.RawData.ToArray(), sizeof(uint), dims)); } else if (nodeMeta.ElementType == typeof(long)) { return(CreateNamedOnnxValueFromRawData <long>(tensorProto.Name, tensorProto.RawData.ToArray(), sizeof(long), dims)); } else if (nodeMeta.ElementType == typeof(ulong)) { return(CreateNamedOnnxValueFromRawData <ulong>(tensorProto.Name, tensorProto.RawData.ToArray(), sizeof(ulong), dims)); } else if (nodeMeta.ElementType == typeof(short)) { return(CreateNamedOnnxValueFromRawData <short>(tensorProto.Name, tensorProto.RawData.ToArray(), sizeof(short), dims)); } else if (nodeMeta.ElementType == typeof(ushort)) { return(CreateNamedOnnxValueFromRawData <ushort>(tensorProto.Name, tensorProto.RawData.ToArray(), sizeof(ushort), dims)); } else if (nodeMeta.ElementType == typeof(byte)) { return(CreateNamedOnnxValueFromRawData <byte>(tensorProto.Name, tensorProto.RawData.ToArray(), sizeof(byte), dims)); } else if (nodeMeta.ElementType == typeof(bool)) { return(CreateNamedOnnxValueFromRawData <bool>(tensorProto.Name, tensorProto.RawData.ToArray(), sizeof(bool), dims)); } else { throw new Exception("Tensors of type " + nameof(nodeMeta.ElementType) + " not currently supported in this tool"); } }