[CK_Tile] Support for various group sizes Preshuffle quant for 2d block scale gemm (#3445)

* formatted

* formatted

* formatting

* formatting

* formatting

* [CK TILE GEMM] Refactor block_scale_gemm examples

- Split cpp file to reduce building time
- Support multiple GemmConfig

* [CK TILE GEMM] Refactor block_scale_gemm examples

- Update Readme

* enable prefill shapes

* [CK TILE GEMM] Refactor block_scale_gemm examples

- Add support for rowcol and tensor GEMM operations

* [CK TILE GEMM] Refactor block_scale_gemm examples

- Update README

* adding preshuffle quant as new parameter and its associated new files

* remove debugging statements

* adding test

* enable preshuffle quant with permuteN

* updating readme and correcponding gemmconfigs

* updating cmake file

* fixing CI failures for grouped quant gemm

* debugging permuteN

* debugging

* debugging PermuteN

* initial commit

* resolving merge conflicts

* adding test cases

* initial commit with prints

* debugging

* fine-grained working

* debugging medium grained

* fixing the tile window

* formatting

* enabling prefill shapes

* working prefill shapes

* formatted

* clean up

* code cleanup

* bug fix after merging with develop

* clean up after merging with develop

* added comments for the tile window and tile distribution encoding

---------

Co-authored-by: Cong Ma <congma13@amd.com>
Co-authored-by: Thomas Ning <Thomas.Ning@amd.com>
Co-authored-by: Agarwal <khuagarw@ctr2-alola-login-03.amd.com>
This commit is contained in:
Khushbu Agarwal
2026-01-06 12:46:59 -08:00
committed by GitHub
parent 76696ace44
commit aaa35f0bbf
8 changed files with 428 additions and 669 deletions

View File

@@ -280,12 +280,13 @@ struct QuantGemmKernel
// Helper: Create Pre-shuffled Quantization Tensor Descriptor
// ===================================================================
template <index_t KPerBlockBQ,
index_t NPerBlockBQ,
index_t NPerBlock,
index_t WarpTileN,
index_t GetVectorSizeBQ,
typename BQDataType_>
CK_TILE_DEVICE static auto
MakePreshuffledQuantTensorView(const BQDataType_* bq_ptr, index_t N, index_t QK_B)
MakePreshuffledQuantTensorView(const BQDataType_* bq_ptr, index_t N, index_t QN_B, index_t QK_B)
{
// Step 1: Calculate base BQ tensor dimensions
// ----------------------------------------------------------
@@ -304,8 +305,9 @@ struct QuantGemmKernel
// ----------------------------------------------------------
// Pad the X dimension to be a multiple of block_tile_size to ensure
// each thread block can process complete tiles without edge cases
const auto block_tile_size = NPerBlock * KPerBlockBQ;
const auto bq_pad0_desc = transform_tensor_descriptor(
const auto block_tile_size = NPerBlockBQ * KPerBlockBQ;
const auto bq_pad0_desc = transform_tensor_descriptor(
bq_desc,
make_tuple(make_pass_through_transform(bq_y),
make_right_pad_transform(bq_x, get_padding_size(bq_x, block_tile_size))),
@@ -318,7 +320,7 @@ struct QuantGemmKernel
// This separates the work into tiles that can be processed by
// individual warps/waves
const auto pad_bq_x = bq_pad0_desc.get_lengths()[I1];
const auto wave_tile_size = WarpTileN * KPerBlockBQ;
const auto wave_tile_size = ((QN_B <= WarpTileN) ? (WarpTileN / QN_B) : 1) * KPerBlockBQ;
const auto wave_tile_count_x = ck_tile::integer_divide_ceil(pad_bq_x, wave_tile_size);
const auto bq_unmerge_pad0_desc = transform_tensor_descriptor(
@@ -813,12 +815,18 @@ struct QuantGemmKernel
static_assert(std::is_same_v<BQLayout, tensor_layout::gemm::ColumnMajor>,
"PreshuffleQuant with BQuantGrouped currently only supports "
"ColumnMajor BQ layout");
using QuantGroupSize = remove_cvref_t<typename GemmPipeline::QuantGroupSize>;
return MakePreshuffledQuantTensorView<
GemmPipeline::KPerBlockBQ,
GemmPipeline::NPerBlockBQ,
GemmPipeline::NPerBlock,
TilePartitioner::BlockGemmShape::WarpTile::at(I1),
GemmPipeline::GetVectorSizeBQ()>(bq_ptr, kargs.N, kargs.QK_B);
GemmPipeline::GetVectorSizeBQ()>(
bq_ptr,
ck_tile::integer_divide_ceil(kargs.N, QuantGroupSize::kN),
QuantGroupSize::kN,
kargs.QK_B);
}
else
{
@@ -879,13 +887,38 @@ struct QuantGemmKernel
if constexpr(PreshuffleQuant)
{
static_assert(std::is_same_v<BQLayout, tensor_layout::gemm::ColumnMajor>);
constexpr auto block_n = TilePartitioner::NPerBlock / QuantGroupSize::kN;
constexpr auto warp_n = TilePartitioner::BlockGemmShape::WarpTile::at(I1);
constexpr auto bqk_per_block = TilePartitioner::KPerBlock / QuantGroupSize::kK;
constexpr auto tile_window_width =
constexpr auto block_n =
TilePartitioner::NPerBlock /
QuantGroupSize::kN; // Number of N-dimension quantization groups per block
constexpr auto warp_n = TilePartitioner::BlockGemmShape::WarpTile::at(
I1); // Number of N-dimension elements per warp
constexpr auto warp_per_group =
(QuantGroupSize::kN <
warp_n) // Determine how many warps share the same scale in N-dimension
? (warp_n / QuantGroupSize::kN)
: (QuantGroupSize::kN / warp_n);
constexpr auto bqk_per_block =
TilePartitioner::KPerBlock /
QuantGroupSize::kK; // Number of K-dimension quantization groups per block
constexpr auto
tile_window_width = // The pre-shuffled layout flattens warp_n ×
// bqk_per_block scales per row, Padded up to warp_size
// to ensure coalesced memory access.
ck_tile::integer_least_multiple(warp_n * bqk_per_block, get_warp_size());
constexpr auto tile_window_height = block_n / warp_n;
auto block_n_idx = i_n / block_n;
// Adapts based on fine vs coarse quantization granularity:
// - Fine-grained (QuantGroupSize::kN < warp_n):
// Multiple quant groups per warp → fewer rows needed per block.
// height = block_n / warp_per_group
//
// - Coarse-grained (QuantGroupSize::kN >= warp_n):
// Each row represents one quant group.
// height = block_n
constexpr auto tile_window_height =
(QuantGroupSize::kN < warp_n) ? block_n / warp_per_group : block_n;
auto block_n_idx =
i_n / TilePartitioner::NPerBlock; // Converts the global N-index (i_n) to a
// block index.
return make_tile_window(
bq_tensor_view,
@@ -1125,596 +1158,6 @@ struct QuantGemmKernel
return true;
}
template <memory_operation_enum DstInMemOp = memory_operation_enum::set>
CK_TILE_DEVICE static auto MakeGemmTensorViews(const ADataType* a_ptr,
const BDataType* b_ptr,
const AQDataType* aq_ptr,
const BQDataType* bq_ptr,
CDataType* c_ptr,
const QuantGemmKernelArgs& kargs,
const SplitKBatchOffset& splitk_batch_offset)
{
static_assert(!GemmPipeline::BlockGemmShape::PermuteA, "Not implemented!");
const auto& a_tensor_view = [&]() {
if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
{
return make_naive_tensor_view<address_space_enum::global>(
a_ptr,
make_tuple(kargs.M, splitk_batch_offset.splitted_k),
make_tuple(kargs.stride_A, 1),
number<GemmPipeline::GetVectorSizeA()>{},
number<1>{});
}
else
{
return make_naive_tensor_view<address_space_enum::global>(
a_ptr,
make_tuple(splitk_batch_offset.splitted_k, kargs.M),
make_tuple(kargs.stride_A, 1),
number<GemmPipeline::GetVectorSizeA()>{},
number<1>{});
}
}();
const auto& aq_tensor_view = [&]() {
if constexpr(kQuantType == QuantType::AQuantGrouped && PreshuffleQuant)
{
static_assert(std::is_same_v<AQLayout, tensor_layout::gemm::RowMajor>);
const auto aq_x = kargs.M * GemmPipeline::KPerBlockAQ;
const auto aq_y = kargs.QK_A / GemmPipeline::KPerBlockAQ;
const auto aq_desc =
make_naive_tensor_descriptor(make_tuple(aq_y, aq_x),
make_tuple(aq_x, 1),
number<GemmPipeline::GetVectorSizeAQ()>{},
number<1>{});
const auto block_tile_size = GemmPipeline::MPerBlock * GemmPipeline::KPerBlockAQ;
const auto aq_pad0_desc = transform_tensor_descriptor(
aq_desc,
make_tuple(
make_pass_through_transform(aq_y),
make_right_pad_transform(aq_x, get_padding_size(aq_x, block_tile_size))),
make_tuple(sequence<0>{}, sequence<1>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
const auto pad_aq_x = aq_pad0_desc.get_lengths()[I1];
const auto wave_tile_size =
GemmPipeline::BlockGemmShape::WarpTile::at(I0) * GemmPipeline::KPerBlockAQ;
const auto wave_tile_count_x =
ck_tile::integer_divide_ceil(pad_aq_x, wave_tile_size);
const auto aq_unmerge_pad0_desc = transform_tensor_descriptor(
aq_pad0_desc,
make_tuple(
make_pass_through_transform(aq_y),
make_unmerge_transform(make_tuple(wave_tile_count_x, wave_tile_size))),
make_tuple(sequence<0>{}, sequence<1>{}),
make_tuple(sequence<0>{}, sequence<1, 2>{}));
const auto aq_pad1_desc = transform_tensor_descriptor(
aq_unmerge_pad0_desc,
make_tuple(
make_pass_through_transform(aq_y),
make_pass_through_transform(wave_tile_count_x),
make_right_pad_transform(
wave_tile_size, get_padding_size(wave_tile_size, get_warp_size()))),
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}),
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}));
const auto pad_wave_size =
ck_tile::integer_least_multiple(wave_tile_size, get_warp_size());
const auto aq_merge_pad1_desc = transform_tensor_descriptor(
aq_pad1_desc,
make_tuple(make_merge_transform(make_tuple(aq_y, wave_tile_count_x)),
make_pass_through_transform(pad_wave_size)),
make_tuple(sequence<0, 1>{}, sequence<2>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
return make_tensor_view<address_space_enum::global>(aq_ptr, aq_merge_pad1_desc);
}
else if constexpr((kQuantType == QuantType::AQuantGrouped ||
kQuantType == QuantType::ABQuantGrouped) &&
!PreshuffleQuant)
{
if constexpr(std::is_same_v<AQLayout, tensor_layout::gemm::RowMajor>)
{
return make_naive_tensor_view<address_space_enum::global>(
aq_ptr,
make_tuple(kargs.M, kargs.QK_A),
make_tuple(kargs.stride_AQ, 1),
number<GemmPipeline::GetVectorSizeAQ()>{},
number<1>{});
}
else // Column major AQ
{
return make_naive_tensor_view<address_space_enum::global>(
aq_ptr,
make_tuple(kargs.QK_A, kargs.M), // Swapped dimensions
make_tuple(kargs.stride_AQ, 1), // Same stride pattern
number<GemmPipeline::GetVectorSizeAQ()>{},
number<1>{});
}
}
else if constexpr(kQuantType == QuantType::RowColQuant)
{
return make_naive_tensor_view<address_space_enum::global>(
aq_ptr,
make_tuple(kargs.M, kargs.N),
make_tuple(1, 0), // broadcasting over n
number<1>{},
number<1>{});
}
else
{
return nullptr; // TODO: use some other "empty" type for this
}
}();
const auto& b_tensor_view = [&]() {
if constexpr(std::is_same_v<BLayout, tensor_layout::gemm::RowMajor>)
{
if constexpr(GemmPipeline::BlockGemmShape::PermuteB)
{
constexpr index_t K1 = GemmPipeline::GetSmemPackB();
const index_t K0 = splitk_batch_offset.splitted_k / K1;
constexpr index_t VectorSizeB = std::min(K1, GemmPipeline::GetVectorSizeB());
const auto b_k0_n_k1_desc =
make_naive_tensor_descriptor(make_tuple(K0, kargs.N, K1),
make_tuple(kargs.N * K1, K1, I1),
number<VectorSizeB>{},
number<1>{});
const auto b_n_k_desc = transform_tensor_descriptor(
b_k0_n_k1_desc,
make_tuple(make_merge_transform(make_tuple(K0, K1)),
make_pass_through_transform(kargs.N)),
make_tuple(sequence<0, 2>{}, sequence<1>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
return make_tensor_view<address_space_enum::global>(b_ptr, b_n_k_desc);
}
else
{
return make_naive_tensor_view<address_space_enum::global>(
b_ptr,
make_tuple(splitk_batch_offset.splitted_k, kargs.N),
make_tuple(kargs.stride_B, 1),
number<GemmPipeline::GetVectorSizeB()>{},
number<1>{});
}
}
else
{
if constexpr(GemmPipeline::BlockGemmShape::PermuteB)
{
constexpr index_t K1 = GemmPipeline::GetSmemPackB();
const index_t K0 = splitk_batch_offset.splitted_k / K1;
constexpr index_t VectorSizeB = std::min(K1, GemmPipeline::GetVectorSizeB());
const auto b_k0_n_k1_desc =
make_naive_tensor_descriptor(make_tuple(K0, kargs.N, K1),
make_tuple(kargs.N * K1, K1, I1),
number<VectorSizeB>{},
number<1>{});
const auto b_n_k_desc = transform_tensor_descriptor(
b_k0_n_k1_desc,
make_tuple(make_merge_transform(make_tuple(K0, K1)),
make_pass_through_transform(kargs.N)),
make_tuple(sequence<0, 2>{}, sequence<1>{}),
make_tuple(sequence<1>{}, sequence<0>{}));
return make_tensor_view<address_space_enum::global>(b_ptr, b_n_k_desc);
}
else
{
if constexpr(PreshuffleB)
{
index_t kFlatK = GemmPipeline::flatKPerWarp *
(splitk_batch_offset.splitted_k /
GemmPipeline::BlockGemmShape::WarpTile::at(number<2>{}));
index_t kFlatN = kargs.N * kargs.K / kFlatK;
return make_naive_tensor_view<address_space_enum::global>(
b_ptr,
make_tuple(kFlatN, kFlatK),
make_tuple(kFlatK, 1),
number<GemmPipeline::GetVectorSizeB()>{},
number<1>{});
}
else
{
if constexpr(std::is_same_v<BDataType, pk_fp4_raw_t>)
return make_naive_tensor_view<address_space_enum::global>(
b_ptr,
make_tuple(kargs.N, splitk_batch_offset.splitted_k / 2),
make_tuple(kargs.stride_B, 1),
number<GemmPipeline::GetVectorSizeB()>{},
number<1>{});
else
return make_naive_tensor_view<address_space_enum::global>(
b_ptr,
make_tuple(kargs.N, splitk_batch_offset.splitted_k),
make_tuple(kargs.stride_B, 1),
number<GemmPipeline::GetVectorSizeB()>{},
number<1>{});
}
}
}
}();
const auto& bq_tensor_view = [&]() {
if constexpr(kQuantType == QuantType::RowColQuant)
{
return make_naive_tensor_view<address_space_enum::global>(
bq_ptr,
make_tuple(kargs.M, kargs.N),
make_tuple(0, 1), // broadcasting over m
number<1>{},
number<1>{});
}
else if constexpr(kQuantType == QuantType::BQuantGrouped)
{
if constexpr(PreshuffleQuant)
{
static_assert(std::is_same_v<BQLayout, tensor_layout::gemm::ColumnMajor>,
"PreshuffleQuant with BQuantGrouped currently only supports "
"ColumnMajor BQ layout");
return MakePreshuffledQuantTensorView<
GemmPipeline::KPerBlockBQ,
GemmPipeline::NPerBlock,
TilePartitioner::BlockGemmShape::WarpTile::at(I1),
GemmPipeline::GetVectorSizeBQ()>(bq_ptr, kargs.N, kargs.QK_B);
}
else
{
using QuantGroupSize = remove_cvref_t<typename GemmPipeline::QuantGroupSize>;
if constexpr(std::is_same_v<BQLayout, tensor_layout::gemm::RowMajor>)
{
// For RowMajor BQ: memory layout is [K/QuantGroupK][N/QuantGroupN]
// Dimensions: [K/QuantGroupK, N/QuantGroupN]
// Strides: [N/QuantGroupN, 1]
return make_naive_tensor_view<address_space_enum::global>(
bq_ptr,
make_tuple(integer_divide_ceil(kargs.K, QuantGroupSize::kK),
integer_divide_ceil(kargs.N, QuantGroupSize::kN)),
make_tuple(integer_divide_ceil(kargs.N, QuantGroupSize::kN), 1),
number<GemmPipeline::GetVectorSizeBQ()>{},
number<1>{});
}
else
{
static_assert(std::is_same_v<BQLayout, tensor_layout::gemm::ColumnMajor>);
// For ColumnMajor BQ: memory layout is [N/QuantGroupN][K/QuantGroupK]
// Dimensions: [N/QuantGroupN, K/QuantGroupK]
// Strides: [K/QuantGroupK, 1]
return make_naive_tensor_view<address_space_enum::global>(
bq_ptr,
make_tuple(integer_divide_ceil(kargs.N, QuantGroupSize::kN),
integer_divide_ceil(kargs.K, QuantGroupSize::kK)),
make_tuple(integer_divide_ceil(kargs.K, QuantGroupSize::kK), 1),
number<GemmPipeline::GetVectorSizeBQ()>{},
number<1>{});
}
}
}
else if constexpr(kQuantType == QuantType::ABQuantGrouped)
{
static_assert(std::is_same_v<BQLayout, tensor_layout::gemm::ColumnMajor>);
using QuantGroupSize = remove_cvref_t<typename GemmPipeline::BQuantGroupSize>;
return make_naive_tensor_view<address_space_enum::global>(
bq_ptr,
make_tuple(integer_divide_ceil(kargs.N, QuantGroupSize::kN), kargs.QK_B),
make_tuple(kargs.stride_BQ, 1),
number<GemmPipeline::GetVectorSizeBQ()>{},
number<1>{});
}
else
{
return nullptr; // TODO: use some other "empty" type for this
}
}();
// TODO: enable vector write for C in ColMajor
const auto& c_tensor_view = [&]() {
if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
{
return make_naive_tensor_view<address_space_enum::global, DstInMemOp>(
c_ptr,
make_tuple(kargs.M, kargs.N),
make_tuple(kargs.stride_C, 1),
number<EpiloguePipeline::GetVectorSizeC()>{},
number<1>{});
}
else
{
return make_naive_tensor_view<address_space_enum::global, DstInMemOp>(
c_ptr,
make_tuple(kargs.M, kargs.N),
make_tuple(1, kargs.stride_C),
number<1>{},
number<1>{});
}
}();
return make_tuple(
a_tensor_view, aq_tensor_view, b_tensor_view, bq_tensor_view, c_tensor_view);
}
template <typename TensorView>
CK_TILE_DEVICE static auto MakeGemmPadViews(const TensorView& views)
{
const auto& a_pad_view = [&]() {
const auto& a_tensor_view = views.at(I0);
if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
{
return pad_tensor_view(a_tensor_view,
make_tuple(number<TilePartitioner::MPerBlock>{},
number<TilePartitioner::KPerBlock>{}),
sequence<false, GemmPipeline::kPadK>{});
}
else
{
return pad_tensor_view(a_tensor_view,
make_tuple(number<TilePartitioner::KPerBlock>{},
number<TilePartitioner::MPerBlock>{}),
sequence<false, GemmPipeline::kPadM>{});
}
}();
// no padding
const auto& aq_pad_view = [&]() { return views.at(I1); }();
const auto& b_flat_view = views.at(I2); // not applying any padding to flat B view
const auto& b_pad_view = [&]() {
const auto& b_tensor_view = views.at(I2);
if constexpr(std::is_same_v<BLayout, tensor_layout::gemm::ColumnMajor>)
{
if constexpr(std::is_same_v<BDataType, pk_fp4_raw_t>)
return pad_tensor_view(b_tensor_view,
make_tuple(number<TilePartitioner::NPerBlock>{},
number<TilePartitioner::KPerBlock / 2>{}),
sequence<false, GemmPipeline::kPadK>{});
else
return pad_tensor_view(b_tensor_view,
make_tuple(number<TilePartitioner::NPerBlock>{},
number<TilePartitioner::KPerBlock>{}),
sequence<false, GemmPipeline::kPadK>{});
}
else
{
return pad_tensor_view(b_tensor_view,
make_tuple(number<TilePartitioner::KPerBlock>{},
number<TilePartitioner::NPerBlock>{}),
sequence<false, GemmPipeline::kPadN>{});
}
}();
// no padding
const auto& bq_pad_view = [&]() { return views.at(I3); }();
// TODO vector write in for C in ColMajor
const auto& c_pad_view = [&]() {
const auto& c_tensor_view = views.at(I4);
if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
{
return pad_tensor_view(c_tensor_view,
make_tuple(number<TilePartitioner::MPerBlock>{},
number<TilePartitioner::NPerBlock>{}),
sequence<false, GemmPipeline::kPadN>{});
}
else
{
return pad_tensor_view(c_tensor_view,
make_tuple(number<TilePartitioner::MPerBlock>{},
number<TilePartitioner::NPerBlock>{}),
sequence<GemmPipeline::kPadM, false>{});
}
}();
if constexpr(PreshuffleB)
{
return make_tuple(a_pad_view, aq_pad_view, b_flat_view, bq_pad_view, c_pad_view);
}
else
{
return make_tuple(a_pad_view, aq_pad_view, b_pad_view, bq_pad_view, c_pad_view);
}
}
template <typename PadView>
CK_TILE_DEVICE static auto
MakeGemmTileWindows(const PadView& views, const index_t i_m, const index_t i_n)
{
const auto& a_pad_view = views.at(I0);
const auto& aq_pad_view = views.at(I1);
const auto& b_pad_view = views.at(I2);
const auto& bq_pad_view = views.at(I3);
const auto& c_pad_view = views.at(I4);
const auto& a_block_window = [&]() {
if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
{
return make_tile_window(a_pad_view,
make_tuple(number<TilePartitioner::MPerBlock>{},
number<TilePartitioner::KPerBlock>{}),
{i_m, 0});
}
else
{
return make_tile_window(a_pad_view,
make_tuple(number<TilePartitioner::KPerBlock>{},
number<TilePartitioner::MPerBlock>{}),
{0, i_m});
}
}();
const auto& aq_block_window = [&]() {
if constexpr(kQuantType == QuantType::AQuantGrouped && PreshuffleQuant)
{
static_assert(std::is_same_v<AQLayout, tensor_layout::gemm::RowMajor>);
using QuantGroupSize = remove_cvref_t<typename GemmPipeline::QuantGroupSize>;
constexpr auto block_m = TilePartitioner::MPerBlock;
constexpr auto warp_m = GemmPipeline::BlockGemmShape::WarpTile::at(I0);
constexpr auto aqk_per_block = TilePartitioner::KPerBlock / QuantGroupSize::kK;
constexpr auto tile_window_width =
ck_tile::integer_least_multiple(warp_m * aqk_per_block, get_warp_size());
constexpr auto tile_window_height = block_m / warp_m;
auto block_m_idx = i_m / block_m;
return make_tile_window(
aq_pad_view,
make_tuple(number<tile_window_height>{}, number<tile_window_width>{}),
{block_m_idx * tile_window_height, 0});
}
else if constexpr(kQuantType == QuantType::AQuantGrouped && !PreshuffleQuant)
{
using QuantGroupSize = remove_cvref_t<typename GemmPipeline::QuantGroupSize>;
constexpr auto aqk_per_block = TilePartitioner::KPerBlock / QuantGroupSize::kK;
constexpr auto block_m = TilePartitioner::MPerBlock;
if constexpr(std::is_same_v<AQLayout, tensor_layout::gemm::RowMajor>)
{
return make_tile_window(aq_pad_view,
make_tuple(number<block_m>{}, number<aqk_per_block>{}),
{i_m, 0});
}
else // Column major AQ
{
return make_tile_window(aq_pad_view,
make_tuple(number<aqk_per_block>{}, number<block_m>{}),
{0, i_m});
}
}
else if constexpr(kQuantType == QuantType::ABQuantGrouped && !PreshuffleQuant)
{
static_assert(std::is_same_v<AQLayout, tensor_layout::gemm::RowMajor>);
using QuantGroupSize = remove_cvref_t<typename GemmPipeline::AQuantGroupSize>;
constexpr auto block_m = TilePartitioner::MPerBlock;
constexpr auto block_k = TilePartitioner::KPerBlock;
return make_tile_window(
aq_pad_view,
make_tuple(number<block_m>{}, number<block_k / QuantGroupSize::kK>{}),
{i_m, 0});
}
else if constexpr(kQuantType == QuantType::RowColQuant)
{
return make_tile_window(aq_pad_view,
make_tuple(number<TilePartitioner::MPerBlock>{},
number<TilePartitioner::NPerBlock>{}),
{i_m, i_n});
}
else
{
return nullptr; // TODO: use some other "empty" type?
}
}();
const auto& b_block_window = [&]() {
if constexpr(PreshuffleB)
{
return make_tile_window(
b_pad_view,
make_tuple(number<GemmPipeline::flatNPerWarp>{},
number<GemmPipeline::flatKPerWarp>{}),
{static_cast<int>(i_n / GemmPipeline::BlockGemmShape::WarpTile::at(I1)), 0});
}
else
{
if constexpr(std::is_same_v<BLayout, tensor_layout::gemm::ColumnMajor>)
{
if constexpr(std::is_same_v<BDataType, pk_fp4_raw_t>)
return make_tile_window(
b_pad_view,
make_tuple(number<TilePartitioner::NPerBlock>{},
number<TilePartitioner::KPerBlock / 2>{}),
{i_n, 0});
else
return make_tile_window(b_pad_view,
make_tuple(number<TilePartitioner::NPerBlock>{},
number<TilePartitioner::KPerBlock>{}),
{i_n, 0});
}
else
{
return make_tile_window(b_pad_view,
make_tuple(number<TilePartitioner::KPerBlock>{},
number<TilePartitioner::NPerBlock>{}),
{0, i_n});
}
}
}();
const auto& bq_block_window = [&]() {
if constexpr(kQuantType == QuantType::RowColQuant)
{
return make_tile_window(bq_pad_view,
make_tuple(number<TilePartitioner::MPerBlock>{},
number<TilePartitioner::NPerBlock>{}),
{i_m, i_n});
}
else if constexpr(kQuantType == QuantType::BQuantGrouped)
{
using QuantGroupSize = remove_cvref_t<typename GemmPipeline::QuantGroupSize>;
if constexpr(PreshuffleQuant)
{
static_assert(std::is_same_v<BQLayout, tensor_layout::gemm::ColumnMajor>);
constexpr auto block_n = TilePartitioner::NPerBlock / QuantGroupSize::kN;
constexpr auto warp_n = TilePartitioner::BlockGemmShape::WarpTile::at(I1);
constexpr auto bqk_per_block = TilePartitioner::KPerBlock / QuantGroupSize::kK;
constexpr auto tile_window_width =
ck_tile::integer_least_multiple(warp_n * bqk_per_block, get_warp_size());
constexpr auto tile_window_height = block_n / warp_n;
auto block_n_idx = i_n / block_n;
return make_tile_window(
bq_pad_view,
make_tuple(number<tile_window_height>{}, number<tile_window_width>{}),
{block_n_idx * tile_window_height, 0});
}
else
{
if constexpr(std::is_same_v<BQLayout, tensor_layout::gemm::RowMajor>)
{
return make_tile_window(
bq_pad_view,
make_tuple(number<TilePartitioner::KPerBlock / QuantGroupSize::kK>{},
number<TilePartitioner::NPerBlock / QuantGroupSize::kN>{}),
{0, i_n / QuantGroupSize::kN});
}
else
{
static_assert(std::is_same_v<BQLayout, tensor_layout::gemm::ColumnMajor>);
return make_tile_window(
bq_pad_view,
make_tuple(number<TilePartitioner::NPerBlock / QuantGroupSize::kN>{},
number<TilePartitioner::KPerBlock / QuantGroupSize::kK>{}),
{i_n / QuantGroupSize::kN, 0});
}
}
}
else if constexpr(kQuantType == QuantType::ABQuantGrouped)
{
static_assert(std::is_same_v<BQLayout, tensor_layout::gemm::ColumnMajor>);
using QuantGroupSize = remove_cvref_t<typename GemmPipeline::BQuantGroupSize>;
return make_tile_window(
bq_pad_view,
make_tuple(number<TilePartitioner::NPerBlock / QuantGroupSize::kN>{},
number<TilePartitioner::KPerBlock / QuantGroupSize::kK>{}),
{i_n / QuantGroupSize::kN, 0});
}
else
{
return nullptr; // TODO: use some other "empty" type here
}
}();
auto c_block_window = make_tile_window(
c_pad_view,
make_tuple(number<TilePartitioner::MPerBlock>{}, number<TilePartitioner::NPerBlock>{}),
{i_m, i_n});
return make_tuple(
a_block_window, aq_block_window, b_block_window, bq_block_window, c_block_window);
}
/**
* @brief Runs single GEMM problem cooperatively by whole workgroup.
*