mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 22:39:03 +00:00
feat: add split_k support for block scale gemm bquant mode. (#3653)
* WIP: add splitk to bquant * feat: add support for bf8i4 and fp8i4 by calculating correct stride for packed data types * chore: remove temporary test script * fix: incorrect tile window length for splitted bq tensor window * chore: improve comments * test: add unit tests to cover bquant splitk functionality * fix: conflict resolution by renaming variables
This commit is contained in:
@@ -4,7 +4,7 @@
|
||||
#include "run_gemm_quant_example.inc"
|
||||
|
||||
template <typename T>
|
||||
using GemmConfig = GemmConfigQuantPrefill<T>;
|
||||
using GemmConfig = GemmConfigQuantDecode<T>;
|
||||
|
||||
#define RUN_GEMM_EXAMPLE_PREC_TYPE \
|
||||
run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>, \
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
#include "run_gemm_quant_example.inc"
|
||||
|
||||
template <typename T>
|
||||
using GemmConfig = GemmConfigQuantPrefill<T>;
|
||||
using GemmConfig = GemmConfigQuantDecode<T>;
|
||||
|
||||
#define RUN_GEMM_EXAMPLE_PREC_TYPE \
|
||||
run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>, \
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
#include "run_gemm_quant_example.inc"
|
||||
|
||||
template <typename T>
|
||||
using GemmConfig = GemmConfigQuantPrefill<T>;
|
||||
using GemmConfig = GemmConfigQuantDecode<T>;
|
||||
|
||||
#define RUN_GEMM_EXAMPLE_PREC_TYPE \
|
||||
run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>, \
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
#include "run_gemm_quant_example.inc"
|
||||
|
||||
template <typename T>
|
||||
using GemmConfig = GemmConfigQuantPrefill<T>;
|
||||
using GemmConfig = GemmConfigQuantDecode<T>;
|
||||
|
||||
#define RUN_GEMM_EXAMPLE_PREC_TYPE \
|
||||
run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>, \
|
||||
|
||||
@@ -215,11 +215,8 @@ float gemm_calc_quant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::str
|
||||
const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch);
|
||||
const dim3 blocks = Kernel::BlockSize();
|
||||
|
||||
if(args.k_batch != 1)
|
||||
{
|
||||
throw std::runtime_error("split-k is not supported yet!");
|
||||
}
|
||||
|
||||
// Split-K validation is handled by Kernel::IsSupportedArgument
|
||||
// Split-K is only supported for BQuantGrouped without preshuffle
|
||||
if(!Kernel::IsSupportedArgument(kargs))
|
||||
{
|
||||
throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n");
|
||||
@@ -661,182 +658,6 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser,
|
||||
}
|
||||
}
|
||||
}
|
||||
else if(init_method == 3)
|
||||
{
|
||||
if constexpr(QuantMode == ck_tile::QuantType::BQuantGrouped)
|
||||
{
|
||||
ck_tile::FillConstant<ADataType>{static_cast<ADataType>(0x38)}(a_m_k);
|
||||
ck_tile::FillConstant<BDataType>{static_cast<BDataType>(0x22)}(b_k_n);
|
||||
ck_tile::FillConstant<BQDataType>{static_cast<BQDataType>(0.5f)}(*bq_tensor_ptr);
|
||||
}
|
||||
else if constexpr(QuantMode == ck_tile::QuantType::ABQuantGrouped)
|
||||
{
|
||||
ck_tile::FillConstant<ADataType>{static_cast<ADataType>(0x38)}(a_m_k);
|
||||
ck_tile::FillConstant<BDataType>{static_cast<BDataType>(0x22)}(b_k_n);
|
||||
ck_tile::FillConstant<AQDataType>{static_cast<AQDataType>(0.5f)}(*aq_tensor_ptr);
|
||||
ck_tile::FillConstant<BQDataType>{static_cast<BQDataType>(0.5f)}(*bq_tensor_ptr);
|
||||
}
|
||||
else
|
||||
{
|
||||
ck_tile::FillConstant<ADataType>{static_cast<ADataType>(0x22)}(a_m_k);
|
||||
ck_tile::FillConstant<AQDataType>{static_cast<AQDataType>(2.0f)}(*aq_tensor_ptr);
|
||||
ck_tile::FillConstant<BDataType>{static_cast<BDataType>(0x38)}(b_k_n);
|
||||
|
||||
if constexpr(QuantMode == ck_tile::QuantType::RowColQuant)
|
||||
{
|
||||
ck_tile::FillConstant<BQDataType>{static_cast<BQDataType>(0.5f)}(*bq_tensor_ptr);
|
||||
}
|
||||
}
|
||||
}
|
||||
else if(init_method == 4)
|
||||
{
|
||||
if constexpr(QuantMode == ck_tile::QuantType::BQuantGrouped)
|
||||
{
|
||||
if constexpr(std::is_same_v<BDataType, ck_tile::pk_int4_t>)
|
||||
{
|
||||
ck_tile::FillUniformDistribution<ck_tile::pk_int4_t>{-5.0f, 5.0f, fill_seed(gen)}(
|
||||
b_k_n);
|
||||
ck_tile::FillUniformDistribution<BQDataType>{-2.0f, 2.0f, fill_seed(gen)}(
|
||||
*bq_tensor_ptr);
|
||||
}
|
||||
else if constexpr(std::is_same_v<BDataType, ck_tile::pk_fp4_raw_t>)
|
||||
{
|
||||
ck_tile::FillUniformDistribution<BDataType>{-5.0f, 5.0f, fill_seed(gen)}(b_k_n);
|
||||
ck_tile::FillUniformDistribution<BQDataType>{125.f, 130.f, fill_seed(gen)}(
|
||||
*bq_tensor_ptr);
|
||||
}
|
||||
else
|
||||
{
|
||||
ck_tile::FillUniformDistribution<BDataType>{-2.0f, 3.0f, fill_seed(gen)}(b_k_n);
|
||||
ck_tile::FillUniformDistribution<BQDataType>{-2.0f, 2.0f, fill_seed(gen)}(
|
||||
*bq_tensor_ptr);
|
||||
}
|
||||
|
||||
ck_tile::FillUniformDistribution<ADataType>{-5.0f, 5.0f, fill_seed(gen)}(a_m_k);
|
||||
}
|
||||
else if constexpr(QuantMode == ck_tile::QuantType::AQuantGrouped)
|
||||
{
|
||||
if constexpr(std::is_same_v<ADataType, ck_tile::pk_int4_t>)
|
||||
{
|
||||
ck_tile::FillUniformDistribution<ck_tile::pk_int4_t>{-5.0f, 5.0f, fill_seed(gen)}(
|
||||
a_m_k);
|
||||
}
|
||||
else
|
||||
{
|
||||
ck_tile::FillUniformDistribution<ADataType>{-2.0f, 3.0f, fill_seed(gen)}(a_m_k);
|
||||
}
|
||||
ck_tile::FillUniformDistribution<AQDataType>{2.0f, 2.0f, fill_seed(gen)}(
|
||||
*aq_tensor_ptr);
|
||||
ck_tile::FillUniformDistribution<BDataType>{-5.0f, 5.0f, fill_seed(gen)}(b_k_n);
|
||||
}
|
||||
else if constexpr(QuantMode == ck_tile::QuantType::ABQuantGrouped)
|
||||
{
|
||||
if constexpr(std::is_same_v<ADataType, ck_tile::pk_int4_t> ||
|
||||
std::is_same_v<ADataType, ck_tile::pk_fp4_t>)
|
||||
{
|
||||
ck_tile::FillUniformDistribution<ADataType>{-5.0f, 5.0f, fill_seed(gen)}(a_m_k);
|
||||
ck_tile::FillUniformDistribution<BDataType>{-5.0f, 5.0f, fill_seed(gen)}(b_k_n);
|
||||
}
|
||||
else
|
||||
{
|
||||
ck_tile::FillUniformDistribution<ADataType>{-2.0f, 3.0f, fill_seed(gen)}(a_m_k);
|
||||
ck_tile::FillUniformDistribution<BDataType>{-2.0f, 3.0f, fill_seed(gen)}(b_k_n);
|
||||
}
|
||||
ck_tile::FillUniformDistribution<AQDataType>{-2.0f, 2.0f, fill_seed(gen)}(
|
||||
*aq_tensor_ptr);
|
||||
ck_tile::FillUniformDistribution<BQDataType>{-2.0f, 2.0f, fill_seed(gen)}(
|
||||
*bq_tensor_ptr);
|
||||
}
|
||||
else
|
||||
{
|
||||
ck_tile::FillUniformDistribution<ADataType>{-2.0f, 2.0f, fill_seed(gen)}(a_m_k);
|
||||
ck_tile::FillUniformDistribution<BDataType>{-2.0f, 2.0f, fill_seed(gen)}(b_k_n);
|
||||
ck_tile::FillUniformDistribution<AQDataType>{-2.0f, 2.0f, fill_seed(gen)}(
|
||||
*aq_tensor_ptr);
|
||||
ck_tile::FillUniformDistribution<BQDataType>{-2.0f, 2.0f, fill_seed(gen)}(
|
||||
*bq_tensor_ptr);
|
||||
}
|
||||
}
|
||||
else if(init_method == 5)
|
||||
{
|
||||
if constexpr(QuantMode == ck_tile::QuantType::BQuantGrouped)
|
||||
{
|
||||
if constexpr(std::is_same_v<BDataType, ck_tile::pk_int4_t>)
|
||||
{
|
||||
ck_tile::FillUniformDistribution<ck_tile::pk_int4_t>{-5.0f, 5.0f, fill_seed(gen)}(
|
||||
b_k_n);
|
||||
ck_tile::FillUniformDistribution<BQDataType>{-2.0f, 2.0f, fill_seed(gen)}(
|
||||
*bq_tensor_ptr);
|
||||
}
|
||||
else if constexpr(std::is_same_v<BDataType, ck_tile::pk_fp4_raw_t>)
|
||||
{
|
||||
ck_tile::FillUniformDistribution<BDataType>{-5.0f, 5.0f, fill_seed(gen)}(b_k_n);
|
||||
ck_tile::FillUniformDistribution<BQDataType>{125.f, 130.f, fill_seed(gen)}(
|
||||
*bq_tensor_ptr);
|
||||
}
|
||||
else
|
||||
{
|
||||
ck_tile::FillUniformDistribution<BDataType>{-2.0f, 3.0f, fill_seed(gen)}(b_k_n);
|
||||
ck_tile::FillUniformDistribution<BQDataType>{-2.0f, 2.0f, fill_seed(gen)}(
|
||||
*bq_tensor_ptr);
|
||||
}
|
||||
|
||||
ck_tile::FillUniformDistribution<ADataType>{-5.0f, 5.0f, fill_seed(gen)}(a_m_k);
|
||||
}
|
||||
else if constexpr(QuantMode == ck_tile::QuantType::AQuantGrouped)
|
||||
{
|
||||
if constexpr(std::is_same_v<ADataType, ck_tile::pk_int4_t>)
|
||||
{
|
||||
ck_tile::FillUniformDistribution<ck_tile::pk_int4_t>{-5.0f, 5.0f, fill_seed(gen)}(
|
||||
a_m_k);
|
||||
}
|
||||
else
|
||||
{
|
||||
ck_tile::FillUniformDistribution<ADataType>{1.0f, 1.0f, fill_seed(gen)}(a_m_k);
|
||||
}
|
||||
// Fill aquant such that column j has value 2^j (1, 2, 4, 8, ...)
|
||||
for(ck_tile::index_t row = 0;
|
||||
row < static_cast<ck_tile::index_t>(aq_tensor_ptr->get_length(0));
|
||||
++row)
|
||||
{
|
||||
for(ck_tile::index_t col = 0;
|
||||
col < static_cast<ck_tile::index_t>(aq_tensor_ptr->get_length(1));
|
||||
++col)
|
||||
{
|
||||
(*aq_tensor_ptr)(row, col) = static_cast<AQDataType>(col + 1);
|
||||
}
|
||||
}
|
||||
// std::cout << "aq_tensor_ptr: " << *aq_tensor_ptr << std::endl;
|
||||
ck_tile::FillUniformDistribution<BDataType>{1.0f, 1.0f, fill_seed(gen)}(b_k_n);
|
||||
}
|
||||
else if constexpr(QuantMode == ck_tile::QuantType::ABQuantGrouped)
|
||||
{
|
||||
if constexpr(std::is_same_v<ADataType, ck_tile::pk_int4_t> ||
|
||||
std::is_same_v<ADataType, ck_tile::pk_fp4_t>)
|
||||
{
|
||||
ck_tile::FillUniformDistribution<ADataType>{-5.0f, 5.0f, fill_seed(gen)}(a_m_k);
|
||||
ck_tile::FillUniformDistribution<BDataType>{-5.0f, 5.0f, fill_seed(gen)}(b_k_n);
|
||||
}
|
||||
else
|
||||
{
|
||||
ck_tile::FillUniformDistribution<ADataType>{-2.0f, 3.0f, fill_seed(gen)}(a_m_k);
|
||||
ck_tile::FillUniformDistribution<BDataType>{-2.0f, 3.0f, fill_seed(gen)}(b_k_n);
|
||||
}
|
||||
ck_tile::FillUniformDistribution<AQDataType>{-2.0f, 2.0f, fill_seed(gen)}(
|
||||
*aq_tensor_ptr);
|
||||
ck_tile::FillUniformDistribution<BQDataType>{-2.0f, 2.0f, fill_seed(gen)}(
|
||||
*bq_tensor_ptr);
|
||||
}
|
||||
else
|
||||
{
|
||||
ck_tile::FillUniformDistribution<ADataType>{-2.0f, 2.0f, fill_seed(gen)}(a_m_k);
|
||||
ck_tile::FillUniformDistribution<BDataType>{-2.0f, 2.0f, fill_seed(gen)}(b_k_n);
|
||||
ck_tile::FillUniformDistribution<AQDataType>{-2.0f, 2.0f, fill_seed(gen)}(
|
||||
*aq_tensor_ptr);
|
||||
ck_tile::FillUniformDistribution<BQDataType>{-2.0f, 2.0f, fill_seed(gen)}(
|
||||
*bq_tensor_ptr);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
a_m_k.SetZero();
|
||||
|
||||
@@ -380,9 +380,18 @@ struct QuantGemmKernel
|
||||
__device__ SplitKBatchOffset(const QuantGemmKernelArgs& kargs,
|
||||
const std::size_t k_id = blockIdx.z)
|
||||
{
|
||||
constexpr auto K1 = GemmPipeline::BlockGemmShape::WarpTile::at(I2);
|
||||
const index_t K_t = amd_wave_read_first_lane(kargs.k_batch * K1);
|
||||
const index_t KRead = amd_wave_read_first_lane((kargs.K + K_t - 1) / K_t * K1);
|
||||
constexpr auto K1 =
|
||||
GemmPipeline::BlockGemmShape::WarpTile::at(I2); // smallest unit of K work per block
|
||||
const index_t K_t = amd_wave_read_first_lane(
|
||||
kargs.k_batch * K1); // amount of K elements consumed if every split-K batch
|
||||
// performs exactly one "unit" (K1)
|
||||
const index_t KRead = amd_wave_read_first_lane(
|
||||
(kargs.K + K_t - 1) / K_t * K1); // total k elements to be read in this batch
|
||||
// offset not necessarily = KRead, because B can have packed elements (e.g. fp8i4)
|
||||
constexpr index_t BPackedSize =
|
||||
ck_tile::numeric_traits<remove_cvref_t<BDataType>>::PackedSize;
|
||||
const index_t b_k_offset_elements =
|
||||
amd_wave_read_first_lane(k_id * KRead / BPackedSize);
|
||||
|
||||
if constexpr(std::is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
|
||||
{
|
||||
@@ -395,11 +404,11 @@ struct QuantGemmKernel
|
||||
|
||||
if constexpr(std::is_same_v<tensor_layout::gemm::RowMajor, BLayout>)
|
||||
{
|
||||
b_k_split_offset = amd_wave_read_first_lane(k_id * KRead * kargs.stride_B);
|
||||
b_k_split_offset = amd_wave_read_first_lane(b_k_offset_elements * kargs.stride_B);
|
||||
}
|
||||
else if constexpr(std::is_same_v<tensor_layout::gemm::ColumnMajor, BLayout>)
|
||||
{
|
||||
b_k_split_offset = amd_wave_read_first_lane(k_id * KRead);
|
||||
b_k_split_offset = amd_wave_read_first_lane(b_k_offset_elements);
|
||||
}
|
||||
|
||||
if(k_id < static_cast<uint32_t>(kargs.k_batch - 1))
|
||||
@@ -410,10 +419,47 @@ struct QuantGemmKernel
|
||||
{
|
||||
splitted_k = amd_wave_read_first_lane(kargs.K - KRead * (kargs.k_batch - 1));
|
||||
}
|
||||
|
||||
// Compute BQ offset for BQuantGrouped mode (non-preshuffle only)
|
||||
// Note: With the alignment validation in IsSupportedArgument, KRead is always
|
||||
// a multiple of BQuantGroupSize::kK, so bq_k_split_offset will be correctly aligned.
|
||||
if constexpr(kQuantType == QuantType::BQuantGrouped && !BPreshuffleQuant)
|
||||
{
|
||||
using BQuantGroupSize = remove_cvref_t<typename GemmPipeline::BQuantGroupSize>;
|
||||
// Compute the K offset for this batch (in terms of K elements)
|
||||
const index_t k_offset = amd_wave_read_first_lane(k_id * KRead);
|
||||
// Convert K offset to BQ group offset (logical offset in K/kK dimension)
|
||||
bq_group_offset = amd_wave_read_first_lane(k_offset / BQuantGroupSize::kK);
|
||||
|
||||
// BQ tensor layout:
|
||||
// RowMajor: [K/kK, N/kN] with stride [N/kN, 1]
|
||||
// ColumnMajor: [N/kN, K/kK] with stride [K/kK, 1]
|
||||
if constexpr(std::is_same_v<tensor_layout::gemm::RowMajor, BQLayout>)
|
||||
{
|
||||
// For RowMajor BQ, K is the row dimension
|
||||
// offset = bq_group_offset * stride_BQ
|
||||
const index_t stride_bq =
|
||||
amd_wave_read_first_lane(integer_divide_ceil(kargs.N, BQuantGroupSize::kN));
|
||||
bq_k_split_offset = amd_wave_read_first_lane(bq_group_offset * stride_bq);
|
||||
}
|
||||
else if constexpr(std::is_same_v<tensor_layout::gemm::ColumnMajor, BQLayout>)
|
||||
{
|
||||
// For ColumnMajor BQ, K is the column dimension
|
||||
// offset = bq_group_offset
|
||||
bq_k_split_offset = amd_wave_read_first_lane(bq_group_offset);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
bq_group_offset = 0;
|
||||
bq_k_split_offset = 0;
|
||||
}
|
||||
}
|
||||
|
||||
index_t a_k_split_offset;
|
||||
index_t b_k_split_offset;
|
||||
index_t bq_group_offset; // Logical offset in K-groups (K/kK dimension)
|
||||
index_t bq_k_split_offset; // Memory pointer offset (accounting for layout/stride)
|
||||
index_t splitted_k;
|
||||
};
|
||||
|
||||
@@ -805,10 +851,13 @@ struct QuantGemmKernel
|
||||
|
||||
CK_TILE_DEVICE static auto MakeBQBlockWindow(const BQDataType* bq_ptr,
|
||||
const QuantGemmKernelArgs& kargs,
|
||||
const index_t bq_group_offset,
|
||||
const index_t i_m,
|
||||
const index_t i_n)
|
||||
{
|
||||
// Step 1: Create tensor view for BQ
|
||||
// Note: For split-K, the bq_ptr is already offset by bq_k_split_offset (pointer offset).
|
||||
// The dimension should use the remaining K-groups from this offset position.
|
||||
const auto& bq_tensor_view = [&]() {
|
||||
if constexpr(kQuantType == QuantType::RowColQuant)
|
||||
{
|
||||
@@ -850,11 +899,12 @@ struct QuantGemmKernel
|
||||
"ABQuantGrouped requires ColumnMajor BQ layout");
|
||||
}
|
||||
|
||||
using BQuantGroupSize = remove_cvref_t<typename GemmPipeline::BQuantGroupSize>;
|
||||
if constexpr(std::is_same_v<BQLayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return make_naive_tensor_view<address_space_enum::global>(
|
||||
bq_ptr,
|
||||
make_tuple(integer_divide_ceil(kargs.K, BQuantGroupSize::kK),
|
||||
make_tuple(kargs.QK_B - bq_group_offset,
|
||||
integer_divide_ceil(kargs.N, BQuantGroupSize::kN)),
|
||||
make_tuple(integer_divide_ceil(kargs.N, BQuantGroupSize::kN), 1),
|
||||
number<GemmPipeline::GetVectorSizeBQ()>{},
|
||||
@@ -865,8 +915,8 @@ struct QuantGemmKernel
|
||||
return make_naive_tensor_view<address_space_enum::global>(
|
||||
bq_ptr,
|
||||
make_tuple(integer_divide_ceil(kargs.N, BQuantGroupSize::kN),
|
||||
integer_divide_ceil(kargs.K, BQuantGroupSize::kK)),
|
||||
make_tuple(integer_divide_ceil(kargs.K, BQuantGroupSize::kK), 1),
|
||||
kargs.QK_B - bq_group_offset),
|
||||
make_tuple(kargs.QK_B, 1),
|
||||
number<GemmPipeline::GetVectorSizeBQ()>{},
|
||||
number<1>{});
|
||||
}
|
||||
@@ -1047,13 +1097,61 @@ struct QuantGemmKernel
|
||||
|
||||
CK_TILE_HOST static bool IsSupportedArgument(const QuantGemmKernelArgs& kargs)
|
||||
{
|
||||
// Split-K is supported for BQuantGrouped mode without preshuffle
|
||||
if(kargs.k_batch != 1)
|
||||
{
|
||||
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
|
||||
constexpr bool is_bquant_non_preshuffle =
|
||||
(kQuantType == QuantType::BQuantGrouped) && !BPreshuffleQuant;
|
||||
if constexpr(!is_bquant_non_preshuffle)
|
||||
{
|
||||
CK_TILE_ERROR("Conditions not met for Kbatch >1 !");
|
||||
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
|
||||
{
|
||||
CK_TILE_ERROR("Conditions not met for Kbatch >1 ! "
|
||||
"Split-K only supported for BQuantGrouped without preshuffle.");
|
||||
}
|
||||
return false;
|
||||
}
|
||||
else
|
||||
{
|
||||
using BQuantGroupSize = remove_cvref_t<typename GemmPipeline::BQuantGroupSize>;
|
||||
constexpr auto K1 = GemmPipeline::BlockGemmShape::WarpTile::at(I2);
|
||||
const index_t K_t = kargs.k_batch * K1;
|
||||
const index_t KRead = (kargs.K + K_t - 1) / K_t * K1;
|
||||
constexpr index_t BPackedSize =
|
||||
ck_tile::numeric_traits<remove_cvref_t<BDataType>>::PackedSize;
|
||||
|
||||
// Constraint 1: KRead must align with B packing requirements.
|
||||
// For packed data types, multiple K elements are stored in each storage unit.
|
||||
// Split-K advances the B pointer by (KRead / BPackedSize) storage units per batch.
|
||||
// If KRead is not divisible by BPackedSize, this division produces a fractional
|
||||
// offset, making it impossible to start reading from a valid storage unit boundary.
|
||||
if(KRead % BPackedSize != 0)
|
||||
{
|
||||
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
|
||||
{
|
||||
CK_TILE_ERROR("KRead must be a multiple of B packed size for split-K!");
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
// Constraint 2: KRead must align with quantization group boundaries.
|
||||
// Each split-K batch reads KRead consecutive K elements. If KRead is not
|
||||
// a multiple of BQuantGroupSize::kK, the batch will span partial quantization
|
||||
// groups, requiring split access to a quantization scale. This violates the
|
||||
// atomic processing requirement where each batch must work with complete groups.
|
||||
if(KRead % BQuantGroupSize::kK != 0)
|
||||
{
|
||||
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
|
||||
{
|
||||
CK_TILE_ERROR("Split-K batch size must be aligned with quantization group "
|
||||
"size! KRead=" +
|
||||
std::to_string(KRead) +
|
||||
" is not divisible by BQuantGroupSize::kK=" +
|
||||
std::to_string(BQuantGroupSize::kK));
|
||||
}
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
|
||||
@@ -1215,7 +1313,10 @@ struct QuantGemmKernel
|
||||
const auto& b_block_window =
|
||||
MakeBBlockWindow(b_ptr, kargs, splitk_batch_offset.splitted_k, block_idx_n);
|
||||
const auto& aq_block_window = MakeAQBlockWindow(aq_ptr, kargs, block_idx_m, block_idx_n);
|
||||
const auto& bq_block_window = MakeBQBlockWindow(bq_ptr, kargs, block_idx_m, block_idx_n);
|
||||
// Note: Pass bq_group_offset so the tensor view dimension reflects
|
||||
// the remaining K-groups from the split-K offset position.
|
||||
const auto& bq_block_window = MakeBQBlockWindow(
|
||||
bq_ptr, kargs, splitk_batch_offset.bq_group_offset, block_idx_m, block_idx_n);
|
||||
|
||||
const index_t num_loop =
|
||||
amd_wave_read_first_lane(TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k));
|
||||
@@ -1343,8 +1444,9 @@ struct QuantGemmKernel
|
||||
const BDataType* b_ptr =
|
||||
static_cast<const BDataType*>(kargs.b_ptr) + splitk_batch_offset.b_k_split_offset;
|
||||
const AQDataType* aq_ptr = static_cast<const AQDataType*>(kargs.aq_ptr);
|
||||
const BQDataType* bq_ptr = static_cast<const BQDataType*>(kargs.bq_ptr);
|
||||
CDataType* c_ptr = static_cast<CDataType*>(kargs.c_ptr);
|
||||
const BQDataType* bq_ptr =
|
||||
static_cast<const BQDataType*>(kargs.bq_ptr) + splitk_batch_offset.bq_k_split_offset;
|
||||
CDataType* c_ptr = static_cast<CDataType*>(kargs.c_ptr);
|
||||
|
||||
// allocate LDS
|
||||
__shared__ char smem_ptr[GetSmemSize()];
|
||||
|
||||
@@ -387,8 +387,8 @@ struct QuantGroupedGemmKernel
|
||||
Base::MakeABlockWindow(a_ptr, kargs, splitk_batch_offset.splitted_k, block_idx_m);
|
||||
const auto& b_block_window =
|
||||
Base::MakeBBlockWindow(b_ptr, kargs, splitk_batch_offset.splitted_k, block_idx_n);
|
||||
const auto& bq_block_window =
|
||||
Base::MakeBQBlockWindow(bq_ptr, kargs, block_idx_m, block_idx_n);
|
||||
const auto& bq_block_window = Base::MakeBQBlockWindow(
|
||||
bq_ptr, kargs, splitk_batch_offset.bq_group_offset, block_idx_m, block_idx_n);
|
||||
|
||||
const index_t num_loop = __builtin_amdgcn_readfirstlane(
|
||||
TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k));
|
||||
@@ -453,8 +453,8 @@ struct QuantGroupedGemmKernel
|
||||
Base::MakeBBlockWindow(b_ptr, kargs, splitk_batch_offset.splitted_k, block_idx_n);
|
||||
const auto& aq_block_window =
|
||||
Base::MakeAQBlockWindow(aq_ptr, kargs, block_idx_m, block_idx_n);
|
||||
const auto& bq_block_window =
|
||||
Base::MakeBQBlockWindow(bq_ptr, kargs, block_idx_m, block_idx_n);
|
||||
const auto& bq_block_window = Base::MakeBQBlockWindow(
|
||||
bq_ptr, kargs, splitk_batch_offset.bq_group_offset, block_idx_m, block_idx_n);
|
||||
|
||||
// Get hot-loop and tail configuration
|
||||
const index_t num_loop = __builtin_amdgcn_readfirstlane(
|
||||
|
||||
@@ -128,6 +128,17 @@ if(GPU_TARGETS MATCHES "gfx94|gfx95|gfx12")
|
||||
)
|
||||
target_compile_options(test_tile_gemm_quant_bquant_transpose PRIVATE ${TEST_GEMM_COMPILE_OPTIONS})
|
||||
|
||||
# BQuant split-K tests (no preshuffle)
|
||||
add_gtest_executable(test_tile_gemm_quant_bquant_splitk_decode
|
||||
test_gemm_quant_bquant_splitk_decode.cpp
|
||||
)
|
||||
target_compile_options(test_tile_gemm_quant_bquant_splitk_decode PRIVATE ${TEST_GEMM_COMPILE_OPTIONS})
|
||||
|
||||
add_gtest_executable(test_tile_gemm_quant_bquant_splitk_prefill
|
||||
test_gemm_quant_bquant_splitk_prefill.cpp
|
||||
)
|
||||
target_compile_options(test_tile_gemm_quant_bquant_splitk_prefill PRIVATE ${TEST_GEMM_COMPILE_OPTIONS})
|
||||
|
||||
# BQuant tests (with PreshuffleB) - split into 5 files
|
||||
add_gtest_executable(test_tile_gemm_quant_bquant_preshuffle_decode_1d
|
||||
test_gemm_quant_bquant_preshuffle_decode_1d.cpp
|
||||
|
||||
@@ -0,0 +1,61 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "ck_tile/ops/gemm.hpp"
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
#include <memory>
|
||||
|
||||
#include "test_gemm_quant_fixtures.hpp"
|
||||
|
||||
// Type aliases for readability
|
||||
using RowMajor = ck_tile::tensor_layout::gemm::RowMajor;
|
||||
using ColumnMajor = ck_tile::tensor_layout::gemm::ColumnMajor;
|
||||
using FP8 = ck_tile::fp8_t;
|
||||
using BF8 = ck_tile::bf8_t;
|
||||
using Half = ck_tile::half_t;
|
||||
using PkInt4 = ck_tile::pk_int4_t;
|
||||
using BQuantGrouped = std::integral_constant<ck_tile::QuantType, ck_tile::QuantType::BQuantGrouped>;
|
||||
using GroupSize128 = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
|
||||
|
||||
// Type combinations for BQuant split-K tests - Decode shape, GroupSize 128
|
||||
// Tuple format: <ALayout, BLayout, CLayout, BQLayout, ADataType, BDataType, QDataType, CDataType,
|
||||
// QuantType, GemmConfig, QuantGroupSize>
|
||||
// clang-format off
|
||||
using BQuantSplitKDecodeTypes = ::testing::Types<
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, FP8, FP8, float, Half, BQuantGrouped, GemmConfigDecode, GroupSize128>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, BF8, BF8, float, Half, BQuantGrouped, GemmConfigDecode, GroupSize128>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, FP8, PkInt4, FP8, Half, BQuantGrouped, GemmConfigDecode, GroupSize128>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, BF8, PkInt4, BF8, Half, BQuantGrouped, GemmConfigDecode, GroupSize128>
|
||||
>;
|
||||
// clang-format on
|
||||
|
||||
// Test suite for BQuant split-K Decode
|
||||
TYPED_TEST_SUITE(TestCkTileGemmBQuant, BQuantSplitKDecodeTypes);
|
||||
|
||||
// BQuant split-K tests
|
||||
TYPED_TEST(TestCkTileGemmBQuant, BQuantGroupedSplitK2Test)
|
||||
{
|
||||
// K=1024 for split_k=2: 1024/2=512=4×128 ✓
|
||||
this->run_test_with_validation(32, 128, 1024, 2);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestCkTileGemmBQuant, BQuantGroupedSplitK3Test)
|
||||
{
|
||||
// K=3072 for split_k=3: 3072/3=1024=8×128 ✓
|
||||
this->run_test_with_validation(32, 128, 3072, 3);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestCkTileGemmBQuant, BQuantGroupedSplitK4Test)
|
||||
{
|
||||
// K=2048 for split_k=4: 2048/4=512=4×128 ✓
|
||||
this->run_test_with_validation(32, 128, 2048, 4);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestCkTileGemmBQuant, BQuantGroupedSplitK5Test)
|
||||
{
|
||||
// K=2560 for split_k=5: 2560/5=512=4×128 ✓
|
||||
// Also K must be divisible by K_Tile(256)*split_k(5)=1280
|
||||
this->run_test_with_validation(32, 128, 2560, 5);
|
||||
}
|
||||
@@ -0,0 +1,64 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "ck_tile/ops/gemm.hpp"
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
#include <memory>
|
||||
|
||||
#include "test_gemm_quant_fixtures.hpp"
|
||||
|
||||
// Type aliases for readability
|
||||
using RowMajor = ck_tile::tensor_layout::gemm::RowMajor;
|
||||
using ColumnMajor = ck_tile::tensor_layout::gemm::ColumnMajor;
|
||||
using FP8 = ck_tile::fp8_t;
|
||||
using BF8 = ck_tile::bf8_t;
|
||||
using Half = ck_tile::half_t;
|
||||
using PkInt4 = ck_tile::pk_int4_t;
|
||||
using BQuantGrouped = std::integral_constant<ck_tile::QuantType, ck_tile::QuantType::BQuantGrouped>;
|
||||
using GroupSize128 = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
|
||||
|
||||
// Type combinations for BQuant split-K tests - Prefill shape, GroupSize 128
|
||||
// Tuple format: <ALayout, BLayout, CLayout, BQLayout, ADataType, BDataType, QDataType, CDataType,
|
||||
// QuantType, GemmConfig, QuantGroupSize>
|
||||
// clang-format off
|
||||
using BQuantSplitKPrefillTypes = ::testing::Types<
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, FP8, FP8, float, Half, BQuantGrouped, GemmConfigPrefill, GroupSize128>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, BF8, BF8, float, Half, BQuantGrouped, GemmConfigPrefill, GroupSize128>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, FP8, PkInt4, FP8, Half, BQuantGrouped, GemmConfigPrefill, GroupSize128>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, BF8, PkInt4, BF8, Half, BQuantGrouped, GemmConfigPrefill, GroupSize128>
|
||||
>;
|
||||
// clang-format on
|
||||
|
||||
// Test suite for BQuant split-K Prefill
|
||||
TYPED_TEST_SUITE(TestCkTileGemmBQuant, BQuantSplitKPrefillTypes);
|
||||
|
||||
// BQuant split-K tests
|
||||
TYPED_TEST(TestCkTileGemmBQuant, BQuantGroupedSplitK2Test)
|
||||
{
|
||||
// K=1024 for split_k=2: 1024/2=512=4×128 ✓
|
||||
// K must be divisible by K_Tile(128)*split_k(2)=256
|
||||
this->run_test_with_validation(128, 128, 1024, 2);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestCkTileGemmBQuant, BQuantGroupedSplitK3Test)
|
||||
{
|
||||
// K=3072 for split_k=3: 3072/3=1024=8×128 ✓
|
||||
// K must be divisible by K_Tile(128)*split_k(3)=384
|
||||
this->run_test_with_validation(128, 128, 3072, 3);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestCkTileGemmBQuant, BQuantGroupedSplitK4Test)
|
||||
{
|
||||
// K=2048 for split_k=4: 2048/4=512=4×128 ✓
|
||||
// K must be divisible by K_Tile(128)*split_k(4)=512
|
||||
this->run_test_with_validation(128, 128, 2048, 4);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestCkTileGemmBQuant, BQuantGroupedSplitK5Test)
|
||||
{
|
||||
// K=1920 for split_k=5: 1920/5=384=3×128 ✓
|
||||
// K must be divisible by K_Tile(128)*split_k(5)=640
|
||||
this->run_test_with_validation(128, 128, 1920, 5);
|
||||
}
|
||||
@@ -655,7 +655,10 @@ class TestCkTileGemmBQuant : public TestCkTileGemmQuantBase<Tuple, TestCkTileGem
|
||||
void SetUpQuantTypeSpecific() {}
|
||||
void TearDownQuantTypeSpecific() {}
|
||||
|
||||
void run_test_with_validation(ck_tile::index_t M, ck_tile::index_t N, ck_tile::index_t K)
|
||||
void run_test_with_validation(ck_tile::index_t M,
|
||||
ck_tile::index_t N,
|
||||
ck_tile::index_t K,
|
||||
ck_tile::index_t k_batch = 1)
|
||||
{
|
||||
const ck_tile::index_t stride_A = K;
|
||||
const ck_tile::index_t stride_B =
|
||||
@@ -698,6 +701,9 @@ class TestCkTileGemmBQuant : public TestCkTileGemmQuantBase<Tuple, TestCkTileGem
|
||||
sizeof(QDataType));
|
||||
ck_tile::DeviceMem c_m_n_dev_buf(M * N * sizeof(CDataType));
|
||||
|
||||
// Zero C buffer - required for split-K atomic_add accumulation
|
||||
c_m_n_dev_buf.SetZero();
|
||||
|
||||
// Copy to device
|
||||
a_m_k_dev_buf.ToDevice(a_m_k.data());
|
||||
ck_tile::HostTensor<BDataType> b_k_n_dev = b_k_n;
|
||||
@@ -746,12 +752,12 @@ class TestCkTileGemmBQuant : public TestCkTileGemmQuantBase<Tuple, TestCkTileGem
|
||||
c_m_n_dev_buf.GetDeviceBuffer(), // c_ptr
|
||||
nullptr, // aq_ptr (not used for BQuant)
|
||||
bq_bqk_bqn_dev_buf.GetDeviceBuffer(), // bq_ptr (scales)
|
||||
1, // k_batch
|
||||
k_batch, // k_batch (split-K)
|
||||
M,
|
||||
N,
|
||||
K, // M, N, K
|
||||
0, // QK_A (not used for BQuant)
|
||||
BQK, // QK_B - TODO: we can remove BQK and BQN from args later?
|
||||
BQK, // QK_B
|
||||
stride_A,
|
||||
stride_B,
|
||||
stride_C,
|
||||
@@ -796,7 +802,7 @@ class TestCkTileGemmBQuant : public TestCkTileGemmQuantBase<Tuple, TestCkTileGem
|
||||
*std::max_element(c_m_n_host_ref.mData.begin(), c_m_n_host_ref.mData.end());
|
||||
const auto rtol_atol =
|
||||
this->template calculate_rtol_atol<ADataType, BDataType, AccDataType, CDataType>(
|
||||
K, 1, max_accumulated_value);
|
||||
K, k_batch, max_accumulated_value);
|
||||
|
||||
// Validate results
|
||||
bool pass = ck_tile::check_err(c_m_n_dev_result,
|
||||
@@ -806,7 +812,7 @@ class TestCkTileGemmBQuant : public TestCkTileGemmQuantBase<Tuple, TestCkTileGem
|
||||
rtol_atol.at(ck_tile::number<1>{}));
|
||||
|
||||
EXPECT_TRUE(pass) << "BQuantGrouped validation failed with M=" << M << ", N=" << N
|
||||
<< ", K=" << K;
|
||||
<< ", K=" << K << ", k_batch=" << k_batch;
|
||||
|
||||
if(!pass)
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user