Conversation
|
Your PR no longer requires formatting changes. Thank you for your contribution! |
a2b2ec4 to
121b86e
Compare
Codecov Report❌ Patch coverage is
🚀 New features to boost your workflow:
|
| return C_ΔC, blas_contract_pullback | ||
| end | ||
|
|
||
| function Mooncake.frule!!( |
There was a problem hiding this comment.
I am less familiar with Mooncake's frule!! support, but I have one comment about this implementation that mostly has to do with efficiency:
By intercepting the rule at this level, we will have to permute the input arrays multiple times, once for each blas_contract call. It seems like it would be better to try and overload one level down, since we could just get the tensoradd and mul calls separately?
I know this is also not entirely implemented for the rrule, and there are definitely efficiency questions for that too.
There was a problem hiding this comment.
Yeah it's a fair point. I did this here mostly to match the rrule, but maybe we should instead add a rule for tensoradd since we anyway have mul!. We could also merge this for now but open an issue to look at handling the lower level stuff for both fwd and rvs modes?
There was a problem hiding this comment.
What would be the multiple permutation of A appearing here? I thought I understood this remark in the context of the pullback, but here I don't really see how this would arise.
There was a problem hiding this comment.
- on line 94, 97 and 98 we have
blas_contract!(..., A, pA, ...), which will each effectively callpermute(A, pA). - on line 94, 96 and 98 the same argument is true for
blas_contract!(..., B, pB, ...)
Also, the AB that is first added to C, and then to DeltaC could in principle be reused, similar to the rrule case. (line 94 and line 98)
The problem is that if we simply don't overload this, this will automatically be resolved. Schematically (ignoring alpha and beta for simplicity):
Ap = permute(A, pA)
dA = permute(dA, pA) # frule for permute
Bp = permute(B, pB)
dB = permute(dB, pB) # frule for permute
AB = Ap * Bp
dAB = Ap * dB + dA * Ap # frule for *
C = permute(AB, pAB)
dC = permute(dAB, pAB) # frule for permuteThere was a problem hiding this comment.
Well, certainly, I would have hoped that with Moonzyme, we needed to write much fewer rules, so if we can do without certain rules, I think that would be great, also from a maintainer's perspective.
There was a problem hiding this comment.
I was sort of hoping to do all this in a "tick-tock" fashion -- first, write rules relatively 1:1 with what already exists for ChainRules, so we can verify correctness and have "something to go off of". Then find lower level commonalities, write rules for those, and lean more on the fancy compilers to support us. All this is to say I'll open an issue and link to these comments.
Co-authored-by: Lukas Devos <ldevos98@gmail.com>
Jutho
left a comment
There was a problem hiding this comment.
Two stylistic comments, but otherwise looks good to me.
Jutho
left a comment
There was a problem hiding this comment.
Up to minor stylistic comments, looks good to me.
No description provided.