public void FirstBranchMask()
        {
            var actuator1 = new TestActuator(ActionSpec.MakeDiscrete(new[] { 4, 5, 6 }), "actuator1");
            var masker    = new ActuatorDiscreteActionMask(new IActuator[] { actuator1 }, 15, 3);
            var mask      = masker.GetMask();

            Assert.IsNull(mask);
            masker.WriteMask(0, new[] { 1, 2, 3 });
            mask = masker.GetMask();
            Assert.IsFalse(mask[0]);
            Assert.IsTrue(mask[1]);
            Assert.IsTrue(mask[2]);
            Assert.IsTrue(mask[3]);
            Assert.IsFalse(mask[4]);
            Assert.AreEqual(mask.Length, 15);
        }
        public void NullMask()
        {
            var masker = new ActuatorDiscreteActionMask(new List <IActuator>(), 0, 0);
            var mask   = masker.GetMask();

            Assert.IsNull(mask);
        }
        public void ThrowsError()
        {
            var actuator1 = new TestActuator(ActionSpec.MakeDiscrete(new[] { 4, 5, 6 }), "actuator1");
            var masker    = new ActuatorDiscreteActionMask(new IActuator[] { actuator1 }, 15, 3);

            Assert.Catch <UnityAgentsException>(
                () => masker.WriteMask(0, new[] { 5 }));
            Assert.Catch <UnityAgentsException>(
                () => masker.WriteMask(1, new[] { 5 }));
            masker.WriteMask(2, new[] { 5 });
            Assert.Catch <UnityAgentsException>(
                () => masker.WriteMask(3, new[] { 1 }));
            masker.GetMask();
            masker.ResetMask();
            masker.WriteMask(0, new[] { 0, 1, 2, 3 });
            Assert.Catch <UnityAgentsException>(
                () => masker.GetMask());
        }
        public void ThrowsError()
        {
            var actuator1 = new TestActuator(ActionSpec.MakeDiscrete(new[] { 4, 5, 6 }), "actuator1");
            var masker    = new ActuatorDiscreteActionMask(new IActuator[] { actuator1 }, 15, 3);

            Assert.Catch <UnityAgentsException>(
                () => masker.SetActionEnabled(0, 5, false));
            Assert.Catch <UnityAgentsException>(
                () => masker.SetActionEnabled(1, 5, false));
            masker.SetActionEnabled(2, 5, false);
            Assert.Catch <UnityAgentsException>(
                () => masker.SetActionEnabled(3, 1, false));
            masker.GetMask();
            masker.ResetMask();
            masker.SetActionEnabled(0, 0, false);
            masker.SetActionEnabled(0, 1, false);
            masker.SetActionEnabled(0, 2, false);
            masker.SetActionEnabled(0, 3, false);
            Assert.Catch <UnityAgentsException>(
                () => masker.GetMask());
        }
        public void CanOverwriteMask()
        {
            var actuator1 = new TestActuator(ActionSpec.MakeDiscrete(new[] { 4, 5, 6 }), "actuator1");
            var masker    = new ActuatorDiscreteActionMask(new IActuator[] { actuator1 }, 15, 3);

            masker.SetActionEnabled(0, 1, false);
            var mask = masker.GetMask();

            Assert.IsTrue(mask[1]);

            masker.SetActionEnabled(0, 1, true);
            Assert.IsFalse(mask[1]);
        }
        public void MaskReset()
        {
            var actuator1 = new TestActuator(ActionSpec.MakeDiscrete(new[] { 4, 5, 6 }), "actuator1");
            var masker    = new ActuatorDiscreteActionMask(new IActuator[] { actuator1 }, 15, 3);

            masker.WriteMask(1, new[] { 1, 2, 3 });
            masker.ResetMask();
            var mask = masker.GetMask();

            for (var i = 0; i < 15; i++)
            {
                Assert.IsFalse(mask[i]);
            }
        }
        public void SecondBranchMask()
        {
            var actuator1 = new TestActuator(ActionSpec.MakeDiscrete(new[] { 4, 5, 6 }), "actuator1");
            var masker    = new ActuatorDiscreteActionMask(new[] { actuator1 }, 15, 3);

            masker.WriteMask(1, new[] { 1, 2, 3 });
            var mask = masker.GetMask();

            Assert.IsFalse(mask[0]);
            Assert.IsFalse(mask[4]);
            Assert.IsTrue(mask[5]);
            Assert.IsTrue(mask[6]);
            Assert.IsTrue(mask[7]);
            Assert.IsFalse(mask[8]);
            Assert.IsFalse(mask[9]);
        }
        public void TestWriteDiscreteActionMask()
        {
            var ar   = new TestActionReceiver();
            var va   = new VectorActuator(ar, ActionSpec.MakeDiscrete(1, 2, 3), "name");
            var bdam = new ActuatorDiscreteActionMask(new[] { va }, 6, 3);

            var groundTruthMask = new[] { false, true, false, false, true, true };

            ar.Branch = 1;
            ar.Mask   = new[] { 0 };
            va.WriteDiscreteActionMask(bdam);
            ar.Branch = 2;
            ar.Mask   = new[] { 1, 2 };
            va.WriteDiscreteActionMask(bdam);

            Assert.IsTrue(groundTruthMask.SequenceEqual(bdam.GetMask()));
        }
        public void MultipleMaskEdit()
        {
            var actuator1 = new TestActuator(ActionSpec.MakeDiscrete(new[] { 4, 5, 6 }), "actuator1");
            var masker    = new ActuatorDiscreteActionMask(new IActuator[] { actuator1 }, 15, 3);

            masker.WriteMask(0, new[] { 0, 1 });
            masker.WriteMask(0, new[] { 3 });
            masker.WriteMask(2, new[] { 1 });
            var mask = masker.GetMask();

            for (var i = 0; i < 15; i++)
            {
                if ((i == 0) || (i == 1) || (i == 3) || (i == 10))
                {
                    Assert.IsTrue(mask[i]);
                }
                else
                {
                    Assert.IsFalse(mask[i]);
                }
            }
        }