示例#1
0
        public void Generate(Tensor tensor, int batchSize, Dictionary <Agent, AgentInfo> agentInfo)
        {
            tensor.Shape[0] = batchSize;
            var actionSize = tensor.Shape[tensor.Shape.Length - 1];

            tensor.Data = new float[batchSize, actionSize];
            _randomNormal.FillTensor(tensor);
        }
示例#2
0
        public void RandomNormalTestDataNull()
        {
            RandomNormal rn = new RandomNormal(1982);
            TensorProxy  t  = new TensorProxy
            {
                ValueType = TensorProxy.TensorType.FloatingPoint
            };

            Assert.Throws <ArgumentNullException>(() => rn.FillTensor(t));
        }
示例#3
0
        public void RandomNormalTestTensorInt()
        {
            RandomNormal rn = new RandomNormal(1982);
            TensorProxy  t  = new TensorProxy
            {
                ValueType = TensorProxy.TensorType.Integer
            };

            Assert.Throws <NotImplementedException>(() => rn.FillTensor(t));
        }
        public void RandomNormalTestTensor()
        {
            RandomNormal rn = new RandomNormal(1982);
            Tensor       t  = new Tensor
            {
                ValueType = Tensor.TensorType.FloatingPoint,
                Data      = Array.CreateInstance(typeof(float), new long[3] {
                    3, 4, 2
                })
            };

            rn.FillTensor(t);

            float[] reference = new float[]
            {
                -0.2139822f,
                0.5051259f,
                -0.5640336f,
                -0.3357787f,
                -0.2055894f,
                -0.09432302f,
                -0.01419199f,
                0.53621f,
                -0.5507085f,
                -0.2651141f,
                0.09315512f,
                -0.04918706f,
                -0.179625f,
                0.2280539f,
                0.1883962f,
                0.4047216f,
                0.1704049f,
                0.5050544f,
                -0.3365685f,
                0.3542781f,
                0.5951571f,
                0.03460682f,
                -0.5537263f,
                -0.4378373f,
            };

            int i = 0;

            foreach (float f in t.Data)
            {
                Assert.AreEqual(f, reference[i], 0.0001);
                ++i;
            }
        }
示例#5
0
        public void RandomNormalTestTensor()
        {
            RandomNormal rn = new RandomNormal(1982);
            TensorProxy  t  = new TensorProxy
            {
                ValueType = TensorProxy.TensorType.FloatingPoint,
                Data      = new Tensor(1, 3, 4, 2)
            };

            rn.FillTensor(t);

            float[] reference = new float[]
            {
                -0.4315872f,
                -1.11074f,
                0.3414804f,
                -1.130287f,
                0.1413168f,
                -0.5105762f,
                -0.3027347f,
                -0.2645015f,
                1.225356f,
                -0.02921959f,
                0.3716498f,
                -1.092338f,
                0.9561074f,
                -0.5018106f,
                1.167787f,
                -0.7763879f,
                -0.07491868f,
                0.5396146f,
                -0.1377991f,
                0.3331701f,
                0.06144788f,
                0.9520947f,
                1.088157f,
                -1.177194f,
            };

            for (var i = 0; i < t.Data.length; i++)
            {
                Assert.AreEqual(t.Data[i], reference[i], 0.0001);
            }
        }
示例#6
0
 public void Generate(TensorProxy tensorProxy, int batchSize, Dictionary <Agent, AgentInfo> agentInfo)
 {
     TensorUtils.ResizeTensor(tensorProxy, batchSize, _allocator);
     _randomNormal.FillTensor(tensorProxy);
 }