diff --git a/Project.toml b/Project.toml index 3d0abc9b1..f8a50d723 100644 --- a/Project.toml +++ b/Project.toml @@ -23,6 +23,8 @@ VectorInterface = "409d34a3-91d5-4945-b6ec-7529ddf182d8" AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e" CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" +Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" +EnzymeTestUtils = "12d8515a-0907-448a-8884-5fe00fdf1c5a" FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000" Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" cuTENSOR = "011b41b2-24ef-40a8-b3eb-fa098493e9e1" @@ -31,6 +33,8 @@ cuTENSOR = "011b41b2-24ef-40a8-b3eb-fa098493e9e1" TensorKitAMDGPUExt = "AMDGPU" TensorKitCUDAExt = ["CUDA", "cuTENSOR"] TensorKitChainRulesCoreExt = "ChainRulesCore" +TensorKitEnzymeExt = "Enzyme" +TensorKitEnzymeTestUtilsExt = "EnzymeTestUtils" TensorKitFiniteDifferencesExt = "FiniteDifferences" TensorKitMooncakeExt = "Mooncake" @@ -43,10 +47,12 @@ AMDGPU = "2" CUDA = "6" ChainRulesCore = "1" Dictionaries = "0.4" +Enzyme = "0.13.146" +EnzymeTestUtils = "0.2.7" FiniteDifferences = "0.12" LRUCache = "1.0.2" LinearAlgebra = "1" -MatrixAlgebraKit = "0.6.7" +MatrixAlgebraKit = "0.6.8" Mooncake = "0.5.27" OhMyThreads = "0.8.0" Printf = "1" @@ -54,8 +60,8 @@ Random = "1" ScopedValues = "1.3.0" Strided = "2" TensorKitSectors = "0.3.7" -TensorOperations = "5.5" +TensorOperations = "5.5.2" TupleTools = "1.5" -VectorInterface = "0.4.8, 0.5, 0.6" +VectorInterface = "0.4.8, 0.5" cuTENSOR = "6" julia = "1.10" diff --git a/ext/TensorKitEnzymeExt/TensorKitEnzymeExt.jl b/ext/TensorKitEnzymeExt/TensorKitEnzymeExt.jl new file mode 100644 index 000000000..7f448f9e3 --- /dev/null +++ b/ext/TensorKitEnzymeExt/TensorKitEnzymeExt.jl @@ -0,0 +1,16 @@ +module TensorKitEnzymeExt + +using Enzyme +using TensorKit +import TensorKit as TK +using VectorInterface +using TensorOperations: TensorOperations, IndexTuple, Index2Tuple, linearize +import TensorOperations as TO +using MatrixAlgebraKit +using TupleTools +using Random: AbstractRNG + +include("utility.jl") +include("linalg.jl") + +end diff --git a/ext/TensorKitEnzymeExt/linalg.jl b/ext/TensorKitEnzymeExt/linalg.jl new file mode 100644 index 000000000..2e61c3bca --- /dev/null +++ b/ext/TensorKitEnzymeExt/linalg.jl @@ -0,0 +1,262 @@ +# Shared +# ------ +# Can Enzyme do this itself? Apparently not... +function EnzymeRules.augmented_primal( + config::EnzymeRules.RevConfigWidth{1}, + func::Const{typeof(mul!)}, + ::Type{RT}, + C::Annotation{<:AbstractTensorMap}, + A::Annotation{<:AbstractTensorMap}, + B::Annotation{<:AbstractTensorMap}, + α::Annotation, + β::Annotation, + ) where {RT} + cacheC = !isa(β, Const) && copy(C.val) + cacheA = !isa(B, Const) && EnzymeRules.overwritten(config)[3] ? copy(A.val) : nothing + cacheB = !isa(A, Const) && EnzymeRules.overwritten(config)[4] ? copy(B.val) : nothing + AB = if !isa(α, Const) + AB = A.val * B.val + add!(C.val, AB, α.val, β.val) + AB + else + mul!(C.val, A.val, B.val, α.val, β.val) + nothing + end + primal = EnzymeRules.needs_primal(config) ? C.val : nothing + shadow = EnzymeRules.needs_shadow(config) ? C.dval : nothing + cache = (cacheC, cacheA, cacheB, AB) + return EnzymeRules.AugmentedReturn(primal, shadow, cache) +end + +function EnzymeRules.reverse( + config::EnzymeRules.RevConfigWidth{1}, + func::Const{typeof(mul!)}, + ::Type{RT}, + cache, + C::Annotation{<:AbstractTensorMap}, + A::Annotation{<:AbstractTensorMap}, + B::Annotation{<:AbstractTensorMap}, + α::Annotation{<:Number}, + β::Annotation{<:Number}, + ) where {RT} + if RT <: Const + Δα = isa(α, Const) ? nothing : zero(α.val) + Δβ = isa(β, Const) ? nothing : zero(β.val) + return (nothing, nothing, nothing, Δα, Δβ) + end + cacheC, cacheA, cacheB, AB = cache + Cval = something(cacheC, C.val) + Aval = something(cacheA, A.val) + Bval = something(cacheB, B.val) + + !isa(A, Const) && !isa(C, Const) && project_mul!(A.dval, C.dval, Bval', conj(α.val)) + !isa(B, Const) && !isa(C, Const) && project_mul!(B.dval, Aval', C.dval, conj(α.val)) + Δαr = pullback_dα(α, C, AB) + Δβr = pullback_dβ(β, C, Cval) + !isa(C, Const) && pullback_dC!(C.dval, β.val) + + return (nothing, nothing, nothing, Δαr, Δβr) +end + +function EnzymeRules.forward( + config::EnzymeRules.FwdConfigWidth{1}, + func::Const{typeof(mul!)}, + ::Type{RT}, + C::Annotation{<:AbstractTensorMap}, + A::Annotation{<:AbstractTensorMap}, + B::Annotation{<:AbstractTensorMap}, + α::Annotation{<:Number}, + β::Annotation{<:Number}, + ) where {RT} + # ΔC′ = ΔC*β + C*Δβ + A*B*Δα + ΔA*B*α + A*ΔB*α + if !isa(C, Const) + scale!(C.dval, β.val) + !isa(β, Const) && add!(C.dval, C.val, β.dval) + !isa(α, Const) && project_mul!(C.dval, A.val, B.val, α.dval) + !isa(A, Const) && project_mul!(C.dval, A.dval, B.val, α.val) + !isa(B, Const) && project_mul!(C.dval, A.val, B.dval, α.val) + end + mul!(C.val, A.val, B.val, α.val, β.val) + if EnzymeRules.needs_primal(config) && EnzymeRules.needs_shadow(config) + return C + elseif EnzymeRules.needs_primal(config) + return C.val + elseif EnzymeRules.needs_shadow(config) + return C.dval + else + return nothing + end +end + +function EnzymeRules.augmented_primal( + config::EnzymeRules.RevConfigWidth{1}, + func::Const{typeof(tr)}, + ::Type{RT}, + A::Annotation{<:AbstractTensorMap}, + ) where {RT} + ret = func.val(A.val) + primal = EnzymeRules.needs_primal(config) ? ret : nothing + shadow = EnzymeRules.needs_shadow(config) ? zero(ret) : nothing + cache = EnzymeRules.overwritten(config)[2] ? copy(A.val) : nothing + return EnzymeRules.AugmentedReturn(primal, shadow, cache) +end +function EnzymeRules.reverse( + config::EnzymeRules.RevConfigWidth{1}, + func::Const{typeof(tr)}, + dret::Active, + cache, + A::Annotation{<:AbstractTensorMap}, + ) + Aval = something(cache, A.val) + Δtrace = dret.val + if !isa(A, Const) + for (_, b) in blocks(A.dval) + TensorKit.diagview(b) .+= Δtrace + end + end + return (nothing,) +end +function EnzymeRules.reverse( + config::EnzymeRules.RevConfigWidth{1}, + func::Const{typeof(tr)}, + ::Type{<:Const}, + cache, + A::Annotation{<:AbstractTensorMap}, + ) + return (nothing,) +end +function EnzymeRules.forward( + config::EnzymeRules.FwdConfigWidth{1}, + ::Type{RT}, + func::Const{typeof(tr)}, + A::Annotation{<:AbstractTensorMap}, + ) where {RT} + y = EnzymeRules.needs_primal(config) ? tr(A.val) : nothing + Δy = if EnzymeRules.needs_shadow(config) && !isa(A, Const) + tr(A.dval) + elseif EnzymeRules.needs_shadow(config) + zero(eltype(A.dval)) + else + nothing + end + if EnzymeRules.needs_primal(config) && EnzymeRules.needs_shadow(config) + return Duplicated(y, Δy) + elseif EnzymeRules.needs_primal(config) + return y + elseif EnzymeRules.needs_shadow(config) + return Δy + else + return nothing + end +end +function EnzymeRules.augmented_primal( + config::EnzymeRules.RevConfigWidth{1}, + func::Const{typeof(norm)}, + ::Type{RT}, + A::Annotation{<:AbstractTensorMap}, + p::Const{<:Real}, + ) where {RT} + p.val == 2 || error("currently only implemented for p = 2") + ret = func.val(A.val, p.val) + primal = EnzymeRules.needs_primal(config) ? ret : nothing + shadow = EnzymeRules.needs_shadow(config) ? zero(ret) : nothing + cacheA = EnzymeRules.overwritten(config)[2] ? copy(A.val) : nothing + cache = (ret, cacheA) + return EnzymeRules.AugmentedReturn(primal, shadow, cache) +end +function EnzymeRules.reverse( + config::EnzymeRules.RevConfigWidth{1}, + func::Const{typeof(norm)}, + dret::Active, + cache, + A::Annotation{<:AbstractTensorMap}, + p::Const{<:Real}, + ) + n, cacheA = cache + Δn = dret.val + p.val == 2 || error("currently only implemented for p = 2") + Aval = something(cacheA, A.val) + if !isa(A, Const) + x = (Δn' + Δn) / 2 / hypot(n, eps(one(n))) + add!(A.dval, A.val, x) + end + return (nothing, nothing) +end +function EnzymeRules.reverse( + config::EnzymeRules.RevConfigWidth{1}, + func::Const{typeof(norm)}, + ::Type{<:Const}, + cache, + A::Annotation{<:AbstractTensorMap}, + p::Const{<:Real}, + ) + return (nothing, nothing) +end +function EnzymeRules.forward( + config::EnzymeRules.FwdConfigWidth{1}, + func::Const{typeof(norm)}, + ::Type{RT}, + A::Annotation{<:AbstractTensorMap}, + p::Const{<:Real}, + ) where {RT} + y = norm(A.val, p.val) + Δy = if EnzymeRules.needs_shadow(config) && !isa(A, Const) + real(dot(A.val, A.dval)) * pinv(y) + elseif EnzymeRules.needs_shadow(config) + zero(eltype(A.dval)) + else + nothing + end + if EnzymeRules.needs_primal(config) && EnzymeRules.needs_shadow(config) + return Duplicated(y, Δy) + elseif EnzymeRules.needs_primal(config) + return y + elseif EnzymeRules.needs_shadow(config) + return Δy + else + return nothing + end +end +function EnzymeRules.augmented_primal( + config::EnzymeRules.RevConfigWidth{1}, + func::Const{typeof(inv)}, + ::Type{RT}, + A::Annotation{<:AbstractTensorMap}, + ) where {RT} + ret = inv(A.val) + primal = EnzymeRules.needs_primal(config) ? ret : nothing + shadow = EnzymeRules.needs_shadow(config) ? make_zero(ret) : nothing + cache = (ret, shadow) + return EnzymeRules.AugmentedReturn(primal, shadow, cache) +end + +function EnzymeRules.reverse( + config::EnzymeRules.RevConfigWidth{1}, + func::Const{typeof(inv)}, + ::Type{RT}, + cache, + A::Annotation{<:AbstractTensorMap}, + ) where {RT} + Ainv, ΔAinv = cache + !isa(A, Const) && mul!(A.dval, Ainv' * ΔAinv, Ainv', -1, One()) + return (nothing,) +end + +function EnzymeRules.forward( + config::EnzymeRules.FwdConfigWidth{1}, + func::Const{typeof(inv)}, + ::Type{RT}, + A::Annotation{<:AbstractTensorMap}, + ) where {RT} + Ainv = inv(A.val) + ΔAinv = !isa(A, Const) ? scale!(Ainv * A.dval * Ainv, -1) : make_zero(Ainv) + if EnzymeRules.needs_primal(config) && EnzymeRules.needs_shadow(config) + return Duplicated(Ainv, ΔAinv) + elseif EnzymeRules.needs_primal(config) + return Ainv + elseif EnzymeRules.needs_shadow(config) + return ΔAinv + else + return nothing + end +end diff --git a/ext/TensorKitEnzymeExt/utility.jl b/ext/TensorKitEnzymeExt/utility.jl new file mode 100644 index 000000000..03ade424a --- /dev/null +++ b/ext/TensorKitEnzymeExt/utility.jl @@ -0,0 +1,80 @@ +# Projection +# ---------- +pullback_dα(α::Const, C::Const, A) = nothing +pullback_dα(α::Const, C::Annotation, A) = nothing +pullback_dα(α::Annotation, C::Const, A) = zero(α.val) +pullback_dα(α::Annotation, C::Annotation, A) = project_scalar(α.val, inner(A, C.dval)) + +pullback_dβ(β::Const, C::Const, Ccache) = nothing +pullback_dβ(β::Const, C::Annotation, Ccache) = nothing +pullback_dβ(β::Annotation, C::Const, Ccache) = zero(β.val) +pullback_dβ(β::Annotation, C::Annotation, Ccache) = project_scalar(β.val, inner(Ccache, C.dval)) + +pullback_dC!(ΔC, β::Number) = scale!(ΔC, conj(β)) + +""" + project_scalar(x::Number, dx::Number) + +Project a computed tangent `dx` onto the correct tangent type for `x`. +For example, we might compute a complex `dx` but only require the real part. +""" +project_scalar(x::Number, dx::Number) = oftype(x, dx) +project_scalar(x::Real, dx::Complex) = project_scalar(x, real(dx)) + +# in-place multiplication and accumulation which might project to (real) +# TODO: this could probably be done without allocating +function project_mul!(C, A, B, α) + TC = TO.promote_contract(scalartype(A), scalartype(B), scalartype(α)) + return if !(TC <: Real) && scalartype(C) <: Real + add!(C, real(mul!(zerovector(C, TC), A, B, α))) + else + mul!(C, A, B, α, One()) + end +end +function project_contract!(C, A, pA, conjA, B, pB, conjB, pAB, α, backend, allocator) + TA = TensorKit.promote_permute(A) + TB = TensorKit.promote_permute(B) + TC = TO.promote_contract(TA, TB, scalartype(α)) + + return if scalartype(C) <: Real && !(TC <: Real) + add!(C, real(TO.tensorcontract!(zerovector(C, TC), A, pA, conjA, B, pB, conjB, pAB, α, Zero(), backend, allocator))) + else + TO.tensorcontract!(C, A, pA, conjA, B, pB, conjB, pAB, α, One(), backend, allocator) + end +end + +# IndexTuple utility +# ------------------ +trivtuple(N) = ntuple(identity, N) + +Base.@constprop :aggressive function _repartition(p::IndexTuple, N₁::Int) + length(p) >= N₁ || + throw(ArgumentError("cannot repartition $(typeof(p)) to $N₁, $(length(p) - N₁)")) + return TupleTools.getindices(p, trivtuple(N₁)), + TupleTools.getindices(p, trivtuple(length(p) - N₁) .+ N₁) +end +Base.@constprop :aggressive function _repartition(p::Index2Tuple, N₁::Int) + return _repartition(linearize(p), N₁) +end +function _repartition(p::Union{IndexTuple, Index2Tuple}, ::Index2Tuple{N₁}) where {N₁} + return _repartition(p, N₁) +end +function _repartition(p::Union{IndexTuple, Index2Tuple}, t::AbstractTensorMap) + return _repartition(p, TensorKit.numout(t)) +end + +# Ignore derivatives +# ------------------ + +@inline EnzymeRules.inactive_type(::Type{<:TensorKit.FusionTree}) = true +@inline EnzymeRules.inactive_type(::Type{<:TensorKit.GenericTreeTransformer}) = true +@inline EnzymeRules.inactive_type(::Type{<:TensorKit.VectorSpace}) = true + +@inline EnzymeRules.inactive(::typeof(TensorKit.sectorstructure), ::Any) = nothing +@inline EnzymeRules.inactive(::typeof(TensorKit.degeneracystructure), ::Any) = nothing +@inline EnzymeRules.inactive(::typeof(TensorKit.select), s::HomSpace, i::Index2Tuple) = nothing +@inline EnzymeRules.inactive(::typeof(TensorKit.flip), s::HomSpace, i::Any) = nothing +@inline EnzymeRules.inactive(::typeof(TensorKit.permute), s::HomSpace, i::Index2Tuple) = nothing +@inline EnzymeRules.inactive(::typeof(TensorKit.braid), s::HomSpace, i::Index2Tuple, ::IndexTuple) = nothing +@inline EnzymeRules.inactive(::typeof(TensorKit.compose), s1::HomSpace, s2::HomSpace) = nothing +@inline EnzymeRules.inactive(::typeof(TensorOperations.tensorcontract), c::HomSpace, p::Index2Tuple, α::Bool, b::HomSpace, q::Index2Tuple, β::Bool, pq::Index2Tuple) = nothing diff --git a/ext/TensorKitEnzymeTestUtilsExt.jl b/ext/TensorKitEnzymeTestUtilsExt.jl new file mode 100644 index 000000000..4a1f393b1 --- /dev/null +++ b/ext/TensorKitEnzymeTestUtilsExt.jl @@ -0,0 +1,66 @@ +module TensorKitEnzymeTestUtilsExt + +using TensorKit +using EnzymeTestUtils +using EnzymeTestUtils: Enzyme +import EnzymeTestUtils: to_vec, from_vec, rand_tangent + +function EnzymeTestUtils.to_vec(x::TensorMap, seen_vecs::EnzymeTestUtils.AliasDict) + has_seen = haskey(seen_vecs, x) + is_const = Enzyme.Compiler.guaranteed_const(Core.Typeof(x)) + if has_seen || is_const + x_vec = Float32[] + else + vec_of_vecs = [b * TensorKit.sqrtdim(c) for (c, b) in blocks(x)] + x_vec, back = to_vec(vec_of_vecs) + seen_vecs[x] = x_vec + end + function TensorMap_from_vec(x_vec_new::AbstractVector, seen_xs::EnzymeTestUtils.AliasDict) + if xor(has_seen, haskey(seen_xs, x)) + throw(ErrorException("Arrays must be reconstructed in the same order as they are vectorized.")) + end + has_seen && return seen_xs[x] + is_const && return x + + x_new = similar(x) + xvec_of_vecs = back(x_vec_new) + for (i, (c, b)) in enumerate(blocks(x_new)) + scale!(b, xvec_of_vecs[i], TensorKit.invsqrtdim(c)) + end + if Core.Typeof(x_new) != Core.Typeof(x) + x_new = Core.Typeof(x)(x_new) + end + seen_xs[x] = x_new + return x_new + end + return x_vec, TensorMap_from_vec +end +function EnzymeTestUtils.to_vec(t::TensorKit.AdjointTensorMap, seen_vecs::EnzymeTestUtils.AliasDict) + parent_vec, parent_t = to_vec(parent(t), seen_vecs) + return parent_vec, adjoint ∘ parent_t +end +function EnzymeTestUtils.to_vec(t::TensorKit.DiagonalTensorMap, seen_vecs::EnzymeTestUtils.AliasDict) + parent_vec, parent_t = to_vec(TensorMap(t), seen_vecs) + return parent_vec, TensorKit.DiagonalTensorMap ∘ parent_t +end + +# generate random tangents for testing +function EnzymeTestUtils.rand_tangent(rng, t::TensorMap) + return TensorMap(EnzymeTestUtils.rand_tangent(rng, t.data), space(t)) +end + +function EnzymeTestUtils.rand_tangent(rng, t::TensorKit.AdjointTensorMap) + return adjoint(rand_tangent(rng, parent(t))) +end + +function EnzymeTestUtils.rand_tangent(rng, t::DiagonalTensorMap) + return DiagonalTensorMap(EnzymeTestUtils.rand_tangent(rng, t.data), space(t, 1)) +end + +function EnzymeTestUtils.map_fields_recursive(f::typeof(Base.copyto!), y::TensorKit.SortedVectorDict{K, V}, x::TensorKit.SortedVectorDict{K, V}) where {K, V} + copyto!(y.keys, x.keys) + copyto!(y.values, x.values) + return y +end + +end diff --git a/test/Project.toml b/test/Project.toml index 18af8af80..5252ff1f4 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -9,6 +9,8 @@ CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a" Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa" +Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" +EnzymeTestUtils = "12d8515a-0907-448a-8884-5fe00fdf1c5a" FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000" GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7" JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b" diff --git a/test/enzyme-linalg/inv.jl b/test/enzyme-linalg/inv.jl new file mode 100644 index 000000000..8e8920f46 --- /dev/null +++ b/test/enzyme-linalg/inv.jl @@ -0,0 +1,28 @@ +using Test, TestExtras +using TensorKit +using Enzyme, EnzymeTestUtils +using Random + +spacelist = ad_spacelist(fast_tests) +eltypes = (Float64, ComplexF64) + +is_ci = get(ENV, "CI", "false") == "true" +TDs = is_ci ? (Duplicated,) : (Const, Duplicated) + +@timedtestset "Enzyme - LinearAlgebra (inv):" begin + @timedtestset "$(TensorKit.type_repr(sectortype(eltype(V)))) ($T)" for V in spacelist, T in eltypes + atol = default_tol(T) + rtol = default_tol(T) + D1 = randn(T, V[1] ← V[1]) + D2 = randn(T, V[1] ⊗ V[2] ← V[1] ⊗ V[2]) + D3 = randn(T, V[1] ⊗ V[2] ⊗ V[3] ← V[1] ⊗ V[2] ⊗ V[3]) + @testset "inv: TD $TD" for TD in TDs + EnzymeTestUtils.test_reverse(inv, TD, (D1, TD); atol, rtol) + EnzymeTestUtils.test_reverse(inv, TD, (D2, TD); atol, rtol) + EnzymeTestUtils.test_reverse(inv, TD, (D3, TD); atol, rtol) + EnzymeTestUtils.test_forward(inv, TD, (D1, TD); atol, rtol) + EnzymeTestUtils.test_forward(inv, TD, (D2, TD); atol, rtol) + EnzymeTestUtils.test_forward(inv, TD, (D3, TD); atol, rtol) + end + end +end diff --git a/test/enzyme-linalg/mul.jl b/test/enzyme-linalg/mul.jl new file mode 100644 index 000000000..c4918d8b9 --- /dev/null +++ b/test/enzyme-linalg/mul.jl @@ -0,0 +1,34 @@ +using Test, TestExtras +using TensorKit +using Enzyme, EnzymeTestUtils +using Random + +spacelist = ad_spacelist(fast_tests) +eltypes = (Float64, ComplexF64) + +is_ci = get(ENV, "CI", "false") == "true" +rTs = is_ci ? (Active,) : (Const, Active) +fTs = is_ci ? (Duplicated,) : (Const, Duplicated) + +@timedtestset verbose = true "Enzyme - LinearAlgebra (mul):" begin + @timedtestset "$(TensorKit.type_repr(sectortype(eltype(V)))) ($T)" for V in spacelist, T in eltypes + atol = default_tol(T) + rtol = default_tol(T) + + C = randn(T, V[1] ⊗ V[2] ← V[5]) + A = randn(T, codomain(C) ← V[3] ⊗ V[4]) + B = randn(T, domain(A) ← domain(C)) + α = randn(T) + β = randn(T) + @testset "mul: TC $TC, TA $TA, TB $TB" for TC in (Duplicated,), TA in (Duplicated,), TB in (Duplicated,) + @testset "Tα $Tα, Tβ $Tβ" for Tα in rTs, Tβ in rTs + EnzymeTestUtils.test_reverse(mul!, TC, (C, TC), (A, TA), (B, TB), (α, Tα), (β, Tβ); atol, rtol) + end + @testset "Tα $Tα, Tβ $Tβ" for Tα in fTs, Tβ in fTs + EnzymeTestUtils.test_forward(mul!, TC, (C, TC), (A, TA), (B, TB), (α, Tα), (β, Tβ); atol, rtol) + end + EnzymeTestUtils.test_reverse(mul!, TC, (C, TC), (A, TA), (B, TB); atol, rtol) + EnzymeTestUtils.test_forward(mul!, TC, (C, TC), (A, TA), (B, TB); atol, rtol) + end + end +end diff --git a/test/enzyme-linalg/norm.jl b/test/enzyme-linalg/norm.jl new file mode 100644 index 000000000..d12dccc1a --- /dev/null +++ b/test/enzyme-linalg/norm.jl @@ -0,0 +1,27 @@ +using Test, TestExtras +using TensorKit +using Enzyme, EnzymeTestUtils +using Random + +spacelist = ad_spacelist(fast_tests) +eltypes = (Float64, ComplexF64) + +is_ci = get(ENV, "CI", "false") == "true" +rRTs = is_ci ? (Active,) : (Const, Active) +fRTs = is_ci ? (Duplicated,) : (Const, Duplicated) + +@timedtestset verbose = true "Enzyme - LinearAlgebra (norm):" begin + @timedtestset "$(TensorKit.type_repr(sectortype(eltype(V)))) ($T), TC $TC" for V in spacelist, T in eltypes, TC in (Const, Duplicated) + atol = default_tol(T) + rtol = default_tol(T) + C = randn(T, V[1] ⊗ V[2] ← (V[3] ⊗ V[4] ⊗ V[5])') + for RT in rRTs + EnzymeTestUtils.test_reverse(norm, RT, (C, TC), (2, Const); atol, rtol) + EnzymeTestUtils.test_reverse(norm, RT, (C', TC), (2, Const); atol, rtol) + end + for RT in fRTs + EnzymeTestUtils.test_forward(norm, RT, (C, TC), (2, Const); atol, rtol) + EnzymeTestUtils.test_forward(norm, RT, (C', TC), (2, Const); atol, rtol) + end + end +end diff --git a/test/enzyme-linalg/tr.jl b/test/enzyme-linalg/tr.jl new file mode 100644 index 000000000..1ba0c2df7 --- /dev/null +++ b/test/enzyme-linalg/tr.jl @@ -0,0 +1,33 @@ +using Test, TestExtras +using TensorKit +using Enzyme, EnzymeTestUtils +using Random + +is_ci = get(ENV, "CI", "false") == "true" + +spacelist = ad_spacelist(fast_tests) +eltypes = (Float64, ComplexF64) + +rRTs = is_ci ? (Active,) : (Const, Active) +fRTs = is_ci ? (Duplicated,) : (Const, Duplicated) +TDs = is_ci ? (Duplicated,) : (Const, Duplicated) + +@timedtestset verbose = true "Enzyme - LinearAlgebra (tr):" begin + @timedtestset "$(TensorKit.type_repr(sectortype(eltype(V)))) ($T)" for V in spacelist, T in eltypes + atol = default_tol(T) + rtol = default_tol(T) + D1 = randn(T, V[1] ← V[1]) + D2 = randn(T, V[1] ⊗ V[2] ← V[1] ⊗ V[2]) + D3 = randn(T, V[1] ⊗ V[2] ⊗ V[3] ← V[1] ⊗ V[2] ⊗ V[3]) + @testset "tr reverse: RT $RT, TD $TD" for RT in rRTs, TD in TDs + EnzymeTestUtils.test_reverse(tr, RT, (D1, TD); atol, rtol) + EnzymeTestUtils.test_reverse(tr, RT, (D2, TD); atol, rtol) + EnzymeTestUtils.test_reverse(tr, RT, (D3, TD); atol, rtol) + end + @testset "tr forward: RT $RT, TD $TD" for RT in fRTs, TD in TDs + EnzymeTestUtils.test_forward(tr, RT, (D1, TD); atol, rtol) + EnzymeTestUtils.test_forward(tr, RT, (D2, TD); atol, rtol) + EnzymeTestUtils.test_forward(tr, RT, (D3, TD); atol, rtol) + end + end +end