예제 #1
0
        /// <summary>
        /// Shut down the Academy.
        /// </summary>
        public void Dispose()
        {
            DisableAutomaticStepping();

            // Signal to listeners that the academy is being destroyed now
            DestroyAction?.Invoke();

            Communicator?.Dispose();
            Communicator = null;

            m_EnvironmentParameters.Dispose();
            m_StatsRecorder.Dispose();
            SideChannelsManager.UnregisterAllSideChannels();  // unregister custom side channels

            if (m_ModelRunners != null)
            {
                foreach (var mr in m_ModelRunners)
                {
                    mr.Dispose();
                }

                m_ModelRunners = null;
            }

            // Clear out the actions so we're not keeping references to any old objects
            ResetActions();

            // TODO - Pass worker ID or some other identifier,
            // so that multiple envs won't overwrite each others stats.
            TimerStack.Instance.SaveJsonTimers();
            m_Initialized = false;

            // Reset the Lazy instance
            s_Lazy = new Lazy <Academy>(() => new Academy());
        }
예제 #2
0
        /// <summary>
        /// Performs a single environment update of the Academy and Agent
        /// objects within the environment.
        /// </summary>
        public void EnvironmentStep()
        {
            if (!m_HadFirstReset)
            {
                ForcedFullReset();
            }

            AgentPreStep?.Invoke(m_StepCount);

            m_StepCount      += 1;
            m_TotalStepCount += 1;
            AgentIncrementStep?.Invoke();

            using (TimerStack.Instance.Scoped("AgentSendState"))
            {
                AgentSendState?.Invoke();
            }

            using (TimerStack.Instance.Scoped("DecideAction"))
            {
                DecideAction?.Invoke();
            }

            // If the communicator is not on, we need to clear the SideChannel sending queue
            if (!IsCommunicatorOn)
            {
                SideChannelsManager.GetSideChannelMessage();
            }

            using (TimerStack.Instance.Scoped("AgentAct"))
            {
                AgentAct?.Invoke();
            }
        }
예제 #3
0
        public void TestRawBytesSideChannel()
        {
            var str1 = "Test string";
            var str2 = "Test string, second";

            var strSender   = new RawBytesChannel(new Guid("9a5b8954-4f82-11ea-b238-784f4387d1f7"));
            var strReceiver = new RawBytesChannel(new Guid("9a5b8954-4f82-11ea-b238-784f4387d1f7"));
            var dictSender  = new Dictionary <Guid, SideChannel> {
                { strSender.ChannelId, strSender }
            };
            var dictReceiver = new Dictionary <Guid, SideChannel> {
                { strReceiver.ChannelId, strReceiver }
            };

            strSender.SendRawBytes(Encoding.ASCII.GetBytes(str1));
            strSender.SendRawBytes(Encoding.ASCII.GetBytes(str2));

            byte[] fakeData = SideChannelsManager.GetSideChannelMessage(dictSender);
            SideChannelsManager.ProcessSideChannelData(dictReceiver, fakeData);

            var messages = strReceiver.GetAndClearReceivedMessages();

            Assert.AreEqual(messages.Count, 2);
            Assert.AreEqual(Encoding.ASCII.GetString(messages[0]), str1);
            Assert.AreEqual(Encoding.ASCII.GetString(messages[1]), str2);
        }
 public void OnDestroy()
 {
     // De-register the Debug.Log callback
     Application.logMessageReceived -= stringChannel.SendDebugStatementToPython;
     if (Academy.IsInitialized)
     {
         SideChannelsManager.UnregisterSideChannel(stringChannel);
     }
 }
예제 #5
0
 public SamplerTests()
 {
     m_Channel = SideChannelsManager.GetSideChannel <EnvironmentParametersChannel>();
     // if running test on its own
     if (m_Channel == null)
     {
         m_Channel = new EnvironmentParametersChannel();
         SideChannelsManager.RegisterSideChannel(m_Channel);
     }
 }
    public void Awake()
    {
        // We create the Side Channel
        stringChannel = new StringLogSideChannel();

        // When a Debug.Log message is created, we send it to the stringChannel
        Application.logMessageReceived += stringChannel.SendDebugStatementToPython;

        // The channel must be registered with the SideChannelManager class
        SideChannelsManager.RegisterSideChannel(stringChannel);
    }
