ck_tile kernel for gemm with groupwise quantized B tensor. (#2663)

* This change introduces new pipelines with Intrawave scheduler and block gemm primitives that loads the scale tensor to registers to perform dequantization post MFMA on C tensor in registers.

Scale tensor data, BQ is spliced across threads in registers and not stored in LDS.

Current support is for the following combinations, but it should be fairly straightforward to extend support to more formats.

fp8, fp8 -> f32
bf8, bf8 -> f32
fp8, i4 -> f32
bf8, i4 -> f32
Group size can go down to as low as K length of underlying WarpGemm primitive.

* Solve merge conflict

* [CK TILE] Update CHANGELOG.md

---------

Co-authored-by: Vijay Krishnamoorthy <vjkrish@fb.com>
Co-authored-by: ThomasNing <thomas.ning@amd.com>
Co-authored-by: Cong Ma <congma13@amd.com>
This commit is contained in:
Vijay Krish
2025-08-28 23:43:02 -07:00
committed by GitHub
parent 428090f749
commit 4208e28988
20 changed files with 2471 additions and 26 deletions

View File

@@ -250,8 +250,7 @@ CK_TILE_HOST_DEVICE constexpr pk_fp4_t fp16x2_to_pk_fp4(const fp16x2_t& x, float
#if CK_TILE_FP4_CVT_DEVICE
return impl::_to_f4(x, scale);
#else
return pk_fp4_t::pack(float_to_e2m1(type_convert<float>(x[0]), scale),
float_to_e2m1(type_convert<float>(x[1]), scale));
return pk_fp4_t::pack(float_to_e2m1(x[0], scale), float_to_e2m1(x[1], scale));
#endif
}
CK_TILE_HOST_DEVICE constexpr pk_fp4_t bf16x2_to_pk_fp4(const bf16x2_t& x, float scale)
@@ -259,8 +258,7 @@ CK_TILE_HOST_DEVICE constexpr pk_fp4_t bf16x2_to_pk_fp4(const bf16x2_t& x, float
#if CK_TILE_FP4_CVT_DEVICE
return impl::_to_f4(x, scale);
#else
return pk_fp4_t::pack(float_to_e2m1(type_convert<float>(x[0]), scale),
float_to_e2m1(type_convert<float>(x[1]), scale));
return pk_fp4_t::pack(float_to_e2m1(x[0], scale), float_to_e2m1(x[1], scale));
#endif
}
CK_TILE_HOST_DEVICE constexpr pk_fp4_t fp32x2_to_pk_fp4(const fp32x2_t& x, float scale)