[CK_TILE] B matrix 2D block scale gemm (#3074)

* Refactor quant group size to be configurable for M/N/K, not just K

* add some asserts for configurations not implemented

* start setting of group size for N dimension

* enable 2d for reference quant gemm

* WIP: trying to figure out tile dstr and/or indexing for scale matrix

* WIP

* Fix handling of n dim blocks in tile windows etc

* remove commented code and enable all tests again

* fix formatting

* Add more specialized tile distributions

* Enable NWarps replication for bquant tile dstr

* fix formatting

* fix format

* Fix some issues from the merge

* fix formatting

* one more fix to tile dstr, and revert debug initialization

* Remove commented code

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* simplify conditions that are needed for tile distributions

* only enable the working group sizes in tests

* fix formatting

* Update tile distribution for 2D bquant

* add some documentation and 2d block scale example

* fix formatting

* Add in Changlog and restructure the quant 2d example

* fix CMake

* support the change for blockscale 2d

* fix the test file

---------

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Co-authored-by: Cong Ma <congma13@amd.com>
Co-authored-by: ThomasNing <thomas.ning@amd.com>
This commit is contained in:
Sami Remes
2025-11-03 00:49:20 +00:00
committed by GitHub
parent 73f637894d
commit 16e85cf179
24 changed files with 476 additions and 363 deletions

View File

@@ -16,7 +16,7 @@ template <typename ADataType,
typename BDataType,
typename AccDataType,
typename CDataType,
uint32_t QuantGroupSize,
typename QuantGroupSize,
bool aquant,
typename AElementOp = ck_tile::identity,
typename BElementOp = ck_tile::identity,
@@ -80,12 +80,11 @@ CK_TILE_HOST void reference_gemm_quant(const HostTensor<ADataType>& a_m_k,
v_block_acc += v_a * v_b;
// Apply group dequant scale
if((k + 1) % QuantGroupSize == 0)
if((k + 1) % QuantGroupSize::kK == 0)
{
float scale = 0.f;
index_t outer_dim = (aquant) ? m : k / QuantGroupSize;
index_t inner_dim = (aquant) ? k / QuantGroupSize : n;
index_t outer_dim = (aquant) ? (m / QuantGroupSize::kM) : (k / QuantGroupSize::kK);
index_t inner_dim = (aquant) ? (k / QuantGroupSize::kK) : (n / QuantGroupSize::kN);
if constexpr(std::is_same_v<QDataType, float>)
{
scale = q(outer_dim, inner_dim);