Refactor(linear): split LinearBackward kernel into 3 independent kernels#142
Merged
Conversation
283d083 to
23d301b
Compare
Move grad_flags logic from kernel to autograd layer. The monolithic LinearBackward kernel is replaced by LinearBackwardInput, LinearBackwardWeight, and LinearBackwardBias — each a pure compute operation with no autograd-related parameters.
Move needs_input_grad logic from kernel to autograd layer. The monolithic MatmulBackward kernel is replaced by MatmulBackwardInput1 and MatmulBackwardInput2.
…ls; rename MatmulBackwardInput1/2 - Add gemm.cuh / gemm.cu: GemmParams struct + GemmCuda() dispatch (cublasGemmEx or cublasGemmStridedBatchedEx based on batch_count), GetCublasHandle(), GetCudaStream() shared across all GEMM-using kernels - Split matmul kernels (CPU + CUDA) out of linear.cc / linear.cu into dedicated matmul.cc / matmul.cu; linear.* now only contains the four Linear kernels - Rename MatmulBackwardInput1 → MatmulBackwardInput, MatmulBackwardInput2 → MatmulBackwardOther for semantic clarity matching MatmulForward(input, other) parameter names - Rewrite outer.cu to use GemmCuda() (OuterForward + bf16 backward paths); keep cublasSgemv for the fp32 backward path (more efficient, bf16 unsupported)
ae80cec to
88579ba
Compare
…es in linear kernels
88579ba to
252e6cd
Compare
Contributor
|
另外我看 cpu 的改动也挺多,但看不出什么问题,最好也辛苦验证下精度没问题 |
Chamberlain0w0
requested changes
Apr 28, 2026
…s to designated initializers - Save input1_dims_/input2_dims_ in Matmul::SetupContext to avoid Dims() calls on potentially-null saved tensors in Backward - Get device from grad_output instead of input1 in Matmul::Backward - Add CHECK guards before dereferencing nullable saved tensors - Convert all GemmParams/SgemvParams construction in linear.cu, matmul.cu, outer.cu to C++20 designated initializer form
…evice param GemmParams and SgemvParams are pure problem descriptions and should not carry runtime state. Move handle acquisition into GemmCuda/SgemvCuda via a device parameter, inline the dynamic_cast directly. Remove the public GetCublasHandle/GetCudaStream helpers from gemm.cuh.
Chamberlain0w0
approved these changes
May 7, 2026
Contributor
Author
kilinchange
requested changes
May 7, 2026
Collaborator
|
麻烦贴一下测试通过的截图。 |
include/ is for public-facing interfaces only; gemm primitives are internal, so relocate them under src/. Update all include paths. Also rename ctype -> compute_type and add FIXME on bf16 output dtype promotion hack in linear backward passes.
Contributor
Author
kilinchange
requested changes
May 12, 2026
| #include "infini_train/include/core/runtime/device_guard.h" | ||
| #include "infini_train/include/dispatcher.h" | ||
| #include "infini_train/include/tensor.h" | ||
|
|
Collaborator
There was a problem hiding this comment.
这一行不应该删吧,下面这个分组是项目内头文件。
kilinchange
approved these changes
May 12, 2026
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.



Summary
Architecture refactoring of Linear/Matmul/Outer kernels.
The core idea is separation of concerns — moving the decision of whether a gradient should be computed from the kernel layer up to the autograd layer, making kernels pure compute functions. At the same time, unified GEMM/SGEMV primitives are abstracted at the bottom layer to eliminate duplicated cuBLAS boilerplate.
Changes