Skip to content

Forward rules for TensorOperations calls#437

Merged
kshyatt merged 5 commits into
mainfrom
ksh/to
Jun 9, 2026
Merged

Forward rules for TensorOperations calls#437
kshyatt merged 5 commits into
mainfrom
ksh/to

Conversation

@kshyatt

@kshyatt kshyatt commented May 19, 2026

Copy link
Copy Markdown
Member

No description provided.

@github-actions

github-actions Bot commented May 19, 2026

Copy link
Copy Markdown
Contributor

Your PR no longer requires formatting changes. Thank you for your contribution!

Comment thread ext/TensorKitMooncakeExt/tensoroperations.jl Outdated
@kshyatt kshyatt force-pushed the ksh/to branch 2 times, most recently from a2b2ec4 to 121b86e Compare June 5, 2026 18:06
@codecov

codecov Bot commented Jun 5, 2026

Copy link
Copy Markdown

Codecov Report

❌ Patch coverage is 96.77419% with 1 line in your changes missing coverage. Please review.

Files with missing lines Patch % Lines
ext/TensorKitMooncakeExt/tensoroperations.jl 96.77% 1 Missing ⚠️
Files with missing lines Coverage Δ
ext/TensorKitMooncakeExt/TensorKitMooncakeExt.jl 100.00% <ø> (ø)
ext/TensorKitMooncakeExt/tensoroperations.jl 98.01% <96.77%> (-0.56%) ⬇️
🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

return C_ΔC, blas_contract_pullback
end

function Mooncake.frule!!(

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am less familiar with Mooncake's frule!! support, but I have one comment about this implementation that mostly has to do with efficiency:
By intercepting the rule at this level, we will have to permute the input arrays multiple times, once for each blas_contract call. It seems like it would be better to try and overload one level down, since we could just get the tensoradd and mul calls separately?

I know this is also not entirely implemented for the rrule, and there are definitely efficiency questions for that too.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah it's a fair point. I did this here mostly to match the rrule, but maybe we should instead add a rule for tensoradd since we anyway have mul!. We could also merge this for now but open an issue to look at handling the lower level stuff for both fwd and rvs modes?

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What would be the multiple permutation of A appearing here? I thought I understood this remark in the context of the pullback, but here I don't really see how this would arise.

@lkdvos lkdvos Jun 9, 2026

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • on line 94, 97 and 98 we have blas_contract!(..., A, pA, ...), which will each effectively call permute(A, pA).
  • on line 94, 96 and 98 the same argument is true for blas_contract!(..., B, pB, ...)

Also, the AB that is first added to C, and then to DeltaC could in principle be reused, similar to the rrule case. (line 94 and line 98)

The problem is that if we simply don't overload this, this will automatically be resolved. Schematically (ignoring alpha and beta for simplicity):

Ap = permute(A, pA)
dA = permute(dA, pA) # frule for permute

Bp = permute(B, pB)
dB = permute(dB, pB) # frule for permute

AB = Ap * Bp
dAB = Ap * dB + dA * Ap # frule for *

C = permute(AB, pAB)
dC = permute(dAB, pAB) # frule for permute

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, certainly, I would have hoped that with Moonzyme, we needed to write much fewer rules, so if we can do without certain rules, I think that would be great, also from a maintainer's perspective.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was sort of hoping to do all this in a "tick-tock" fashion -- first, write rules relatively 1:1 with what already exists for ChainRules, so we can verify correctness and have "something to go off of". Then find lower level commonalities, write rules for those, and lean more on the fancy compilers to support us. All this is to say I'll open an issue and link to these comments.

Comment thread ext/TensorKitMooncakeExt/tensoroperations.jl Outdated
Comment thread ext/TensorKitMooncakeExt/tensoroperations.jl Outdated

@Jutho Jutho left a comment

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Two stylistic comments, but otherwise looks good to me.

@Jutho Jutho left a comment

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Up to minor stylistic comments, looks good to me.

@kshyatt kshyatt enabled auto-merge (squash) June 9, 2026 11:32
@kshyatt kshyatt merged commit 1de432d into main Jun 9, 2026
43 checks passed
@kshyatt kshyatt deleted the ksh/to branch June 9, 2026 13:30
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants