feat: add split_k support for block scale gemm bquant mode. (#3653)

* WIP: add splitk to bquant

* feat: add support for bf8i4 and fp8i4 by calculating correct stride for packed data types

* chore: remove temporary test script

* fix: incorrect tile window length for splitted bq tensor window

* chore: improve comments

* test: add unit tests to cover bquant splitk functionality

* fix: conflict resolution by renaming variables
This commit is contained in:
Aviral Goel
2026-02-03 02:41:53 +04:00
committed by GitHub
parent 301eb5cf08
commit 3e77721755
11 changed files with 273 additions and 208 deletions

View File

@@ -380,9 +380,18 @@ struct QuantGemmKernel
__device__ SplitKBatchOffset(const QuantGemmKernelArgs& kargs,
const std::size_t k_id = blockIdx.z)
{
constexpr auto K1 = GemmPipeline::BlockGemmShape::WarpTile::at(I2);
const index_t K_t = amd_wave_read_first_lane(kargs.k_batch * K1);
const index_t KRead = amd_wave_read_first_lane((kargs.K + K_t - 1) / K_t * K1);
constexpr auto K1 =
GemmPipeline::BlockGemmShape::WarpTile::at(I2); // smallest unit of K work per block
const index_t K_t = amd_wave_read_first_lane(
kargs.k_batch * K1); // amount of K elements consumed if every split-K batch
// performs exactly one "unit" (K1)
const index_t KRead = amd_wave_read_first_lane(
(kargs.K + K_t - 1) / K_t * K1); // total k elements to be read in this batch
// offset not necessarily = KRead, because B can have packed elements (e.g. fp8i4)
constexpr index_t BPackedSize =
ck_tile::numeric_traits<remove_cvref_t<BDataType>>::PackedSize;
const index_t b_k_offset_elements =
amd_wave_read_first_lane(k_id * KRead / BPackedSize);
if constexpr(std::is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
{
@@ -395,11 +404,11 @@ struct QuantGemmKernel
if constexpr(std::is_same_v<tensor_layout::gemm::RowMajor, BLayout>)
{
b_k_split_offset = amd_wave_read_first_lane(k_id * KRead * kargs.stride_B);
b_k_split_offset = amd_wave_read_first_lane(b_k_offset_elements * kargs.stride_B);
}
else if constexpr(std::is_same_v<tensor_layout::gemm::ColumnMajor, BLayout>)
{
b_k_split_offset = amd_wave_read_first_lane(k_id * KRead);
b_k_split_offset = amd_wave_read_first_lane(b_k_offset_elements);
}
if(k_id < static_cast<uint32_t>(kargs.k_batch - 1))
@@ -410,10 +419,47 @@ struct QuantGemmKernel
{
splitted_k = amd_wave_read_first_lane(kargs.K - KRead * (kargs.k_batch - 1));
}
// Compute BQ offset for BQuantGrouped mode (non-preshuffle only)
// Note: With the alignment validation in IsSupportedArgument, KRead is always
// a multiple of BQuantGroupSize::kK, so bq_k_split_offset will be correctly aligned.
if constexpr(kQuantType == QuantType::BQuantGrouped && !BPreshuffleQuant)
{
using BQuantGroupSize = remove_cvref_t<typename GemmPipeline::BQuantGroupSize>;
// Compute the K offset for this batch (in terms of K elements)
const index_t k_offset = amd_wave_read_first_lane(k_id * KRead);
// Convert K offset to BQ group offset (logical offset in K/kK dimension)
bq_group_offset = amd_wave_read_first_lane(k_offset / BQuantGroupSize::kK);
// BQ tensor layout:
// RowMajor: [K/kK, N/kN] with stride [N/kN, 1]
// ColumnMajor: [N/kN, K/kK] with stride [K/kK, 1]
if constexpr(std::is_same_v<tensor_layout::gemm::RowMajor, BQLayout>)
{
// For RowMajor BQ, K is the row dimension
// offset = bq_group_offset * stride_BQ
const index_t stride_bq =
amd_wave_read_first_lane(integer_divide_ceil(kargs.N, BQuantGroupSize::kN));
bq_k_split_offset = amd_wave_read_first_lane(bq_group_offset * stride_bq);
}
else if constexpr(std::is_same_v<tensor_layout::gemm::ColumnMajor, BQLayout>)
{
// For ColumnMajor BQ, K is the column dimension
// offset = bq_group_offset
bq_k_split_offset = amd_wave_read_first_lane(bq_group_offset);
}
}
else
{
bq_group_offset = 0;
bq_k_split_offset = 0;
}
}
index_t a_k_split_offset;
index_t b_k_split_offset;
index_t bq_group_offset; // Logical offset in K-groups (K/kK dimension)
index_t bq_k_split_offset; // Memory pointer offset (accounting for layout/stride)
index_t splitted_k;
};
@@ -805,10 +851,13 @@ struct QuantGemmKernel
CK_TILE_DEVICE static auto MakeBQBlockWindow(const BQDataType* bq_ptr,
const QuantGemmKernelArgs& kargs,
const index_t bq_group_offset,
const index_t i_m,
const index_t i_n)
{
// Step 1: Create tensor view for BQ
// Note: For split-K, the bq_ptr is already offset by bq_k_split_offset (pointer offset).
// The dimension should use the remaining K-groups from this offset position.
const auto& bq_tensor_view = [&]() {
if constexpr(kQuantType == QuantType::RowColQuant)
{
@@ -850,11 +899,12 @@ struct QuantGemmKernel
"ABQuantGrouped requires ColumnMajor BQ layout");
}
using BQuantGroupSize = remove_cvref_t<typename GemmPipeline::BQuantGroupSize>;
if constexpr(std::is_same_v<BQLayout, tensor_layout::gemm::RowMajor>)
{
return make_naive_tensor_view<address_space_enum::global>(
bq_ptr,
make_tuple(integer_divide_ceil(kargs.K, BQuantGroupSize::kK),
make_tuple(kargs.QK_B - bq_group_offset,
integer_divide_ceil(kargs.N, BQuantGroupSize::kN)),
make_tuple(integer_divide_ceil(kargs.N, BQuantGroupSize::kN), 1),
number<GemmPipeline::GetVectorSizeBQ()>{},
@@ -865,8 +915,8 @@ struct QuantGemmKernel
return make_naive_tensor_view<address_space_enum::global>(
bq_ptr,
make_tuple(integer_divide_ceil(kargs.N, BQuantGroupSize::kN),
integer_divide_ceil(kargs.K, BQuantGroupSize::kK)),
make_tuple(integer_divide_ceil(kargs.K, BQuantGroupSize::kK), 1),
kargs.QK_B - bq_group_offset),
make_tuple(kargs.QK_B, 1),
number<GemmPipeline::GetVectorSizeBQ()>{},
number<1>{});
}
@@ -1047,13 +1097,61 @@ struct QuantGemmKernel
CK_TILE_HOST static bool IsSupportedArgument(const QuantGemmKernelArgs& kargs)
{
// Split-K is supported for BQuantGrouped mode without preshuffle
if(kargs.k_batch != 1)
{
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
constexpr bool is_bquant_non_preshuffle =
(kQuantType == QuantType::BQuantGrouped) && !BPreshuffleQuant;
if constexpr(!is_bquant_non_preshuffle)
{
CK_TILE_ERROR("Conditions not met for Kbatch >1 !");
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
{
CK_TILE_ERROR("Conditions not met for Kbatch >1 ! "
"Split-K only supported for BQuantGrouped without preshuffle.");
}
return false;
}
else
{
using BQuantGroupSize = remove_cvref_t<typename GemmPipeline::BQuantGroupSize>;
constexpr auto K1 = GemmPipeline::BlockGemmShape::WarpTile::at(I2);
const index_t K_t = kargs.k_batch * K1;
const index_t KRead = (kargs.K + K_t - 1) / K_t * K1;
constexpr index_t BPackedSize =
ck_tile::numeric_traits<remove_cvref_t<BDataType>>::PackedSize;
// Constraint 1: KRead must align with B packing requirements.
// For packed data types, multiple K elements are stored in each storage unit.
// Split-K advances the B pointer by (KRead / BPackedSize) storage units per batch.
// If KRead is not divisible by BPackedSize, this division produces a fractional
// offset, making it impossible to start reading from a valid storage unit boundary.
if(KRead % BPackedSize != 0)
{
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
{
CK_TILE_ERROR("KRead must be a multiple of B packed size for split-K!");
}
return false;
}
// Constraint 2: KRead must align with quantization group boundaries.
// Each split-K batch reads KRead consecutive K elements. If KRead is not
// a multiple of BQuantGroupSize::kK, the batch will span partial quantization
// groups, requiring split access to a quantization scale. This violates the
// atomic processing requirement where each batch must work with complete groups.
if(KRead % BQuantGroupSize::kK != 0)
{
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
{
CK_TILE_ERROR("Split-K batch size must be aligned with quantization group "
"size! KRead=" +
std::to_string(KRead) +
" is not divisible by BQuantGroupSize::kK=" +
std::to_string(BQuantGroupSize::kK));
}
return false;
}
}
return false;
}
if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
@@ -1215,7 +1313,10 @@ struct QuantGemmKernel
const auto& b_block_window =
MakeBBlockWindow(b_ptr, kargs, splitk_batch_offset.splitted_k, block_idx_n);
const auto& aq_block_window = MakeAQBlockWindow(aq_ptr, kargs, block_idx_m, block_idx_n);
const auto& bq_block_window = MakeBQBlockWindow(bq_ptr, kargs, block_idx_m, block_idx_n);
// Note: Pass bq_group_offset so the tensor view dimension reflects
// the remaining K-groups from the split-K offset position.
const auto& bq_block_window = MakeBQBlockWindow(
bq_ptr, kargs, splitk_batch_offset.bq_group_offset, block_idx_m, block_idx_n);
const index_t num_loop =
amd_wave_read_first_lane(TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k));
@@ -1343,8 +1444,9 @@ struct QuantGemmKernel
const BDataType* b_ptr =
static_cast<const BDataType*>(kargs.b_ptr) + splitk_batch_offset.b_k_split_offset;
const AQDataType* aq_ptr = static_cast<const AQDataType*>(kargs.aq_ptr);
const BQDataType* bq_ptr = static_cast<const BQDataType*>(kargs.bq_ptr);
CDataType* c_ptr = static_cast<CDataType*>(kargs.c_ptr);
const BQDataType* bq_ptr =
static_cast<const BQDataType*>(kargs.bq_ptr) + splitk_batch_offset.bq_k_split_offset;
CDataType* c_ptr = static_cast<CDataType*>(kargs.c_ptr);
// allocate LDS
__shared__ char smem_ptr[GetSmemSize()];

View File

@@ -387,8 +387,8 @@ struct QuantGroupedGemmKernel
Base::MakeABlockWindow(a_ptr, kargs, splitk_batch_offset.splitted_k, block_idx_m);
const auto& b_block_window =
Base::MakeBBlockWindow(b_ptr, kargs, splitk_batch_offset.splitted_k, block_idx_n);
const auto& bq_block_window =
Base::MakeBQBlockWindow(bq_ptr, kargs, block_idx_m, block_idx_n);
const auto& bq_block_window = Base::MakeBQBlockWindow(
bq_ptr, kargs, splitk_batch_offset.bq_group_offset, block_idx_m, block_idx_n);
const index_t num_loop = __builtin_amdgcn_readfirstlane(
TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k));
@@ -453,8 +453,8 @@ struct QuantGroupedGemmKernel
Base::MakeBBlockWindow(b_ptr, kargs, splitk_batch_offset.splitted_k, block_idx_n);
const auto& aq_block_window =
Base::MakeAQBlockWindow(aq_ptr, kargs, block_idx_m, block_idx_n);
const auto& bq_block_window =
Base::MakeBQBlockWindow(bq_ptr, kargs, block_idx_m, block_idx_n);
const auto& bq_block_window = Base::MakeBQBlockWindow(
bq_ptr, kargs, splitk_batch_offset.bq_group_offset, block_idx_m, block_idx_n);
// Get hot-loop and tail configuration
const index_t num_loop = __builtin_amdgcn_readfirstlane(