From 718af7060a5189c1e84085af5cd7798a7e1bbae3 Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Fri, 5 Jun 2026 16:45:44 +0200 Subject: [PATCH] Test orthnull with polar in fwd mode --- test/testsuite/enzyme/orthnull.jl | 10 ++++++++-- test/testsuite/mooncake/orthnull.jl | 12 ++++++------ 2 files changed, 14 insertions(+), 8 deletions(-) diff --git a/test/testsuite/enzyme/orthnull.jl b/test/testsuite/enzyme/orthnull.jl index 9fae30e5e..280e546eb 100644 --- a/test/testsuite/enzyme/orthnull.jl +++ b/test/testsuite/enzyme/orthnull.jl @@ -17,7 +17,7 @@ end """ test_enzyme_left_orth(T, sz; rng, atol, rtol) -Test the Enzyme reverse-mode AD rules for `left_orth` with QR and polar (when `m >= n`) +Test the Enzyme forward- and reverse-mode AD rules for `left_orth` with QR and polar (when `m >= n`) algorithms, and their in-place variants. """ function test_enzyme_left_orth( @@ -44,6 +44,9 @@ function test_enzyme_left_orth( VC, ΔVC = ad_left_orth_setup(A) test_reverse(left_orth, RT, (A, TA), (alg, Const); atol, rtol, fdm, output_tangent = ΔVC) test_reverse(call_and_zero!, RT, (left_orth!, Const), (A, TA), (alg, Const); atol, rtol, fdm, output_tangent = ΔVC) + A = instantiate_matrix(T, sz) + test_forward(left_orth, RT, (A, TA), (alg, Const); atol, rtol, fdm) + test_forward(call_and_zero!, RT, (left_orth!, Const), (A, TA), (alg, Const); atol, rtol, fdm) end end end @@ -52,7 +55,7 @@ end """ test_enzyme_right_orth(T, sz; rng, atol, rtol) -Test the Enzyme reverse-mode AD rules for `right_orth` with LQ and polar (when `m <= n`) +Test the Enzyme forward- and reverse-mode AD rules for `right_orth` with LQ and polar (when `m <= n`) algorithms, and their in-place variants. """ function test_enzyme_right_orth( @@ -78,6 +81,9 @@ function test_enzyme_right_orth( CVᴴ, ΔCVᴴ = ad_right_orth_setup(A) test_reverse(right_orth, RT, (A, TA), (alg, Const); atol, rtol, fdm, output_tangent = ΔCVᴴ) test_reverse(call_and_zero!, RT, (right_orth!, Const), (A, TA), (alg, Const); atol, rtol, fdm, output_tangent = ΔCVᴴ) + A = instantiate_matrix(T, sz) + test_forward(right_orth, RT, (A, TA), (alg, Const); atol, rtol, fdm) + test_forward(call_and_zero!, RT, (right_orth!, Const), (A, TA), (alg, Const); atol, rtol, fdm) end end end diff --git a/test/testsuite/mooncake/orthnull.jl b/test/testsuite/mooncake/orthnull.jl index e89700868..afbac9d99 100644 --- a/test/testsuite/mooncake/orthnull.jl +++ b/test/testsuite/mooncake/orthnull.jl @@ -17,7 +17,7 @@ end """ test_mooncake_left_orth(T, sz; rng, atol, rtol) -Test the Mooncake reverse-mode AD rules for `left_orth` with QR and polar (when `m >= n`) +Test the Mooncake forward- and reverse-mode AD rules for `left_orth` with QR and polar (when `m >= n`) algorithms, and their in-place variants. """ function test_mooncake_left_orth( @@ -51,11 +51,11 @@ function test_mooncake_left_orth( Mooncake.TestUtils.test_rule( rng, left_orth, A, alg; - mode = Mooncake.ReverseMode, output_tangent, is_primitive = false, atol, rtol + output_tangent, is_primitive = false, atol, rtol ) Mooncake.TestUtils.test_rule( rng, call_and_zero!, left_orth!, A, alg; - mode = Mooncake.ReverseMode, output_tangent, is_primitive = false, atol, rtol + output_tangent, is_primitive = false, atol, rtol ) end end @@ -65,7 +65,7 @@ end """ test_mooncake_right_orth(T, sz; rng, atol, rtol) -Test the Mooncake reverse-mode AD rules for `right_orth` with LQ and polar (when `m <= n`) +Test the Mooncake forward- and reverse-mode AD rules for `right_orth` with LQ and polar (when `m <= n`) algorithms, and their in-place variants. """ function test_mooncake_right_orth( @@ -99,11 +99,11 @@ function test_mooncake_right_orth( Mooncake.TestUtils.test_rule( rng, right_orth, A, alg; - mode = Mooncake.ReverseMode, output_tangent, is_primitive = false, atol, rtol + output_tangent, is_primitive = false, atol, rtol ) Mooncake.TestUtils.test_rule( rng, call_and_zero!, right_orth!, A, alg; - mode = Mooncake.ReverseMode, output_tangent, is_primitive = false, atol, rtol + output_tangent, is_primitive = false, atol, rtol ) end end