Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion ext/TensorKitMooncakeExt/TensorKitMooncakeExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ module TensorKitMooncakeExt
using Mooncake
using Mooncake: @zero_derivative, @is_primitive,
DefaultCtx, MinimalCtx, ReverseMode, NoFData, NoRData, NoTangent,
CoDual, Dual, arrayify, primal, tangent, zero_fcodual
CoDual, Dual, arrayify, primal, tangent, zero_fcodual, extract
using TensorKit
import TensorKit as TK
using VectorInterface
Expand Down
63 changes: 61 additions & 2 deletions ext/TensorKitMooncakeExt/tensoroperations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
# this permutation is done multiple times.
@is_primitive(
DefaultCtx,
ReverseMode,
Tuple{
typeof(TensorKit.blas_contract!),
AbstractTensorMap,
Expand Down Expand Up @@ -70,6 +69,36 @@ function Mooncake.rrule!!(
return C_ΔC, blas_contract_pullback
end

function Mooncake.frule!!(

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am less familiar with Mooncake's frule!! support, but I have one comment about this implementation that mostly has to do with efficiency:
By intercepting the rule at this level, we will have to permute the input arrays multiple times, once for each blas_contract call. It seems like it would be better to try and overload one level down, since we could just get the tensoradd and mul calls separately?

I know this is also not entirely implemented for the rrule, and there are definitely efficiency questions for that too.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah it's a fair point. I did this here mostly to match the rrule, but maybe we should instead add a rule for tensoradd since we anyway have mul!. We could also merge this for now but open an issue to look at handling the lower level stuff for both fwd and rvs modes?

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What would be the multiple permutation of A appearing here? I thought I understood this remark in the context of the pullback, but here I don't really see how this would arise.

@lkdvos lkdvos Jun 9, 2026

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • on line 94, 97 and 98 we have blas_contract!(..., A, pA, ...), which will each effectively call permute(A, pA).
  • on line 94, 96 and 98 the same argument is true for blas_contract!(..., B, pB, ...)

Also, the AB that is first added to C, and then to DeltaC could in principle be reused, similar to the rrule case. (line 94 and line 98)

The problem is that if we simply don't overload this, this will automatically be resolved. Schematically (ignoring alpha and beta for simplicity):

Ap = permute(A, pA)
dA = permute(dA, pA) # frule for permute

Bp = permute(B, pB)
dB = permute(dB, pB) # frule for permute

AB = Ap * Bp
dAB = Ap * dB + dA * Ap # frule for *

C = permute(AB, pAB)
dC = permute(dAB, pAB) # frule for permute

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, certainly, I would have hoped that with Moonzyme, we needed to write much fewer rules, so if we can do without certain rules, I think that would be great, also from a maintainer's perspective.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was sort of hoping to do all this in a "tick-tock" fashion -- first, write rules relatively 1:1 with what already exists for ChainRules, so we can verify correctness and have "something to go off of". Then find lower level commonalities, write rules for those, and lean more on the fancy compilers to support us. All this is to say I'll open an issue and link to these comments.

::Dual{typeof(TensorKit.blas_contract!)},
C_ΔC::Dual{<:AbstractTensorMap},
A_ΔA::Dual{<:AbstractTensorMap}, pA_ΔpA::Dual{<:Index2Tuple},
B_ΔB::Dual{<:AbstractTensorMap}, pB_ΔpB::Dual{<:Index2Tuple},
pAB_ΔpAB::Dual{<:Index2Tuple},
α_Δα::Dual{<:Number}, β_Δβ::Dual{<:Number},
backend_Δbackend::Dual, allocator_Δallocator::Dual
)
# prepare arguments
(C, ΔC), (A, ΔA), (B, ΔB) = arrayify.((C_ΔC, A_ΔA, B_ΔB))
pA, pB, pAB = primal.((pA_ΔpA, pB_ΔpB, pAB_ΔpAB))
α, Δα = extract(α_Δα)
β, Δβ = extract(β_Δβ)
backend, allocator = primal.((backend_Δbackend, allocator_Δallocator))
# ΔC′ = ΔC*β + C*Δβ + A*B*Δα + ΔA*B*α + A*ΔB*α
if isa(Δβ, NoTangent)
scale!(ΔC, β)
else
add!(ΔC, C, Δβ, β)
end
if !isa(Δα, NoTangent)
TensorKit.blas_contract!(ΔC, A, pA, B, pB, pAB, Δα, One(), backend, allocator)
end
TensorKit.blas_contract!(ΔC, ΔA, pA, B, pB, pAB, α, One(), backend, allocator)
TensorKit.blas_contract!(ΔC, A, pA, ΔB, pB, pAB, α, One(), backend, allocator)
TensorKit.blas_contract!(C, A, pA, B, pB, pAB, α, β, backend, allocator)
return C_ΔC
end

function blas_contract_pullback_ΔA!(
ΔA, ΔC, A, pA, B, pB, pAB, α, backend, allocator
)
Expand Down Expand Up @@ -124,7 +153,6 @@ end
# ------------
@is_primitive(
DefaultCtx,
ReverseMode,
Tuple{
typeof(TensorKit.trace_permute!),
AbstractTensorMap,
Expand Down Expand Up @@ -177,6 +205,37 @@ function Mooncake.rrule!!(
return C_ΔC, trace_permute_pullback
end

function Mooncake.frule!!(
::Dual{typeof(TensorKit.trace_permute!)},
C_ΔC::Dual{<:AbstractTensorMap},
A_ΔA::Dual{<:AbstractTensorMap}, p_Δp::Dual{<:Index2Tuple}, q_Δq::Dual{<:Index2Tuple},
α_Δα::Dual{<:Number}, β_Δβ::Dual{<:Number},
backend_Δbackend::Dual
)
# prepare arguments
C, ΔC = arrayify(C_ΔC)
A, ΔA = arrayify(A_ΔA)
p = primal(p_Δp)
q = primal(q_Δq)
α, Δα = extract(α_Δα)
β, Δβ = extract(β_Δβ)
backend = primal(backend_Δbackend)

# dD = dα * tr(A) + α * tr(dA) + dβ * C + β * dC
# dC1 = dβ * C + β * dC
if isa(Δβ, NoTangent)
scale!(ΔC, β)
else
add!(ΔC, C, Δβ, β)
end
if !isa(Δα, NoTangent)
TensorKit.trace_permute!(ΔC, A, p, q, Δα, One(), backend)
end
TensorKit.trace_permute!(ΔC, ΔA, p, q, α, One(), backend)
TensorKit.trace_permute!(C, A, p, q, α, β, backend)
return C_ΔC
end

function trace_permute_pullback_ΔA!(
ΔA, ΔC, A, p, q, α, backend
)
Expand Down
14 changes: 6 additions & 8 deletions test/mooncake/tensoroperations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,6 @@ using VectorInterface: One, Zero
using Mooncake
using Random


mode = Mooncake.ReverseMode
rng = Random.default_rng()

spacelist = ad_spacelist(fast_tests)
Expand Down Expand Up @@ -53,32 +51,32 @@ eltypes = (Float64, ComplexF64)
rng, TensorKit.blas_contract!,
C, A, pA, B, pB, pAB, One(), Zero(),
TensorOperations.DefaultBackend(), TensorOperations.DefaultAllocator();
atol, rtol, mode
atol, rtol
)
Mooncake.TestUtils.test_rule(
rng, TensorKit.blas_contract!,
C, A, pA, B, pB, pAB, α, β,
TensorOperations.DefaultBackend(), TensorOperations.DefaultAllocator();
atol, rtol, mode
atol, rtol
)
if !(T <: Real)
Mooncake.TestUtils.test_rule(
rng, TensorKit.blas_contract!,
C, A, pA, B, pB, pAB, real(α), real(β),
TensorOperations.DefaultBackend(), TensorOperations.DefaultAllocator();
atol, rtol, mode
atol, rtol
)
Mooncake.TestUtils.test_rule(
rng, TensorKit.blas_contract!,
C, real(A), pA, B, pB, pAB, real(α), real(β),
TensorOperations.DefaultBackend(), TensorOperations.DefaultAllocator();
atol, rtol, mode
atol, rtol
)
Mooncake.TestUtils.test_rule(
rng, TensorKit.blas_contract!,
C, A, pA, real(B), pB, pAB, real(α), real(β),
TensorOperations.DefaultBackend(), TensorOperations.DefaultAllocator();
atol, rtol, mode
atol, rtol
)
end
end
Expand All @@ -102,7 +100,7 @@ eltypes = (Float64, ComplexF64)
C = randn!(TensorOperations.tensoralloc_add(T, A, p, false, Val(false)))
Mooncake.TestUtils.test_rule(
rng, TensorKit.trace_permute!, C, A, p, q, α, β, TensorOperations.DefaultBackend();
atol, rtol, mode
atol, rtol
)
end
end
Expand Down
Loading