[CK_Tile] Adding support for preshuffleQuant in AB quant Block Scale Gemm (#3629)

* initial commit

* preshuffleQuant support for ABQuant

* fix mxfp4 to use correct QuantGroupSize

* addressing review comments and seperated Preshufflequant for A and B

* updated grouped gemm example for updated traits definition

* fix for CI failure

* updated grouped_gemm_abquant test for updated traits definition

* updated grouped_gemm_abquant test for updated traits definition
This commit is contained in:
Khushbu Agarwal
2026-01-28 19:45:09 -08:00
committed by GitHub
parent e3556fed04
commit 9b168082b7
33 changed files with 490 additions and 367 deletions

View File

@@ -67,15 +67,27 @@ struct get_bq_data_type_or<T, Default, std::void_t<typename T::BQDataType>>
};
template <typename, typename = void>
struct is_quantpreshuffle_enabled
struct is_Aquantpreshuffle_enabled
{
static constexpr bool value = false;
};
template <typename T>
struct is_quantpreshuffle_enabled<T, std::void_t<decltype(T::PreshuffleQuant)>>
struct is_Aquantpreshuffle_enabled<T, std::void_t<decltype(T::APreshuffleQuant)>>
{
static constexpr bool value = T::PreshuffleQuant;
static constexpr bool value = T::APreshuffleQuant;
};
template <typename, typename = void>
struct is_Bquantpreshuffle_enabled
{
static constexpr bool value = false;
};
template <typename T>
struct is_Bquantpreshuffle_enabled<T, std::void_t<decltype(T::BPreshuffleQuant)>>
{
static constexpr bool value = T::BPreshuffleQuant;
};
template <typename, typename = void>
@@ -206,8 +218,10 @@ struct QuantGemmKernel
typename detail::get_bq_layout_or<GemmPipeline, typename GemmPipeline::BLayout>::type>;
static constexpr index_t kBlockSize = GemmPipeline::BlockSize;
static constexpr bool PreshuffleQuant =
detail::is_quantpreshuffle_enabled<GemmPipeline_>::value;
static constexpr bool APreshuffleQuant =
detail::is_Aquantpreshuffle_enabled<GemmPipeline_>::value;
static constexpr bool BPreshuffleQuant =
detail::is_Bquantpreshuffle_enabled<GemmPipeline_>::value;
static constexpr bool PreshuffleB = detail::is_preshuffleB_enabled<GemmPipeline_>::value;
using ADataType = remove_cvref_t<typename GemmPipeline::ADataType>;
@@ -476,7 +490,7 @@ struct QuantGemmKernel
{
// Step 1: Create tensor view for AQ
const auto& aq_tensor_view = [&]() {
if constexpr(kQuantType == QuantType::AQuantGrouped && PreshuffleQuant)
if constexpr(kQuantType == QuantType::AQuantGrouped && APreshuffleQuant)
{
static_assert(std::is_same_v<AQLayout, tensor_layout::gemm::RowMajor>);
const auto aq_x = kargs.M * GemmPipeline::KPerBlockAQ;
@@ -533,7 +547,7 @@ struct QuantGemmKernel
}
else if constexpr((kQuantType == QuantType::AQuantGrouped ||
kQuantType == QuantType::ABQuantGrouped) &&
!PreshuffleQuant)
!APreshuffleQuant)
{
if constexpr(std::is_same_v<AQLayout, tensor_layout::gemm::RowMajor>)
{
@@ -571,13 +585,13 @@ struct QuantGemmKernel
// Step 2: Create tile window (no padding for AQ)
const auto& aq_block_window = [&]() {
if constexpr(kQuantType == QuantType::AQuantGrouped && PreshuffleQuant)
if constexpr(kQuantType == QuantType::AQuantGrouped && APreshuffleQuant)
{
static_assert(std::is_same_v<AQLayout, tensor_layout::gemm::RowMajor>);
using QuantGroupSize = remove_cvref_t<typename GemmPipeline::QuantGroupSize>;
using AQuantGroupSize = remove_cvref_t<typename GemmPipeline::AQuantGroupSize>;
constexpr auto block_m = TilePartitioner::MPerBlock;
constexpr auto warp_m = GemmPipeline::BlockGemmShape::WarpTile::at(I0);
constexpr auto aqk_per_block = TilePartitioner::KPerBlock / QuantGroupSize::kK;
constexpr auto aqk_per_block = TilePartitioner::KPerBlock / AQuantGroupSize::kK;
constexpr auto tile_window_width =
ck_tile::integer_least_multiple(warp_m * aqk_per_block, get_warp_size());
constexpr auto tile_window_height = block_m / warp_m;
@@ -587,11 +601,19 @@ struct QuantGemmKernel
make_tuple(number<tile_window_height>{}, number<tile_window_width>{}),
{block_m_idx * tile_window_height, 0});
}
else if constexpr(kQuantType == QuantType::AQuantGrouped && !PreshuffleQuant)
else if constexpr((kQuantType == QuantType::AQuantGrouped ||
kQuantType == QuantType::ABQuantGrouped) &&
!APreshuffleQuant)
{
using QuantGroupSize = remove_cvref_t<typename GemmPipeline::QuantGroupSize>;
constexpr auto aqk_per_block = TilePartitioner::KPerBlock / QuantGroupSize::kK;
using AQuantGroupSize = remove_cvref_t<typename GemmPipeline::AQuantGroupSize>;
constexpr auto aqk_per_block = TilePartitioner::KPerBlock / AQuantGroupSize::kK;
constexpr auto block_m = TilePartitioner::MPerBlock;
if constexpr(kQuantType == QuantType::ABQuantGrouped)
{
static_assert(std::is_same_v<AQLayout, tensor_layout::gemm::RowMajor>,
"ABQuantGrouped requires RowMajor AQ layout");
}
if constexpr(std::is_same_v<AQLayout, tensor_layout::gemm::RowMajor>)
{
return make_tile_window(aq_tensor_view,
@@ -605,17 +627,6 @@ struct QuantGemmKernel
{0, i_m});
}
}
else if constexpr(kQuantType == QuantType::ABQuantGrouped && !PreshuffleQuant)
{
static_assert(std::is_same_v<AQLayout, tensor_layout::gemm::RowMajor>);
using QuantGroupSize = remove_cvref_t<typename GemmPipeline::AQuantGroupSize>;
constexpr auto block_m = TilePartitioner::MPerBlock;
constexpr auto block_k = TilePartitioner::KPerBlock;
return make_tile_window(
aq_tensor_view,
make_tuple(number<block_m>{}, number<block_k / QuantGroupSize::kK>{}),
{i_m, 0});
}
else if constexpr(kQuantType == QuantType::RowColQuant)
{
return make_tile_window(aq_tensor_view,
@@ -808,14 +819,15 @@ struct QuantGemmKernel
number<1>{},
number<1>{});
}
else if constexpr(kQuantType == QuantType::BQuantGrouped)
else if constexpr(kQuantType == QuantType::BQuantGrouped ||
kQuantType == QuantType::ABQuantGrouped)
{
if constexpr(PreshuffleQuant)
if constexpr(BPreshuffleQuant)
{
static_assert(std::is_same_v<BQLayout, tensor_layout::gemm::ColumnMajor>,
"PreshuffleQuant with BQuantGrouped currently only supports "
"ColumnMajor BQ layout");
using QuantGroupSize = remove_cvref_t<typename GemmPipeline::QuantGroupSize>;
using BQuantGroupSize = remove_cvref_t<typename GemmPipeline::BQuantGroupSize>;
return MakePreshuffledQuantTensorView<
GemmPipeline::KPerBlockBQ,
@@ -824,48 +836,42 @@ struct QuantGemmKernel
TilePartitioner::BlockGemmShape::WarpTile::at(I1),
GemmPipeline::GetVectorSizeBQ()>(
bq_ptr,
ck_tile::integer_divide_ceil(kargs.N, QuantGroupSize::kN),
QuantGroupSize::kN,
ck_tile::integer_divide_ceil(kargs.N, BQuantGroupSize::kN),
BQuantGroupSize::kN,
kargs.QK_B);
}
else
{
using QuantGroupSize = remove_cvref_t<typename GemmPipeline::QuantGroupSize>;
using BQuantGroupSize = remove_cvref_t<typename GemmPipeline::BQuantGroupSize>;
if constexpr(kQuantType == QuantType::ABQuantGrouped)
{
static_assert(std::is_same_v<BQLayout, tensor_layout::gemm::ColumnMajor>,
"ABQuantGrouped requires ColumnMajor BQ layout");
}
if constexpr(std::is_same_v<BQLayout, tensor_layout::gemm::RowMajor>)
{
return make_naive_tensor_view<address_space_enum::global>(
bq_ptr,
make_tuple(integer_divide_ceil(kargs.K, QuantGroupSize::kK),
integer_divide_ceil(kargs.N, QuantGroupSize::kN)),
make_tuple(integer_divide_ceil(kargs.N, QuantGroupSize::kN), 1),
make_tuple(integer_divide_ceil(kargs.K, BQuantGroupSize::kK),
integer_divide_ceil(kargs.N, BQuantGroupSize::kN)),
make_tuple(integer_divide_ceil(kargs.N, BQuantGroupSize::kN), 1),
number<GemmPipeline::GetVectorSizeBQ()>{},
number<1>{});
}
else
{
static_assert(std::is_same_v<BQLayout, tensor_layout::gemm::ColumnMajor>);
return make_naive_tensor_view<address_space_enum::global>(
bq_ptr,
make_tuple(integer_divide_ceil(kargs.N, QuantGroupSize::kN),
integer_divide_ceil(kargs.K, QuantGroupSize::kK)),
make_tuple(integer_divide_ceil(kargs.K, QuantGroupSize::kK), 1),
make_tuple(integer_divide_ceil(kargs.N, BQuantGroupSize::kN),
integer_divide_ceil(kargs.K, BQuantGroupSize::kK)),
make_tuple(integer_divide_ceil(kargs.K, BQuantGroupSize::kK), 1),
number<GemmPipeline::GetVectorSizeBQ()>{},
number<1>{});
}
}
}
else if constexpr(kQuantType == QuantType::ABQuantGrouped)
{
static_assert(std::is_same_v<BQLayout, tensor_layout::gemm::ColumnMajor>);
using QuantGroupSize = remove_cvref_t<typename GemmPipeline::BQuantGroupSize>;
return make_naive_tensor_view<address_space_enum::global>(
bq_ptr,
make_tuple(integer_divide_ceil(kargs.N, QuantGroupSize::kN), kargs.QK_B),
make_tuple(kargs.stride_BQ, 1),
number<GemmPipeline::GetVectorSizeBQ()>{},
number<1>{});
}
else
{
return nullptr;
@@ -881,28 +887,29 @@ struct QuantGemmKernel
number<TilePartitioner::NPerBlock>{}),
{i_m, i_n});
}
else if constexpr(kQuantType == QuantType::BQuantGrouped)
else if constexpr(kQuantType == QuantType::BQuantGrouped ||
kQuantType == QuantType::ABQuantGrouped)
{
using QuantGroupSize = remove_cvref_t<typename GemmPipeline::QuantGroupSize>;
if constexpr(PreshuffleQuant)
using BQuantGroupSize = remove_cvref_t<typename GemmPipeline::BQuantGroupSize>;
if constexpr(BPreshuffleQuant)
{
static_assert(std::is_same_v<BQLayout, tensor_layout::gemm::ColumnMajor>);
// Number of N-dimension quantization groups per block
constexpr auto block_n = (QuantGroupSize::kN <= TilePartitioner::NPerBlock)
? TilePartitioner::NPerBlock / QuantGroupSize::kN
: QuantGroupSize::kN / TilePartitioner::NPerBlock;
constexpr auto block_n = (BQuantGroupSize::kN <= TilePartitioner::NPerBlock)
? TilePartitioner::NPerBlock / BQuantGroupSize::kN
: BQuantGroupSize::kN / TilePartitioner::NPerBlock;
// Number of N-dimension elements per warp
constexpr auto warp_n = TilePartitioner::BlockGemmShape::WarpTile::at(I1);
// Determine how many warps share the same scale in N-dimension
constexpr auto warp_per_group = (QuantGroupSize::kN < warp_n)
? (warp_n / QuantGroupSize::kN)
: (QuantGroupSize::kN / warp_n);
constexpr auto warp_per_group = (BQuantGroupSize::kN < warp_n)
? (warp_n / BQuantGroupSize::kN)
: (BQuantGroupSize::kN / warp_n);
// Number of K-dimension quantization groups per block
constexpr auto bqk_per_block = TilePartitioner::KPerBlock / QuantGroupSize::kK;
constexpr auto bqk_per_block = TilePartitioner::KPerBlock / BQuantGroupSize::kK;
// The pre-shuffled layout flattens warp_n ×
// bqk_per_block scales per row, Padded up to warp_size
@@ -911,25 +918,25 @@ struct QuantGemmKernel
ck_tile::integer_least_multiple(warp_n * bqk_per_block, get_warp_size());
// Adapts based on fine vs coarse quantization granularity:
// - Fine-grained (QuantGroupSize::kN < warp_n):
// - Fine-grained (BQuantGroupSize::kN < warp_n):
// Multiple quant groups per warp → fewer rows needed per block.
// height = block_n / warp_per_group
//
// - Coarse-grained (QuantGroupSize::kN >= warp_n):
// - Coarse-grained (BQuantGroupSize::kN >= warp_n):
// Each row represents one quant group.
// height = block_n
constexpr auto tile_window_height =
(QuantGroupSize::kN < warp_n) ? block_n / warp_per_group : block_n;
(BQuantGroupSize::kN < warp_n) ? block_n / warp_per_group : block_n;
auto block_n_idx = i_n / TilePartitioner::NPerBlock;
// For decode shapes GN: 128, Blocks needs to repeat 0,0,1,1,2,2 ...
if(QuantGroupSize::kN > TilePartitioner::NPerBlock)
if(BQuantGroupSize::kN > TilePartitioner::NPerBlock)
{
block_n_idx = block_n_idx >> 1;
}
if(QuantGroupSize::kN > TilePartitioner::NPerBlock)
if(BQuantGroupSize::kN > TilePartitioner::NPerBlock)
{
return make_tile_window(
bq_tensor_view,
@@ -946,17 +953,22 @@ struct QuantGemmKernel
}
else
{
if constexpr(kQuantType == QuantType::ABQuantGrouped)
{
static_assert(std::is_same_v<BQLayout, tensor_layout::gemm::ColumnMajor>,
"ABQuantGrouped requires RowMajor AQ layout");
}
constexpr auto tensor_dim =
(QuantGroupSize::kN <= TilePartitioner::NPerBlock)
? TilePartitioner::NPerBlock / QuantGroupSize::kN
(BQuantGroupSize::kN <= TilePartitioner::NPerBlock)
? TilePartitioner::NPerBlock / BQuantGroupSize::kN
: 1;
if constexpr(std::is_same_v<BQLayout, tensor_layout::gemm::RowMajor>)
{
return make_tile_window(
bq_tensor_view,
make_tuple(number<TilePartitioner::KPerBlock / QuantGroupSize::kK>{},
make_tuple(number<TilePartitioner::KPerBlock / BQuantGroupSize::kK>{},
number<tensor_dim>{}),
{0, i_n / QuantGroupSize::kN});
{0, i_n / BQuantGroupSize::kN});
}
else
{
@@ -964,21 +976,11 @@ struct QuantGemmKernel
return make_tile_window(
bq_tensor_view,
make_tuple(number<tensor_dim>{},
number<TilePartitioner::KPerBlock / QuantGroupSize::kK>{}),
{i_n / QuantGroupSize::kN, 0});
number<TilePartitioner::KPerBlock / BQuantGroupSize::kK>{}),
{i_n / BQuantGroupSize::kN, 0});
}
}
}
else if constexpr(kQuantType == QuantType::ABQuantGrouped)
{
static_assert(std::is_same_v<BQLayout, tensor_layout::gemm::ColumnMajor>);
using QuantGroupSize = remove_cvref_t<typename GemmPipeline::BQuantGroupSize>;
return make_tile_window(
bq_tensor_view,
make_tuple(number<TilePartitioner::NPerBlock / QuantGroupSize::kN>{},
number<TilePartitioner::KPerBlock / QuantGroupSize::kK>{}),
{i_n / QuantGroupSize::kN, 0});
}
else
{
return nullptr;
@@ -1223,7 +1225,7 @@ struct QuantGemmKernel
if constexpr(kQuantType == QuantType::AQuantGrouped)
{
index_t m = 0;
if constexpr(PreshuffleQuant)
if constexpr(APreshuffleQuant)
{
m = kargs.M;
}
@@ -1233,7 +1235,7 @@ struct QuantGemmKernel
else if constexpr(kQuantType == QuantType::BQuantGrouped)
{
index_t n = 0;
if constexpr(PreshuffleQuant)
if constexpr(BPreshuffleQuant)
{
n = kargs.N;
}
@@ -1244,9 +1246,9 @@ struct QuantGemmKernel
{
index_t m = 0;
index_t n = 0;
if constexpr(PreshuffleQuant)
if constexpr(BPreshuffleQuant)
{
m = kargs.M;
// m = kargs.M;
n = kargs.N;
}
return GemmPipeline{}.template operator()(a_block_window,