예제 #7
0
        public void TestFloatPropertiesSideChannel()
        {
            var k1        = "gravity";
            var k2        = "length";
            int wasCalled = 0;

            var propA        = new FloatPropertiesChannel();
            var propB        = new FloatPropertiesChannel();
            var dictReceiver = new Dictionary <Guid, SideChannel> {
                { propA.ChannelId, propA }
            };
            var dictSender = new Dictionary <Guid, SideChannel> {
                { propB.ChannelId, propB }
            };

            propA.RegisterCallback(k1, f => { wasCalled++; });
            var tmp = propB.GetWithDefault(k2, 3.0f);

            Assert.AreEqual(tmp, 3.0f);
            propB.Set(k2, 1.0f);
            tmp = propB.GetWithDefault(k2, 3.0f);
            Assert.AreEqual(tmp, 1.0f);

            byte[] fakeData = SideChannelsManager.GetSideChannelMessage(dictSender);
            SideChannelsManager.ProcessSideChannelData(dictReceiver, fakeData);

            tmp = propA.GetWithDefault(k2, 3.0f);
            Assert.AreEqual(tmp, 1.0f);

            Assert.AreEqual(wasCalled, 0);
            propB.Set(k1, 1.0f);
            Assert.AreEqual(wasCalled, 0);
            fakeData = SideChannelsManager.GetSideChannelMessage(dictSender);
            SideChannelsManager.ProcessSideChannelData(dictReceiver, fakeData);
            Assert.AreEqual(wasCalled, 1);

            var keysA = propA.Keys();

            Assert.AreEqual(2, keysA.Count);
            Assert.IsTrue(keysA.Contains(k1));
            Assert.IsTrue(keysA.Contains(k2));

            var keysB = propA.Keys();

            Assert.AreEqual(2, keysB.Count);
            Assert.IsTrue(keysB.Contains(k1));
            Assert.IsTrue(keysB.Contains(k2));
        }
        public void TestAcademyDispose()
        {
            var envParams1    = SideChannelsManager.GetSideChannel <EnvironmentParametersChannel>();
            var engineParams1 = SideChannelsManager.GetSideChannel <EngineConfigurationChannel>();
            var statsParams1  = SideChannelsManager.GetSideChannel <StatsSideChannel>();

            Academy.Instance.Dispose();

            Academy.Instance.LazyInitialize();
            var envParams2    = SideChannelsManager.GetSideChannel <EnvironmentParametersChannel>();
            var engineParams2 = SideChannelsManager.GetSideChannel <EngineConfigurationChannel>();
            var statsParams2  = SideChannelsManager.GetSideChannel <StatsSideChannel>();

            Academy.Instance.Dispose();

            Assert.AreNotEqual(envParams1, envParams2);
            Assert.AreNotEqual(engineParams1, engineParams2);
            Assert.AreNotEqual(statsParams1, statsParams2);
        }
예제 #9
0
        public void GaussianSamplerTest()
        {
            float  mean      = 3.0f;
            float  stddev    = 0.2f;
            string parameter = "parameter2";

            using (var outgoingMsg = new OutgoingMessage())
            {
                outgoingMsg.WriteString(parameter);
                // 1 indicates this meessage is a Sampler
                outgoingMsg.WriteInt32(1);
                outgoingMsg.WriteInt32(k_Seed);
                outgoingMsg.WriteInt32((int)SamplerType.Gaussian);
                outgoingMsg.WriteFloat32(mean);
                outgoingMsg.WriteFloat32(stddev);
                byte[] message = GetByteMessage(m_Channel, outgoingMsg);
                SideChannelsManager.ProcessSideChannelData(message);
            }
            Assert.AreEqual(2.936162f, m_Channel.GetWithDefault(parameter, 1.0f), k_Epsilon);
            Assert.AreEqual(2.951348f, m_Channel.GetWithDefault(parameter, 1.0f), k_Epsilon);
        }
예제 #10
0
        public void UniformSamplerTest()
        {
            float  min_value = 1.0f;
            float  max_value = 2.0f;
            string parameter = "parameter1";

            using (var outgoingMsg = new OutgoingMessage())
            {
                outgoingMsg.WriteString(parameter);
                // 1 indicates this meessage is a Sampler
                outgoingMsg.WriteInt32(1);
                outgoingMsg.WriteInt32(k_Seed);
                outgoingMsg.WriteInt32((int)SamplerType.Uniform);
                outgoingMsg.WriteFloat32(min_value);
                outgoingMsg.WriteFloat32(max_value);
                byte[] message = GetByteMessage(m_Channel, outgoingMsg);
                SideChannelsManager.ProcessSideChannelData(message);
            }
            Assert.AreEqual(1.208888f, m_Channel.GetWithDefault(parameter, 1.0f), k_Epsilon);
            Assert.AreEqual(1.118017f, m_Channel.GetWithDefault(parameter, 1.0f), k_Epsilon);
        }
