Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 9 additions & 3 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ VectorInterface = "409d34a3-91d5-4945-b6ec-7529ddf182d8"
AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
EnzymeTestUtils = "12d8515a-0907-448a-8884-5fe00fdf1c5a"
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
cuTENSOR = "011b41b2-24ef-40a8-b3eb-fa098493e9e1"
Expand All @@ -31,6 +33,8 @@ cuTENSOR = "011b41b2-24ef-40a8-b3eb-fa098493e9e1"
TensorKitAMDGPUExt = "AMDGPU"
TensorKitCUDAExt = ["CUDA", "cuTENSOR"]
TensorKitChainRulesCoreExt = "ChainRulesCore"
TensorKitEnzymeExt = "Enzyme"
TensorKitEnzymeTestUtilsExt = "EnzymeTestUtils"
TensorKitFiniteDifferencesExt = "FiniteDifferences"
TensorKitMooncakeExt = "Mooncake"

Expand All @@ -43,19 +47,21 @@ AMDGPU = "2"
CUDA = "6"
ChainRulesCore = "1"
Dictionaries = "0.4"
Enzyme = "0.13.146"
EnzymeTestUtils = "0.2.7"
FiniteDifferences = "0.12"
LRUCache = "1.0.2"
LinearAlgebra = "1"
MatrixAlgebraKit = "0.6.7"
MatrixAlgebraKit = "0.6.8"
Mooncake = "0.5.27"
OhMyThreads = "0.8.0"
Printf = "1"
Random = "1"
ScopedValues = "1.3.0"
Strided = "2"
TensorKitSectors = "0.3.7"
TensorOperations = "5.5"
TensorOperations = "5.5.2"
TupleTools = "1.5"
VectorInterface = "0.4.8, 0.5, 0.6"
VectorInterface = "0.4.8, 0.5"
cuTENSOR = "6"
julia = "1.10"
16 changes: 16 additions & 0 deletions ext/TensorKitEnzymeExt/TensorKitEnzymeExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
module TensorKitEnzymeExt

using Enzyme
using TensorKit
import TensorKit as TK
using VectorInterface
using TensorOperations: TensorOperations, IndexTuple, Index2Tuple, linearize
import TensorOperations as TO
using MatrixAlgebraKit
using TupleTools
using Random: AbstractRNG

include("utility.jl")
include("linalg.jl")

