[CK_TILE] B matrix 2D block scale gemm (#3074)

* Refactor quant group size to be configurable for M/N/K, not just K

* add some asserts for configurations not implemented

* start setting of group size for N dimension

* enable 2d for reference quant gemm

* WIP: trying to figure out tile dstr and/or indexing for scale matrix

* WIP

* Fix handling of n dim blocks in tile windows etc

* remove commented code and enable all tests again

* fix formatting

* Add more specialized tile distributions

* Enable NWarps replication for bquant tile dstr

* fix formatting

* fix format

* Fix some issues from the merge

* fix formatting

* one more fix to tile dstr, and revert debug initialization

* Remove commented code

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* simplify conditions that are needed for tile distributions

* only enable the working group sizes in tests

* fix formatting

* Update tile distribution for 2D bquant

* add some documentation and 2d block scale example

* fix formatting

* Add in Changlog and restructure the quant 2d example

* fix CMake

* support the change for blockscale 2d

* fix the test file

---------

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Co-authored-by: Cong Ma <congma13@amd.com>
Co-authored-by: ThomasNing <thomas.ning@amd.com>
This commit is contained in:
Sami Remes
2025-11-03 00:49:20 +00:00
committed by GitHub
parent 73f637894d
commit 16e85cf179
24 changed files with 476 additions and 363 deletions

View File

@@ -685,9 +685,10 @@ struct QuantGemmKernel
else if constexpr(kQuantType == QuantType::BQuantGrouped)
{
static_assert(std::is_same_v<BQLayout, tensor_layout::gemm::ColumnMajor>);
using QuantGroupSize = remove_cvref_t<typename GemmPipeline::QuantGroupSize>;
return make_naive_tensor_view<address_space_enum::global>(
bq_ptr,
make_tuple(kargs.QK_B, kargs.N),
make_tuple(kargs.QK_B, integer_divide_ceil(kargs.N, QuantGroupSize::kN)),
make_tuple(1, kargs.stride_BQ),
number<GemmPipeline::GetVectorSizeBQ()>{},
number<1>{});
@@ -831,10 +832,10 @@ struct QuantGemmKernel
if constexpr(kQuantType == QuantType::AQuantGrouped && PreshuffleQuant)
{
static_assert(std::is_same_v<AQLayout, tensor_layout::gemm::RowMajor>);
using QuantGroupSize = remove_cvref_t<typename GemmPipeline::QuantGroupSize>;
constexpr auto block_m = TilePartitioner::MPerBlock;
constexpr auto warp_m = TilePartitioner::BlockGemmShape::WarpTile::at(I0);
constexpr auto aqk_per_block =
TilePartitioner::KPerBlock / GemmPipeline::QuantGroupSize;
constexpr auto aqk_per_block = TilePartitioner::KPerBlock / QuantGroupSize::kK;
constexpr auto tile_window_width =
ck_tile::integer_least_multiple(warp_m * aqk_per_block, get_warp_size());
constexpr auto tile_window_height = block_m / warp_m;
@@ -847,11 +848,12 @@ struct QuantGemmKernel
else if constexpr(kQuantType == QuantType::AQuantGrouped && !PreshuffleQuant)
{
static_assert(std::is_same_v<AQLayout, tensor_layout::gemm::RowMajor>);
using QuantGroupSize = remove_cvref_t<typename GemmPipeline::QuantGroupSize>;
constexpr auto block_m = TilePartitioner::MPerBlock;
constexpr auto block_k = TilePartitioner::KPerBlock;
return make_tile_window(
aq_pad_view,
make_tuple(number<block_m>{}, number<block_k / GemmPipeline::QuantGroupSize>{}),
make_tuple(number<block_m>{}, number<block_k / QuantGroupSize::kK>{}),
{i_m, 0});
}
else if constexpr(kQuantType == QuantType::RowColQuant)
@@ -907,11 +909,12 @@ struct QuantGemmKernel
else if constexpr(kQuantType == QuantType::BQuantGrouped)
{
static_assert(std::is_same_v<BQLayout, tensor_layout::gemm::ColumnMajor>);
using QuantGroupSize = remove_cvref_t<typename GemmPipeline::QuantGroupSize>;
return make_tile_window(
bq_pad_view,
make_tuple(number<TilePartitioner::KPerBlock / GemmPipeline::QuantGroupSize>{},
number<TilePartitioner::NPerBlock>{}),
{0, i_n});
make_tuple(number<TilePartitioner::KPerBlock / QuantGroupSize::kK>{},
number<TilePartitioner::NPerBlock / QuantGroupSize::kN>{}),
{0, i_n / QuantGroupSize::kN});
}
else
{