예제 #11
0
        public void TestIntegerSideChannel()
        {
            var intSender   = new TestSideChannel();
            var intReceiver = new TestSideChannel();
            var dictSender  = new Dictionary <Guid, SideChannel> {
                { intSender.ChannelId, intSender }
            };
            var dictReceiver = new Dictionary <Guid, SideChannel> {
                { intReceiver.ChannelId, intReceiver }
            };

            intSender.SendInt(4);
            intSender.SendInt(5);
            intSender.SendInt(6);

            byte[] fakeData = SideChannelsManager.GetSideChannelMessage(dictSender);
            SideChannelsManager.ProcessSideChannelData(dictReceiver, fakeData);

            Assert.AreEqual(intReceiver.messagesReceived[0], 4);
            Assert.AreEqual(intReceiver.messagesReceived[1], 5);
            Assert.AreEqual(intReceiver.messagesReceived[2], 6);
        }
        public void TestAcademy()
        {
            Assert.AreEqual(false, Academy.IsInitialized);
            var aca = Academy.Instance;

            Assert.AreEqual(true, Academy.IsInitialized);

            // Check that init is idempotent
            aca.LazyInitialize();
            aca.LazyInitialize();

            Assert.AreEqual(0, aca.EpisodeCount);
            Assert.AreEqual(0, aca.StepCount);
            Assert.AreEqual(0, aca.TotalStepCount);
            Assert.AreNotEqual(null, SideChannelsManager.GetSideChannel <EnvironmentParametersChannel>());
            Assert.AreNotEqual(null, SideChannelsManager.GetSideChannel <EngineConfigurationChannel>());
            Assert.AreNotEqual(null, SideChannelsManager.GetSideChannel <StatsSideChannel>());

            // Check that Dispose is idempotent
            aca.Dispose();
            Assert.AreEqual(false, Academy.IsInitialized);
            aca.Dispose();
        }
예제 #13
0
        public void MultiRangeUniformSamplerTest()
        {
            float[] intervals = new float[4];
            intervals[0] = 1.2f;
            intervals[1] = 2f;
            intervals[2] = 3.2f;
            intervals[3] = 4.1f;
            string parameter = "parameter3";

            using (var outgoingMsg = new OutgoingMessage())
            {
                outgoingMsg.WriteString(parameter);
                // 1 indicates this meessage is a Sampler
                outgoingMsg.WriteInt32(1);
                outgoingMsg.WriteInt32(k_Seed);
                outgoingMsg.WriteInt32((int)SamplerType.MultiRangeUniform);
                outgoingMsg.WriteFloatList(intervals);
                byte[] message = GetByteMessage(m_Channel, outgoingMsg);
                SideChannelsManager.ProcessSideChannelData(message);
            }
            Assert.AreEqual(3.387999f, m_Channel.GetWithDefault(parameter, 1.0f), k_Epsilon);
            Assert.AreEqual(1.294413f, m_Channel.GetWithDefault(parameter, 1.0f), k_Epsilon);
        }
예제 #14
0
        /// <summary>
        /// Initializes the environment, configures it and initializes the Academy.
        /// </summary>
        void InitializeEnvironment()
        {
            TimerStack.Instance.AddMetadata("communication_protocol_version", k_ApiVersion);
            TimerStack.Instance.AddMetadata("com.unity.ml-agents_version", k_PackageVersion);

            EnableAutomaticStepping();

            SideChannelsManager.RegisterSideChannel(new EngineConfigurationChannel());
            m_EnvironmentParameters = new EnvironmentParameters();
            m_StatsRecorder         = new StatsRecorder();

            // Try to launch the communicator by using the arguments passed at launch
            var port = ReadPortFromArgs();

            if (port > 0)
            {
                Communicator = new RpcCommunicator(
                    new CommunicatorInitParameters
                {
                    port = port
                }
                    );
            }

            if (Communicator != null)
            {
                // We try to exchange the first message with Python. If this fails, it means
                // no Python Process is ready to train the environment. In this case, the
                //environment must use Inference.
                try
                {
                    var unityRlInitParameters = Communicator.Initialize(
                        new CommunicatorInitParameters
                    {
                        unityCommunicationVersion = k_ApiVersion,
                        unityPackageVersion       = k_PackageVersion,
                        name = "AcademySingleton",
                    });
                    UnityEngine.Random.InitState(unityRlInitParameters.seed);
                    // We might have inference-only Agents, so set the seed for them too.
                    m_InferenceSeed = unityRlInitParameters.seed;
                }
                catch
                {
                    Debug.Log($"" +
                              $"Couldn't connect to trainer on port {port} using API version {k_ApiVersion}. " +
                              "Will perform inference instead."
                              );
                    Communicator = null;
                }

                if (Communicator != null)
                {
                    Communicator.QuitCommandReceived  += OnQuitCommandReceived;
                    Communicator.ResetCommandReceived += OnResetCommand;
                }
            }

            // If a communicator is enabled/provided, then we assume we are in
            // training mode. In the absence of a communicator, we assume we are
            // in inference mode.

            ResetActions();
        }
예제 #15
0
 /// <summary>
 /// Constructor.
 /// </summary>
 internal EnvironmentParameters()
 {
     m_Channel = new EnvironmentParametersChannel();
     SideChannelsManager.RegisterSideChannel(m_Channel);
 }
예제 #16
0
 /// <summary>
 /// Constructor.
 /// </summary>
 internal StatsRecorder()
 {
     m_Channel = new StatsSideChannel();
     SideChannelsManager.RegisterSideChannel(m_Channel);
 }
예제 #17
0
 internal void Dispose()
 {
     SideChannelsManager.UnregisterSideChannel(m_Channel);
 }