end
262 changes: 262 additions & 0 deletions ext/TensorKitEnzymeExt/linalg.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,262 @@
# Shared
# ------
# Can Enzyme do this itself? Apparently not...
function EnzymeRules.augmented_primal(
config::EnzymeRules.RevConfigWidth{1},
func::Const{typeof(mul!)},
::Type{RT},
C::Annotation{<:AbstractTensorMap},
A::Annotation{<:AbstractTensorMap},
B::Annotation{<:AbstractTensorMap},
α::Annotation,
β::Annotation,
) where {RT}
cacheC = !isa(β, Const) && copy(C.val)
cacheA = !isa(B, Const) && EnzymeRules.overwritten(config)[3] ? copy(A.val) : nothing
cacheB = !isa(A, Const) && EnzymeRules.overwritten(config)[4] ? copy(B.val) : nothing

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since my Enzyme knowledge has fallen back to zero, I was comparing this with some mul! rule in Enzyme.jl ( https://github.com/EnzymeAD/Enzyme.jl/blob/6d9c0cb7fa1ab4a4ce347ba506ea9715761365a8/src/internal_rules/linalg.jl#L304 ), and while there are clearly some similarities, there are also some differences, e.g. in which overwritten(config) positions are checked to decide on cacheA and cacheB (they use 5 and 6, compared to 3 and 4 here). Is there a good explanation for that?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's probably that they are correct and I made a mistake. This part is very difficult to test because it only arises in a longer set of operations. I wish they had written their mul! rule to be a little more generic

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

https://enzymead.github.io/Enzyme.jl/dev/generated/custom_rule/#Defining-a-reverse-mode-rule what's confusing is it doesn't match what is shown here. I will ask the devs.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I also think your counting is consistent with the doc string of overwritten.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I posted about it on Slack to hopefully get this figured out but now I'm so confused about who is right 😹

AB = if !isa(α, Const)
AB = A.val * B.val
add!(C.val, AB, α.val, β.val)
AB
else
mul!(C.val, A.val, B.val, α.val, β.val)
nothing
end
primal = EnzymeRules.needs_primal(config) ? C.val : nothing
shadow = EnzymeRules.needs_shadow(config) ? C.dval : nothing
cache = (cacheC, cacheA, cacheB, AB)
return EnzymeRules.AugmentedReturn(primal, shadow, cache)
end

function EnzymeRules.reverse(
config::EnzymeRules.RevConfigWidth{1},
func::Const{typeof(mul!)},
::Type{RT},
cache,
C::Annotation{<:AbstractTensorMap},
A::Annotation{<:AbstractTensorMap},
B::Annotation{<:AbstractTensorMap},
α::Annotation{<:Number},
β::Annotation{<:Number},
) where {RT}
if RT <: Const
Δα = isa(α, Const) ? nothing : zero(α.val)
Δβ = isa(β, Const) ? nothing : zero(β.val)
return (nothing, nothing, nothing, Δα, Δβ)
end
cacheC, cacheA, cacheB, AB = cache
Cval = something(cacheC, C.val)
Aval = something(cacheA, A.val)
Bval = something(cacheB, B.val)

!isa(A, Const) && !isa(C, Const) && project_mul!(A.dval, C.dval, Bval', conj(α.val))
!isa(B, Const) && !isa(C, Const) && project_mul!(B.dval, Aval', C.dval, conj(α.val))
Δαr = pullback_dα(α, C, AB)
Δβr = pullback_dβ(β, C, Cval)
!isa(C, Const) && pullback_dC!(C.dval, β.val)

return (nothing, nothing, nothing, Δαr, Δβr)
end

function EnzymeRules.forward(
config::EnzymeRules.FwdConfigWidth{1},
func::Const{typeof(mul!)},
::Type{RT},
C::Annotation{<:AbstractTensorMap},
A::Annotation{<:AbstractTensorMap},
B::Annotation{<:AbstractTensorMap},
α::Annotation{<:Number},
β::Annotation{<:Number},
) where {RT}
# ΔC′ = ΔC*β + C*Δβ + A*B*Δα + ΔA*B*α + A*ΔB*α
if !isa(C, Const)
scale!(C.dval, β.val)
!isa(β, Const) && add!(C.dval, C.val, β.dval)
!isa(α, Const) && project_mul!(C.dval, A.val, B.val, α.dval)
!isa(A, Const) && project_mul!(C.dval, A.dval, B.val, α.val)
!isa(B, Const) && project_mul!(C.dval, A.val, B.dval, α.val)
end
mul!(C.val, A.val, B.val, α.val, β.val)
if EnzymeRules.needs_primal(config) && EnzymeRules.needs_shadow(config)
return C
elseif EnzymeRules.needs_primal(config)
return C.val
elseif EnzymeRules.needs_shadow(config)
return C.dval
else
return nothing
end
end

function EnzymeRules.augmented_primal(
config::EnzymeRules.RevConfigWidth{1},
func::Const{typeof(tr)},
::Type{RT},
A::Annotation{<:AbstractTensorMap},
) where {RT}
ret = func.val(A.val)
primal = EnzymeRules.needs_primal(config) ? ret : nothing
shadow = EnzymeRules.needs_shadow(config) ? zero(ret) : nothing
cache = EnzymeRules.overwritten(config)[2] ? copy(A.val) : nothing
return EnzymeRules.AugmentedReturn(primal, shadow, cache)
end
function EnzymeRules.reverse(
config::EnzymeRules.RevConfigWidth{1},
func::Const{typeof(tr)},
dret::Active,
cache,
A::Annotation{<:AbstractTensorMap},
)
Aval = something(cache, A.val)
Δtrace = dret.val
if !isa(A, Const)
for (_, b) in blocks(A.dval)
TensorKit.diagview(b) .+= Δtrace
end
end
return (nothing,)
end
function EnzymeRules.reverse(
config::EnzymeRules.RevConfigWidth{1},
func::Const{typeof(tr)},
::Type{<:Const},
cache,
A::Annotation{<:AbstractTensorMap},
)
return (nothing,)
end
function EnzymeRules.forward(
config::EnzymeRules.FwdConfigWidth{1},
::Type{RT},
func::Const{typeof(tr)},
A::Annotation{<:AbstractTensorMap},
) where {RT}
y = EnzymeRules.needs_primal(config) ? tr(A.val) : nothing
Δy = if EnzymeRules.needs_shadow(config) && !isa(A, Const)
tr(A.dval)
elseif EnzymeRules.needs_shadow(config)
zero(eltype(A.dval))
else
nothing
end
if EnzymeRules.needs_primal(config) && EnzymeRules.needs_shadow(config)
return Duplicated(y, Δy)
elseif EnzymeRules.needs_primal(config)
return y
elseif EnzymeRules.needs_shadow(config)
return Δy
else
return nothing
end
end
function EnzymeRules.augmented_primal(
config::EnzymeRules.RevConfigWidth{1},
func::Const{typeof(norm)},
::Type{RT},
A::Annotation{<:AbstractTensorMap},
p::Const{<:Real},
) where {RT}
p.val == 2 || error("currently only implemented for p = 2")
ret = func.val(A.val, p.val)
primal = EnzymeRules.needs_primal(config) ? ret : nothing
shadow = EnzymeRules.needs_shadow(config) ? zero(ret) : nothing
cacheA = EnzymeRules.overwritten(config)[2] ? copy(A.val) : nothing
cache = (ret, cacheA)
return EnzymeRules.AugmentedReturn(primal, shadow, cache)
end
function EnzymeRules.reverse(
config::EnzymeRules.RevConfigWidth{1},
func::Const{typeof(norm)},
dret::Active,
cache,
A::Annotation{<:AbstractTensorMap},
p::Const{<:Real},
)
n, cacheA = cache
Δn = dret.val
p.val == 2 || error("currently only implemented for p = 2")
Aval = something(cacheA, A.val)
if !isa(A, Const)
x = (Δn' + Δn) / 2 / hypot(n, eps(one(n)))
add!(A.dval, A.val, x)
end
return (nothing, nothing)
end
function EnzymeRules.reverse(
config::EnzymeRules.RevConfigWidth{1},
func::Const{typeof(norm)},
::Type{<:Const},
cache,
A::Annotation{<:AbstractTensorMap},
p::Const{<:Real},
)
return (nothing, nothing)
end
function EnzymeRules.forward(
config::EnzymeRules.FwdConfigWidth{1},
func::Const{typeof(norm)},
::Type{RT},
A::Annotation{<:AbstractTensorMap},
p::Const{<:Real},
) where {RT}
y = norm(A.val, p.val)
Δy = if EnzymeRules.needs_shadow(config) && !isa(A, Const)
real(dot(A.val, A.dval)) * pinv(y)
elseif EnzymeRules.needs_shadow(config)
zero(eltype(A.dval))
else
nothing
end
if EnzymeRules.needs_primal(config) && EnzymeRules.needs_shadow(config)
return Duplicated(y, Δy)
elseif EnzymeRules.needs_primal(config)
return y
elseif EnzymeRules.needs_shadow(config)
return Δy
else
return nothing
end
end
function EnzymeRules.augmented_primal(
config::EnzymeRules.RevConfigWidth{1},
func::Const{typeof(inv)},
::Type{RT},
A::Annotation{<:AbstractTensorMap},
) where {RT}
ret = inv(A.val)
primal = EnzymeRules.needs_primal(config) ? ret : nothing
shadow = EnzymeRules.needs_shadow(config) ? make_zero(ret) : nothing
cache = (ret, shadow)
return EnzymeRules.AugmentedReturn(primal, shadow, cache)
end

function EnzymeRules.reverse(
config::EnzymeRules.RevConfigWidth{1},
func::Const{typeof(inv)},
::Type{RT},
cache,
A::Annotation{<:AbstractTensorMap},
) where {RT}
Ainv, ΔAinv = cache
!isa(A, Const) && mul!(A.dval, Ainv' * ΔAinv, Ainv', -1, One())
return (nothing,)
end

function EnzymeRules.forward(
config::EnzymeRules.FwdConfigWidth{1},
func::Const{typeof(inv)},
::Type{RT},
A::Annotation{<:AbstractTensorMap},
) where {RT}
Ainv = inv(A.val)
ΔAinv = !isa(A, Const) ? scale!(Ainv * A.dval * Ainv, -1) : make_zero(Ainv)
if EnzymeRules.needs_primal(config) && EnzymeRules.needs_shadow(config)
return Duplicated(Ainv, ΔAinv)
elseif EnzymeRules.needs_primal(config)
return Ainv
elseif EnzymeRules.needs_shadow(config)
return ΔAinv
else
return nothing
end
end
Loading
Loading