From 83f94ce673a57ba1c5385303423455229d99fbec Mon Sep 17 00:00:00 2001 From: gasoonjia Date: Sat, 6 Jun 2026 00:58:58 -0700 Subject: [PATCH 1/2] [cuda] int4 W4A8 matvec: vectorized activation load (16B-aligned Q8Block) The decode-only int4_plain_mm matvec was bound by activation load-instruction throughput, not DRAM bandwidth (already ~64% peak) or latency. Each inner iteration issued ~15 loads per 16-byte weight chunk: 8 scalar int32 activation loads + the same per-block scale d reloaded 4x. Align Q8Block to 16 bytes (sizeof 36->48) so each block's qs_even/qs_odd 16B halves are 16B-aligned, then load a whole activation block with two vectorized uint4 loads + one d load (~4x fewer activation loads). dp4a math and accumulation order are bit-identical; the int8 activation values and scale are unchanged. gemma4_31b decode (long-ctx harness, stacked on optimize_1): decode 43.98 -> 46.79 tok/s (+6.4%) prefill 1193 -> 1186 (noise; int4_plain_mm is decode-only) nsys: int4 matvec avg 38.4 -> 34.75 us (-9.5%); quant kernel unchanged. Unit tests test_aoti_torch_cuda_int4_plain_mm: 6/6 pass (M=1/8, gs=16/32/128). --- backends/cuda/runtime/shims/int4_plain_mm.cuh | 37 +++++++++++-------- 1 file changed, 21 insertions(+), 16 deletions(-) diff --git a/backends/cuda/runtime/shims/int4_plain_mm.cuh b/backends/cuda/runtime/shims/int4_plain_mm.cuh index 31214bc0bf6..db54da91687 100644 --- a/backends/cuda/runtime/shims/int4_plain_mm.cuh +++ b/backends/cuda/runtime/shims/int4_plain_mm.cuh @@ -55,7 +55,11 @@ __host__ __forceinline__ int32_t log2_pow2(int32_t v) { // blocks) // --------------------------------------------------------------------------- -struct Q8Block { +// alignas(16) pads sizeof(Q8Block) to 48 so each block (and its qs_even/qs_odd +// 16-byte halves) is 16-byte aligned. This lets the matvec load a whole block's +// int8 activations with two vectorized uint4 loads instead of eight scalar +// int32 loads, cutting activation load instructions ~4x. +struct alignas(16) Q8Block { int8_t qs_even[Q8_BLOCK_SIZE / 2]; int8_t qs_odd[Q8_BLOCK_SIZE / 2]; float d; // scale @@ -149,6 +153,18 @@ __global__ void __launch_bounds__(MV_THREADS) int32_t k_base = i * 32; uint32_t words[4] = {packed16.x, packed16.y, packed16.z, packed16.w}; + // One uint4 (32 weights) maps to exactly one Q8 activation block (32 + // activations), i.e. q8_block_idx == i. Load the whole block with two + // vectorized uint4 loads (+ one scale load) instead of eight scalar int32 + // loads. ae.{x,y,z,w} == qs_even[0:4],[4:8],[8:12],[12:16] == a_even for + // w=0..3 (same for ao/qs_odd) -> bit-identical to the scalar path. + const Q8Block* qb = &q8_row[i]; + uint4 ae = *reinterpret_cast(qb->qs_even); + uint4 ao = *reinterpret_cast(qb->qs_odd); + float a_scale = qb->d; + const uint32_t a_even[4] = {ae.x, ae.y, ae.z, ae.w}; + const uint32_t a_odd[4] = {ao.x, ao.y, ao.z, ao.w}; + #pragma unroll for (int32_t w = 0; w < 4; w++) { uint32_t packed = words[w]; @@ -164,22 +180,11 @@ __global__ void __launch_bounds__(MV_THREADS) int32_t vi_lo = packed & 0x0F0F0F0F; int32_t vi_hi = (packed >> 4) & 0x0F0F0F0F; - int32_t q8_block_idx = k_word / Q8_BLOCK_SIZE; - int32_t q8_half_offset = (k_word % Q8_BLOCK_SIZE) / 2; - const Q8Block* qb = &q8_row[q8_block_idx]; - - int32_t a_even = - *reinterpret_cast(qb->qs_even + q8_half_offset); - int32_t a_odd = - *reinterpret_cast(qb->qs_odd + q8_half_offset); - - int32_t dp = __dp4a(vi_lo, a_even, 0); - dp = __dp4a(vi_hi, a_odd, dp); - - float a_scale = qb->d; + int32_t dp = __dp4a(vi_lo, static_cast(a_even[w]), 0); + dp = __dp4a(vi_hi, static_cast(a_odd[w]), dp); - int32_t a_sum8 = __dp4a(0x01010101, a_even, 0); - a_sum8 = __dp4a(0x01010101, a_odd, a_sum8); + int32_t a_sum8 = __dp4a(0x01010101, static_cast(a_even[w]), 0); + a_sum8 = __dp4a(0x01010101, static_cast(a_odd[w]), a_sum8); sum += ws * a_scale * (static_cast(dp) - wz * static_cast(a_sum8)); From 457a316ba9600247f128c501b17d5ee11d2e4244 Mon Sep 17 00:00:00 2001 From: gasoonjia Date: Tue, 9 Jun 2026 09:42:09 -0700 Subject: [PATCH 2/2] int8 vec support --- backends/cuda/runtime/shims/int8_plain_mm.cuh | 24 +++++++++++++------ 1 file changed, 17 insertions(+), 7 deletions(-) diff --git a/backends/cuda/runtime/shims/int8_plain_mm.cuh b/backends/cuda/runtime/shims/int8_plain_mm.cuh index 2c478854644..8458c7680b5 100644 --- a/backends/cuda/runtime/shims/int8_plain_mm.cuh +++ b/backends/cuda/runtime/shims/int8_plain_mm.cuh @@ -58,7 +58,11 @@ __host__ __forceinline__ int32_t log2_pow2_i8(int32_t v) { // blocks, NATURAL order — qs[k] holds the quantized value for element k). // --------------------------------------------------------------------------- -struct Q8BlockNat { +// alignas(16) pads sizeof(Q8BlockNat) 36->48 so each block (and its two 16-byte +// qs halves) is 16-byte aligned. This lets the matvec load 16 int8 activations +// with one vectorized uint4 load instead of four scalar int32 loads, cutting +// activation load instructions ~4x. +struct alignas(16) Q8BlockNat { int8_t qs[Q8_NAT_BLOCK_SIZE]; float d; // scale }; @@ -135,6 +139,17 @@ __global__ void __launch_bounds__(MV8_THREADS) int8_w8a8_matvec_kernel( int32_t k_base = i * 16; uint32_t words[4] = {packed16.x, packed16.y, packed16.z, packed16.w}; + // One uint4 (16 int8 weights) maps to exactly one 16-byte half of a Q8 + // activation block (16 activations): block i>>1, byte offset 0 (i even) or + // 16 (i odd). Load those 16 int8 activations with a single vectorized uint4 + // load (+ one scale load) instead of four scalar int32 loads + four scale + // reloads. av.{x,y,z,w} == qs[off+0:4],[4:8],[8:12],[12:16] == a_word for + // w=0..3 -> bit-identical to the scalar path. + const Q8BlockNat* qb = &q8_row[i >> 1]; + uint4 av = *reinterpret_cast(qb->qs + ((i & 1) ? 16 : 0)); + float a_scale = qb->d; + const uint32_t a_words[4] = {av.x, av.y, av.z, av.w}; + #pragma unroll for (int32_t w = 0; w < 4; w++) { int32_t k_word = k_base + w * 4; // 4 int8 weights start here @@ -147,15 +162,10 @@ __global__ void __launch_bounds__(MV8_THREADS) int8_w8a8_matvec_kernel( } int32_t w_word = static_cast(words[w]); - - int32_t q8_block_idx = k_word / Q8_NAT_BLOCK_SIZE; - int32_t q8_offset = k_word % Q8_NAT_BLOCK_SIZE; - const Q8BlockNat* qb = &q8_row[q8_block_idx]; - int32_t a_word = *reinterpret_cast(qb->qs + q8_offset); + int32_t a_word = static_cast(a_words[w]); int32_t dp = __dp4a(w_word, a_word, 0); int32_t a_sum = __dp4a(0x01010101, a_word, 0); - float a_scale = qb->d; sum += ws * a_scale * (static_cast(dp) - wz * static_cast(a_sum));