コード例 #1
0
        /// <summary>
        /// Check if the model contains all the expected input/output tensors.
        /// </summary>
        /// <param name="model">
        /// The Barracuda engine model for loading static parameters.
        /// </param>
        /// <param name="failedModelChecks">Output list of failure messages</param>
        ///<param name="deterministicInference"> Inference only: set to true if the action selection from model should be
        /// deterministic. </param>
        /// <returns>True if the model contains all the expected tensors.</returns>
        /// TODO: add checks for deterministic actions
        public static bool CheckExpectedTensors(this Model model, List <FailedCheck> failedModelChecks, bool deterministicInference = false)
        {
            // Check the presence of model version
            var modelApiVersionTensor = model.GetTensorByName(TensorNames.VersionNumber);

            if (modelApiVersionTensor == null)
            {
                failedModelChecks.Add(
                    FailedCheck.Warning($"Required constant \"{TensorNames.VersionNumber}\" was not found in the model file.")
                    );
                return(false);
            }

            // Check the presence of memory size
            var memorySizeTensor = model.GetTensorByName(TensorNames.MemorySize);

            if (memorySizeTensor == null)
            {
                failedModelChecks.Add(
                    FailedCheck.Warning($"Required constant \"{TensorNames.MemorySize}\" was not found in the model file.")
                    );
                return(false);
            }

            // Check the presence of action output tensor
            if (!model.outputs.Contains(TensorNames.ActionOutputDeprecated) &&
                !model.outputs.Contains(TensorNames.ContinuousActionOutput) &&
                !model.outputs.Contains(TensorNames.DiscreteActionOutput) &&
                !model.outputs.Contains(TensorNames.DeterministicContinuousActionOutput) &&
                !model.outputs.Contains(TensorNames.DeterministicDiscreteActionOutput))
            {
                failedModelChecks.Add(
                    FailedCheck.Warning("The model does not contain any Action Output Node.")
                    );
                return(false);
            }

            // Check the presence of action output shape tensor
            if (!model.SupportsContinuousAndDiscrete())
            {
                if (model.GetTensorByName(TensorNames.ActionOutputShapeDeprecated) == null)
                {
                    failedModelChecks.Add(
                        FailedCheck.Warning("The model does not contain any Action Output Shape Node.")
                        );
                    return(false);
                }
                if (model.GetTensorByName(TensorNames.IsContinuousControlDeprecated) == null)
                {
                    failedModelChecks.Add(
                        FailedCheck.Warning($"Required constant \"{TensorNames.IsContinuousControlDeprecated}\" was " +
                                            "not found in the model file. " +
                                            "This is only required for model that uses a deprecated model format.")
                        );
                    return(false);
                }
            }
            else
            {
                if (model.outputs.Contains(TensorNames.ContinuousActionOutput))
                {
                    if (model.GetTensorByName(TensorNames.ContinuousActionOutputShape) == null)
                    {
                        failedModelChecks.Add(
                            FailedCheck.Warning("The model uses continuous action but does not contain Continuous Action Output Shape Node.")
                            );
                        return(false);
                    }

                    else if (!model.HasContinuousOutputs(deterministicInference))
                    {
                        var actionType = deterministicInference ? "deterministic" : "stochastic";
                        var actionName = deterministicInference ? "Deterministic" : "";
                        failedModelChecks.Add(
                            FailedCheck.Warning($"The model uses {actionType} inference but does not contain {actionName} Continuous Action Output Tensor. Uncheck `Deterministic inference` flag..")
                            );
                        return(false);
                    }
                }

                if (model.outputs.Contains(TensorNames.DiscreteActionOutput))
                {
                    if (model.GetTensorByName(TensorNames.DiscreteActionOutputShape) == null)
                    {
                        failedModelChecks.Add(
                            FailedCheck.Warning("The model uses discrete action but does not contain Discrete Action Output Shape Node.")
                            );
                        return(false);
                    }
                    else if (!model.HasDiscreteOutputs(deterministicInference))
                    {
                        var actionType = deterministicInference ? "deterministic" : "stochastic";
                        var actionName = deterministicInference ? "Deterministic" : "";
                        failedModelChecks.Add(
                            FailedCheck.Warning($"The model uses {actionType} inference but does not contain {actionName} Discrete Action Output Tensor. Uncheck `Deterministic inference` flag.")
                            );
                        return(false);
                    }
                }
            }
            return(true);
        }
コード例 #2
0
        /// <summary>
        /// Check if the model contains all the expected input/output tensors.
        /// </summary>
        /// <param name="model">
        /// The Barracuda engine model for loading static parameters.
        /// </param>
        /// <param name="failedModelChecks">Output list of failure messages</param>
        ///
        /// <returns>True if the model contains all the expected tensors.</returns>
        public static bool CheckExpectedTensors(this Model model, List <FailedCheck> failedModelChecks)
        {
            // Check the presence of model version
            var modelApiVersionTensor = model.GetTensorByName(TensorNames.VersionNumber);

            if (modelApiVersionTensor == null)
            {
                failedModelChecks.Add(
                    FailedCheck.Warning($"Required constant \"{TensorNames.VersionNumber}\" was not found in the model file.")
                    );
                return(false);
            }

            // Check the presence of memory size
            var memorySizeTensor = model.GetTensorByName(TensorNames.MemorySize);

            if (memorySizeTensor == null)
            {
                failedModelChecks.Add(
                    FailedCheck.Warning($"Required constant \"{TensorNames.MemorySize}\" was not found in the model file.")
                    );
                return(false);
            }

            // Check the presence of action output tensor
            if (!model.outputs.Contains(TensorNames.ActionOutputDeprecated) &&
                !model.outputs.Contains(TensorNames.ContinuousActionOutput) &&
                !model.outputs.Contains(TensorNames.DiscreteActionOutput))
            {
                failedModelChecks.Add(
                    FailedCheck.Warning("The model does not contain any Action Output Node.")
                    );
                return(false);
            }

            // Check the presence of action output shape tensor
            if (!model.SupportsContinuousAndDiscrete())
            {
                if (model.GetTensorByName(TensorNames.ActionOutputShapeDeprecated) == null)
                {
                    failedModelChecks.Add(
                        FailedCheck.Warning("The model does not contain any Action Output Shape Node.")
                        );
                    return(false);
                }
                if (model.GetTensorByName(TensorNames.IsContinuousControlDeprecated) == null)
                {
                    failedModelChecks.Add(
                        FailedCheck.Warning($"Required constant \"{TensorNames.IsContinuousControlDeprecated}\" was " +
                                            "not found in the model file. " +
                                            "This is only required for model that uses a deprecated model format.")
                        );
                    return(false);
                }
            }
            else
            {
                if (model.outputs.Contains(TensorNames.ContinuousActionOutput) &&
                    model.GetTensorByName(TensorNames.ContinuousActionOutputShape) == null)
                {
                    failedModelChecks.Add(
                        FailedCheck.Warning("The model uses continuous action but does not contain Continuous Action Output Shape Node.")
                        );
                    return(false);
                }
                if (model.outputs.Contains(TensorNames.DiscreteActionOutput) &&
                    model.GetTensorByName(TensorNames.DiscreteActionOutputShape) == null)
                {
                    failedModelChecks.Add(
                        FailedCheck.Warning("The model uses discrete action but does not contain Discrete Action Output Shape Node.")
                        );
                    return(false);
                }
            }
            return(true);
        }