/// <summary> /// Initializes the Brain with the Model that it will use when selecting actions for /// the agents /// </summary> /// <param name="model"> The Barracuda model to load </param> /// <param name="brainParameters"> The parameters of the Brain used to generate the /// placeholder tensors </param> /// <param name="inferenceDevice"> Inference execution device. CPU is the fastest /// option for most of ML Agents models. </param> /// <param name="seed"> The seed that will be used to initialize the RandomNormal /// and Multinomial objects used when running inference.</param> /// <exception cref="UnityAgentsException">Throws an error when the model is null /// </exception> public ModelRunner( NNModel model, BrainParameters brainParameters, InferenceDevice inferenceDevice = InferenceDevice.CPU, int seed = 0) { Model barracudaModel; m_Model = model; m_InferenceDevice = inferenceDevice; m_TensorAllocator = new TensorCachingAllocator(); if (model != null) { #if BARRACUDA_VERBOSE m_Verbose = true; #endif D.logEnabled = m_Verbose; //barracudaModel = ModelLoader.Load(model.Value); barracudaModel = ModelLoader.Load(model); var executionDevice = inferenceDevice == InferenceDevice.GPU // ? BarracudaWorkerFactory.Type.ComputePrecompiled // : BarracudaWorkerFactory.Type.CSharp; // m_Engine = BarracudaWorkerFactory.CreateWorker(executionDevice, barracudaModel, m_Verbose); ? WorkerFactory.Type.ComputePrecompiled : WorkerFactory.Type.CSharp; m_Engine = WorkerFactory.CreateWorker(executionDevice, barracudaModel, m_Verbose); } else { barracudaModel = null; m_Engine = null; } m_InferenceInputs = BarracudaModelParamLoader.GetInputTensors(barracudaModel); m_OutputNames = BarracudaModelParamLoader.GetOutputNames(barracudaModel); m_TensorGenerator = new TensorGenerator( seed, m_TensorAllocator, m_Memories, barracudaModel); m_TensorApplier = new TensorApplier( brainParameters, seed, m_TensorAllocator, m_Memories, barracudaModel); }
public bool HasModel(NNModel other, InferenceDevice otherInferenceDevice) { return(m_Model == other && m_InferenceDevice == otherInferenceDevice); }