[CK_TILE] B matrix 2D block scale gemm (#3074)

* Refactor quant group size to be configurable for M/N/K, not just K

* add some asserts for configurations not implemented

* start setting of group size for N dimension

* enable 2d for reference quant gemm

* WIP: trying to figure out tile dstr and/or indexing for scale matrix

* WIP

* Fix handling of n dim blocks in tile windows etc

* remove commented code and enable all tests again

* fix formatting

* Add more specialized tile distributions

* Enable NWarps replication for bquant tile dstr

* fix formatting

* fix format

* Fix some issues from the merge

* fix formatting

* one more fix to tile dstr, and revert debug initialization

* Remove commented code

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* simplify conditions that are needed for tile distributions

* only enable the working group sizes in tests

* fix formatting

* Update tile distribution for 2D bquant

* add some documentation and 2d block scale example

* fix formatting

* Add in Changlog and restructure the quant 2d example

* fix CMake

* support the change for blockscale 2d

* fix the test file

---------

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Co-authored-by: Cong Ma <congma13@amd.com>
Co-authored-by: ThomasNing <thomas.ning@amd.com>
This commit is contained in:
Sami Remes
2025-11-03 00:49:20 +00:00
committed by GitHub
parent 73f637894d
commit 16e85cf179
24 changed files with 476 additions and 363 deletions

View File

@@ -93,6 +93,8 @@ class TestCkTileGroupedGemmQuant : public ::testing::Test
constexpr ck_tile::index_t TileParitionerGroupNum = 8;
constexpr ck_tile::index_t TileParitionerM01 = 4;
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
using GemmShape =
ck_tile::TileGemmShape<ck_tile::sequence<GroupedGemKernelParam::M_Tile,
GroupedGemKernelParam::N_Tile,
@@ -135,7 +137,7 @@ class TestCkTileGroupedGemmQuant : public ::testing::Test
AccDataType,
GemmShape,
GemmUniversalTraits,
128>, // QuantGroupSize
QuantGroupSize>,
ck_tile::GemmRowColTensorQuantPipelineProblem<ADataType,
BDataType,
AccDataType,
@@ -258,6 +260,8 @@ class TestCkTileGroupedGemmQuant : public ::testing::Test
std::vector<std::unique_ptr<ck_tile::DeviceMem>> aq_dev_buf;
std::vector<std::unique_ptr<ck_tile::DeviceMem>> bq_dev_buf;
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
a_m_k_dev_buf.reserve(group_count);
b_k_n_dev_buf.reserve(group_count);
c_m_n_dev_buf.reserve(group_count);
@@ -495,7 +499,7 @@ class TestCkTileGroupedGemmQuant : public ::testing::Test
BDataType,
AccDataType,
CDataType,
128,
QuantGroupSize,
false>(
a_m_k_tensors[i], bq_tensors[i], b_k_n_tensors[i], c_m_n_host_ref);
}