/// <summary> /// Checks the gradient of an expression with one input. /// If a shape of the input is unknown, it will be replaced by 10. /// </summary> public static void PassesGradientCheck <X>(Tensor <X> .Var input, Scalar <float> expr, Tensor <float> W, float epsilon = 0.001f, float relativeErr = 1e-3f, float absErr = 1e-4f, int repeat = 50, Func <Array <X> > init = null) { var xShape = input.Shape.Select(s => (s as Scalar <int> .Const)?.Value ?? 10).ToArray(); var checkGrad = T.RandomGradientCheck(new[] { input }, expr, W); if (init == null) { init = () => NN.Random.Uniform(-1f, 1f, xShape).As <X>(); } var fault = 0; var last = ""; for (int _ = 0; _ < repeat; ++_) { var x = init(); var checkRes = checkGrad(x, epsilon); var finite = checkRes.Item1; var backpropagated = checkRes.Item2; if (!AssertArray.CheckAreAlmostEqual(finite, backpropagated, relativeErr, absErr)) { var abs = Math.Abs(finite - backpropagated); var relative = 2 * abs / (Math.Abs(finite) + Math.Abs(backpropagated)); last += $"Expected: {finite}, actual {backpropagated}, diff {abs}, relative {relative}.\n"; ++fault; } } if (fault > 0) { throw new Exception($"The computed gradient of {W.Name} doesn't match finite difference (failed {fault} times over {repeat}).\n{last}"); } }
/// <summary> /// Checks the gradient of an expression without inputs. /// </summary> public static void PassesGradientCheck(Scalar <float> expr, Scalar <float> W, float epsilon = 0.001f, float relativeErr = 1e-3f, float absErr = 1e-4f, int repeat = 6) { var checkGrad = T.RandomGradientCheck(EmptyArray <IVar> .Value, expr, W); var fault = 0; var errors = ""; for (int _ = 0; _ < repeat; ++_) { var eps = (_ % 2 == 0) ? epsilon : -epsilon; var checkRes = checkGrad(eps); var finite = checkRes.Item1; var backpropagated = checkRes.Item2; if (!AssertArray.CheckAreAlmostEqual(finite, backpropagated, relativeErr, absErr)) { var abs = Math.Abs(finite - backpropagated); var relative = 2 * abs / (Math.Abs(finite) + Math.Abs(backpropagated)); errors += $"For epsilon {eps} expected: {finite}, actual {backpropagated}, diff {abs}, relative {relative}.\n"; ++fault; } if (_ % 2 == 1) { epsilon *= 10; } } if (fault > 0) { throw new Exception($"The computed gradient of {W.ToString()} doesn't match finite difference (failed {fault} times over {repeat}).\n{errors}"); } }