public void GenerateActionMaskInput() { var inputTensor = new TensorProxy { shape = new long[] { 2, 5 }, valueType = TensorProxy.TensorType.FloatingPoint }; const int batchSize = 4; var agentInfos = GetFakeAgents(); var alloc = new TensorCachingAllocator(); var generator = new ActionMaskInputGenerator(alloc); var agent0 = agentInfos[0]; var agent1 = agentInfos[1]; var inputs = new List <AgentInfoSensorsPair> { new AgentInfoSensorsPair { agentInfo = agent0._Info, sensors = agent0.sensors }, new AgentInfoSensorsPair { agentInfo = agent1._Info, sensors = agent1.sensors }, }; generator.Generate(inputTensor, batchSize, inputs); Assert.IsNotNull(inputTensor.data); Assert.AreEqual(inputTensor.data[0, 0], 1); Assert.AreEqual(inputTensor.data[0, 4], 1); Assert.AreEqual(inputTensor.data[1, 0], 0); Assert.AreEqual(inputTensor.data[1, 4], 1); alloc.Dispose(); }
public void GenerateActionMaskInput() { var inputTensor = new Tensor() { Shape = new long[] { 2, 5 }, ValueType = Tensor.TensorType.FloatingPoint }; var batchSize = 4; var agentInfos = GetFakeAgentInfos(); var generator = new ActionMaskInputGenerator(); generator.Generate(inputTensor, batchSize, agentInfos); Assert.IsNotNull(inputTensor.Data as float[, ]); Assert.AreEqual((inputTensor.Data as float[, ])[0, 0], 1); Assert.AreEqual((inputTensor.Data as float[, ])[0, 4], 1); Assert.AreEqual((inputTensor.Data as float[, ])[1, 0], 0); Assert.AreEqual((inputTensor.Data as float[, ])[1, 4], 1); }
public void GenerateActionMaskInput() { var inputTensor = new TensorProxy() { Shape = new long[] { 2, 5 }, ValueType = TensorProxy.TensorType.FloatingPoint }; var batchSize = 4; var agentInfos = GetFakeAgentInfos(); var alloc = new TensorCachingAllocator(); var generator = new ActionMaskInputGenerator(alloc); generator.Generate(inputTensor, batchSize, agentInfos); Assert.IsNotNull(inputTensor.Data); Assert.AreEqual(inputTensor.Data[0, 0], 1); Assert.AreEqual(inputTensor.Data[0, 4], 1); Assert.AreEqual(inputTensor.Data[1, 0], 0); Assert.AreEqual(inputTensor.Data[1, 4], 1); alloc.Dispose(); }