Exemplo n.º 1
0
        public void ReferenceTest()
        {
            Shape shape  = new Shape(ShapeType.Map, 3, 4, 2, 7, 8, 5);
            int   length = shape.Length;

            float[] xval            = (new float[length]).Select((_, idx) => (float)(((idx * 4969 % 17 + 3) * (idx * 6577 % 13 + 5) + idx) % 8)).ToArray();
            OverflowCheckedTensor x = new OverflowCheckedTensor(shape, xval);
            OverflowCheckedTensor y = new OverflowCheckedTensor(shape);

            Flip ope = new Flip(shape, axis: 3);

            ope.Execute(x, y);

            float[] y_actual = y.State;

            AssertError.Tolerance(y_expect, y_actual, 1e-7f, 1e-5f, $"not equal");
        }
Exemplo n.º 2
0
        public void ExecuteTest()
        {
            Random rd = new Random(1234);

            {
                Shape shape = new Shape(ShapeType.Map, 17, 8, 2, 4, 1, 3, 67);

                for (int axis = 0; axis < shape.Ndim; axis++)
                {
                    int stride = 1, length = shape[axis];
                    for (int i = 0; i < axis; i++)
                    {
                        stride *= shape[i];
                    }

                    float[] x = (new float[shape.Length]).Select((_, idx) => (float)(((idx * 4969 % 17 + 3) * (idx * 6577 % 13 + 5) + idx) % 8)).ToArray();

                    OverflowCheckedTensor inval  = new OverflowCheckedTensor(shape, x);
                    OverflowCheckedTensor outval = new OverflowCheckedTensor(shape);

                    Flip ope = new Flip(shape, axis);

                    ope.Execute(inval, outval);

                    CollectionAssert.AreEqual(x, inval.State);

                    float[] y = outval.State;
                    int     p = 0;

                    for (int i = 0; i < shape.Length / length; i++, p = i / stride * stride * length + i % stride)
                    {
                        for (int j = 0; j < length; j++)
                        {
                            if (y[p + (length - j - 1) * stride] != x[p + j * stride])
                            {
                                Assert.Fail($"axis:{axis} outval");
                            }
                        }
                    }

                    Assert.AreEqual(shape.Length, p);

                    Console.WriteLine($"pass : axis{axis}");
                }
            }
        }