Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions test/testsuite/enzyme/orthnull.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ end
"""
test_enzyme_left_orth(T, sz; rng, atol, rtol)

Test the Enzyme reverse-mode AD rules for `left_orth` with QR and polar (when `m >= n`)
Test the Enzyme forward- and reverse-mode AD rules for `left_orth` with QR and polar (when `m >= n`)
algorithms, and their in-place variants.
"""
function test_enzyme_left_orth(
Expand All @@ -44,6 +44,9 @@ function test_enzyme_left_orth(
VC, ΔVC = ad_left_orth_setup(A)
test_reverse(left_orth, RT, (A, TA), (alg, Const); atol, rtol, fdm, output_tangent = ΔVC)
test_reverse(call_and_zero!, RT, (left_orth!, Const), (A, TA), (alg, Const); atol, rtol, fdm, output_tangent = ΔVC)
A = instantiate_matrix(T, sz)
test_forward(left_orth, RT, (A, TA), (alg, Const); atol, rtol, fdm)
test_forward(call_and_zero!, RT, (left_orth!, Const), (A, TA), (alg, Const); atol, rtol, fdm)
end
end
end
Expand All @@ -52,7 +55,7 @@ end
"""
test_enzyme_right_orth(T, sz; rng, atol, rtol)

Test the Enzyme reverse-mode AD rules for `right_orth` with LQ and polar (when `m <= n`)
Test the Enzyme forward- and reverse-mode AD rules for `right_orth` with LQ and polar (when `m <= n`)
algorithms, and their in-place variants.
"""
function test_enzyme_right_orth(
Expand All @@ -78,6 +81,9 @@ function test_enzyme_right_orth(
CVᴴ, ΔCVᴴ = ad_right_orth_setup(A)
test_reverse(right_orth, RT, (A, TA), (alg, Const); atol, rtol, fdm, output_tangent = ΔCVᴴ)
test_reverse(call_and_zero!, RT, (right_orth!, Const), (A, TA), (alg, Const); atol, rtol, fdm, output_tangent = ΔCVᴴ)
A = instantiate_matrix(T, sz)
test_forward(right_orth, RT, (A, TA), (alg, Const); atol, rtol, fdm)
test_forward(call_and_zero!, RT, (right_orth!, Const), (A, TA), (alg, Const); atol, rtol, fdm)
end
end
end
Expand Down
12 changes: 6 additions & 6 deletions test/testsuite/mooncake/orthnull.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ end
"""
test_mooncake_left_orth(T, sz; rng, atol, rtol)

Test the Mooncake reverse-mode AD rules for `left_orth` with QR and polar (when `m >= n`)
Test the Mooncake forward- and reverse-mode AD rules for `left_orth` with QR and polar (when `m >= n`)
algorithms, and their in-place variants.
"""
function test_mooncake_left_orth(
Expand Down Expand Up @@ -51,11 +51,11 @@ function test_mooncake_left_orth(

Mooncake.TestUtils.test_rule(
rng, left_orth, A, alg;
mode = Mooncake.ReverseMode, output_tangent, is_primitive = false, atol, rtol
output_tangent, is_primitive = false, atol, rtol
)
Mooncake.TestUtils.test_rule(
rng, call_and_zero!, left_orth!, A, alg;
mode = Mooncake.ReverseMode, output_tangent, is_primitive = false, atol, rtol
output_tangent, is_primitive = false, atol, rtol
)
end
end
Expand All @@ -65,7 +65,7 @@ end
"""
test_mooncake_right_orth(T, sz; rng, atol, rtol)

Test the Mooncake reverse-mode AD rules for `right_orth` with LQ and polar (when `m <= n`)
Test the Mooncake forward- and reverse-mode AD rules for `right_orth` with LQ and polar (when `m <= n`)
algorithms, and their in-place variants.
"""
function test_mooncake_right_orth(
Expand Down Expand Up @@ -99,11 +99,11 @@ function test_mooncake_right_orth(

Mooncake.TestUtils.test_rule(
rng, right_orth, A, alg;
mode = Mooncake.ReverseMode, output_tangent, is_primitive = false, atol, rtol
output_tangent, is_primitive = false, atol, rtol
)
Mooncake.TestUtils.test_rule(
rng, call_and_zero!, right_orth!, A, alg;
mode = Mooncake.ReverseMode, output_tangent, is_primitive = false, atol, rtol
output_tangent, is_primitive = false, atol, rtol
)
end
end
Expand Down
Loading