Beispiel #1
0
        /// <summary>
        /// Generates the Tensor inputs that are expected to be present in the Model.
        /// </summary>
        /// <param name="model">
        /// The Barracuda engine model for loading static parameters.
        /// </param>
        /// <returns>TensorProxy IEnumerable with the expected Tensor inputs.</returns>
        public static IReadOnlyList <TensorProxy> GetInputTensors(Model model)
        {
            var tensors = new List <TensorProxy>();

            if (model == null)
            {
                return(tensors);
            }

            foreach (var input in model.inputs)
            {
                tensors.Add(new TensorProxy
                {
                    name      = input.name,
                    valueType = TensorProxy.TensorType.FloatingPoint,
                    data      = null,
                    shape     = input.shape.Select(i => (long)i).ToArray()
                });
            }

            foreach (var mem in model.memories)
            {
                tensors.Add(new TensorProxy
                {
                    name      = mem.input,
                    valueType = TensorProxy.TensorType.FloatingPoint,
                    data      = null,
                    shape     = TensorUtils.TensorShapeFromBarracuda(mem.shape)
                });
            }

            tensors.Sort((el1, el2) => el1.name.CompareTo(el2.name));

            return(tensors);
        }
Beispiel #2
0
        /// <summary>
        /// Generates the Tensor inputs that are expected to be present in the Model.
        /// </summary>
        /// <param name="model">
        /// The Barracuda engine model for loading static parameters.
        /// </param>
        /// <returns>TensorProxy IEnumerable with the expected Tensor inputs.</returns>
        public static IReadOnlyList <TensorProxy> GetInputTensors(this Model model)
        {
            var tensors = new List <TensorProxy>();

            if (model == null)
            {
                return(tensors);
            }

            foreach (var input in model.inputs)
            {
                tensors.Add(new TensorProxy
                {
                    name      = input.name,
                    valueType = TensorProxy.TensorType.FloatingPoint,
                    data      = null,
                    shape     = input.shape.Select(i => (long)i).ToArray()
                });
            }

            var modelVersion = model.GetVersion();

            if (modelVersion < (int)BarracudaModelParamLoader.ModelApiVersion.MLAgents2_0)
            {
                foreach (var mem in model.memories)
                {
                    tensors.Add(new TensorProxy
                    {
                        name      = mem.input,
                        valueType = TensorProxy.TensorType.FloatingPoint,
                        data      = null,
                        shape     = TensorUtils.TensorShapeFromBarracuda(mem.shape)
                    });
                }
            }

            tensors.Sort((el1, el2) => string.Compare(el1.name, el2.name, StringComparison.InvariantCulture));

            return(tensors);
        }