[CK_Tile] Support for various group sizes Preshuffle quant for 2d block scale gemm (#3445)

* formatted

* formatted

* formatting

* formatting

* formatting

* [CK TILE GEMM] Refactor block_scale_gemm examples

- Split cpp file to reduce building time
- Support multiple GemmConfig

* [CK TILE GEMM] Refactor block_scale_gemm examples

- Update Readme

* enable prefill shapes

* [CK TILE GEMM] Refactor block_scale_gemm examples

- Add support for rowcol and tensor GEMM operations

* [CK TILE GEMM] Refactor block_scale_gemm examples

- Update README

* adding preshuffle quant as new parameter and its associated new files

* remove debugging statements

* adding test

* enable preshuffle quant with permuteN

* updating readme and correcponding gemmconfigs

* updating cmake file

* fixing CI failures for grouped quant gemm

* debugging permuteN

* debugging

* debugging PermuteN

* initial commit

* resolving merge conflicts

* adding test cases

* initial commit with prints

* debugging

* fine-grained working

* debugging medium grained

* fixing the tile window

* formatting

* enabling prefill shapes

* working prefill shapes

* formatted

* clean up

* code cleanup

* bug fix after merging with develop

* clean up after merging with develop

* added comments for the tile window and tile distribution encoding

---------

Co-authored-by: Cong Ma <congma13@amd.com>
Co-authored-by: Thomas Ning <Thomas.Ning@amd.com>
Co-authored-by: Agarwal <khuagarw@ctr2-alola-login-03.amd.com>
This commit is contained in:
Khushbu Agarwal
2026-01-06 12:46:59 -08:00
committed by GitHub
parent 76696ace44
commit aaa35f0bbf
8 changed files with 428 additions and 669 deletions

View File

@@ -322,6 +322,7 @@ struct BQuantBlockUniversalGemmAsBsCr
constexpr index_t reg_offset = nIter;
auto pull_from_lane =
(__lane_id() & (WarpGemm::kN - 1)) * Traits::KQPerBlock + kQScale;
auto& scale_reg = bq_block_tensor.get_thread_buffer()[reg_offset];
// cross lane ops
uint32_t scale_reg_dword;

View File

@@ -280,12 +280,13 @@ struct QuantGemmKernel
// Helper: Create Pre-shuffled Quantization Tensor Descriptor
// ===================================================================
template <index_t KPerBlockBQ,
index_t NPerBlockBQ,
index_t NPerBlock,
index_t WarpTileN,
index_t GetVectorSizeBQ,
typename BQDataType_>
CK_TILE_DEVICE static auto
MakePreshuffledQuantTensorView(const BQDataType_* bq_ptr, index_t N, index_t QK_B)
MakePreshuffledQuantTensorView(const BQDataType_* bq_ptr, index_t N, index_t QN_B, index_t QK_B)
{
// Step 1: Calculate base BQ tensor dimensions
// ----------------------------------------------------------
@@ -304,8 +305,9 @@ struct QuantGemmKernel
// ----------------------------------------------------------
// Pad the X dimension to be a multiple of block_tile_size to ensure
// each thread block can process complete tiles without edge cases
const auto block_tile_size = NPerBlock * KPerBlockBQ;
const auto bq_pad0_desc = transform_tensor_descriptor(
const auto block_tile_size = NPerBlockBQ * KPerBlockBQ;
const auto bq_pad0_desc = transform_tensor_descriptor(
bq_desc,
make_tuple(make_pass_through_transform(bq_y),
make_right_pad_transform(bq_x, get_padding_size(bq_x, block_tile_size))),
@@ -318,7 +320,7 @@ struct QuantGemmKernel
// This separates the work into tiles that can be processed by
// individual warps/waves
const auto pad_bq_x = bq_pad0_desc.get_lengths()[I1];
const auto wave_tile_size = WarpTileN * KPerBlockBQ;
const auto wave_tile_size = ((QN_B <= WarpTileN) ? (WarpTileN / QN_B) : 1) * KPerBlockBQ;
const auto wave_tile_count_x = ck_tile::integer_divide_ceil(pad_bq_x, wave_tile_size);
const auto bq_unmerge_pad0_desc = transform_tensor_descriptor(
@@ -813,12 +815,18 @@ struct QuantGemmKernel
static_assert(std::is_same_v<BQLayout, tensor_layout::gemm::ColumnMajor>,
"PreshuffleQuant with BQuantGrouped currently only supports "
"ColumnMajor BQ layout");
using QuantGroupSize = remove_cvref_t<typename GemmPipeline::QuantGroupSize>;
return MakePreshuffledQuantTensorView<
GemmPipeline::KPerBlockBQ,
GemmPipeline::NPerBlockBQ,
GemmPipeline::NPerBlock,
TilePartitioner::BlockGemmShape::WarpTile::at(I1),
GemmPipeline::GetVectorSizeBQ()>(bq_ptr, kargs.N, kargs.QK_B);
GemmPipeline::GetVectorSizeBQ()>(
bq_ptr,
ck_tile::integer_divide_ceil(kargs.N, QuantGroupSize::kN),
QuantGroupSize::kN,
kargs.QK_B);
}
else
{
@@ -879,13 +887,38 @@ struct QuantGemmKernel
if constexpr(PreshuffleQuant)
{
static_assert(std::is_same_v<BQLayout, tensor_layout::gemm::ColumnMajor>);
constexpr auto block_n = TilePartitioner::NPerBlock / QuantGroupSize::kN;
constexpr auto warp_n = TilePartitioner::BlockGemmShape::WarpTile::at(I1);
constexpr auto bqk_per_block = TilePartitioner::KPerBlock / QuantGroupSize::kK;
constexpr auto tile_window_width =
constexpr auto block_n =
TilePartitioner::NPerBlock /
QuantGroupSize::kN; // Number of N-dimension quantization groups per block
constexpr auto warp_n = TilePartitioner::BlockGemmShape::WarpTile::at(
I1); // Number of N-dimension elements per warp
constexpr auto warp_per_group =
(QuantGroupSize::kN <
warp_n) // Determine how many warps share the same scale in N-dimension
? (warp_n / QuantGroupSize::kN)
: (QuantGroupSize::kN / warp_n);
constexpr auto bqk_per_block =
TilePartitioner::KPerBlock /
QuantGroupSize::kK; // Number of K-dimension quantization groups per block
constexpr auto
tile_window_width = // The pre-shuffled layout flattens warp_n ×
// bqk_per_block scales per row, Padded up to warp_size
// to ensure coalesced memory access.
ck_tile::integer_least_multiple(warp_n * bqk_per_block, get_warp_size());
constexpr auto tile_window_height = block_n / warp_n;
auto block_n_idx = i_n / block_n;
// Adapts based on fine vs coarse quantization granularity:
// - Fine-grained (QuantGroupSize::kN < warp_n):
// Multiple quant groups per warp → fewer rows needed per block.
// height = block_n / warp_per_group
//
// - Coarse-grained (QuantGroupSize::kN >= warp_n):
// Each row represents one quant group.
// height = block_n
constexpr auto tile_window_height =
(QuantGroupSize::kN < warp_n) ? block_n / warp_per_group : block_n;
auto block_n_idx =
i_n / TilePartitioner::NPerBlock; // Converts the global N-index (i_n) to a
// block index.
return make_tile_window(
bq_tensor_view,
@@ -1125,596 +1158,6 @@ struct QuantGemmKernel
return true;
}
template <memory_operation_enum DstInMemOp = memory_operation_enum::set>
CK_TILE_DEVICE static auto MakeGemmTensorViews(const ADataType* a_ptr,
const BDataType* b_ptr,
const AQDataType* aq_ptr,
const BQDataType* bq_ptr,
CDataType* c_ptr,
const QuantGemmKernelArgs& kargs,
const SplitKBatchOffset& splitk_batch_offset)
{
static_assert(!GemmPipeline::BlockGemmShape::PermuteA, "Not implemented!");
const auto& a_tensor_view = [&]() {
if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
{
return make_naive_tensor_view<address_space_enum::global>(
a_ptr,
make_tuple(kargs.M, splitk_batch_offset.splitted_k),
make_tuple(kargs.stride_A, 1),
number<GemmPipeline::GetVectorSizeA()>{},
number<1>{});
}
else
{
return make_naive_tensor_view<address_space_enum::global>(
a_ptr,
make_tuple(splitk_batch_offset.splitted_k, kargs.M),
make_tuple(kargs.stride_A, 1),
number<GemmPipeline::GetVectorSizeA()>{},
number<1>{});
}
}();
const auto& aq_tensor_view = [&]() {
if constexpr(kQuantType == QuantType::AQuantGrouped && PreshuffleQuant)
{
static_assert(std::is_same_v<AQLayout, tensor_layout::gemm::RowMajor>);
const auto aq_x = kargs.M * GemmPipeline::KPerBlockAQ;
const auto aq_y = kargs.QK_A / GemmPipeline::KPerBlockAQ;
const auto aq_desc =
make_naive_tensor_descriptor(make_tuple(aq_y, aq_x),
make_tuple(aq_x, 1),
number<GemmPipeline::GetVectorSizeAQ()>{},
number<1>{});
const auto block_tile_size = GemmPipeline::MPerBlock * GemmPipeline::KPerBlockAQ;
const auto aq_pad0_desc = transform_tensor_descriptor(
aq_desc,
make_tuple(
make_pass_through_transform(aq_y),
make_right_pad_transform(aq_x, get_padding_size(aq_x, block_tile_size))),
make_tuple(sequence<0>{}, sequence<1>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
const auto pad_aq_x = aq_pad0_desc.get_lengths()[I1];
const auto wave_tile_size =
GemmPipeline::BlockGemmShape::WarpTile::at(I0) * GemmPipeline::KPerBlockAQ;
const auto wave_tile_count_x =
ck_tile::integer_divide_ceil(pad_aq_x, wave_tile_size);
const auto aq_unmerge_pad0_desc = transform_tensor_descriptor(
aq_pad0_desc,
make_tuple(
make_pass_through_transform(aq_y),
make_unmerge_transform(make_tuple(wave_tile_count_x, wave_tile_size))),
make_tuple(sequence<0>{}, sequence<1>{}),
make_tuple(sequence<0>{}, sequence<1, 2>{}));
const auto aq_pad1_desc = transform_tensor_descriptor(
aq_unmerge_pad0_desc,
make_tuple(
make_pass_through_transform(aq_y),
make_pass_through_transform(wave_tile_count_x),
make_right_pad_transform(
wave_tile_size, get_padding_size(wave_tile_size, get_warp_size()))),
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}),
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}));
const auto pad_wave_size =
ck_tile::integer_least_multiple(wave_tile_size, get_warp_size());
const auto aq_merge_pad1_desc = transform_tensor_descriptor(
aq_pad1_desc,
make_tuple(make_merge_transform(make_tuple(aq_y, wave_tile_count_x)),
make_pass_through_transform(pad_wave_size)),
make_tuple(sequence<0, 1>{}, sequence<2>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
return make_tensor_view<address_space_enum::global>(aq_ptr, aq_merge_pad1_desc);
}
else if constexpr((kQuantType == QuantType::AQuantGrouped ||
kQuantType == QuantType::ABQuantGrouped) &&
!PreshuffleQuant)
{
if constexpr(std::is_same_v<AQLayout, tensor_layout::gemm::RowMajor>)
{
return make_naive_tensor_view<address_space_enum::global>(
aq_ptr,
make_tuple(kargs.M, kargs.QK_A),
make_tuple(kargs.stride_AQ, 1),
number<GemmPipeline::GetVectorSizeAQ()>{},
number<1>{});
}
else // Column major AQ
{
return make_naive_tensor_view<address_space_enum::global>(
aq_ptr,
make_tuple(kargs.QK_A, kargs.M), // Swapped dimensions
make_tuple(kargs.stride_AQ, 1), // Same stride pattern
number<GemmPipeline::GetVectorSizeAQ()>{},
number<1>{});
}
}
else if constexpr(kQuantType == QuantType::RowColQuant)
{
return make_naive_tensor_view<address_space_enum::global>(
aq_ptr,
make_tuple(kargs.M, kargs.N),
make_tuple(1, 0), // broadcasting over n
number<1>{},
number<1>{});
}
else
{
return nullptr; // TODO: use some other "empty" type for this
}
}();
const auto& b_tensor_view = [&]() {
if constexpr(std::is_same_v<BLayout, tensor_layout::gemm::RowMajor>)
{
if constexpr(GemmPipeline::BlockGemmShape::PermuteB)
{
constexpr index_t K1 = GemmPipeline::GetSmemPackB();
const index_t K0 = splitk_batch_offset.splitted_k / K1;
constexpr index_t VectorSizeB = std::min(K1, GemmPipeline::GetVectorSizeB());
const auto b_k0_n_k1_desc =
make_naive_tensor_descriptor(make_tuple(K0, kargs.N, K1),
make_tuple(kargs.N * K1, K1, I1),
number<VectorSizeB>{},
number<1>{});
const auto b_n_k_desc = transform_tensor_descriptor(
b_k0_n_k1_desc,
make_tuple(make_merge_transform(make_tuple(K0, K1)),
make_pass_through_transform(kargs.N)),
make_tuple(sequence<0, 2>{}, sequence<1>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
return make_tensor_view<address_space_enum::global>(b_ptr, b_n_k_desc);
}
else
{
return make_naive_tensor_view<address_space_enum::global>(
b_ptr,
make_tuple(splitk_batch_offset.splitted_k, kargs.N),
make_tuple(kargs.stride_B, 1),
number<GemmPipeline::GetVectorSizeB()>{},
number<1>{});
}
}
else
{
if constexpr(GemmPipeline::BlockGemmShape::PermuteB)
{
constexpr index_t K1 = GemmPipeline::GetSmemPackB();
const index_t K0 = splitk_batch_offset.splitted_k / K1;
constexpr index_t VectorSizeB = std::min(K1, GemmPipeline::GetVectorSizeB());
const auto b_k0_n_k1_desc =
make_naive_tensor_descriptor(make_tuple(K0, kargs.N, K1),
make_tuple(kargs.N * K1, K1, I1),
number<VectorSizeB>{},
number<1>{});
const auto b_n_k_desc = transform_tensor_descriptor(
b_k0_n_k1_desc,
make_tuple(make_merge_transform(make_tuple(K0, K1)),
make_pass_through_transform(kargs.N)),
make_tuple(sequence<0, 2>{}, sequence<1>{}),
make_tuple(sequence<1>{}, sequence<0>{}));
return make_tensor_view<address_space_enum::global>(b_ptr, b_n_k_desc);
}
else
{
if constexpr(PreshuffleB)
{
index_t kFlatK = GemmPipeline::flatKPerWarp *
(splitk_batch_offset.splitted_k /
GemmPipeline::BlockGemmShape::WarpTile::at(number<2>{}));
index_t kFlatN = kargs.N * kargs.K / kFlatK;
return make_naive_tensor_view<address_space_enum::global>(
b_ptr,
make_tuple(kFlatN, kFlatK),
make_tuple(kFlatK, 1),
number<GemmPipeline::GetVectorSizeB()>{},
number<1>{});
}
else
{
if constexpr(std::is_same_v<BDataType, pk_fp4_raw_t>)
return make_naive_tensor_view<address_space_enum::global>(
b_ptr,
make_tuple(kargs.N, splitk_batch_offset.splitted_k / 2),
make_tuple(kargs.stride_B, 1),
number<GemmPipeline::GetVectorSizeB()>{},
number<1>{});
else
return make_naive_tensor_view<address_space_enum::global>(
b_ptr,
make_tuple(kargs.N, splitk_batch_offset.splitted_k),
make_tuple(kargs.stride_B, 1),
number<GemmPipeline::GetVectorSizeB()>{},
number<1>{});
}
}
}
}();
const auto& bq_tensor_view = [&]() {
if constexpr(kQuantType == QuantType::RowColQuant)
{
return make_naive_tensor_view<address_space_enum::global>(
bq_ptr,
make_tuple(kargs.M, kargs.N),
make_tuple(0, 1), // broadcasting over m
number<1>{},
number<1>{});
}
else if constexpr(kQuantType == QuantType::BQuantGrouped)
{
if constexpr(PreshuffleQuant)
{
static_assert(std::is_same_v<BQLayout, tensor_layout::gemm::ColumnMajor>,
"PreshuffleQuant with BQuantGrouped currently only supports "
"ColumnMajor BQ layout");
return MakePreshuffledQuantTensorView<
GemmPipeline::KPerBlockBQ,
GemmPipeline::NPerBlock,
TilePartitioner::BlockGemmShape::WarpTile::at(I1),
GemmPipeline::GetVectorSizeBQ()>(bq_ptr, kargs.N, kargs.QK_B);
}
else
{
using QuantGroupSize = remove_cvref_t<typename GemmPipeline::QuantGroupSize>;
if constexpr(std::is_same_v<BQLayout, tensor_layout::gemm::RowMajor>)
{
// For RowMajor BQ: memory layout is [K/QuantGroupK][N/QuantGroupN]
// Dimensions: [K/QuantGroupK, N/QuantGroupN]
// Strides: [N/QuantGroupN, 1]
return make_naive_tensor_view<address_space_enum::global>(
bq_ptr,
make_tuple(integer_divide_ceil(kargs.K, QuantGroupSize::kK),
integer_divide_ceil(kargs.N, QuantGroupSize::kN)),
make_tuple(integer_divide_ceil(kargs.N, QuantGroupSize::kN), 1),
number<GemmPipeline::GetVectorSizeBQ()>{},
number<1>{});
}
else
{
static_assert(std::is_same_v<BQLayout, tensor_layout::gemm::ColumnMajor>);
// For ColumnMajor BQ: memory layout is [N/QuantGroupN][K/QuantGroupK]
// Dimensions: [N/QuantGroupN, K/QuantGroupK]
// Strides: [K/QuantGroupK, 1]
return make_naive_tensor_view<address_space_enum::global>(
bq_ptr,
make_tuple(integer_divide_ceil(kargs.N, QuantGroupSize::kN),
integer_divide_ceil(kargs.K, QuantGroupSize::kK)),
make_tuple(integer_divide_ceil(kargs.K, QuantGroupSize::kK), 1),
number<GemmPipeline::GetVectorSizeBQ()>{},
number<1>{});
}
}
}
else if constexpr(kQuantType == QuantType::ABQuantGrouped)
{
static_assert(std::is_same_v<BQLayout, tensor_layout::gemm::ColumnMajor>);
using QuantGroupSize = remove_cvref_t<typename GemmPipeline::BQuantGroupSize>;
return make_naive_tensor_view<address_space_enum::global>(
bq_ptr,
make_tuple(integer_divide_ceil(kargs.N, QuantGroupSize::kN), kargs.QK_B),
make_tuple(kargs.stride_BQ, 1),
number<GemmPipeline::GetVectorSizeBQ()>{},
number<1>{});
}
else
{
return nullptr; // TODO: use some other "empty" type for this
}
}();
// TODO: enable vector write for C in ColMajor
const auto& c_tensor_view = [&]() {
if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
{
return make_naive_tensor_view<address_space_enum::global, DstInMemOp>(
c_ptr,
make_tuple(kargs.M, kargs.N),
make_tuple(kargs.stride_C, 1),
number<EpiloguePipeline::GetVectorSizeC()>{},
number<1>{});
}
else
{
return make_naive_tensor_view<address_space_enum::global, DstInMemOp>(
c_ptr,
make_tuple(kargs.M, kargs.N),
make_tuple(1, kargs.stride_C),
number<1>{},
number<1>{});
}
}();
return make_tuple(
a_tensor_view, aq_tensor_view, b_tensor_view, bq_tensor_view, c_tensor_view);
}
template <typename TensorView>
CK_TILE_DEVICE static auto MakeGemmPadViews(const TensorView& views)
{
const auto& a_pad_view = [&]() {
const auto& a_tensor_view = views.at(I0);
if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
{
return pad_tensor_view(a_tensor_view,
make_tuple(number<TilePartitioner::MPerBlock>{},
number<TilePartitioner::KPerBlock>{}),
sequence<false, GemmPipeline::kPadK>{});
}
else
{
return pad_tensor_view(a_tensor_view,
make_tuple(number<TilePartitioner::KPerBlock>{},
number<TilePartitioner::MPerBlock>{}),
sequence<false, GemmPipeline::kPadM>{});
}
}();
// no padding
const auto& aq_pad_view = [&]() { return views.at(I1); }();
const auto& b_flat_view = views.at(I2); // not applying any padding to flat B view
const auto& b_pad_view = [&]() {
const auto& b_tensor_view = views.at(I2);
if constexpr(std::is_same_v<BLayout, tensor_layout::gemm::ColumnMajor>)
{
if constexpr(std::is_same_v<BDataType, pk_fp4_raw_t>)
return pad_tensor_view(b_tensor_view,
make_tuple(number<TilePartitioner::NPerBlock>{},
number<TilePartitioner::KPerBlock / 2>{}),
sequence<false, GemmPipeline::kPadK>{});
else
return pad_tensor_view(b_tensor_view,
make_tuple(number<TilePartitioner::NPerBlock>{},
number<TilePartitioner::KPerBlock>{}),
sequence<false, GemmPipeline::kPadK>{});
}
else
{
return pad_tensor_view(b_tensor_view,
make_tuple(number<TilePartitioner::KPerBlock>{},
number<TilePartitioner::NPerBlock>{}),
sequence<false, GemmPipeline::kPadN>{});
}
}();
// no padding
const auto& bq_pad_view = [&]() { return views.at(I3); }();
// TODO vector write in for C in ColMajor
const auto& c_pad_view = [&]() {
const auto& c_tensor_view = views.at(I4);
if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
{
return pad_tensor_view(c_tensor_view,
make_tuple(number<TilePartitioner::MPerBlock>{},
number<TilePartitioner::NPerBlock>{}),
sequence<false, GemmPipeline::kPadN>{});
}
else
{
return pad_tensor_view(c_tensor_view,
make_tuple(number<TilePartitioner::MPerBlock>{},
number<TilePartitioner::NPerBlock>{}),
sequence<GemmPipeline::kPadM, false>{});
}
}();
if constexpr(PreshuffleB)
{
return make_tuple(a_pad_view, aq_pad_view, b_flat_view, bq_pad_view, c_pad_view);
}
else
{
return make_tuple(a_pad_view, aq_pad_view, b_pad_view, bq_pad_view, c_pad_view);
}
}
template <typename PadView>
CK_TILE_DEVICE static auto
MakeGemmTileWindows(const PadView& views, const index_t i_m, const index_t i_n)
{
const auto& a_pad_view = views.at(I0);
const auto& aq_pad_view = views.at(I1);
const auto& b_pad_view = views.at(I2);
const auto& bq_pad_view = views.at(I3);
const auto& c_pad_view = views.at(I4);
const auto& a_block_window = [&]() {
if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
{
return make_tile_window(a_pad_view,
make_tuple(number<TilePartitioner::MPerBlock>{},
number<TilePartitioner::KPerBlock>{}),
{i_m, 0});
}
else
{
return make_tile_window(a_pad_view,
make_tuple(number<TilePartitioner::KPerBlock>{},
number<TilePartitioner::MPerBlock>{}),
{0, i_m});
}
}();
const auto& aq_block_window = [&]() {
if constexpr(kQuantType == QuantType::AQuantGrouped && PreshuffleQuant)
{
static_assert(std::is_same_v<AQLayout, tensor_layout::gemm::RowMajor>);
using QuantGroupSize = remove_cvref_t<typename GemmPipeline::QuantGroupSize>;
constexpr auto block_m = TilePartitioner::MPerBlock;
constexpr auto warp_m = GemmPipeline::BlockGemmShape::WarpTile::at(I0);
constexpr auto aqk_per_block = TilePartitioner::KPerBlock / QuantGroupSize::kK;
constexpr auto tile_window_width =
ck_tile::integer_least_multiple(warp_m * aqk_per_block, get_warp_size());
constexpr auto tile_window_height = block_m / warp_m;
auto block_m_idx = i_m / block_m;
return make_tile_window(
aq_pad_view,
make_tuple(number<tile_window_height>{}, number<tile_window_width>{}),
{block_m_idx * tile_window_height, 0});
}
else if constexpr(kQuantType == QuantType::AQuantGrouped && !PreshuffleQuant)
{
using QuantGroupSize = remove_cvref_t<typename GemmPipeline::QuantGroupSize>;
constexpr auto aqk_per_block = TilePartitioner::KPerBlock / QuantGroupSize::kK;
constexpr auto block_m = TilePartitioner::MPerBlock;
if constexpr(std::is_same_v<AQLayout, tensor_layout::gemm::RowMajor>)
{
return make_tile_window(aq_pad_view,
make_tuple(number<block_m>{}, number<aqk_per_block>{}),
{i_m, 0});
}
else // Column major AQ
{
return make_tile_window(aq_pad_view,
make_tuple(number<aqk_per_block>{}, number<block_m>{}),
{0, i_m});
}
}
else if constexpr(kQuantType == QuantType::ABQuantGrouped && !PreshuffleQuant)
{
static_assert(std::is_same_v<AQLayout, tensor_layout::gemm::RowMajor>);
using QuantGroupSize = remove_cvref_t<typename GemmPipeline::AQuantGroupSize>;
constexpr auto block_m = TilePartitioner::MPerBlock;
constexpr auto block_k = TilePartitioner::KPerBlock;
return make_tile_window(
aq_pad_view,
make_tuple(number<block_m>{}, number<block_k / QuantGroupSize::kK>{}),
{i_m, 0});
}
else if constexpr(kQuantType == QuantType::RowColQuant)
{
return make_tile_window(aq_pad_view,
make_tuple(number<TilePartitioner::MPerBlock>{},
number<TilePartitioner::NPerBlock>{}),
{i_m, i_n});
}
else
{
return nullptr; // TODO: use some other "empty" type?
}
}();
const auto& b_block_window = [&]() {
if constexpr(PreshuffleB)
{
return make_tile_window(
b_pad_view,
make_tuple(number<GemmPipeline::flatNPerWarp>{},
number<GemmPipeline::flatKPerWarp>{}),
{static_cast<int>(i_n / GemmPipeline::BlockGemmShape::WarpTile::at(I1)), 0});
}
else
{
if constexpr(std::is_same_v<BLayout, tensor_layout::gemm::ColumnMajor>)
{
if constexpr(std::is_same_v<BDataType, pk_fp4_raw_t>)
return make_tile_window(
b_pad_view,
make_tuple(number<TilePartitioner::NPerBlock>{},
number<TilePartitioner::KPerBlock / 2>{}),
{i_n, 0});
else
return make_tile_window(b_pad_view,
make_tuple(number<TilePartitioner::NPerBlock>{},
number<TilePartitioner::KPerBlock>{}),
{i_n, 0});
}
else
{
return make_tile_window(b_pad_view,
make_tuple(number<TilePartitioner::KPerBlock>{},
number<TilePartitioner::NPerBlock>{}),
{0, i_n});
}
}
}();
const auto& bq_block_window = [&]() {
if constexpr(kQuantType == QuantType::RowColQuant)
{
return make_tile_window(bq_pad_view,
make_tuple(number<TilePartitioner::MPerBlock>{},
number<TilePartitioner::NPerBlock>{}),
{i_m, i_n});
}
else if constexpr(kQuantType == QuantType::BQuantGrouped)
{
using QuantGroupSize = remove_cvref_t<typename GemmPipeline::QuantGroupSize>;
if constexpr(PreshuffleQuant)
{
static_assert(std::is_same_v<BQLayout, tensor_layout::gemm::ColumnMajor>);
constexpr auto block_n = TilePartitioner::NPerBlock / QuantGroupSize::kN;
constexpr auto warp_n = TilePartitioner::BlockGemmShape::WarpTile::at(I1);
constexpr auto bqk_per_block = TilePartitioner::KPerBlock / QuantGroupSize::kK;
constexpr auto tile_window_width =
ck_tile::integer_least_multiple(warp_n * bqk_per_block, get_warp_size());
constexpr auto tile_window_height = block_n / warp_n;
auto block_n_idx = i_n / block_n;
return make_tile_window(
bq_pad_view,
make_tuple(number<tile_window_height>{}, number<tile_window_width>{}),
{block_n_idx * tile_window_height, 0});
}
else
{
if constexpr(std::is_same_v<BQLayout, tensor_layout::gemm::RowMajor>)
{
return make_tile_window(
bq_pad_view,
make_tuple(number<TilePartitioner::KPerBlock / QuantGroupSize::kK>{},
number<TilePartitioner::NPerBlock / QuantGroupSize::kN>{}),
{0, i_n / QuantGroupSize::kN});
}
else
{
static_assert(std::is_same_v<BQLayout, tensor_layout::gemm::ColumnMajor>);
return make_tile_window(
bq_pad_view,
make_tuple(number<TilePartitioner::NPerBlock / QuantGroupSize::kN>{},
number<TilePartitioner::KPerBlock / QuantGroupSize::kK>{}),
{i_n / QuantGroupSize::kN, 0});
}
}
}
else if constexpr(kQuantType == QuantType::ABQuantGrouped)
{
static_assert(std::is_same_v<BQLayout, tensor_layout::gemm::ColumnMajor>);
using QuantGroupSize = remove_cvref_t<typename GemmPipeline::BQuantGroupSize>;
return make_tile_window(
bq_pad_view,
make_tuple(number<TilePartitioner::NPerBlock / QuantGroupSize::kN>{},
number<TilePartitioner::KPerBlock / QuantGroupSize::kK>{}),
{i_n / QuantGroupSize::kN, 0});
}
else
{
return nullptr; // TODO: use some other "empty" type here
}
}();
auto c_block_window = make_tile_window(
c_pad_view,
make_tuple(number<TilePartitioner::MPerBlock>{}, number<TilePartitioner::NPerBlock>{}),
{i_m, i_n});
return make_tuple(
a_block_window, aq_block_window, b_block_window, bq_block_window, c_block_window);
}
/**
* @brief Runs single GEMM problem cooperatively by whole workgroup.
*

View File

@@ -48,7 +48,6 @@ struct GemmBQuantPipelineAgBgCrDefaultPolicy : public UniversalGemmPipelineAgBgC
constexpr index_t NPerBlockBQ = NPerBlock / Problem::BQuantGroupSize::kN;
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
constexpr index_t KPerBlockBQ = KPerBlock / Problem::BQuantGroupSize::kK;
constexpr index_t VecLoadSize = GetVectorSizeBQ<Problem>();
constexpr bool PreshuffleQuant = Problem::Traits::PreshuffleQuant;
using WarpTile = typename Problem::BlockGemmShape::WarpTile;
@@ -68,7 +67,8 @@ struct GemmBQuantPipelineAgBgCrDefaultPolicy : public UniversalGemmPipelineAgBgC
BlockSize,
NPerBlock / WarpGemm::kN,
ck_tile::integer_least_multiple(WarpGemm::kN * KPerBlockBQ, get_warp_size()),
VecLoadSize,
Problem::BQuantGroupSize::kN,
Problem::BQuantGroupSize::kK,
BQLayout,
PreshuffleQuant>;
return TileEncodingPattern::make_2d_static_tile_distribution();
@@ -83,6 +83,7 @@ struct GemmBQuantPipelineAgBgCrDefaultPolicy : public UniversalGemmPipelineAgBgC
KPerBlockBQ, // Logical K dimension
NPerBlockBQ, // Logical N dimension
Problem::BQuantGroupSize::kN,
Problem::BQuantGroupSize::kK,
BQLayout>;
return TileEncodingPattern::make_2d_static_tile_distribution();

View File

@@ -65,8 +65,10 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Prob
static constexpr index_t NPerBlock = BlockGemmShape::kN;
static constexpr index_t KPerBlock = BlockGemmShape::kK;
static constexpr index_t NPerBlockBQ = BlockGemmShape::kN / QuantGroupSize::kN;
static constexpr index_t KPerBlockBQ = BlockGemmShape::kK / QuantGroupSize::kK;
static constexpr index_t NPerBlockBQ =
integer_divide_ceil(BlockGemmShape::kN, QuantGroupSize::kN);
static constexpr index_t KPerBlockBQ =
integer_divide_ceil(BlockGemmShape::kK, QuantGroupSize::kK);
static constexpr index_t GetVectorSizeA() { return Policy::template GetVectorSizeA<Problem>(); }
static constexpr index_t GetVectorSizeB() { return Policy::template GetVectorSizeB<Problem>(); }
@@ -300,9 +302,12 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Prob
constexpr BDramTileWindowStep b_dram_tile_window_step =
is_b_row_major ? make_array(KPerBlock, 0) : make_array(0, KPerBlock);
const BQDramTileWindowStep bq_dram_tile_window_step =
(PreshuffleQuant) ? make_array(ck_tile::integer_least_multiple(n, NPerBlock) /
BlockGemmShape::WarpTile::at(number<1>{}),
0)
(PreshuffleQuant)
? make_array(((NPerBlockBQ <= BlockGemmShape::BlockWarps::at(number<1>{}))
? ck_tile::integer_divide_ceil(n, QuantGroupSize::kN)
: ck_tile::integer_least_multiple(n, NPerBlock) /
BlockGemmShape::WarpTile::at(number<1>{})),
0)
: is_bq_row_major ? make_array(KPerBlockBQ, 0)
: make_array(0, KPerBlockBQ);

View File

@@ -192,6 +192,7 @@ template <typename BlockGemmShape,
index_t KPerTile,
index_t NPerTile,
index_t NPerQ,
index_t KPerQ,
typename BQLayout = tensor_layout::gemm::ColumnMajor,
bool PreshuffleQuant = false>
struct tile_distribution_encoding_pattern_bq : public tile_distribution_encoding_pattern
@@ -208,31 +209,6 @@ struct tile_distribution_encoding_pattern_bq : public tile_distribution_encoding
static_assert(num_warps == MWarps * NWarps * KWarps);
static_assert(KWarps == 1);
/// @brief Creates a 2D tile distribution for BQ (B-matrix quantization scales)
///
/// This function determines the optimal thread distribution pattern for loading and applying
/// quantization scales to the B matrix based on the quantization group size (NPerQ) relative
/// to warp dimensions.
///
/// Three distinct distribution patterns are handled:
///
/// 1. Fine-grained quantization (NPerQ < WarpGemm::kN):
/// - Multiple quantization groups exist within a single warp's N-dimension
/// - Each warp processes multiple scales (WarpGemm::kN / NPerQ scales per warp)
/// - Distribution includes explicit replication factor (XR = NPerQ) for scale broadcast
/// - Example: NPerQ=8, WarpGemm::kN=16, NWarps=4 → 2 scales per warp
///
/// 2. Medium-grained quantization (WarpGemm::kN <= NPerQ <= WarpGemm::kN * NWarps):
/// - Each warp handles exactly one quantization scale
/// - Scales are distributed across warps with replication factor XR = NPerQ / WarpGemm::kN
/// - Example: NPerQ=64, WarpGemm::kN=16, NWarps=4 → 1 scale per warp, XR=4
///
/// 3. Coarse-grained quantization (NPerQ > WarpGemm::kN * NWarps):
/// - Quantization group spans multiple warps
/// - All warps share the same scale value
/// - Example: NPerQ=128, WarpGemm::kN=16, NWarps=4 → all warps use same scale
///
/// @return A static tile distribution encoding for the BQ scale tensor
CK_TILE_HOST_DEVICE static constexpr auto make_2d_static_tile_distribution()
{
// Preshuffle only supported for ColumnMajor currently
@@ -241,22 +217,136 @@ struct tile_distribution_encoding_pattern_bq : public tile_distribution_encoding
if constexpr(PreshuffleQuant)
{
// ColumnMajor only for preshuffle
constexpr index_t X1 = warp_size;
constexpr index_t X0 = NPerTile / warp_size;
constexpr index_t Y1 = NWarps;
constexpr index_t Y0 = KPerTile / Y1;
// =============================================================================
// PRE-SHUFFLED BQ SCALE TILE DISTRIBUTION
// =============================================================================
// For pre-shuffled quantization, the BQ scale tensor has been reorganized
// (pre-shuffled) to optimize memory access patterns during dequantization.
//
// Tile Dimensions:
// - K-axis (Y in encoding): Corresponds to the K-dimension iteration
// - N-axis (X in encoding): Flattened scale index combining N and K groups
//
// The encoding distributes work across threads such that each thread loads
// the correct pre-shuffled scale for its corresponding B-matrix elements.
// =============================================================================
if constexpr(NPerQ <= WarpGemm::kN)
{
// =========================================================================
// CASE 1: Fine-grained Quantization (NPerQ <= WarpGemm::kN)
// =========================================================================
// Multiple quantization scales exist within a single warp's N-dimension.
// Each warp processes multiple scales: WarpGemm::kN / NPerQ scales per warp.
//
// Example: NPerQ=8, WarpGemm::kN=16, KPerQ=128, BlockGemmShape::kK=256
// → 2 scales per warp in N, 2 K-groups per block
constexpr auto N1 = BlockGemmShape::kK /
KPerQ; // Number of K-dimension quantization groups per block,
// Each K-group of KPerQ elements shares the same scale.
constexpr auto N0 =
WarpGemm::kN / NPerQ; // Number of scales per warp in N-dimension, Since NPerQ
// <= WarpGemm::kN, each warp handles multiple scales.
constexpr auto N2 = 1; // Elements per thread
constexpr auto NR1 = NPerQ; // Elements sharing the same scale in N-dimension
constexpr auto NR0 =
warp_size /
(N0 * N1 * N2 * NR1); // Interleave factor to ensure full warp utilization
constexpr auto K1 = NWarps; // Number of warps distributed along this dimension
constexpr auto K0 = KPerTile / K1; // Iterations per warp to cover the K-tile
constexpr auto KR = 1; // No replication in K-dimension
return make_static_tile_distribution(
tile_distribution_encoding<sequence<MWarps>,
tuple<sequence<Y0, Y1>, sequence<X0, X1>>,
tuple<sequence<0, 1>, sequence<2>>,
tuple<sequence<0, 1>, sequence<1>>,
sequence<1, 2>,
sequence<0, 0>>{});
return make_static_tile_distribution(
tile_distribution_encoding<sequence<MWarps, NR0, NR1, KR>,
tuple<sequence<K0, K1>, sequence<N0, N1, N2>>,
tuple<sequence<0, 1>, sequence<0, 2, 0, 2, 0>>,
tuple<sequence<0, 1>, sequence<1, 0, 2, 1, 3>>,
sequence<1, 2>,
sequence<0, 2>>{});
}
else if constexpr(NPerQ < WarpGemm::kN * NWarps)
{
// =========================================================================
// CASE 2: Medium-grained Quantization (WarpGemm::kN < NPerQ < WarpGemm::kN *
// NWarps)
// =========================================================================
// Each warp handles exactly one quantization scale in N-dimension.
// Some warps share the same scale (KR > 1 creates warp grouping).
//
// Example: NPerQ=32, WarpGemm::kN=16, NWarps=4
// → KR=2 (2 warps share same scale), K1=2 (2 unique scale groups)
constexpr auto KR = NPerQ / WarpGemm::kN; // Number of warps sharing the same scale
constexpr auto K1 = NWarps / KR; // Number of distinct warp groups (unique scales)
constexpr auto K0 = KPerTile / K1; // Iterations to cover K-tile per warp group
constexpr auto N1 = BlockGemmShape::kK / KPerQ; // K-dimension quantization groups
constexpr auto N0 = 1; // Scales per warp in N-dim (1 since NPerQ >= WarpGemm::kN)
constexpr auto N2 = 1; // Elements per thread
constexpr auto NR1 = NPerQ; // Scale broadcast factor (full NPerQ)
constexpr auto NR0 =
warp_size / (N0 * N1 * N2 * NR1); // Remaining interleave factor
return make_static_tile_distribution(
tile_distribution_encoding<sequence<MWarps, NR0, NR1, KR>,
tuple<sequence<K0, K1>, sequence<N0, N1, N2>>,
tuple<sequence<0, 1, 0>, sequence<0, 2, 0, 2>>,
tuple<sequence<0, 1, 3>, sequence<1, 0, 2, 1>>,
sequence<1, 2>,
sequence<0, 2>>{});
}
else
{
// =========================================================================
// CASE 3: Coarse-grained Quantization (NPerQ >= WarpGemm::kN * NWarps)
// =========================================================================
// The quantization group spans ALL warps in N-dimension.
// All warps share the same scale value for their N-tiles.
//
// Example: NPerQ=128, WarpGemm::kN=16, NWarps=4
// → 128 >= 16*4=64, so all 4 warps use the same scale
constexpr auto N1 = BlockGemmShape::kK / KPerQ; // K-dimension quantization groups
constexpr auto N0 = 1; // Minimal (1) since scale is shared across N
constexpr auto N2 = 1; // Elements per thread
constexpr auto NR1 = 32; // Fixed broadcast size
constexpr auto NR0 =
warp_size / (N0 * N1 * N2 * NR1); // Remaining interleave factor
return make_static_tile_distribution(
tile_distribution_encoding<sequence<MWarps, NWarps, NR0, NR1>,
tuple<sequence<KPerTile>, sequence<N0, N1, N2>>,
tuple<sequence<0, 0>, sequence<0, 2, 0, 2>>,
tuple<sequence<0, 1>, sequence<2, 0, 3, 1>>,
sequence<1, 2>,
sequence<0, 2>>{});
}
}
else
{
/// @brief Creates a 2D tile distribution for BQ (B-matrix quantization scales)
///
/// This function determines the optimal thread distribution pattern for loading and
/// applying quantization scales to the B matrix based on the quantization group size
/// (NPerQ) relative to warp dimensions.
///
/// Three distinct distribution patterns are handled:
///
/// 1. Fine-grained quantization (NPerQ < WarpGemm::kN):
/// - Multiple quantization groups exist within a single warp's N-dimension
/// - Each warp processes multiple scales (WarpGemm::kN / NPerQ scales per warp)
/// - Distribution includes explicit replication factor (XR = NPerQ) for scale
/// broadcast
/// - Example: NPerQ=8, WarpGemm::kN=16, NWarps=4 → 2 scales per warp
///
/// 2. Medium-grained quantization (WarpGemm::kN <= NPerQ <= WarpGemm::kN * NWarps):
/// - Each warp handles exactly one quantization scale
/// - Scales are distributed across warps with replication factor XR = NPerQ /
/// WarpGemm::kN
/// - Example: NPerQ=64, WarpGemm::kN=16, NWarps=4 → 1 scale per warp, XR=4
///
/// 3. Coarse-grained quantization (NPerQ > WarpGemm::kN * NWarps):
/// - Quantization group spans multiple warps
/// - All warps share the same scale value
/// - Example: NPerQ=128, WarpGemm::kN=16, NWarps=4 → all warps use same scale
///
/// @return A static tile distribution encoding for the BQ scale tensor
if constexpr(NPerQ < WarpGemm::kN)
{
// Case 1: Fine-grained - multiple quantization scales within a single warp

View File

@@ -71,6 +71,8 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV
static constexpr bool PreshuffleQuant = Problem::Traits::PreshuffleQuant;
static constexpr index_t VectorLoadSize = Problem::VectorLoadSize;
static constexpr index_t NPerBlockBQ =
integer_divide_ceil(BlockGemmShape::kN, QuantGroupSize::kN);
static constexpr index_t KPerBlockBQ =
integer_divide_ceil(BlockGemmShape::kK, QuantGroupSize::kK);
static constexpr index_t QScalesPerBlockRow =
@@ -352,8 +354,10 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV
if constexpr(PreshuffleQuant)
{
move_tile_window(bq_copy_dram_window,
{ck_tile::integer_least_multiple(n, kNPerBlock) /
BlockGemmShape::WarpTile::at(number<1>{}),
{((NPerBlockBQ < BlockGemmShape::BlockWarps::at(number<1>{}))
? ck_tile::integer_divide_ceil(n, QuantGroupSize::kN)
: ck_tile::integer_least_multiple(n, kNPerBlock) /
BlockGemmShape::WarpTile::at(number<1>{})),
0});
}
else
@@ -427,8 +431,10 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV
if constexpr(PreshuffleQuant)
{
move_tile_window(bq_copy_dram_window,
{ck_tile::integer_least_multiple(n, kNPerBlock) /
BlockGemmShape::WarpTile::at(number<1>{}),
{((NPerBlockBQ < BlockGemmShape::BlockWarps::at(number<1>{}))
? ck_tile::integer_divide_ceil(n, QuantGroupSize::kN)
: ck_tile::integer_least_multiple(n, kNPerBlock) /
BlockGemmShape::WarpTile::at(number<1>{})),
0});
}
else
@@ -462,8 +468,10 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV
if constexpr(PreshuffleQuant)
{
move_tile_window(bq_copy_dram_window,
{ck_tile::integer_least_multiple(n, kNPerBlock) /
BlockGemmShape::WarpTile::at(number<1>{}),
{((NPerBlockBQ < BlockGemmShape::BlockWarps::at(number<1>{}))
? ck_tile::integer_divide_ceil(n, QuantGroupSize::kN)
: ck_tile::integer_least_multiple(n, kNPerBlock) /
BlockGemmShape::WarpTile::at(number<1>{})),
0});
}
else