예제 #1
0
        /// <summary>
        /// Initializes the Brain with the Model that it will use when selecting actions for
        /// the agents
        /// </summary>
        /// <param name="seed"> The seed that will be used to initialize the RandomNormal
        /// and Multinomial obsjects used when running inference.</param>
        /// <exception cref="UnityAgentsException">Throws an error when the model is null
        /// </exception>
        public void ReloadModel(int seed = 0)
        {
            if (m_TensorAllocator == null)
            {
                m_TensorAllocator = new TensorCachingAllocator();
            }

            if (model != null)
            {
#if BARRACUDA_VERBOSE
                _verbose = true;
#endif

                D.logEnabled = m_Verbose;

                // Cleanup previous instance
                if (m_Engine != null)
                {
                    m_Engine.Dispose();
                }

                m_BarracudaModel = ModelLoader.Load(model.Value);
                var executionDevice = inferenceDevice == InferenceDevice.GPU
                    ? BarracudaWorkerFactory.Type.ComputePrecompiled
                    : BarracudaWorkerFactory.Type.CSharp;

                m_Engine = BarracudaWorkerFactory.CreateWorker(executionDevice, m_BarracudaModel, m_Verbose);
            }
            else
            {
                m_BarracudaModel = null;
                m_Engine         = null;
            }

            m_ModelParamLoader = BarracudaModelParamLoader.GetLoaderAndCheck(m_Engine, m_BarracudaModel, brainParameters);
            m_InferenceInputs  = m_ModelParamLoader.GetInputTensors();
            m_OutputNames      = m_ModelParamLoader.GetOutputNames();
            m_TensorGenerator  = new TensorGenerator(brainParameters, seed, m_TensorAllocator, m_BarracudaModel);
            m_TensorApplier    = new TensorApplier(brainParameters, seed, m_TensorAllocator, m_BarracudaModel);
        }
        public void TestCheckModelThrowsVectorObservationHybrid()
        {
            var model = ModelLoader.Load(hybridONNXModel);

            var brainParameters = GetHybridBrainParameters();

            brainParameters.VectorObservationSize = 9; // Invalid observation
            var errors = BarracudaModelParamLoader.CheckModel(
                model, brainParameters,
                new SensorComponent[] { }, new ActuatorComponent[0]
                );

            Assert.Greater(errors.Count(), 0);

            brainParameters = GetContinuous2vis8vec2actionBrainParameters();
            brainParameters.NumStackedVectorObservations = 2;// Invalid stacking
            errors = BarracudaModelParamLoader.CheckModel(
                model, brainParameters,
                new SensorComponent[] { }, new ActuatorComponent[0]
                );
            Assert.Greater(errors.Count(), 0);
        }
        public void TestCheckModelThrowsVectorObservationContinuous(bool useDeprecatedNNModel)
        {
            var model = useDeprecatedNNModel ? ModelLoader.Load(continuousNNModel) : ModelLoader.Load(continuousONNXModel);

            var brainParameters = GetContinuous2vis8vec2actionBrainParameters();

            brainParameters.VectorObservationSize = 9; // Invalid observation
            var errors = BarracudaModelParamLoader.CheckModel(
                model, brainParameters,
                new SensorComponent[] { sensor_21_20_3, sensor_20_22_3 }, new ActuatorComponent[0]
                );

            Assert.Greater(errors.Count(), 0);

            brainParameters = GetContinuous2vis8vec2actionBrainParameters();
            brainParameters.NumStackedVectorObservations = 2;// Invalid stacking
            errors = BarracudaModelParamLoader.CheckModel(
                model, brainParameters,
                new SensorComponent[] { sensor_21_20_3, sensor_20_22_3 }, new ActuatorComponent[0]
                );
            Assert.Greater(errors.Count(), 0);
        }
 public void TestCheckModelThrowsNoModel()
 {
     var brainParameters = GetContinuous2vis8vec2actionBrainParameters();
     var errors = BarracudaModelParamLoader.CheckModel(null, brainParameters, new SensorComponent[] { sensor_21_20_3, sensor_20_22_3 });
     Assert.Greater(errors.Count(), 0);
 }