feat: add split_k support for block scale gemm bquant mode. (#3653)

* WIP: add splitk to bquant

* feat: add support for bf8i4 and fp8i4 by calculating correct stride for packed data types

* chore: remove temporary test script

* fix: incorrect tile window length for splitted bq tensor window

* chore: improve comments

* test: add unit tests to cover bquant splitk functionality

* fix: conflict resolution by renaming variables
This commit is contained in:
Aviral Goel
2026-02-03 02:41:53 +04:00
committed by GitHub
parent 301eb5cf08
commit 3e77721755
11 changed files with 273 additions and 208 deletions

View File

@@ -4,7 +4,7 @@
#include "run_gemm_quant_example.inc"
template <typename T>
using GemmConfig = GemmConfigQuantPrefill<T>;
using GemmConfig = GemmConfigQuantDecode<T>;
#define RUN_GEMM_EXAMPLE_PREC_TYPE \
run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>, \

View File

@@ -4,7 +4,7 @@
#include "run_gemm_quant_example.inc"
template <typename T>
using GemmConfig = GemmConfigQuantPrefill<T>;
using GemmConfig = GemmConfigQuantDecode<T>;
#define RUN_GEMM_EXAMPLE_PREC_TYPE \
run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>, \

View File

@@ -4,7 +4,7 @@
#include "run_gemm_quant_example.inc"
template <typename T>
using GemmConfig = GemmConfigQuantPrefill<T>;
using GemmConfig = GemmConfigQuantDecode<T>;
#define RUN_GEMM_EXAMPLE_PREC_TYPE \
run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>, \

View File

@@ -4,7 +4,7 @@
#include "run_gemm_quant_example.inc"
template <typename T>
using GemmConfig = GemmConfigQuantPrefill<T>;
using GemmConfig = GemmConfigQuantDecode<T>;
#define RUN_GEMM_EXAMPLE_PREC_TYPE \
run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>, \

View File

@@ -215,11 +215,8 @@ float gemm_calc_quant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::str
const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch);
const dim3 blocks = Kernel::BlockSize();
if(args.k_batch != 1)
{
throw std::runtime_error("split-k is not supported yet!");
}
// Split-K validation is handled by Kernel::IsSupportedArgument
// Split-K is only supported for BQuantGrouped without preshuffle
if(!Kernel::IsSupportedArgument(kargs))
{
throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n");
@@ -661,182 +658,6 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser,
}
}
}
else if(init_method == 3)
{
if constexpr(QuantMode == ck_tile::QuantType::BQuantGrouped)
{
ck_tile::FillConstant<ADataType>{static_cast<ADataType>(0x38)}(a_m_k);
ck_tile::FillConstant<BDataType>{static_cast<BDataType>(0x22)}(b_k_n);
ck_tile::FillConstant<BQDataType>{static_cast<BQDataType>(0.5f)}(*bq_tensor_ptr);
}
else if constexpr(QuantMode == ck_tile::QuantType::ABQuantGrouped)
{
ck_tile::FillConstant<ADataType>{static_cast<ADataType>(0x38)}(a_m_k);
ck_tile::FillConstant<BDataType>{static_cast<BDataType>(0x22)}(b_k_n);
ck_tile::FillConstant<AQDataType>{static_cast<AQDataType>(0.5f)}(*aq_tensor_ptr);
ck_tile::FillConstant<BQDataType>{static_cast<BQDataType>(0.5f)}(*bq_tensor_ptr);
}
else
{
ck_tile::FillConstant<ADataType>{static_cast<ADataType>(0x22)}(a_m_k);
ck_tile::FillConstant<AQDataType>{static_cast<AQDataType>(2.0f)}(*aq_tensor_ptr);
ck_tile::FillConstant<BDataType>{static_cast<BDataType>(0x38)}(b_k_n);
if constexpr(QuantMode == ck_tile::QuantType::RowColQuant)
{
ck_tile::FillConstant<BQDataType>{static_cast<BQDataType>(0.5f)}(*bq_tensor_ptr);
}
}
}
else if(init_method == 4)
{
if constexpr(QuantMode == ck_tile::QuantType::BQuantGrouped)
{
if constexpr(std::is_same_v<BDataType, ck_tile::pk_int4_t>)
{
ck_tile::FillUniformDistribution<ck_tile::pk_int4_t>{-5.0f, 5.0f, fill_seed(gen)}(
b_k_n);
ck_tile::FillUniformDistribution<BQDataType>{-2.0f, 2.0f, fill_seed(gen)}(
*bq_tensor_ptr);
}
else if constexpr(std::is_same_v<BDataType, ck_tile::pk_fp4_raw_t>)
{
ck_tile::FillUniformDistribution<BDataType>{-5.0f, 5.0f, fill_seed(gen)}(b_k_n);
ck_tile::FillUniformDistribution<BQDataType>{125.f, 130.f, fill_seed(gen)}(
*bq_tensor_ptr);
}
else
{
ck_tile::FillUniformDistribution<BDataType>{-2.0f, 3.0f, fill_seed(gen)}(b_k_n);
ck_tile::FillUniformDistribution<BQDataType>{-2.0f, 2.0f, fill_seed(gen)}(
*bq_tensor_ptr);
}
ck_tile::FillUniformDistribution<ADataType>{-5.0f, 5.0f, fill_seed(gen)}(a_m_k);
}
else if constexpr(QuantMode == ck_tile::QuantType::AQuantGrouped)
{
if constexpr(std::is_same_v<ADataType, ck_tile::pk_int4_t>)
{
ck_tile::FillUniformDistribution<ck_tile::pk_int4_t>{-5.0f, 5.0f, fill_seed(gen)}(
a_m_k);
}
else
{
ck_tile::FillUniformDistribution<ADataType>{-2.0f, 3.0f, fill_seed(gen)}(a_m_k);
}
ck_tile::FillUniformDistribution<AQDataType>{2.0f, 2.0f, fill_seed(gen)}(
*aq_tensor_ptr);
ck_tile::FillUniformDistribution<BDataType>{-5.0f, 5.0f, fill_seed(gen)}(b_k_n);
}
else if constexpr(QuantMode == ck_tile::QuantType::ABQuantGrouped)
{
if constexpr(std::is_same_v<ADataType, ck_tile::pk_int4_t> ||
std::is_same_v<ADataType, ck_tile::pk_fp4_t>)
{
ck_tile::FillUniformDistribution<ADataType>{-5.0f, 5.0f, fill_seed(gen)}(a_m_k);
ck_tile::FillUniformDistribution<BDataType>{-5.0f, 5.0f, fill_seed(gen)}(b_k_n);
}
else
{
ck_tile::FillUniformDistribution<ADataType>{-2.0f, 3.0f, fill_seed(gen)}(a_m_k);
ck_tile::FillUniformDistribution<BDataType>{-2.0f, 3.0f, fill_seed(gen)}(b_k_n);
}
ck_tile::FillUniformDistribution<AQDataType>{-2.0f, 2.0f, fill_seed(gen)}(
*aq_tensor_ptr);
ck_tile::FillUniformDistribution<BQDataType>{-2.0f, 2.0f, fill_seed(gen)}(
*bq_tensor_ptr);
}
else
{
ck_tile::FillUniformDistribution<ADataType>{-2.0f, 2.0f, fill_seed(gen)}(a_m_k);
ck_tile::FillUniformDistribution<BDataType>{-2.0f, 2.0f, fill_seed(gen)}(b_k_n);
ck_tile::FillUniformDistribution<AQDataType>{-2.0f, 2.0f, fill_seed(gen)}(
*aq_tensor_ptr);
ck_tile::FillUniformDistribution<BQDataType>{-2.0f, 2.0f, fill_seed(gen)}(
*bq_tensor_ptr);
}
}
else if(init_method == 5)
{
if constexpr(QuantMode == ck_tile::QuantType::BQuantGrouped)
{
if constexpr(std::is_same_v<BDataType, ck_tile::pk_int4_t>)
{
ck_tile::FillUniformDistribution<ck_tile::pk_int4_t>{-5.0f, 5.0f, fill_seed(gen)}(
b_k_n);
ck_tile::FillUniformDistribution<BQDataType>{-2.0f, 2.0f, fill_seed(gen)}(
*bq_tensor_ptr);
}
else if constexpr(std::is_same_v<BDataType, ck_tile::pk_fp4_raw_t>)
{
ck_tile::FillUniformDistribution<BDataType>{-5.0f, 5.0f, fill_seed(gen)}(b_k_n);
ck_tile::FillUniformDistribution<BQDataType>{125.f, 130.f, fill_seed(gen)}(
*bq_tensor_ptr);
}
else
{
ck_tile::FillUniformDistribution<BDataType>{-2.0f, 3.0f, fill_seed(gen)}(b_k_n);
ck_tile::FillUniformDistribution<BQDataType>{-2.0f, 2.0f, fill_seed(gen)}(
*bq_tensor_ptr);
}
ck_tile::FillUniformDistribution<ADataType>{-5.0f, 5.0f, fill_seed(gen)}(a_m_k);
}
else if constexpr(QuantMode == ck_tile::QuantType::AQuantGrouped)
{
if constexpr(std::is_same_v<ADataType, ck_tile::pk_int4_t>)
{
ck_tile::FillUniformDistribution<ck_tile::pk_int4_t>{-5.0f, 5.0f, fill_seed(gen)}(
a_m_k);
}
else
{
ck_tile::FillUniformDistribution<ADataType>{1.0f, 1.0f, fill_seed(gen)}(a_m_k);
}
// Fill aquant such that column j has value 2^j (1, 2, 4, 8, ...)
for(ck_tile::index_t row = 0;
row < static_cast<ck_tile::index_t>(aq_tensor_ptr->get_length(0));
++row)
{
for(ck_tile::index_t col = 0;
col < static_cast<ck_tile::index_t>(aq_tensor_ptr->get_length(1));
++col)
{
(*aq_tensor_ptr)(row, col) = static_cast<AQDataType>(col + 1);
}
}
// std::cout << "aq_tensor_ptr: " << *aq_tensor_ptr << std::endl;
ck_tile::FillUniformDistribution<BDataType>{1.0f, 1.0f, fill_seed(gen)}(b_k_n);
}
else if constexpr(QuantMode == ck_tile::QuantType::ABQuantGrouped)
{
if constexpr(std::is_same_v<ADataType, ck_tile::pk_int4_t> ||
std::is_same_v<ADataType, ck_tile::pk_fp4_t>)
{
ck_tile::FillUniformDistribution<ADataType>{-5.0f, 5.0f, fill_seed(gen)}(a_m_k);
ck_tile::FillUniformDistribution<BDataType>{-5.0f, 5.0f, fill_seed(gen)}(b_k_n);
}
else
{
ck_tile::FillUniformDistribution<ADataType>{-2.0f, 3.0f, fill_seed(gen)}(a_m_k);
ck_tile::FillUniformDistribution<BDataType>{-2.0f, 3.0f, fill_seed(gen)}(b_k_n);
}
ck_tile::FillUniformDistribution<AQDataType>{-2.0f, 2.0f, fill_seed(gen)}(
*aq_tensor_ptr);
ck_tile::FillUniformDistribution<BQDataType>{-2.0f, 2.0f, fill_seed(gen)}(
*bq_tensor_ptr);
}
else
{
ck_tile::FillUniformDistribution<ADataType>{-2.0f, 2.0f, fill_seed(gen)}(a_m_k);
ck_tile::FillUniformDistribution<BDataType>{-2.0f, 2.0f, fill_seed(gen)}(b_k_n);
ck_tile::FillUniformDistribution<AQDataType>{-2.0f, 2.0f, fill_seed(gen)}(
*aq_tensor_ptr);
ck_tile::FillUniformDistribution<BQDataType>{-2.0f, 2.0f, fill_seed(gen)}(
*bq_tensor_ptr);
}
}
else
{
a_m_k.SetZero();

View File

@@ -380,9 +380,18 @@ struct QuantGemmKernel
__device__ SplitKBatchOffset(const QuantGemmKernelArgs& kargs,
const std::size_t k_id = blockIdx.z)
{
constexpr auto K1 = GemmPipeline::BlockGemmShape::WarpTile::at(I2);
const index_t K_t = amd_wave_read_first_lane(kargs.k_batch * K1);
const index_t KRead = amd_wave_read_first_lane((kargs.K + K_t - 1) / K_t * K1);
constexpr auto K1 =
GemmPipeline::BlockGemmShape::WarpTile::at(I2); // smallest unit of K work per block
const index_t K_t = amd_wave_read_first_lane(
kargs.k_batch * K1); // amount of K elements consumed if every split-K batch
// performs exactly one "unit" (K1)
const index_t KRead = amd_wave_read_first_lane(
(kargs.K + K_t - 1) / K_t * K1); // total k elements to be read in this batch
// offset not necessarily = KRead, because B can have packed elements (e.g. fp8i4)
constexpr index_t BPackedSize =
ck_tile::numeric_traits<remove_cvref_t<BDataType>>::PackedSize;
const index_t b_k_offset_elements =
amd_wave_read_first_lane(k_id * KRead / BPackedSize);
if constexpr(std::is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
{
@@ -395,11 +404,11 @@ struct QuantGemmKernel
if constexpr(std::is_same_v<tensor_layout::gemm::RowMajor, BLayout>)
{
b_k_split_offset = amd_wave_read_first_lane(k_id * KRead * kargs.stride_B);
b_k_split_offset = amd_wave_read_first_lane(b_k_offset_elements * kargs.stride_B);
}
else if constexpr(std::is_same_v<tensor_layout::gemm::ColumnMajor, BLayout>)
{
b_k_split_offset = amd_wave_read_first_lane(k_id * KRead);
b_k_split_offset = amd_wave_read_first_lane(b_k_offset_elements);
}
if(k_id < static_cast<uint32_t>(kargs.k_batch - 1))
@@ -410,10 +419,47 @@ struct QuantGemmKernel
{
splitted_k = amd_wave_read_first_lane(kargs.K - KRead * (kargs.k_batch - 1));
}
// Compute BQ offset for BQuantGrouped mode (non-preshuffle only)
// Note: With the alignment validation in IsSupportedArgument, KRead is always
// a multiple of BQuantGroupSize::kK, so bq_k_split_offset will be correctly aligned.
if constexpr(kQuantType == QuantType::BQuantGrouped && !BPreshuffleQuant)
{
using BQuantGroupSize = remove_cvref_t<typename GemmPipeline::BQuantGroupSize>;
// Compute the K offset for this batch (in terms of K elements)
const index_t k_offset = amd_wave_read_first_lane(k_id * KRead);
// Convert K offset to BQ group offset (logical offset in K/kK dimension)
bq_group_offset = amd_wave_read_first_lane(k_offset / BQuantGroupSize::kK);
// BQ tensor layout:
// RowMajor: [K/kK, N/kN] with stride [N/kN, 1]
// ColumnMajor: [N/kN, K/kK] with stride [K/kK, 1]
if constexpr(std::is_same_v<tensor_layout::gemm::RowMajor, BQLayout>)
{
// For RowMajor BQ, K is the row dimension
// offset = bq_group_offset * stride_BQ
const index_t stride_bq =
amd_wave_read_first_lane(integer_divide_ceil(kargs.N, BQuantGroupSize::kN));
bq_k_split_offset = amd_wave_read_first_lane(bq_group_offset * stride_bq);
}
else if constexpr(std::is_same_v<tensor_layout::gemm::ColumnMajor, BQLayout>)
{
// For ColumnMajor BQ, K is the column dimension
// offset = bq_group_offset
bq_k_split_offset = amd_wave_read_first_lane(bq_group_offset);
}
}
else
{
bq_group_offset = 0;
bq_k_split_offset = 0;
}
}
index_t a_k_split_offset;
index_t b_k_split_offset;
index_t bq_group_offset; // Logical offset in K-groups (K/kK dimension)
index_t bq_k_split_offset; // Memory pointer offset (accounting for layout/stride)
index_t splitted_k;
};
@@ -805,10 +851,13 @@ struct QuantGemmKernel
CK_TILE_DEVICE static auto MakeBQBlockWindow(const BQDataType* bq_ptr,
const QuantGemmKernelArgs& kargs,
const index_t bq_group_offset,
const index_t i_m,
const index_t i_n)
{
// Step 1: Create tensor view for BQ
// Note: For split-K, the bq_ptr is already offset by bq_k_split_offset (pointer offset).
// The dimension should use the remaining K-groups from this offset position.
const auto& bq_tensor_view = [&]() {
if constexpr(kQuantType == QuantType::RowColQuant)
{
@@ -850,11 +899,12 @@ struct QuantGemmKernel
"ABQuantGrouped requires ColumnMajor BQ layout");
}
using BQuantGroupSize = remove_cvref_t<typename GemmPipeline::BQuantGroupSize>;
if constexpr(std::is_same_v<BQLayout, tensor_layout::gemm::RowMajor>)
{
return make_naive_tensor_view<address_space_enum::global>(
bq_ptr,
make_tuple(integer_divide_ceil(kargs.K, BQuantGroupSize::kK),
make_tuple(kargs.QK_B - bq_group_offset,
integer_divide_ceil(kargs.N, BQuantGroupSize::kN)),
make_tuple(integer_divide_ceil(kargs.N, BQuantGroupSize::kN), 1),
number<GemmPipeline::GetVectorSizeBQ()>{},
@@ -865,8 +915,8 @@ struct QuantGemmKernel
return make_naive_tensor_view<address_space_enum::global>(
bq_ptr,
make_tuple(integer_divide_ceil(kargs.N, BQuantGroupSize::kN),
integer_divide_ceil(kargs.K, BQuantGroupSize::kK)),
make_tuple(integer_divide_ceil(kargs.K, BQuantGroupSize::kK), 1),
kargs.QK_B - bq_group_offset),
make_tuple(kargs.QK_B, 1),
number<GemmPipeline::GetVectorSizeBQ()>{},
number<1>{});
}
@@ -1047,13 +1097,61 @@ struct QuantGemmKernel
CK_TILE_HOST static bool IsSupportedArgument(const QuantGemmKernelArgs& kargs)
{
// Split-K is supported for BQuantGrouped mode without preshuffle
if(kargs.k_batch != 1)
{
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
constexpr bool is_bquant_non_preshuffle =
(kQuantType == QuantType::BQuantGrouped) && !BPreshuffleQuant;
if constexpr(!is_bquant_non_preshuffle)
{
CK_TILE_ERROR("Conditions not met for Kbatch >1 !");
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
{
CK_TILE_ERROR("Conditions not met for Kbatch >1 ! "
"Split-K only supported for BQuantGrouped without preshuffle.");
}
return false;
}
else
{
using BQuantGroupSize = remove_cvref_t<typename GemmPipeline::BQuantGroupSize>;
constexpr auto K1 = GemmPipeline::BlockGemmShape::WarpTile::at(I2);
const index_t K_t = kargs.k_batch * K1;
const index_t KRead = (kargs.K + K_t - 1) / K_t * K1;
constexpr index_t BPackedSize =
ck_tile::numeric_traits<remove_cvref_t<BDataType>>::PackedSize;
// Constraint 1: KRead must align with B packing requirements.
// For packed data types, multiple K elements are stored in each storage unit.
// Split-K advances the B pointer by (KRead / BPackedSize) storage units per batch.
// If KRead is not divisible by BPackedSize, this division produces a fractional
// offset, making it impossible to start reading from a valid storage unit boundary.
if(KRead % BPackedSize != 0)
{
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
{
CK_TILE_ERROR("KRead must be a multiple of B packed size for split-K!");
}
return false;
}
// Constraint 2: KRead must align with quantization group boundaries.
// Each split-K batch reads KRead consecutive K elements. If KRead is not
// a multiple of BQuantGroupSize::kK, the batch will span partial quantization
// groups, requiring split access to a quantization scale. This violates the
// atomic processing requirement where each batch must work with complete groups.
if(KRead % BQuantGroupSize::kK != 0)
{
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
{
CK_TILE_ERROR("Split-K batch size must be aligned with quantization group "
"size! KRead=" +
std::to_string(KRead) +
" is not divisible by BQuantGroupSize::kK=" +
std::to_string(BQuantGroupSize::kK));
}
return false;
}
}
return false;
}
if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
@@ -1215,7 +1313,10 @@ struct QuantGemmKernel
const auto& b_block_window =
MakeBBlockWindow(b_ptr, kargs, splitk_batch_offset.splitted_k, block_idx_n);
const auto& aq_block_window = MakeAQBlockWindow(aq_ptr, kargs, block_idx_m, block_idx_n);
const auto& bq_block_window = MakeBQBlockWindow(bq_ptr, kargs, block_idx_m, block_idx_n);
// Note: Pass bq_group_offset so the tensor view dimension reflects
// the remaining K-groups from the split-K offset position.
const auto& bq_block_window = MakeBQBlockWindow(
bq_ptr, kargs, splitk_batch_offset.bq_group_offset, block_idx_m, block_idx_n);
const index_t num_loop =
amd_wave_read_first_lane(TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k));
@@ -1343,8 +1444,9 @@ struct QuantGemmKernel
const BDataType* b_ptr =
static_cast<const BDataType*>(kargs.b_ptr) + splitk_batch_offset.b_k_split_offset;
const AQDataType* aq_ptr = static_cast<const AQDataType*>(kargs.aq_ptr);
const BQDataType* bq_ptr = static_cast<const BQDataType*>(kargs.bq_ptr);
CDataType* c_ptr = static_cast<CDataType*>(kargs.c_ptr);
const BQDataType* bq_ptr =
static_cast<const BQDataType*>(kargs.bq_ptr) + splitk_batch_offset.bq_k_split_offset;
CDataType* c_ptr = static_cast<CDataType*>(kargs.c_ptr);
// allocate LDS
__shared__ char smem_ptr[GetSmemSize()];

View File

@@ -387,8 +387,8 @@ struct QuantGroupedGemmKernel
Base::MakeABlockWindow(a_ptr, kargs, splitk_batch_offset.splitted_k, block_idx_m);
const auto& b_block_window =
Base::MakeBBlockWindow(b_ptr, kargs, splitk_batch_offset.splitted_k, block_idx_n);
const auto& bq_block_window =
Base::MakeBQBlockWindow(bq_ptr, kargs, block_idx_m, block_idx_n);
const auto& bq_block_window = Base::MakeBQBlockWindow(
bq_ptr, kargs, splitk_batch_offset.bq_group_offset, block_idx_m, block_idx_n);
const index_t num_loop = __builtin_amdgcn_readfirstlane(
TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k));
@@ -453,8 +453,8 @@ struct QuantGroupedGemmKernel
Base::MakeBBlockWindow(b_ptr, kargs, splitk_batch_offset.splitted_k, block_idx_n);
const auto& aq_block_window =
Base::MakeAQBlockWindow(aq_ptr, kargs, block_idx_m, block_idx_n);
const auto& bq_block_window =
Base::MakeBQBlockWindow(bq_ptr, kargs, block_idx_m, block_idx_n);
const auto& bq_block_window = Base::MakeBQBlockWindow(
bq_ptr, kargs, splitk_batch_offset.bq_group_offset, block_idx_m, block_idx_n);
// Get hot-loop and tail configuration
const index_t num_loop = __builtin_amdgcn_readfirstlane(

View File

@@ -128,6 +128,17 @@ if(GPU_TARGETS MATCHES "gfx94|gfx95|gfx12")
)
target_compile_options(test_tile_gemm_quant_bquant_transpose PRIVATE ${TEST_GEMM_COMPILE_OPTIONS})
# BQuant split-K tests (no preshuffle)
add_gtest_executable(test_tile_gemm_quant_bquant_splitk_decode
test_gemm_quant_bquant_splitk_decode.cpp
)
target_compile_options(test_tile_gemm_quant_bquant_splitk_decode PRIVATE ${TEST_GEMM_COMPILE_OPTIONS})
add_gtest_executable(test_tile_gemm_quant_bquant_splitk_prefill
test_gemm_quant_bquant_splitk_prefill.cpp
)
target_compile_options(test_tile_gemm_quant_bquant_splitk_prefill PRIVATE ${TEST_GEMM_COMPILE_OPTIONS})
# BQuant tests (with PreshuffleB) - split into 5 files
add_gtest_executable(test_tile_gemm_quant_bquant_preshuffle_decode_1d
test_gemm_quant_bquant_preshuffle_decode_1d.cpp

View File

@@ -0,0 +1,61 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "ck_tile/host.hpp"
#include "ck_tile/ops/gemm.hpp"
#include <gtest/gtest.h>
#include <memory>
#include "test_gemm_quant_fixtures.hpp"
// Type aliases for readability
using RowMajor = ck_tile::tensor_layout::gemm::RowMajor;
using ColumnMajor = ck_tile::tensor_layout::gemm::ColumnMajor;
using FP8 = ck_tile::fp8_t;
using BF8 = ck_tile::bf8_t;
using Half = ck_tile::half_t;
using PkInt4 = ck_tile::pk_int4_t;
using BQuantGrouped = std::integral_constant<ck_tile::QuantType, ck_tile::QuantType::BQuantGrouped>;
using GroupSize128 = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
// Type combinations for BQuant split-K tests - Decode shape, GroupSize 128
// Tuple format: <ALayout, BLayout, CLayout, BQLayout, ADataType, BDataType, QDataType, CDataType,
// QuantType, GemmConfig, QuantGroupSize>
// clang-format off
using BQuantSplitKDecodeTypes = ::testing::Types<
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, FP8, FP8, float, Half, BQuantGrouped, GemmConfigDecode, GroupSize128>,
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, BF8, BF8, float, Half, BQuantGrouped, GemmConfigDecode, GroupSize128>,
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, FP8, PkInt4, FP8, Half, BQuantGrouped, GemmConfigDecode, GroupSize128>,
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, BF8, PkInt4, BF8, Half, BQuantGrouped, GemmConfigDecode, GroupSize128>
>;
// clang-format on
// Test suite for BQuant split-K Decode
TYPED_TEST_SUITE(TestCkTileGemmBQuant, BQuantSplitKDecodeTypes);
// BQuant split-K tests
TYPED_TEST(TestCkTileGemmBQuant, BQuantGroupedSplitK2Test)
{
// K=1024 for split_k=2: 1024/2=512=4×128 ✓
this->run_test_with_validation(32, 128, 1024, 2);
}
TYPED_TEST(TestCkTileGemmBQuant, BQuantGroupedSplitK3Test)
{
// K=3072 for split_k=3: 3072/3=1024=8×128 ✓
this->run_test_with_validation(32, 128, 3072, 3);
}
TYPED_TEST(TestCkTileGemmBQuant, BQuantGroupedSplitK4Test)
{
// K=2048 for split_k=4: 2048/4=512=4×128 ✓
this->run_test_with_validation(32, 128, 2048, 4);
}
TYPED_TEST(TestCkTileGemmBQuant, BQuantGroupedSplitK5Test)
{
// K=2560 for split_k=5: 2560/5=512=4×128 ✓
// Also K must be divisible by K_Tile(256)*split_k(5)=1280
this->run_test_with_validation(32, 128, 2560, 5);
}

View File

@@ -0,0 +1,64 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "ck_tile/host.hpp"
#include "ck_tile/ops/gemm.hpp"
#include <gtest/gtest.h>
#include <memory>
#include "test_gemm_quant_fixtures.hpp"
// Type aliases for readability
using RowMajor = ck_tile::tensor_layout::gemm::RowMajor;
using ColumnMajor = ck_tile::tensor_layout::gemm::ColumnMajor;
using FP8 = ck_tile::fp8_t;
using BF8 = ck_tile::bf8_t;
using Half = ck_tile::half_t;
using PkInt4 = ck_tile::pk_int4_t;
using BQuantGrouped = std::integral_constant<ck_tile::QuantType, ck_tile::QuantType::BQuantGrouped>;
using GroupSize128 = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
// Type combinations for BQuant split-K tests - Prefill shape, GroupSize 128
// Tuple format: <ALayout, BLayout, CLayout, BQLayout, ADataType, BDataType, QDataType, CDataType,
// QuantType, GemmConfig, QuantGroupSize>
// clang-format off
using BQuantSplitKPrefillTypes = ::testing::Types<
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, FP8, FP8, float, Half, BQuantGrouped, GemmConfigPrefill, GroupSize128>,
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, BF8, BF8, float, Half, BQuantGrouped, GemmConfigPrefill, GroupSize128>,
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, FP8, PkInt4, FP8, Half, BQuantGrouped, GemmConfigPrefill, GroupSize128>,
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, BF8, PkInt4, BF8, Half, BQuantGrouped, GemmConfigPrefill, GroupSize128>
>;
// clang-format on
// Test suite for BQuant split-K Prefill
TYPED_TEST_SUITE(TestCkTileGemmBQuant, BQuantSplitKPrefillTypes);
// BQuant split-K tests
TYPED_TEST(TestCkTileGemmBQuant, BQuantGroupedSplitK2Test)
{
// K=1024 for split_k=2: 1024/2=512=4×128 ✓
// K must be divisible by K_Tile(128)*split_k(2)=256
this->run_test_with_validation(128, 128, 1024, 2);
}
TYPED_TEST(TestCkTileGemmBQuant, BQuantGroupedSplitK3Test)
{
// K=3072 for split_k=3: 3072/3=1024=8×128 ✓
// K must be divisible by K_Tile(128)*split_k(3)=384
this->run_test_with_validation(128, 128, 3072, 3);
}
TYPED_TEST(TestCkTileGemmBQuant, BQuantGroupedSplitK4Test)
{
// K=2048 for split_k=4: 2048/4=512=4×128 ✓
// K must be divisible by K_Tile(128)*split_k(4)=512
this->run_test_with_validation(128, 128, 2048, 4);
}
TYPED_TEST(TestCkTileGemmBQuant, BQuantGroupedSplitK5Test)
{
// K=1920 for split_k=5: 1920/5=384=3×128 ✓
// K must be divisible by K_Tile(128)*split_k(5)=640
this->run_test_with_validation(128, 128, 1920, 5);
}

View File

@@ -655,7 +655,10 @@ class TestCkTileGemmBQuant : public TestCkTileGemmQuantBase<Tuple, TestCkTileGem
void SetUpQuantTypeSpecific() {}
void TearDownQuantTypeSpecific() {}
void run_test_with_validation(ck_tile::index_t M, ck_tile::index_t N, ck_tile::index_t K)
void run_test_with_validation(ck_tile::index_t M,
ck_tile::index_t N,
ck_tile::index_t K,
ck_tile::index_t k_batch = 1)
{
const ck_tile::index_t stride_A = K;
const ck_tile::index_t stride_B =
@@ -698,6 +701,9 @@ class TestCkTileGemmBQuant : public TestCkTileGemmQuantBase<Tuple, TestCkTileGem
sizeof(QDataType));
ck_tile::DeviceMem c_m_n_dev_buf(M * N * sizeof(CDataType));
// Zero C buffer - required for split-K atomic_add accumulation
c_m_n_dev_buf.SetZero();
// Copy to device
a_m_k_dev_buf.ToDevice(a_m_k.data());
ck_tile::HostTensor<BDataType> b_k_n_dev = b_k_n;
@@ -746,12 +752,12 @@ class TestCkTileGemmBQuant : public TestCkTileGemmQuantBase<Tuple, TestCkTileGem
c_m_n_dev_buf.GetDeviceBuffer(), // c_ptr
nullptr, // aq_ptr (not used for BQuant)
bq_bqk_bqn_dev_buf.GetDeviceBuffer(), // bq_ptr (scales)
1, // k_batch
k_batch, // k_batch (split-K)
M,
N,
K, // M, N, K
0, // QK_A (not used for BQuant)
BQK, // QK_B - TODO: we can remove BQK and BQN from args later?
BQK, // QK_B
stride_A,
stride_B,
stride_C,
@@ -796,7 +802,7 @@ class TestCkTileGemmBQuant : public TestCkTileGemmQuantBase<Tuple, TestCkTileGem
*std::max_element(c_m_n_host_ref.mData.begin(), c_m_n_host_ref.mData.end());
const auto rtol_atol =
this->template calculate_rtol_atol<ADataType, BDataType, AccDataType, CDataType>(
K, 1, max_accumulated_value);
K, k_batch, max_accumulated_value);
// Validate results
bool pass = ck_tile::check_err(c_m_n_dev_result,
@@ -806,7 +812,7 @@ class TestCkTileGemmBQuant : public TestCkTileGemmQuantBase<Tuple, TestCkTileGem
rtol_atol.at(ck_tile::number<1>{}));
EXPECT_TRUE(pass) << "BQuantGrouped validation failed with M=" << M << ", N=" << N
<< ", K=" << K;
<< ", K=" << K << ", k_batch=" << k_batch;
if(!pass)
{