/// <summary> /// Check if the model contains all the expected input/output tensors. /// </summary> /// <param name="model"> /// The Barracuda engine model for loading static parameters. /// </param> /// <param name="failedModelChecks">Output list of failure messages</param> ///<param name="deterministicInference"> Inference only: set to true if the action selection from model should be /// deterministic. </param> /// <returns>True if the model contains all the expected tensors.</returns> /// TODO: add checks for deterministic actions public static bool CheckExpectedTensors(this Model model, List <FailedCheck> failedModelChecks, bool deterministicInference = false) { // Check the presence of model version var modelApiVersionTensor = model.GetTensorByName(TensorNames.VersionNumber); if (modelApiVersionTensor == null) { failedModelChecks.Add( FailedCheck.Warning($"Required constant \"{TensorNames.VersionNumber}\" was not found in the model file.") ); return(false); } // Check the presence of memory size var memorySizeTensor = model.GetTensorByName(TensorNames.MemorySize); if (memorySizeTensor == null) { failedModelChecks.Add( FailedCheck.Warning($"Required constant \"{TensorNames.MemorySize}\" was not found in the model file.") ); return(false); } // Check the presence of action output tensor if (!model.outputs.Contains(TensorNames.ActionOutputDeprecated) && !model.outputs.Contains(TensorNames.ContinuousActionOutput) && !model.outputs.Contains(TensorNames.DiscreteActionOutput) && !model.outputs.Contains(TensorNames.DeterministicContinuousActionOutput) && !model.outputs.Contains(TensorNames.DeterministicDiscreteActionOutput)) { failedModelChecks.Add( FailedCheck.Warning("The model does not contain any Action Output Node.") ); return(false); } // Check the presence of action output shape tensor if (!model.SupportsContinuousAndDiscrete()) { if (model.GetTensorByName(TensorNames.ActionOutputShapeDeprecated) == null) { failedModelChecks.Add( FailedCheck.Warning("The model does not contain any Action Output Shape Node.") ); return(false); } if (model.GetTensorByName(TensorNames.IsContinuousControlDeprecated) == null) { failedModelChecks.Add( FailedCheck.Warning($"Required constant \"{TensorNames.IsContinuousControlDeprecated}\" was " + "not found in the model file. " + "This is only required for model that uses a deprecated model format.") ); return(false); } } else { if (model.outputs.Contains(TensorNames.ContinuousActionOutput)) { if (model.GetTensorByName(TensorNames.ContinuousActionOutputShape) == null) { failedModelChecks.Add( FailedCheck.Warning("The model uses continuous action but does not contain Continuous Action Output Shape Node.") ); return(false); } else if (!model.HasContinuousOutputs(deterministicInference)) { var actionType = deterministicInference ? "deterministic" : "stochastic"; var actionName = deterministicInference ? "Deterministic" : ""; failedModelChecks.Add( FailedCheck.Warning($"The model uses {actionType} inference but does not contain {actionName} Continuous Action Output Tensor. Uncheck `Deterministic inference` flag..") ); return(false); } } if (model.outputs.Contains(TensorNames.DiscreteActionOutput)) { if (model.GetTensorByName(TensorNames.DiscreteActionOutputShape) == null) { failedModelChecks.Add( FailedCheck.Warning("The model uses discrete action but does not contain Discrete Action Output Shape Node.") ); return(false); } else if (!model.HasDiscreteOutputs(deterministicInference)) { var actionType = deterministicInference ? "deterministic" : "stochastic"; var actionName = deterministicInference ? "Deterministic" : ""; failedModelChecks.Add( FailedCheck.Warning($"The model uses {actionType} inference but does not contain {actionName} Discrete Action Output Tensor. Uncheck `Deterministic inference` flag.") ); return(false); } } } return(true); }
/// <summary> /// Check if the model contains all the expected input/output tensors. /// </summary> /// <param name="model"> /// The Barracuda engine model for loading static parameters. /// </param> /// <param name="failedModelChecks">Output list of failure messages</param> /// /// <returns>True if the model contains all the expected tensors.</returns> public static bool CheckExpectedTensors(this Model model, List <FailedCheck> failedModelChecks) { // Check the presence of model version var modelApiVersionTensor = model.GetTensorByName(TensorNames.VersionNumber); if (modelApiVersionTensor == null) { failedModelChecks.Add( FailedCheck.Warning($"Required constant \"{TensorNames.VersionNumber}\" was not found in the model file.") ); return(false); } // Check the presence of memory size var memorySizeTensor = model.GetTensorByName(TensorNames.MemorySize); if (memorySizeTensor == null) { failedModelChecks.Add( FailedCheck.Warning($"Required constant \"{TensorNames.MemorySize}\" was not found in the model file.") ); return(false); } // Check the presence of action output tensor if (!model.outputs.Contains(TensorNames.ActionOutputDeprecated) && !model.outputs.Contains(TensorNames.ContinuousActionOutput) && !model.outputs.Contains(TensorNames.DiscreteActionOutput)) { failedModelChecks.Add( FailedCheck.Warning("The model does not contain any Action Output Node.") ); return(false); } // Check the presence of action output shape tensor if (!model.SupportsContinuousAndDiscrete()) { if (model.GetTensorByName(TensorNames.ActionOutputShapeDeprecated) == null) { failedModelChecks.Add( FailedCheck.Warning("The model does not contain any Action Output Shape Node.") ); return(false); } if (model.GetTensorByName(TensorNames.IsContinuousControlDeprecated) == null) { failedModelChecks.Add( FailedCheck.Warning($"Required constant \"{TensorNames.IsContinuousControlDeprecated}\" was " + "not found in the model file. " + "This is only required for model that uses a deprecated model format.") ); return(false); } } else { if (model.outputs.Contains(TensorNames.ContinuousActionOutput) && model.GetTensorByName(TensorNames.ContinuousActionOutputShape) == null) { failedModelChecks.Add( FailedCheck.Warning("The model uses continuous action but does not contain Continuous Action Output Shape Node.") ); return(false); } if (model.outputs.Contains(TensorNames.DiscreteActionOutput) && model.GetTensorByName(TensorNames.DiscreteActionOutputShape) == null) { failedModelChecks.Add( FailedCheck.Warning("The model uses discrete action but does not contain Discrete Action Output Shape Node.") ); return(false); } } return(true); }