mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 06:49:15 +00:00
feat: add split_k support for block scale gemm bquant mode. (#3653)
* WIP: add splitk to bquant * feat: add support for bf8i4 and fp8i4 by calculating correct stride for packed data types * chore: remove temporary test script * fix: incorrect tile window length for splitted bq tensor window * chore: improve comments * test: add unit tests to cover bquant splitk functionality * fix: conflict resolution by renaming variables
This commit is contained in:
@@ -380,9 +380,18 @@ struct QuantGemmKernel
|
||||
__device__ SplitKBatchOffset(const QuantGemmKernelArgs& kargs,
|
||||
const std::size_t k_id = blockIdx.z)
|
||||
{
|
||||
constexpr auto K1 = GemmPipeline::BlockGemmShape::WarpTile::at(I2);
|
||||
const index_t K_t = amd_wave_read_first_lane(kargs.k_batch * K1);
|
||||
const index_t KRead = amd_wave_read_first_lane((kargs.K + K_t - 1) / K_t * K1);
|
||||
constexpr auto K1 =
|
||||
GemmPipeline::BlockGemmShape::WarpTile::at(I2); // smallest unit of K work per block
|
||||
const index_t K_t = amd_wave_read_first_lane(
|
||||
kargs.k_batch * K1); // amount of K elements consumed if every split-K batch
|
||||
// performs exactly one "unit" (K1)
|
||||
const index_t KRead = amd_wave_read_first_lane(
|
||||
(kargs.K + K_t - 1) / K_t * K1); // total k elements to be read in this batch
|
||||
// offset not necessarily = KRead, because B can have packed elements (e.g. fp8i4)
|
||||
constexpr index_t BPackedSize =
|
||||
ck_tile::numeric_traits<remove_cvref_t<BDataType>>::PackedSize;
|
||||
const index_t b_k_offset_elements =
|
||||
amd_wave_read_first_lane(k_id * KRead / BPackedSize);
|
||||
|
||||
if constexpr(std::is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
|
||||
{
|
||||
@@ -395,11 +404,11 @@ struct QuantGemmKernel
|
||||
|
||||
if constexpr(std::is_same_v<tensor_layout::gemm::RowMajor, BLayout>)
|
||||
{
|
||||
b_k_split_offset = amd_wave_read_first_lane(k_id * KRead * kargs.stride_B);
|
||||
b_k_split_offset = amd_wave_read_first_lane(b_k_offset_elements * kargs.stride_B);
|
||||
}
|
||||
else if constexpr(std::is_same_v<tensor_layout::gemm::ColumnMajor, BLayout>)
|
||||
{
|
||||
b_k_split_offset = amd_wave_read_first_lane(k_id * KRead);
|
||||
b_k_split_offset = amd_wave_read_first_lane(b_k_offset_elements);
|
||||
}
|
||||
|
||||
if(k_id < static_cast<uint32_t>(kargs.k_batch - 1))
|
||||
@@ -410,10 +419,47 @@ struct QuantGemmKernel
|
||||
{
|
||||
splitted_k = amd_wave_read_first_lane(kargs.K - KRead * (kargs.k_batch - 1));
|
||||
}
|
||||
|
||||
// Compute BQ offset for BQuantGrouped mode (non-preshuffle only)
|
||||
// Note: With the alignment validation in IsSupportedArgument, KRead is always
|
||||
// a multiple of BQuantGroupSize::kK, so bq_k_split_offset will be correctly aligned.
|
||||
if constexpr(kQuantType == QuantType::BQuantGrouped && !BPreshuffleQuant)
|
||||
{
|
||||
using BQuantGroupSize = remove_cvref_t<typename GemmPipeline::BQuantGroupSize>;
|
||||
// Compute the K offset for this batch (in terms of K elements)
|
||||
const index_t k_offset = amd_wave_read_first_lane(k_id * KRead);
|
||||
// Convert K offset to BQ group offset (logical offset in K/kK dimension)
|
||||
bq_group_offset = amd_wave_read_first_lane(k_offset / BQuantGroupSize::kK);
|
||||
|
||||
// BQ tensor layout:
|
||||
// RowMajor: [K/kK, N/kN] with stride [N/kN, 1]
|
||||
// ColumnMajor: [N/kN, K/kK] with stride [K/kK, 1]
|
||||
if constexpr(std::is_same_v<tensor_layout::gemm::RowMajor, BQLayout>)
|
||||
{
|
||||
// For RowMajor BQ, K is the row dimension
|
||||
// offset = bq_group_offset * stride_BQ
|
||||
const index_t stride_bq =
|
||||
amd_wave_read_first_lane(integer_divide_ceil(kargs.N, BQuantGroupSize::kN));
|
||||
bq_k_split_offset = amd_wave_read_first_lane(bq_group_offset * stride_bq);
|
||||
}
|
||||
else if constexpr(std::is_same_v<tensor_layout::gemm::ColumnMajor, BQLayout>)
|
||||
{
|
||||
// For ColumnMajor BQ, K is the column dimension
|
||||
// offset = bq_group_offset
|
||||
bq_k_split_offset = amd_wave_read_first_lane(bq_group_offset);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
bq_group_offset = 0;
|
||||
bq_k_split_offset = 0;
|
||||
}
|
||||
}
|
||||
|
||||
index_t a_k_split_offset;
|
||||
index_t b_k_split_offset;
|
||||
index_t bq_group_offset; // Logical offset in K-groups (K/kK dimension)
|
||||
index_t bq_k_split_offset; // Memory pointer offset (accounting for layout/stride)
|
||||
index_t splitted_k;
|
||||
};
|
||||
|
||||
@@ -805,10 +851,13 @@ struct QuantGemmKernel
|
||||
|
||||
CK_TILE_DEVICE static auto MakeBQBlockWindow(const BQDataType* bq_ptr,
|
||||
const QuantGemmKernelArgs& kargs,
|
||||
const index_t bq_group_offset,
|
||||
const index_t i_m,
|
||||
const index_t i_n)
|
||||
{
|
||||
// Step 1: Create tensor view for BQ
|
||||
// Note: For split-K, the bq_ptr is already offset by bq_k_split_offset (pointer offset).
|
||||
// The dimension should use the remaining K-groups from this offset position.
|
||||
const auto& bq_tensor_view = [&]() {
|
||||
if constexpr(kQuantType == QuantType::RowColQuant)
|
||||
{
|
||||
@@ -850,11 +899,12 @@ struct QuantGemmKernel
|
||||
"ABQuantGrouped requires ColumnMajor BQ layout");
|
||||
}
|
||||
|
||||
using BQuantGroupSize = remove_cvref_t<typename GemmPipeline::BQuantGroupSize>;
|
||||
if constexpr(std::is_same_v<BQLayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return make_naive_tensor_view<address_space_enum::global>(
|
||||
bq_ptr,
|
||||
make_tuple(integer_divide_ceil(kargs.K, BQuantGroupSize::kK),
|
||||
make_tuple(kargs.QK_B - bq_group_offset,
|
||||
integer_divide_ceil(kargs.N, BQuantGroupSize::kN)),
|
||||
make_tuple(integer_divide_ceil(kargs.N, BQuantGroupSize::kN), 1),
|
||||
number<GemmPipeline::GetVectorSizeBQ()>{},
|
||||
@@ -865,8 +915,8 @@ struct QuantGemmKernel
|
||||
return make_naive_tensor_view<address_space_enum::global>(
|
||||
bq_ptr,
|
||||
make_tuple(integer_divide_ceil(kargs.N, BQuantGroupSize::kN),
|
||||
integer_divide_ceil(kargs.K, BQuantGroupSize::kK)),
|
||||
make_tuple(integer_divide_ceil(kargs.K, BQuantGroupSize::kK), 1),
|
||||
kargs.QK_B - bq_group_offset),
|
||||
make_tuple(kargs.QK_B, 1),
|
||||
number<GemmPipeline::GetVectorSizeBQ()>{},
|
||||
number<1>{});
|
||||
}
|
||||
@@ -1047,13 +1097,61 @@ struct QuantGemmKernel
|
||||
|
||||
CK_TILE_HOST static bool IsSupportedArgument(const QuantGemmKernelArgs& kargs)
|
||||
{
|
||||
// Split-K is supported for BQuantGrouped mode without preshuffle
|
||||
if(kargs.k_batch != 1)
|
||||
{
|
||||
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
|
||||
constexpr bool is_bquant_non_preshuffle =
|
||||
(kQuantType == QuantType::BQuantGrouped) && !BPreshuffleQuant;
|
||||
if constexpr(!is_bquant_non_preshuffle)
|
||||
{
|
||||
CK_TILE_ERROR("Conditions not met for Kbatch >1 !");
|
||||
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
|
||||
{
|
||||
CK_TILE_ERROR("Conditions not met for Kbatch >1 ! "
|
||||
"Split-K only supported for BQuantGrouped without preshuffle.");
|
||||
}
|
||||
return false;
|
||||
}
|
||||
else
|
||||
{
|
||||
using BQuantGroupSize = remove_cvref_t<typename GemmPipeline::BQuantGroupSize>;
|
||||
constexpr auto K1 = GemmPipeline::BlockGemmShape::WarpTile::at(I2);
|
||||
const index_t K_t = kargs.k_batch * K1;
|
||||
const index_t KRead = (kargs.K + K_t - 1) / K_t * K1;
|
||||
constexpr index_t BPackedSize =
|
||||
ck_tile::numeric_traits<remove_cvref_t<BDataType>>::PackedSize;
|
||||
|
||||
// Constraint 1: KRead must align with B packing requirements.
|
||||
// For packed data types, multiple K elements are stored in each storage unit.
|
||||
// Split-K advances the B pointer by (KRead / BPackedSize) storage units per batch.
|
||||
// If KRead is not divisible by BPackedSize, this division produces a fractional
|
||||
// offset, making it impossible to start reading from a valid storage unit boundary.
|
||||
if(KRead % BPackedSize != 0)
|
||||
{
|
||||
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
|
||||
{
|
||||
CK_TILE_ERROR("KRead must be a multiple of B packed size for split-K!");
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
// Constraint 2: KRead must align with quantization group boundaries.
|
||||
// Each split-K batch reads KRead consecutive K elements. If KRead is not
|
||||
// a multiple of BQuantGroupSize::kK, the batch will span partial quantization
|
||||
// groups, requiring split access to a quantization scale. This violates the
|
||||
// atomic processing requirement where each batch must work with complete groups.
|
||||
if(KRead % BQuantGroupSize::kK != 0)
|
||||
{
|
||||
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
|
||||
{
|
||||
CK_TILE_ERROR("Split-K batch size must be aligned with quantization group "
|
||||
"size! KRead=" +
|
||||
std::to_string(KRead) +
|
||||
" is not divisible by BQuantGroupSize::kK=" +
|
||||
std::to_string(BQuantGroupSize::kK));
|
||||
}
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
|
||||
@@ -1215,7 +1313,10 @@ struct QuantGemmKernel
|
||||
const auto& b_block_window =
|
||||
MakeBBlockWindow(b_ptr, kargs, splitk_batch_offset.splitted_k, block_idx_n);
|
||||
const auto& aq_block_window = MakeAQBlockWindow(aq_ptr, kargs, block_idx_m, block_idx_n);
|
||||
const auto& bq_block_window = MakeBQBlockWindow(bq_ptr, kargs, block_idx_m, block_idx_n);
|
||||
// Note: Pass bq_group_offset so the tensor view dimension reflects
|
||||
// the remaining K-groups from the split-K offset position.
|
||||
const auto& bq_block_window = MakeBQBlockWindow(
|
||||
bq_ptr, kargs, splitk_batch_offset.bq_group_offset, block_idx_m, block_idx_n);
|
||||
|
||||
const index_t num_loop =
|
||||
amd_wave_read_first_lane(TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k));
|
||||
@@ -1343,8 +1444,9 @@ struct QuantGemmKernel
|
||||
const BDataType* b_ptr =
|
||||
static_cast<const BDataType*>(kargs.b_ptr) + splitk_batch_offset.b_k_split_offset;
|
||||
const AQDataType* aq_ptr = static_cast<const AQDataType*>(kargs.aq_ptr);
|
||||
const BQDataType* bq_ptr = static_cast<const BQDataType*>(kargs.bq_ptr);
|
||||
CDataType* c_ptr = static_cast<CDataType*>(kargs.c_ptr);
|
||||
const BQDataType* bq_ptr =
|
||||
static_cast<const BQDataType*>(kargs.bq_ptr) + splitk_batch_offset.bq_k_split_offset;
|
||||
CDataType* c_ptr = static_cast<CDataType*>(kargs.c_ptr);
|
||||
|
||||
// allocate LDS
|
||||
__shared__ char smem_ptr[GetSmemSize()];
|
||||
|
||||
@@ -387,8 +387,8 @@ struct QuantGroupedGemmKernel
|
||||
Base::MakeABlockWindow(a_ptr, kargs, splitk_batch_offset.splitted_k, block_idx_m);
|
||||
const auto& b_block_window =
|
||||
Base::MakeBBlockWindow(b_ptr, kargs, splitk_batch_offset.splitted_k, block_idx_n);
|
||||
const auto& bq_block_window =
|
||||
Base::MakeBQBlockWindow(bq_ptr, kargs, block_idx_m, block_idx_n);
|
||||
const auto& bq_block_window = Base::MakeBQBlockWindow(
|
||||
bq_ptr, kargs, splitk_batch_offset.bq_group_offset, block_idx_m, block_idx_n);
|
||||
|
||||
const index_t num_loop = __builtin_amdgcn_readfirstlane(
|
||||
TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k));
|
||||
@@ -453,8 +453,8 @@ struct QuantGroupedGemmKernel
|
||||
Base::MakeBBlockWindow(b_ptr, kargs, splitk_batch_offset.splitted_k, block_idx_n);
|
||||
const auto& aq_block_window =
|
||||
Base::MakeAQBlockWindow(aq_ptr, kargs, block_idx_m, block_idx_n);
|
||||
const auto& bq_block_window =
|
||||
Base::MakeBQBlockWindow(bq_ptr, kargs, block_idx_m, block_idx_n);
|
||||
const auto& bq_block_window = Base::MakeBQBlockWindow(
|
||||
bq_ptr, kargs, splitk_batch_offset.bq_group_offset, block_idx_m, block_idx_n);
|
||||
|
||||
// Get hot-loop and tail configuration
|
||||
const index_t num_loop = __builtin_amdgcn_readfirstlane(
|
||||
|
||||
Reference in New Issue
Block a user