/// <summary> /// Initializes the Brain with the Model that it will use when selecting actions for /// the agents /// </summary> /// <param name="seed"> The seed that will be used to initialize the RandomNormal /// and Multinomial obsjects used when running inference.</param> /// <exception cref="UnityAgentsException">Throws an error when the model is null /// </exception> public void ReloadModel(int seed = 0) { if (m_TensorAllocator == null) { m_TensorAllocator = new TensorCachingAllocator(); } if (model != null) { #if BARRACUDA_VERBOSE _verbose = true; #endif D.logEnabled = m_Verbose; // Cleanup previous instance if (m_Engine != null) { m_Engine.Dispose(); } m_BarracudaModel = ModelLoader.Load(model.Value); var executionDevice = inferenceDevice == InferenceDevice.GPU ? BarracudaWorkerFactory.Type.ComputePrecompiled : BarracudaWorkerFactory.Type.CSharp; m_Engine = BarracudaWorkerFactory.CreateWorker(executionDevice, m_BarracudaModel, m_Verbose); } else { m_BarracudaModel = null; m_Engine = null; } m_ModelParamLoader = BarracudaModelParamLoader.GetLoaderAndCheck(m_Engine, m_BarracudaModel, brainParameters); m_InferenceInputs = m_ModelParamLoader.GetInputTensors(); m_OutputNames = m_ModelParamLoader.GetOutputNames(); m_TensorGenerator = new TensorGenerator(brainParameters, seed, m_TensorAllocator, m_BarracudaModel); m_TensorApplier = new TensorApplier(brainParameters, seed, m_TensorAllocator, m_BarracudaModel); }
public void TestCheckModelThrowsVectorObservationHybrid() { var model = ModelLoader.Load(hybridONNXModel); var brainParameters = GetHybridBrainParameters(); brainParameters.VectorObservationSize = 9; // Invalid observation var errors = BarracudaModelParamLoader.CheckModel( model, brainParameters, new SensorComponent[] { }, new ActuatorComponent[0] ); Assert.Greater(errors.Count(), 0); brainParameters = GetContinuous2vis8vec2actionBrainParameters(); brainParameters.NumStackedVectorObservations = 2;// Invalid stacking errors = BarracudaModelParamLoader.CheckModel( model, brainParameters, new SensorComponent[] { }, new ActuatorComponent[0] ); Assert.Greater(errors.Count(), 0); }
public void TestCheckModelThrowsVectorObservationContinuous(bool useDeprecatedNNModel) { var model = useDeprecatedNNModel ? ModelLoader.Load(continuousNNModel) : ModelLoader.Load(continuousONNXModel); var brainParameters = GetContinuous2vis8vec2actionBrainParameters(); brainParameters.VectorObservationSize = 9; // Invalid observation var errors = BarracudaModelParamLoader.CheckModel( model, brainParameters, new SensorComponent[] { sensor_21_20_3, sensor_20_22_3 }, new ActuatorComponent[0] ); Assert.Greater(errors.Count(), 0); brainParameters = GetContinuous2vis8vec2actionBrainParameters(); brainParameters.NumStackedVectorObservations = 2;// Invalid stacking errors = BarracudaModelParamLoader.CheckModel( model, brainParameters, new SensorComponent[] { sensor_21_20_3, sensor_20_22_3 }, new ActuatorComponent[0] ); Assert.Greater(errors.Count(), 0); }
public void TestCheckModelThrowsNoModel() { var brainParameters = GetContinuous2vis8vec2actionBrainParameters(); var errors = BarracudaModelParamLoader.CheckModel(null, brainParameters, new SensorComponent[] { sensor_21_20_3, sensor_20_22_3 }); Assert.Greater(errors.Count(), 0); }