Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 25 additions & 6 deletions src/pullbacks/qr.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,18 @@ function check_and_prepare_qr_cotangents(
ΔR₁₁ = UpperTriangular(view(ΔR, 1:p, 1:p))
ΔR₁₂ = view(ΔR, 1:p, (p + 1):n)
ΔR₂₂ = view(ΔR, (p + 1):minmn, (p + 1):n)
Δgauge_R = norm(view(ΔR₂₂, uppertriangularind(ΔR₂₂)), Inf)
Δgauge_R = max(Δgauge_R, norm(view(ΔR₂₂, diagind(ΔR₂₂)), Inf))
Δgauge = max(Δgauge, Δgauge_R)
if p < minmn # otherwise ΔR₂₂ is empty
# uppertriangularind generates linear indices
# compute the appropriate offset in ΔR so we aren't
# operating on a view-of-view, which doesn't work
# for GPU arrays
I = uppertriangularind(ΔR₂₂)
upper_inds = view(LinearIndices(ΔR), (p + 1):minmn, (p + 1):n)[I]
ΔR₂₂upper = view(ΔR, upper_inds)
Δgauge_R = norm(ΔR₂₂upper, Inf)
Δgauge_R = max(Δgauge_R, norm(view(ΔR₂₂, diagind(ΔR₂₂)), Inf))
Δgauge = max(Δgauge, Δgauge_R)
end
else
ΔR₁₁ = nothing
ΔR₁₂ = nothing
Expand Down Expand Up @@ -75,7 +84,7 @@ function qr_pullback!(


Q₁ = view(Q, :, 1:p)
R₁₁ = UpperTriangular(view(R, 1:p, 1:p))
R₁₁ = UpperTriangular(R[1:p, 1:p])

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a subtle and impactful change. The UpperTriangular wrapper is really only necessary to enable the rdiv! call below. If GPUs cannot deal with UpperTriangular of a view of a GPUArray, then maybe we need to call the corresponding BLAS/LAPACK methods directly, or have some intermediate wrapper like rdiv_uppertriangular!.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If GPUs cannot deal with UpperTriangular of a view of a GPUArray

Indeed they can't :(

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I even wonder how rdiv!(::Matrix, ::UpperTriangular) is evaluated on the GPU, since you need cuSOLVERDx to access TRSM.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

R₁₂ = view(R, 1:p, (p + 1):n)

ΔA₁ = view(ΔA, :, 1:p)
Expand All @@ -101,7 +110,8 @@ function qr_pullback!(
Md = diagview(M)
Md .= real.(Md)
end
ΔA₁ .+= rdiv!(mul!(ΔQ₁, Q₁, M, +1, 1), R₁₁')
mul!(ΔQ₁, Q₁, M, +1, 1)
ΔA₁ .+= rdiv!(ΔQ₁, R₁₁')
return ΔA
end

Expand Down Expand Up @@ -160,7 +170,16 @@ function remove_qr_gauge_dependence!(ΔQ, ΔR, A, Q, R; rank_atol = MatrixAlgebr
end
ΔR₂₂ = view(ΔR, (r + 1):minmn, (r + 1):size(R, 2))
zero!(diagview(ΔR₂₂))
zero!(view(ΔR₂₂, uppertriangularind(ΔR₂₂)))
if r < minmn
# uppertriangularind generates linear indices
# compute the appropriate offset in ΔR so we aren't
# operating on a view-of-view, which doesn't work
# for GPU arrays
offset = LinearIndices(ΔR)[r + 1, r + 1]
upper_inds = uppertriangularind(ΔR₂₂) .+ offset
Comment on lines +178 to +179

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is still using the old (faulty) strategy, so it needs to be updated in the same way as in the gauge check above. Maybe that is the origin of the remaining failures.

ΔR₂₂upper = view(ΔR, upper_inds)
zero!(ΔR₂₂upper)
end
return ΔQ, ΔR
end

Expand Down
7 changes: 7 additions & 0 deletions test/mooncake/qr.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,4 +20,11 @@ for T in (BLASFloats..., GenericFloats...), n in (17, m, 23)
TestSuite.test_mooncake_qr(AT, (m, m); atol = m * n * TestSuite.precision(T), rtol = m * n * TestSuite.precision(T))
end
end
if T ∈ BLASFloats && CUDA.functional()
TestSuite.test_mooncake_qr(CuMatrix{T}, (m, n); atol = m * n * TestSuite.precision(T), rtol = m * n * TestSuite.precision(T))
#=if m == n
AT = Diagonal{T, CuVector{T}}
TestSuite.test_mooncake_qr(AT, (m, m); atol = m * n * TestSuite.precision(T), rtol = m * n * TestSuite.precision(T))
end=# # currently broken
end
end
Loading