Skip to content

Refactor(linear): split LinearBackward kernel into 3 independent kernels#142

Merged
kilinchange merged 8 commits into
masterfrom
split_linear_backward
May 12, 2026
Merged

Refactor(linear): split LinearBackward kernel into 3 independent kernels#142
kilinchange merged 8 commits into
masterfrom
split_linear_backward

Conversation

@chen2021673
Copy link
Copy Markdown
Contributor

@chen2021673 chen2021673 commented Apr 10, 2026

Summary

Architecture refactoring of Linear/Matmul/Outer kernels.

The core idea is separation of concerns — moving the decision of whether a gradient should be computed from the kernel layer up to the autograd layer, making kernels pure compute functions. At the same time, unified GEMM/SGEMV primitives are abstracted at the bottom layer to eliminate duplicated cuBLAS boilerplate.

Changes

  • Autograd layer: LinearBackward and MatmulBackward are each decomposed into multiple independent Dispatcher calls. The needs_input_grad checks happen at the autograd layer, invoking only the kernels actually needed.
  • Kernel layer: The monolithic LinearBackward is split into LinearBackwardInput / LinearBackwardWeight / LinearBackwardBias; MatmulBackward is split into MatmulBackwardInput / MatmulBackwardOther, with naming aligned to MatmulForward(input, other).
  • File split: Matmul kernels are extracted from linear.cc / linear.cu into dedicated cpu/matmul.cc and cuda/matmul.cu, giving each file a single responsibility.
  • GEMM primitive: New gemm.cuh / gemm.cu define the GemmParams struct and GemmCuda(), providing a unified wrapper over cublasGemmEx and cublasGemmStridedBatchedEx branching logic. GetCublasHandle() / GetCudaStream() are centrally defined and shared across linear.cu, matmul.cu, and outer.cu, eliminating duplicate definitions.
  • SGEMV primitive: New SgemvParams struct and SgemvCuda() wrap the cublasSgemv call. LinearForward and LinearBackwardInput in linear.cu take the SGEMV path when bs==1 and fp32 (more efficient for matrix-vector shapes); bf16 falls back to GemmCuda since cublasSgemv does not support it. The fp32 backward path in outer.cu is migrated to SgemvCuda as well, eliminating inline cublasSgemv calls.

@chen2021673 chen2021673 force-pushed the split_linear_backward branch 3 times, most recently from 283d083 to 23d301b Compare April 15, 2026 01:58
@chen2021673 chen2021673 requested a review from kilinchange April 15, 2026 02:08
Move grad_flags logic from kernel to autograd layer. The
monolithic LinearBackward kernel is replaced by LinearBackwardInput,
LinearBackwardWeight, and LinearBackwardBias — each a pure compute
operation with no autograd-related parameters.
Move needs_input_grad logic from kernel to autograd layer. The monolithic MatmulBackward kernel
is replaced by MatmulBackwardInput1 and MatmulBackwardInput2.
…ls; rename MatmulBackwardInput1/2

- Add gemm.cuh / gemm.cu: GemmParams struct + GemmCuda() dispatch (cublasGemmEx or
  cublasGemmStridedBatchedEx based on batch_count), GetCublasHandle(), GetCudaStream()
  shared across all GEMM-using kernels
- Split matmul kernels (CPU + CUDA) out of linear.cc / linear.cu into dedicated
  matmul.cc / matmul.cu; linear.* now only contains the four Linear kernels
- Rename MatmulBackwardInput1 → MatmulBackwardInput, MatmulBackwardInput2 → MatmulBackwardOther
  for semantic clarity matching MatmulForward(input, other) parameter names
- Rewrite outer.cu to use GemmCuda() (OuterForward + bf16 backward paths);
  keep cublasSgemv for the fp32 backward path (more efficient, bf16 unsupported)
@chen2021673 chen2021673 force-pushed the split_linear_backward branch 2 times, most recently from ae80cec to 88579ba Compare April 28, 2026 09:06
@chen2021673 chen2021673 force-pushed the split_linear_backward branch from 88579ba to 252e6cd Compare April 28, 2026 09:21
@Chamberlain0w0
Copy link
Copy Markdown
Contributor

另外我看 cpu 的改动也挺多,但看不出什么问题,最好也辛苦验证下精度没问题

Comment thread infini_train/src/kernels/cuda/linear.cu Outdated
Comment thread infini_train/src/autograd/matmul.cc
Comment thread infini_train/src/autograd/matmul.cc Outdated
Comment thread infini_train/include/common/cuda/gemm.cuh Outdated
…s to designated initializers

- Save input1_dims_/input2_dims_ in Matmul::SetupContext to avoid Dims()
  calls on potentially-null saved tensors in Backward
- Get device from grad_output instead of input1 in Matmul::Backward
- Add CHECK guards before dereferencing nullable saved tensors
- Convert all GemmParams/SgemvParams construction in linear.cu, matmul.cu,
  outer.cu to C++20 designated initializer form
@kilinchange kilinchange requested a review from Chamberlain0w0 May 7, 2026 02:17
…evice param

GemmParams and SgemvParams are pure problem descriptions and should not
carry runtime state. Move handle acquisition into GemmCuda/SgemvCuda via
a device parameter, inline the dynamic_cast directly. Remove the public
GetCublasHandle/GetCudaStream helpers from gemm.cuh.
@chen2021673
Copy link
Copy Markdown
Contributor Author

另外我看 cpu 的改动也挺多,但看不出什么问题,最好也辛苦验证下精度没问题

img_v3_0211f_9143a2bd-b6f7-431d-95e5-ea1fc7536c0g 已验证,能完全对齐

Comment thread infini_train/src/kernels/cuda/common/gemm.cuh
Comment thread infini_train/src/kernels/cuda/gemm.cu Outdated
Comment thread infini_train/src/kernels/cuda/linear.cu
Comment thread infini_train/src/kernels/cuda/linear.cu
Comment thread infini_train/src/kernels/cuda/linear.cu
@kilinchange
Copy link
Copy Markdown
Collaborator

麻烦贴一下测试通过的截图。

include/ is for public-facing interfaces only; gemm primitives are
internal, so relocate them under src/. Update all include paths.
Also rename ctype -> compute_type and add FIXME on bf16 output dtype
promotion hack in linear backward passes.
@chen2021673
Copy link
Copy Markdown
Contributor Author

chen2021673 commented May 12, 2026

麻烦贴一下测试通过的截图。

baseline:/data/shared/InfiniTrain-dev/logs/202511_a800/20260508/fix/negative_compile_test/logs
精度对比结果:
image

性能对比结果:有6个测例退化超20%
image

#include "infini_train/include/core/runtime/device_guard.h"
#include "infini_train/include/dispatcher.h"
#include "infini_train/include/tensor.h"

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这一行不应该删吧,下面这个分组是项目内头文件。

Comment thread infini_train/src/kernels/cuda/matmul.cu
Comment thread infini_train/src/kernels/cuda/outer.cu
@kilinchange kilinchange merged commit 7124391 into master May 12, 2026
2 checks passed
@kilinchange kilinchange deleted the split_linear_backward branch May 12, 2026 09:58
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants