/// <summary> /// Generates the Tensor inputs that are expected to be present in the Model. /// </summary> /// <param name="model"> /// The Barracuda engine model for loading static parameters. /// </param> /// <returns>TensorProxy IEnumerable with the expected Tensor inputs.</returns> public static IReadOnlyList <TensorProxy> GetInputTensors(Model model) { var tensors = new List <TensorProxy>(); if (model == null) { return(tensors); } foreach (var input in model.inputs) { tensors.Add(new TensorProxy { name = input.name, valueType = TensorProxy.TensorType.FloatingPoint, data = null, shape = input.shape.Select(i => (long)i).ToArray() }); } foreach (var mem in model.memories) { tensors.Add(new TensorProxy { name = mem.input, valueType = TensorProxy.TensorType.FloatingPoint, data = null, shape = TensorUtils.TensorShapeFromBarracuda(mem.shape) }); } tensors.Sort((el1, el2) => el1.name.CompareTo(el2.name)); return(tensors); }
/// <summary> /// Generates the Tensor inputs that are expected to be present in the Model. /// </summary> /// <param name="model"> /// The Barracuda engine model for loading static parameters. /// </param> /// <returns>TensorProxy IEnumerable with the expected Tensor inputs.</returns> public static IReadOnlyList <TensorProxy> GetInputTensors(this Model model) { var tensors = new List <TensorProxy>(); if (model == null) { return(tensors); } foreach (var input in model.inputs) { tensors.Add(new TensorProxy { name = input.name, valueType = TensorProxy.TensorType.FloatingPoint, data = null, shape = input.shape.Select(i => (long)i).ToArray() }); } var modelVersion = model.GetVersion(); if (modelVersion < (int)BarracudaModelParamLoader.ModelApiVersion.MLAgents2_0) { foreach (var mem in model.memories) { tensors.Add(new TensorProxy { name = mem.input, valueType = TensorProxy.TensorType.FloatingPoint, data = null, shape = TensorUtils.TensorShapeFromBarracuda(mem.shape) }); } } tensors.Sort((el1, el2) => string.Compare(el1.name, el2.name, StringComparison.InvariantCulture)); return(tensors); }