diff --git a/test/testsuite/enzyme/orthnull.jl b/test/testsuite/enzyme/orthnull.jl index 9fae30e5..280e546e 100644 --- a/test/testsuite/enzyme/orthnull.jl +++ b/test/testsuite/enzyme/orthnull.jl @@ -17,7 +17,7 @@ end """ test_enzyme_left_orth(T, sz; rng, atol, rtol) -Test the Enzyme reverse-mode AD rules for `left_orth` with QR and polar (when `m >= n`) +Test the Enzyme forward- and reverse-mode AD rules for `left_orth` with QR and polar (when `m >= n`) algorithms, and their in-place variants. """ function test_enzyme_left_orth( @@ -44,6 +44,9 @@ function test_enzyme_left_orth( VC, ΔVC = ad_left_orth_setup(A) test_reverse(left_orth, RT, (A, TA), (alg, Const); atol, rtol, fdm, output_tangent = ΔVC) test_reverse(call_and_zero!, RT, (left_orth!, Const), (A, TA), (alg, Const); atol, rtol, fdm, output_tangent = ΔVC) + A = instantiate_matrix(T, sz) + test_forward(left_orth, RT, (A, TA), (alg, Const); atol, rtol, fdm) + test_forward(call_and_zero!, RT, (left_orth!, Const), (A, TA), (alg, Const); atol, rtol, fdm) end end end @@ -52,7 +55,7 @@ end """ test_enzyme_right_orth(T, sz; rng, atol, rtol) -Test the Enzyme reverse-mode AD rules for `right_orth` with LQ and polar (when `m <= n`) +Test the Enzyme forward- and reverse-mode AD rules for `right_orth` with LQ and polar (when `m <= n`) algorithms, and their in-place variants. """ function test_enzyme_right_orth( @@ -78,6 +81,9 @@ function test_enzyme_right_orth( CVᴴ, ΔCVᴴ = ad_right_orth_setup(A) test_reverse(right_orth, RT, (A, TA), (alg, Const); atol, rtol, fdm, output_tangent = ΔCVᴴ) test_reverse(call_and_zero!, RT, (right_orth!, Const), (A, TA), (alg, Const); atol, rtol, fdm, output_tangent = ΔCVᴴ) + A = instantiate_matrix(T, sz) + test_forward(right_orth, RT, (A, TA), (alg, Const); atol, rtol, fdm) + test_forward(call_and_zero!, RT, (right_orth!, Const), (A, TA), (alg, Const); atol, rtol, fdm) end end end diff --git a/test/testsuite/mooncake/orthnull.jl b/test/testsuite/mooncake/orthnull.jl index e8970086..afbac9d9 100644 --- a/test/testsuite/mooncake/orthnull.jl +++ b/test/testsuite/mooncake/orthnull.jl @@ -17,7 +17,7 @@ end """ test_mooncake_left_orth(T, sz; rng, atol, rtol) -Test the Mooncake reverse-mode AD rules for `left_orth` with QR and polar (when `m >= n`) +Test the Mooncake forward- and reverse-mode AD rules for `left_orth` with QR and polar (when `m >= n`) algorithms, and their in-place variants. """ function test_mooncake_left_orth( @@ -51,11 +51,11 @@ function test_mooncake_left_orth( Mooncake.TestUtils.test_rule( rng, left_orth, A, alg; - mode = Mooncake.ReverseMode, output_tangent, is_primitive = false, atol, rtol + output_tangent, is_primitive = false, atol, rtol ) Mooncake.TestUtils.test_rule( rng, call_and_zero!, left_orth!, A, alg; - mode = Mooncake.ReverseMode, output_tangent, is_primitive = false, atol, rtol + output_tangent, is_primitive = false, atol, rtol ) end end @@ -65,7 +65,7 @@ end """ test_mooncake_right_orth(T, sz; rng, atol, rtol) -Test the Mooncake reverse-mode AD rules for `right_orth` with LQ and polar (when `m <= n`) +Test the Mooncake forward- and reverse-mode AD rules for `right_orth` with LQ and polar (when `m <= n`) algorithms, and their in-place variants. """ function test_mooncake_right_orth( @@ -99,11 +99,11 @@ function test_mooncake_right_orth( Mooncake.TestUtils.test_rule( rng, right_orth, A, alg; - mode = Mooncake.ReverseMode, output_tangent, is_primitive = false, atol, rtol + output_tangent, is_primitive = false, atol, rtol ) Mooncake.TestUtils.test_rule( rng, call_and_zero!, right_orth!, A, alg; - mode = Mooncake.ReverseMode, output_tangent, is_primitive = false, atol, rtol + output_tangent, is_primitive = false, atol, rtol ) end end