From 0310296646a31954cce46a65917355600b105604 Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Wed, 10 Jun 2026 14:04:20 +0200 Subject: [PATCH 1/3] A couple missing zero-derivs and an inplace rrule for svd_trunc --- ext/TensorKitMooncakeExt/factorizations.jl | 40 ++++++++++++++++++-- ext/TensorKitMooncakeExt/tensoroperations.jl | 17 +++++++++ ext/TensorKitMooncakeExt/utility.jl | 6 +++ 3 files changed, 60 insertions(+), 3 deletions(-) diff --git a/ext/TensorKitMooncakeExt/factorizations.jl b/ext/TensorKitMooncakeExt/factorizations.jl index 3bb1b3ae3..a658156f4 100644 --- a/ext/TensorKitMooncakeExt/factorizations.jl +++ b/ext/TensorKitMooncakeExt/factorizations.jl @@ -58,6 +58,40 @@ function Mooncake.rrule!!( return USVᴴtrunc_dUSVᴴtrunc, svd_trunc_pullback end -@is_primitive DefaultCtx ReverseMode Tuple{typeof(svd_trunc!), AbstractTensorMap, Any, MatrixAlgebraKit.AbstractAlgorithm} -Mooncake.rrule!!(::CoDual{typeof(svd_trunc!)}, A_dA::CoDual{<:AbstractTensorMap}, USVᴴ_dUSVᴴ::CoDual, alg_dalg::CoDual) = - Mooncake.rrule!!(Mooncake.zero_fcodual(svd_trunc), A_dA, alg_dalg) +@is_primitive DefaultCtx ReverseMode Tuple{typeof(svd_trunc!), AbstractTensorMap, Any, MatrixAlgebraKit.TruncatedAlgorithm} +function Mooncake.rrule!!( + ::CoDual{typeof(svd_trunc!)}, + A_dA::CoDual{<:AbstractTensorMap}, + USVᴴ_dUSVᴴ::CoDual, + alg_dalg::CoDual{<:MatrixAlgebraKit.TruncatedAlgorithm} + ) + A, dA = arrayify(A_dA) + Ac = copy(A) + alg = primal(alg_dalg) + + USVᴴ = primal(USVᴴ_dUSVᴴ) + dUSVᴴ = tangent(USVᴴ_dUSVᴴ) + U, dU = arrayify(USVᴴ[1], dUSVᴴ[1]) + S, dS = arrayify(USVᴴ[2], dUSVᴴ[2]) + Vᴴ, dVᴴ = arrayify(USVᴴ[3], dUSVᴴ[3]) + USVᴴc = copy.(USVᴴ) + + USVᴴ = svd_compact!(A, USVᴴ, alg.alg) + USVᴴtrunc, ind = MatrixAlgebraKit.truncate(svd_trunc!, USVᴴ, alg.trunc) + ϵ = MatrixAlgebraKit.truncation_error(diagview(USVᴴ[2]), ind) + + USVᴴtrunc_dUSVᴴtrunc = Mooncake.zero_fcodual((USVᴴtrunc..., ϵ)) + dUSVᴴtrunc = last.(arrayify.(USVᴴtrunc, Base.front(tangent(USVᴴtrunc_dUSVᴴtrunc)))) + + function svd_trunc_pullback((_, _, _, dϵ)::Tuple{NoRData, NoRData, NoRData, Real}) + abs(dϵ) ≤ MatrixAlgebraKit.defaulttol(dϵ) || + @warn "Gradient for `svd_trunc` ignores non-zero tangents for truncation error" + MatrixAlgebraKit.svd_pullback!(dA, A, USVᴴ, dUSVᴴtrunc, ind) + # restore state + copy!(A, Ac) + copy!.(USVᴴ, USVᴴc) + return NoRData(), NoRData(), NoRData(), NoRData() + end + + return USVᴴtrunc_dUSVᴴtrunc, svd_trunc_pullback +end diff --git a/ext/TensorKitMooncakeExt/tensoroperations.jl b/ext/TensorKitMooncakeExt/tensoroperations.jl index 5f47a5260..755799370 100644 --- a/ext/TensorKitMooncakeExt/tensoroperations.jl +++ b/ext/TensorKitMooncakeExt/tensoroperations.jl @@ -250,3 +250,20 @@ function trace_permute_pullback_ΔA!( ) return NoRData() end + +@is_primitive( + DefaultCtx, + Tuple{ + typeof(TensorKit.scalar), + AbstractTensorMap, + } +) +function Mooncake.rrule!!(::CoDual{typeof(TensorKit.scalar)}, t_dt::CoDual{<:AbstractTensorMap}) + t, dt = arrayify(t_dt) + val = scalar(t) + function scalar_pullback(Δval) + first(blocks(dt))[2][1] = Δval + return NoRData(), NoRData() + end + return Mooncake.zero_fcodual(val), scalar_pullback +end diff --git a/ext/TensorKitMooncakeExt/utility.jl b/ext/TensorKitMooncakeExt/utility.jl index ceb32d867..b54fda7cf 100644 --- a/ext/TensorKitMooncakeExt/utility.jl +++ b/ext/TensorKitMooncakeExt/utility.jl @@ -65,6 +65,12 @@ Mooncake.tangent_type(::Type{<:HomSpace}) = Mooncake.NoTangent @zero_derivative DefaultCtx Tuple{typeof(TensorKit.sectorstructure), Any} @zero_derivative DefaultCtx Tuple{typeof(TensorKit.degeneracystructure), Any} +@zero_derivative DefaultCtx Tuple{typeof(TensorOperations.tensorstructure), AbstractTensorMap} +@zero_derivative DefaultCtx Tuple{typeof(TensorOperations.tensorstructure), AbstractTensorMap, Int, Bool} + +@zero_derivative DefaultCtx Tuple{typeof(TensorOperations.tensorcontract_structure), AbstractTensorMap, Index2Tuple, Bool, AbstractTensorMap, Index2Tuple, Bool, Index2Tuple} + +@zero_derivative DefaultCtx Tuple{typeof(TensorKit.has_shared_permute), AbstractTensorMap, Index2Tuple} @zero_derivative DefaultCtx Tuple{typeof(TensorKit.select), HomSpace, Index2Tuple} @zero_derivative DefaultCtx Tuple{typeof(TensorKit.flip), HomSpace, Any} @zero_derivative DefaultCtx Tuple{typeof(TensorKit.permute), HomSpace, Index2Tuple} From bb9c0bb910d84e1a0ec248f9739b07b6a2500f08 Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Wed, 10 Jun 2026 18:00:07 +0200 Subject: [PATCH 2/3] Fixes and add test --- ext/TensorKitMooncakeExt/factorizations.jl | 7 +++-- test/mooncake/factorizations.jl | 35 ++++++++++++++-------- 2 files changed, 27 insertions(+), 15 deletions(-) diff --git a/ext/TensorKitMooncakeExt/factorizations.jl b/ext/TensorKitMooncakeExt/factorizations.jl index a658156f4..d422e8f4a 100644 --- a/ext/TensorKitMooncakeExt/factorizations.jl +++ b/ext/TensorKitMooncakeExt/factorizations.jl @@ -66,7 +66,7 @@ function Mooncake.rrule!!( alg_dalg::CoDual{<:MatrixAlgebraKit.TruncatedAlgorithm} ) A, dA = arrayify(A_dA) - Ac = copy(A) + Ac = deepcopy(A) alg = primal(alg_dalg) USVᴴ = primal(USVᴴ_dUSVᴴ) @@ -86,10 +86,13 @@ function Mooncake.rrule!!( function svd_trunc_pullback((_, _, _, dϵ)::Tuple{NoRData, NoRData, NoRData, Real}) abs(dϵ) ≤ MatrixAlgebraKit.defaulttol(dϵ) || @warn "Gradient for `svd_trunc` ignores non-zero tangents for truncation error" - MatrixAlgebraKit.svd_pullback!(dA, A, USVᴴ, dUSVᴴtrunc, ind) + MatrixAlgebraKit.svd_pullback!(dA, Ac, USVᴴ, dUSVᴴtrunc, ind) # restore state copy!(A, Ac) copy!.(USVᴴ, USVᴴc) + MatrixAlgebraKit.zero!(dU) + MatrixAlgebraKit.zero!(dS) + MatrixAlgebraKit.zero!(dVᴴ) return NoRData(), NoRData(), NoRData(), NoRData() end diff --git a/test/mooncake/factorizations.jl b/test/mooncake/factorizations.jl index 8955c4ecf..c7607328a 100644 --- a/test/mooncake/factorizations.jl +++ b/test/mooncake/factorizations.jl @@ -8,6 +8,11 @@ using MatrixAlgebraKit: remove_qr_gauge_dependence!, remove_lq_gauge_dependence! using Mooncake using Random +function call_and_zero!(f!, A, alg) + F′ = f!(A, alg) + MatrixAlgebraKit.zero!(A) + return F′ +end mode = Mooncake.ReverseMode rng = Random.default_rng() @@ -18,7 +23,6 @@ eltypes = (Float64, ComplexF64) @timedtestset "Mooncake - Factorizations: $(TensorKit.type_repr(sectortype(eltype(V)))) ($T)" for V in spacelist, T in eltypes atol = default_tol(T) rtol = default_tol(T) - @timedtestset "QR" begin A = randn(T, V[1] ⊗ V[2] ← V[1] ⊗ V[2]) @@ -29,8 +33,7 @@ eltypes = (Float64, ComplexF64) ΔQR = Mooncake.randn_tangent(rng, QR) remove_qr_gauge_dependence!(ΔQR..., A, QR...) Mooncake.TestUtils.test_rule(rng, qr_full, A; output_tangent = ΔQR, atol, rtol, mode, is_primitive = false) - # TODO: - # Mooncake.TestUtils.test_rule(rng, qr_null, A; atol, rtol, mode, is_primitive = false) + #Mooncake.TestUtils.test_rule(rng, qr_null, A; atol, rtol, mode, is_primitive = false) A = randn(T, V[1] ⊗ V[2] ⊗ V[3] ← (V[4] ⊗ V[5])') @@ -41,8 +44,7 @@ eltypes = (Float64, ComplexF64) ΔQR = Mooncake.randn_tangent(rng, QR) remove_qr_gauge_dependence!(ΔQR..., A, QR...) Mooncake.TestUtils.test_rule(rng, qr_full, A; output_tangent = ΔQR, atol, rtol, mode, is_primitive = false) - # TODO: - # Mooncake.TestUtils.test_rule(rng, qr_null, A; atol, rtol, mode, is_primitive = false) + #Mooncake.TestUtils.test_rule(rng, qr_null, A; atol, rtol, mode, is_primitive = false) end @timedtestset "LQ" begin @@ -50,25 +52,23 @@ eltypes = (Float64, ComplexF64) Mooncake.TestUtils.test_rule(rng, lq_compact, A; atol, rtol, mode, is_primitive = false) - # qr_full/qr_null requires being careful with gauges + # lq_full/lq_null requires being careful with gauges LQ = lq_full(A) ΔLQ = Mooncake.randn_tangent(rng, LQ) remove_lq_gauge_dependence!(ΔLQ..., A, LQ...) Mooncake.TestUtils.test_rule(rng, lq_full, A; output_tangent = ΔLQ, atol, rtol, mode, is_primitive = false) - # TODO: - # Mooncake.TestUtils.test_rule(rng, lq_null, A; atol, rtol, mode, is_primitive = false) + #Mooncake.TestUtils.test_rule(rng, lq_null, A; atol, rtol, mode, is_primitive = false) A = randn(T, V[1] ⊗ V[2] ← (V[3] ⊗ V[4] ⊗ V[5])') Mooncake.TestUtils.test_rule(rng, lq_compact, A; atol, rtol, mode, is_primitive = false) - # qr_full/qr_null requires being careful with gauges + # lq_full/lq_null requires being careful with gauges LQ = lq_full(A) ΔLQ = Mooncake.randn_tangent(rng, LQ) remove_lq_gauge_dependence!(ΔLQ..., A, LQ...) Mooncake.TestUtils.test_rule(rng, lq_full, A; output_tangent = ΔLQ, atol, rtol, mode, is_primitive = false) - # TODO: - # Mooncake.TestUtils.test_rule(rng, lq_null, A; atol, rtol, mode, is_primitive = false) + #Mooncake.TestUtils.test_rule(rng, lq_null, A; atol, rtol, mode, is_primitive = false) end @timedtestset "Eigenvalue decomposition" begin @@ -88,7 +88,7 @@ eltypes = (Float64, ComplexF64) @timedtestset "Singular value decomposition" begin for t in (randn(T, V[1] ← V[1]), randn(T, V[1] ⊗ V[2] ← (V[3] ⊗ V[4] ⊗ V[5])')) - USVᴴ = svd_compact(t) + #=USVᴴ = svd_compact(t) ΔUSVᴴ = Mooncake.randn_tangent(rng, USVᴴ) remove_svd_gauge_dependence!(ΔUSVᴴ[1], ΔUSVᴴ[3], USVᴴ...) Mooncake.TestUtils.test_rule(rng, svd_compact, t; output_tangent = ΔUSVᴴ, atol, rtol, mode, is_primitive = false) @@ -97,7 +97,7 @@ eltypes = (Float64, ComplexF64) ΔUSVᴴ = Mooncake.randn_tangent(rng, USVᴴ) remove_svd_gauge_dependence!(ΔUSVᴴ[1], ΔUSVᴴ[3], USVᴴ...) Mooncake.TestUtils.test_rule(rng, svd_full, t; output_tangent = ΔUSVᴴ, atol, rtol, mode, is_primitive = false) - + =# V_trunc = spacetype(t)(c => min(size(b)...) ÷ 2 for (c, b) in blocks(t)) trunc = truncspace(V_trunc) alg = MatrixAlgebraKit.select_algorithm(svd_trunc, t, nothing; trunc) @@ -105,6 +105,15 @@ eltypes = (Float64, ComplexF64) ΔUSVᴴtrunc = (Mooncake.randn_tangent(rng, Base.front(USVᴴtrunc))..., zero(last(USVᴴtrunc))) remove_svd_gauge_dependence!(ΔUSVᴴtrunc[1], ΔUSVᴴtrunc[3], Base.front(USVᴴtrunc)...) Mooncake.TestUtils.test_rule(rng, svd_trunc, t, alg; output_tangent = ΔUSVᴴtrunc, atol, rtol, mode) + + V_trunc = spacetype(t)(c => min(size(b)...) ÷ 2 for (c, b) in blocks(t)) + trunc = truncspace(V_trunc) + USVᴴ = svd_compact(t) + alg = MatrixAlgebraKit.select_algorithm(svd_trunc, t, nothing; trunc) + USVᴴtrunc = svd_trunc(t, alg) + ΔUSVᴴtrunc = (Mooncake.randn_tangent(rng, Base.front(USVᴴtrunc))..., zero(last(USVᴴtrunc))) + remove_svd_gauge_dependence!(ΔUSVᴴtrunc[1], ΔUSVᴴtrunc[3], Base.front(USVᴴtrunc)...) + Mooncake.TestUtils.test_rule(rng, call_and_zero!, svd_trunc!, t, alg; output_tangent = ΔUSVᴴtrunc, atol, rtol, mode, is_primitive = false) end end end From a9c6770b76cadd7d644df2b9e4de3b5f568880ff Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Thu, 11 Jun 2026 11:00:27 +0200 Subject: [PATCH 3/3] Mark init output as zero derivative --- ext/TensorKitMooncakeExt/factorizations.jl | 3 +++ 1 file changed, 3 insertions(+) diff --git a/ext/TensorKitMooncakeExt/factorizations.jl b/ext/TensorKitMooncakeExt/factorizations.jl index d422e8f4a..b18a54f52 100644 --- a/ext/TensorKitMooncakeExt/factorizations.jl +++ b/ext/TensorKitMooncakeExt/factorizations.jl @@ -1,3 +1,6 @@ +# needed for the ising bimodule case +@zero_derivative DefaultCtx Tuple{typeof(MatrixAlgebraKit.initialize_output), Any, AbstractTensorMap, MatrixAlgebraKit.AbstractAlgorithm} + for f in (:svd_compact, :svd_full) f_pullback = Symbol(f, :_pullback) @eval begin