From 7b81ec4133b70de897164b81a161c40f9c0e4797 Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Tue, 19 May 2026 16:53:01 +0200 Subject: [PATCH 1/5] Forward rules for TensorOperations calls --- ext/TensorKitMooncakeExt/tensoroperations.jl | 61 +++++++++++++++++++- test/mooncake/tensoroperations.jl | 14 ++--- 2 files changed, 65 insertions(+), 10 deletions(-) diff --git a/ext/TensorKitMooncakeExt/tensoroperations.jl b/ext/TensorKitMooncakeExt/tensoroperations.jl index 0627a7b2c..9d57bd04b 100644 --- a/ext/TensorKitMooncakeExt/tensoroperations.jl +++ b/ext/TensorKitMooncakeExt/tensoroperations.jl @@ -8,7 +8,6 @@ # this permutation is done multiple times. @is_primitive( DefaultCtx, - ReverseMode, Tuple{ typeof(TensorKit.blas_contract!), AbstractTensorMap, @@ -70,6 +69,35 @@ function Mooncake.rrule!!( return C_ΔC, blas_contract_pullback end +function Mooncake.frule!!( + ::Dual{typeof(TensorKit.blas_contract!)}, + C_ΔC::Dual{<:AbstractTensorMap}, + A_ΔA::Dual{<:AbstractTensorMap}, pA_ΔpA::Dual{<:Index2Tuple}, + B_ΔB::Dual{<:AbstractTensorMap}, pB_ΔpB::Dual{<:Index2Tuple}, + pAB_ΔpAB::Dual{<:Index2Tuple}, + α_Δα::Dual{<:Number}, β_Δβ::Dual{<:Number}, + backend_Δbackend::Dual, allocator_Δallocator::Dual + ) + # prepare arguments + (C, ΔC), (A, ΔA), (B, ΔB) = arrayify.((C_ΔC, A_ΔA, B_ΔB)) + pA, pB, pAB = primal.((pA_ΔpA, pB_ΔpB, pAB_ΔpAB)) + α, Δα = Mooncake.extract(α_Δα) + β, Δβ = Mooncake.extract(β_Δβ) + backend, allocator = primal.((backend_Δbackend, allocator_Δallocator)) + # ΔC′ = ΔC*β + C*Δβ + A*B*Δα + ΔA*B*α + A*ΔB*α + scale!(ΔC, β) + if !isa(Δβ, Mooncake.NoTangent) + add!(ΔC, C, Δβ) + end + if !isa(Δα, Mooncake.NoTangent) + TensorKit.blas_contract!(ΔC, A, pA, B, pB, pAB, Δα, One(), backend, allocator) + end + TensorKit.blas_contract!(ΔC, ΔA, pA, B, pB, pAB, α, One(), backend, allocator) + TensorKit.blas_contract!(ΔC, A, pA, ΔB, pB, pAB, α, One(), backend, allocator) + TensorKit.blas_contract!(C, A, pA, B, pB, pAB, α, β, backend, allocator) + return C_ΔC +end + function blas_contract_pullback_ΔA!( ΔA, ΔC, A, pA, B, pB, pAB, α, backend, allocator ) @@ -124,7 +152,6 @@ end # ------------ @is_primitive( DefaultCtx, - ReverseMode, Tuple{ typeof(TensorKit.trace_permute!), AbstractTensorMap, @@ -177,6 +204,36 @@ function Mooncake.rrule!!( return C_ΔC, trace_permute_pullback end +function Mooncake.frule!!( + ::Dual{typeof(TensorKit.trace_permute!)}, + C_ΔC::Dual{<:AbstractTensorMap}, + A_ΔA::Dual{<:AbstractTensorMap}, p_Δp::Dual{<:Index2Tuple}, q_Δq::Dual{<:Index2Tuple}, + α_Δα::Dual{<:Number}, β_Δβ::Dual{<:Number}, + backend_Δbackend::Dual + ) + # prepare arguments + C, ΔC = arrayify(C_ΔC) + A, ΔA = arrayify(A_ΔA) + p = primal(p_Δp) + q = primal(q_Δq) + α, Δα = Mooncake.extract(α_Δα) + β, Δβ = Mooncake.extract(β_Δβ) + backend = primal(backend_Δbackend) + + # dD = dα * tr(A) + α * tr(dA) + dβ * C + β * dC + # dC1 = dβ * C + β * dC + scale!(ΔC, β) + if !isa(Δβ, Mooncake.NoTangent) + add!(ΔC, C, Δβ) + end + if !isa(Δα, Mooncake.NoTangent) + TensorKit.trace_permute!(ΔC, A, p, q, Δα, One(), backend) + end + TensorKit.trace_permute!(ΔC, ΔA, p, q, α, One(), backend) + TensorKit.trace_permute!(C, A, p, q, α, β, backend) + return C_ΔC +end + function trace_permute_pullback_ΔA!( ΔA, ΔC, A, p, q, α, backend ) diff --git a/test/mooncake/tensoroperations.jl b/test/mooncake/tensoroperations.jl index b97b90c2f..cc99c00f1 100644 --- a/test/mooncake/tensoroperations.jl +++ b/test/mooncake/tensoroperations.jl @@ -5,8 +5,6 @@ using VectorInterface: One, Zero using Mooncake using Random - -mode = Mooncake.ReverseMode rng = Random.default_rng() spacelist = ad_spacelist(fast_tests) @@ -53,32 +51,32 @@ eltypes = (Float64, ComplexF64) rng, TensorKit.blas_contract!, C, A, pA, B, pB, pAB, One(), Zero(), TensorOperations.DefaultBackend(), TensorOperations.DefaultAllocator(); - atol, rtol, mode + atol, rtol ) Mooncake.TestUtils.test_rule( rng, TensorKit.blas_contract!, C, A, pA, B, pB, pAB, α, β, TensorOperations.DefaultBackend(), TensorOperations.DefaultAllocator(); - atol, rtol, mode + atol, rtol ) if !(T <: Real) Mooncake.TestUtils.test_rule( rng, TensorKit.blas_contract!, C, A, pA, B, pB, pAB, real(α), real(β), TensorOperations.DefaultBackend(), TensorOperations.DefaultAllocator(); - atol, rtol, mode + atol, rtol ) Mooncake.TestUtils.test_rule( rng, TensorKit.blas_contract!, C, real(A), pA, B, pB, pAB, real(α), real(β), TensorOperations.DefaultBackend(), TensorOperations.DefaultAllocator(); - atol, rtol, mode + atol, rtol ) Mooncake.TestUtils.test_rule( rng, TensorKit.blas_contract!, C, A, pA, real(B), pB, pAB, real(α), real(β), TensorOperations.DefaultBackend(), TensorOperations.DefaultAllocator(); - atol, rtol, mode + atol, rtol ) end end @@ -102,7 +100,7 @@ eltypes = (Float64, ComplexF64) C = randn!(TensorOperations.tensoralloc_add(T, A, p, false, Val(false))) Mooncake.TestUtils.test_rule( rng, TensorKit.trace_permute!, C, A, p, q, α, β, TensorOperations.DefaultBackend(); - atol, rtol, mode + atol, rtol ) end end From 745484d93d006f94da243b5e9dccb2d42544f865 Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Tue, 19 May 2026 17:38:14 +0200 Subject: [PATCH 2/5] Format --- ext/TensorKitMooncakeExt/tensoroperations.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ext/TensorKitMooncakeExt/tensoroperations.jl b/ext/TensorKitMooncakeExt/tensoroperations.jl index 9d57bd04b..8cc9f06ec 100644 --- a/ext/TensorKitMooncakeExt/tensoroperations.jl +++ b/ext/TensorKitMooncakeExt/tensoroperations.jl @@ -219,7 +219,7 @@ function Mooncake.frule!!( α, Δα = Mooncake.extract(α_Δα) β, Δβ = Mooncake.extract(β_Δβ) backend = primal(backend_Δbackend) - + # dD = dα * tr(A) + α * tr(dA) + dβ * C + β * dC # dC1 = dβ * C + β * dC scale!(ΔC, β) From 6ffd7b980dd8e183bbaac58b84ae28a540d7973e Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Fri, 29 May 2026 13:53:02 +0200 Subject: [PATCH 3/5] Update ext/TensorKitMooncakeExt/tensoroperations.jl Co-authored-by: Lukas Devos --- ext/TensorKitMooncakeExt/tensoroperations.jl | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/ext/TensorKitMooncakeExt/tensoroperations.jl b/ext/TensorKitMooncakeExt/tensoroperations.jl index 8cc9f06ec..c580056df 100644 --- a/ext/TensorKitMooncakeExt/tensoroperations.jl +++ b/ext/TensorKitMooncakeExt/tensoroperations.jl @@ -85,9 +85,10 @@ function Mooncake.frule!!( β, Δβ = Mooncake.extract(β_Δβ) backend, allocator = primal.((backend_Δbackend, allocator_Δallocator)) # ΔC′ = ΔC*β + C*Δβ + A*B*Δα + ΔA*B*α + A*ΔB*α - scale!(ΔC, β) - if !isa(Δβ, Mooncake.NoTangent) - add!(ΔC, C, Δβ) + if isa(Δβ, Mooncake.NoTangent) + scale!(ΔC, β) + else + add!(ΔC, C, Δβ, β) end if !isa(Δα, Mooncake.NoTangent) TensorKit.blas_contract!(ΔC, A, pA, B, pB, pAB, Δα, One(), backend, allocator) From 1b2dd7f08ebe5da9ad67851c6d207c445183ffa8 Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Fri, 29 May 2026 13:57:08 +0200 Subject: [PATCH 4/5] Apply suggestion to trace_permute also --- ext/TensorKitMooncakeExt/tensoroperations.jl | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/ext/TensorKitMooncakeExt/tensoroperations.jl b/ext/TensorKitMooncakeExt/tensoroperations.jl index c580056df..0bfe16643 100644 --- a/ext/TensorKitMooncakeExt/tensoroperations.jl +++ b/ext/TensorKitMooncakeExt/tensoroperations.jl @@ -223,9 +223,10 @@ function Mooncake.frule!!( # dD = dα * tr(A) + α * tr(dA) + dβ * C + β * dC # dC1 = dβ * C + β * dC - scale!(ΔC, β) - if !isa(Δβ, Mooncake.NoTangent) - add!(ΔC, C, Δβ) + if isa(Δβ, Mooncake.NoTangent) + scale!(ΔC, β) + else + add!(ΔC, C, Δβ, β) end if !isa(Δα, Mooncake.NoTangent) TensorKit.trace_permute!(ΔC, A, p, q, Δα, One(), backend) From 69e3e9896417c9e1b9acfc385750b9a122db431e Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Tue, 9 Jun 2026 13:31:37 +0200 Subject: [PATCH 5/5] Address style comments --- ext/TensorKitMooncakeExt/TensorKitMooncakeExt.jl | 2 +- ext/TensorKitMooncakeExt/tensoroperations.jl | 16 ++++++++-------- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/ext/TensorKitMooncakeExt/TensorKitMooncakeExt.jl b/ext/TensorKitMooncakeExt/TensorKitMooncakeExt.jl index d436173e2..7c0492239 100644 --- a/ext/TensorKitMooncakeExt/TensorKitMooncakeExt.jl +++ b/ext/TensorKitMooncakeExt/TensorKitMooncakeExt.jl @@ -3,7 +3,7 @@ module TensorKitMooncakeExt using Mooncake using Mooncake: @zero_derivative, @is_primitive, DefaultCtx, MinimalCtx, ReverseMode, NoFData, NoRData, NoTangent, - CoDual, Dual, arrayify, primal, tangent, zero_fcodual + CoDual, Dual, arrayify, primal, tangent, zero_fcodual, extract using TensorKit import TensorKit as TK using VectorInterface diff --git a/ext/TensorKitMooncakeExt/tensoroperations.jl b/ext/TensorKitMooncakeExt/tensoroperations.jl index 0bfe16643..5f47a5260 100644 --- a/ext/TensorKitMooncakeExt/tensoroperations.jl +++ b/ext/TensorKitMooncakeExt/tensoroperations.jl @@ -81,16 +81,16 @@ function Mooncake.frule!!( # prepare arguments (C, ΔC), (A, ΔA), (B, ΔB) = arrayify.((C_ΔC, A_ΔA, B_ΔB)) pA, pB, pAB = primal.((pA_ΔpA, pB_ΔpB, pAB_ΔpAB)) - α, Δα = Mooncake.extract(α_Δα) - β, Δβ = Mooncake.extract(β_Δβ) + α, Δα = extract(α_Δα) + β, Δβ = extract(β_Δβ) backend, allocator = primal.((backend_Δbackend, allocator_Δallocator)) # ΔC′ = ΔC*β + C*Δβ + A*B*Δα + ΔA*B*α + A*ΔB*α - if isa(Δβ, Mooncake.NoTangent) + if isa(Δβ, NoTangent) scale!(ΔC, β) else add!(ΔC, C, Δβ, β) end - if !isa(Δα, Mooncake.NoTangent) + if !isa(Δα, NoTangent) TensorKit.blas_contract!(ΔC, A, pA, B, pB, pAB, Δα, One(), backend, allocator) end TensorKit.blas_contract!(ΔC, ΔA, pA, B, pB, pAB, α, One(), backend, allocator) @@ -217,18 +217,18 @@ function Mooncake.frule!!( A, ΔA = arrayify(A_ΔA) p = primal(p_Δp) q = primal(q_Δq) - α, Δα = Mooncake.extract(α_Δα) - β, Δβ = Mooncake.extract(β_Δβ) + α, Δα = extract(α_Δα) + β, Δβ = extract(β_Δβ) backend = primal(backend_Δbackend) # dD = dα * tr(A) + α * tr(dA) + dβ * C + β * dC # dC1 = dβ * C + β * dC - if isa(Δβ, Mooncake.NoTangent) + if isa(Δβ, NoTangent) scale!(ΔC, β) else add!(ΔC, C, Δβ, β) end - if !isa(Δα, Mooncake.NoTangent) + if !isa(Δα, NoTangent) TensorKit.trace_permute!(ΔC, A, p, q, Δα, One(), backend) end TensorKit.trace_permute!(ΔC, ΔA, p, q, α, One(), backend)