Ejemplo n.º 1
0
        // TODO: this test doesn't work due to bugs in Elementwise
        public void PushCoherentGradientOnComplexAbstraction()
        {
            var x = T.Shared(NN.Range <float>(4), "x");
            var b = T.Scalar <float>("b");

            var y    = T.Apply(x, x_ => (x_ > 0f) * b + x_ + b);
            var loss = T.Sum(y);

            AssertArray.WithMessage("Can't compile the gradient.", () =>
                                    T.Function(input: b, output: T.Grad(loss, b))
                                    );
            AssertTensor.PassesGradientCheck(loss, b);
        }
Ejemplo n.º 2
0
        public void ScanPassesGradientCheckOnRec()
        {
            var n    = 5;
            var zero = T.Shared(0.2f * NN.Random.Uniform(-1.0f, 1.0f, n).As <float>(), "zero");
            var b    = T.Shared(0.2f * NN.Random.Uniform(-1.0f, 1.0f, n).As <float>(), "b");
            var xs   = T.Matrix <float>(-1, n, "xs");

            var sum   = T.Scan((x, acc) => acc + x + b, sequence: xs, outputsInfo: zero)[-1];
            var norm2 = T.Norm2(sum);

            var epsilon       = 0.001f;
            var checkManually = T.RandomGradientCheck(xs, norm2, b, computed: 2 * xs.Shape[0].As <float>() * sum);

            for (int _ = 0; _ < 50; ++_)
            {
                var xs_            = NN.Random.Uniform(-1, 1, 10, n).As <float>();
                var checkRes       = checkManually(xs_, epsilon);
                var finite         = checkRes.Item1;
                var backpropagated = checkRes.Item2;
                AssertArray.WithMessage("GradientCheck isn't precise enough", () =>
                                        AssertArray.AreAlmostEqual(finite, backpropagated, relativeErr: 1e-3f, absErr: 1e-4f)
                                        );
            }
            ;

            var checkGrad = T.RandomGradientCheck(xs, norm2, b);

            for (int _ = 0; _ < 50; ++_)
            {
                var xs_            = NN.Random.Uniform(-1, 1, 10, n).As <float>();
                var checkRes       = checkGrad(xs_, epsilon);
                var finite         = checkRes.Item1;
                var backpropagated = checkRes.Item2;
                AssertArray.WithMessage("Backward isn't precise enough", () =>
                                        AssertArray.AreAlmostEqual(finite, backpropagated, relativeErr: 1e-3f, absErr: 1e-4f)
                                        );
            }
            ;
        }