Support multiple D in quant gemm kernel

A new kernel is added QuantGemmMultiDKernel, the existing
QuantGemmKernel behaves as usual.
This commit is contained in:
Anton Gorenko
2026-06-16 11:42:27 +05:00
parent 335f80033b
commit 8428732dc2
2 changed files with 318 additions and 108 deletions

View File

@@ -107,32 +107,16 @@ struct is_preshuffleB_enabled<T, std::void_t<decltype(T::PreshuffleB)>>
};
} // namespace detail
struct QuantGemmProblem
template <index_t NumDTensor>
struct QuantGemmMultiDHostArgs
{
CK_TILE_HOST QuantGemmProblem() = default;
CK_TILE_HOST QuantGemmProblem(index_t M_,
index_t N_,
index_t K_,
index_t QK_A_,
index_t QK_B_,
index_t stride_A_,
index_t stride_B_,
index_t stride_C_,
index_t stride_AQ_,
index_t stride_BQ_)
: M(M_),
N(N_),
K(K_),
QK_A(QK_A_),
QK_B(QK_B_),
stride_A(stride_A_),
stride_B(stride_B_),
stride_C(stride_C_),
stride_AQ(stride_AQ_),
stride_BQ(stride_BQ_)
{
}
const void* a_ptr;
const void* b_ptr;
const std::array<const void*, NumDTensor> ds_ptr;
void* c_ptr;
const void* aq_ptr;
const void* bq_ptr;
index_t k_batch;
index_t M;
index_t N;
index_t K;
@@ -140,56 +124,20 @@ struct QuantGemmProblem
index_t QK_B;
index_t stride_A;
index_t stride_B;
const std::array<index_t, NumDTensor> stride_Ds;
index_t stride_C;
index_t stride_AQ;
index_t stride_BQ;
};
struct QuantGemmHostArgs : public QuantGemmProblem
{
CK_TILE_HOST QuantGemmHostArgs() = default;
CK_TILE_HOST QuantGemmHostArgs(const void* a_ptr_,
const void* b_ptr_,
void* c_ptr_,
const void* aq_ptr_,
const void* bq_ptr_,
index_t k_batch_,
index_t M_,
index_t N_,
index_t K_,
index_t QK_A_,
index_t QK_B_,
index_t stride_A_,
index_t stride_B_,
index_t stride_C_,
index_t stride_AQ_,
index_t stride_BQ_)
: QuantGemmProblem(
M_, N_, K_, QK_A_, QK_B_, stride_A_, stride_B_, stride_C_, stride_AQ_, stride_BQ_),
a_ptr(a_ptr_),
b_ptr(b_ptr_),
aq_ptr(aq_ptr_),
bq_ptr(bq_ptr_),
c_ptr(c_ptr_),
k_batch(k_batch_)
{
}
const void* a_ptr = nullptr;
const void* b_ptr = nullptr;
const void* aq_ptr = nullptr;
const void* bq_ptr = nullptr;
void* c_ptr = nullptr;
// k_batch must be a positive integer; defaults to 1 (no split-K).
index_t k_batch = 1;
};
struct QuantGemmKernelArgs
template <index_t NumDTensor>
struct QuantGemmMultiDKernelArgs
{
const void* a_ptr;
const void* b_ptr;
const void* aq_ptr;
const void* bq_ptr;
const std::array<const void*, NumDTensor> ds_ptr;
void* c_ptr;
index_t M;
index_t N;
@@ -198,6 +146,7 @@ struct QuantGemmKernelArgs
index_t QK_B;
index_t stride_A;
index_t stride_B;
const std::array<index_t, NumDTensor> stride_Ds;
index_t stride_C;
index_t stride_AQ;
index_t stride_BQ;
@@ -233,7 +182,7 @@ template <typename TilePartitioner_,
typename EpiloguePipeline_,
QuantType QuantType_,
bool RuntimeSplitKTail_ = false>
struct QuantGemmKernel
struct QuantGemmMultiDKernel
{
using TilePartitioner = remove_cvref_t<TilePartitioner_>;
using GemmPipeline = remove_cvref_t<GemmPipeline_>;
@@ -241,6 +190,7 @@ struct QuantGemmKernel
using ALayout = remove_cvref_t<typename GemmPipeline::ALayout>;
using BLayout = remove_cvref_t<typename GemmPipeline::BLayout>;
using CLayout = remove_cvref_t<typename GemmPipeline::CLayout>;
using DsLayout = remove_cvref_t<typename EpiloguePipeline::DsLayout>;
using AQLayout = remove_cvref_t<
typename detail::get_aq_layout_or<GemmPipeline, typename GemmPipeline::ALayout>::type>;
@@ -257,6 +207,7 @@ struct QuantGemmKernel
using ADataType = remove_cvref_t<typename GemmPipeline::ADataType>;
using BDataType = remove_cvref_t<typename GemmPipeline::BDataType>;
using CDataType = remove_cvref_t<typename EpiloguePipeline::ODataType>;
using DsDataType = remove_cvref_t<typename EpiloguePipeline::DsDataType>;
using AccDataType = remove_cvref_t<typename EpiloguePipeline::AccDataType>;
using AQDataType =
@@ -264,15 +215,23 @@ struct QuantGemmKernel
using BQDataType =
remove_cvref_t<typename detail::get_bq_data_type_or<GemmPipeline, AccDataType>::type>;
static constexpr auto I0 = number<0>(); // A Tensor
static constexpr auto I1 = number<1>(); // AQ Tensor
static constexpr auto I2 = number<2>(); // B Tensor
static constexpr auto I3 = number<3>(); // BQ Tensor
static constexpr auto I4 = number<4>(); // C Tensor
static_assert(is_detected<is_tuple, DsLayout>::value &&
is_detected<is_tuple, DsDataType>::value &&
DsLayout::size() == DsDataType::size(),
"DsLayout and DsDataType must be tuples and must have the same size.");
static constexpr index_t NumDTensor = DsDataType::size();
static constexpr auto I0 = number<0>();
static constexpr auto I1 = number<1>();
static constexpr auto I2 = number<2>();
static constexpr auto kQuantType = QuantType_;
static constexpr bool RuntimeSplitKTail = RuntimeSplitKTail_;
using HostArgs = QuantGemmMultiDHostArgs<NumDTensor>;
using KernelArgs = QuantGemmMultiDKernelArgs<NumDTensor>;
[[nodiscard]] CK_TILE_HOST static const std::string GetName()
{
// clang-format off
@@ -290,25 +249,26 @@ struct QuantGemmKernel
return is_wave32() ? dim3(kBlockSize / 2) : dim3(kBlockSize);
}
CK_TILE_HOST static constexpr QuantGemmKernelArgs
MakeKernelArgs(const QuantGemmHostArgs& hostArgs)
CK_TILE_HOST static constexpr KernelArgs MakeKernelArgs(const HostArgs& hostArgs)
{
return QuantGemmKernelArgs{hostArgs.a_ptr,
hostArgs.b_ptr,
hostArgs.aq_ptr,
hostArgs.bq_ptr,
hostArgs.c_ptr,
hostArgs.M,
hostArgs.N,
hostArgs.K,
hostArgs.QK_A,
hostArgs.QK_B,
hostArgs.stride_A,
hostArgs.stride_B,
hostArgs.stride_C,
hostArgs.stride_AQ,
hostArgs.stride_BQ,
hostArgs.k_batch};
return KernelArgs{hostArgs.a_ptr,
hostArgs.b_ptr,
hostArgs.aq_ptr,
hostArgs.bq_ptr,
hostArgs.ds_ptr,
hostArgs.c_ptr,
hostArgs.M,
hostArgs.N,
hostArgs.K,
hostArgs.QK_A,
hostArgs.QK_B,
hostArgs.stride_A,
hostArgs.stride_B,
hostArgs.stride_Ds,
hostArgs.stride_C,
hostArgs.stride_AQ,
hostArgs.stride_BQ,
hostArgs.k_batch};
}
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
@@ -408,8 +368,8 @@ struct QuantGemmKernel
public:
struct SplitKBatchOffset
{
__device__ SplitKBatchOffset(const QuantGemmKernelArgs& kargs,
const std::size_t k_id = blockIdx.z)
CK_TILE_DEVICE SplitKBatchOffset(const KernelArgs& kargs,
const std::size_t k_id = blockIdx.z)
{
constexpr auto K1 =
GemmPipeline::BlockGemmShape::WarpTile::at(I2); // smallest unit of K work per block
@@ -543,7 +503,7 @@ struct QuantGemmKernel
};
CK_TILE_DEVICE static auto MakeABlockWindow(const ADataType* a_ptr,
const QuantGemmKernelArgs& kargs,
const KernelArgs& kargs,
const index_t k_size,
const index_t i_m)
{
@@ -609,7 +569,7 @@ struct QuantGemmKernel
}
CK_TILE_DEVICE static auto MakeAQBlockWindow(const AQDataType* aq_ptr,
const QuantGemmKernelArgs& kargs,
const KernelArgs& kargs,
const index_t i_m,
const index_t i_n,
const index_t aq_group_offset = 0)
@@ -796,7 +756,7 @@ struct QuantGemmKernel
}
CK_TILE_DEVICE static auto MakeBBlockWindow(const BDataType* b_ptr,
const QuantGemmKernelArgs& kargs,
const KernelArgs& kargs,
const index_t k_size,
const index_t i_n)
{
@@ -935,7 +895,7 @@ struct QuantGemmKernel
}
CK_TILE_DEVICE static auto MakeBQBlockWindow(const BQDataType* bq_ptr,
const QuantGemmKernelArgs& kargs,
const KernelArgs& kargs,
const index_t bq_group_offset,
const index_t i_m,
const index_t i_n)
@@ -1125,9 +1085,101 @@ struct QuantGemmKernel
return bq_block_window;
}
template <typename DLayout, index_t VectorSizeD>
CK_TILE_DEVICE static auto
MakeDTensorDescriptor(const index_t M, const index_t N, const index_t stride)
{
if constexpr(std::is_same_v<DLayout, tensor_layout::gemm::RowMajor>)
{
return make_naive_tensor_descriptor(
make_tuple(M, N), make_tuple(stride, 1), number<VectorSizeD>{}, number<1>{});
}
else
{
return make_naive_tensor_descriptor(
make_tuple(N, M), make_tuple(stride, 1), number<VectorSizeD>{}, number<1>{});
}
}
template <typename DsTensorDesc>
CK_TILE_DEVICE static auto MakeDBlockWindows(const std::array<const void*, NumDTensor>& ds_ptr,
const DsTensorDesc& ds_desc,
const index_t i_m,
const index_t i_n)
{
// Step 1: Create tensor views
const auto& ds_tensor_view = generate_tuple(
[&](auto i) {
using DDataType_ = remove_cvref_t<std::tuple_element_t<i.value, DsDataType>>;
return make_tensor_view<address_space_enum::global>(
static_cast<const DDataType_*>(ds_ptr[i]), ds_desc[i]);
},
number<NumDTensor>{});
// Step 2: Create padded views
const auto& ds_pad_view = generate_tuple(
[&](auto i) {
using DiLayout = remove_cvref_t<std::tuple_element_t<i.value, DsLayout>>;
if constexpr(std::is_same_v<DiLayout, tensor_layout::gemm::RowMajor>)
{
return pad_tensor_view(ds_tensor_view[i],
make_tuple(number<TilePartitioner::MPerBlock>{},
number<TilePartitioner::NPerBlock>{}),
sequence<false, GemmPipeline::kPadN>{});
}
else
{
return pad_tensor_view(ds_tensor_view[i],
make_tuple(number<TilePartitioner::NPerBlock>{},
number<TilePartitioner::MPerBlock>{}),
sequence<false, GemmPipeline::kPadM>{});
}
},
number<NumDTensor>{});
// Step 3: Create tile windows
const auto& ds_block_window = generate_tuple(
[&](auto i) {
using DiLayout = remove_cvref_t<std::tuple_element_t<i.value, DsLayout>>;
if constexpr(std::is_same_v<DiLayout, tensor_layout::gemm::RowMajor>)
{
return make_tile_window(ds_pad_view[i],
make_tuple(number<TilePartitioner::MPerBlock>{},
number<TilePartitioner::NPerBlock>{}),
{i_m, i_n});
}
else
{
return make_tile_window(ds_pad_view[i],
make_tuple(number<TilePartitioner::NPerBlock>{},
number<TilePartitioner::MPerBlock>{}),
{i_n, i_m});
}
},
number<NumDTensor>{});
return ds_block_window;
}
CK_TILE_DEVICE static auto MakeDBlockWindows(const std::array<const void*, NumDTensor>& ds_ptr,
const KernelArgs& kargs,
const index_t i_m,
const index_t i_n)
{
const auto& ds_tensor_desc = generate_tuple(
[&](auto i) {
using DiLayout = remove_cvref_t<std::tuple_element_t<i.value, DsLayout>>;
return MakeDTensorDescriptor<DiLayout, EpiloguePipeline::GetVectorSizeD(i)>(
kargs.M, kargs.N, kargs.stride_Ds[i]);
},
number<NumDTensor>{});
return MakeDBlockWindows(ds_ptr, ds_tensor_desc, i_m, i_n);
}
template <memory_operation_enum DstInMemOp = memory_operation_enum::set>
CK_TILE_DEVICE static auto MakeCBlockWindow(CDataType* c_ptr,
const QuantGemmKernelArgs& kargs,
const KernelArgs& kargs,
const index_t i_m,
const index_t i_n)
{
@@ -1180,7 +1232,7 @@ struct QuantGemmKernel
return c_block_window;
}
CK_TILE_HOST static bool IsSupportedArgument(const QuantGemmKernelArgs& kargs)
CK_TILE_HOST static bool IsSupportedArgument(const KernelArgs& kargs)
{
// k_batch must be a positive integer.
if(kargs.k_batch <= 0)
@@ -1438,6 +1490,59 @@ struct QuantGemmKernel
}
}
bool DTensorIsValid = {true};
static_for<0, NumDTensor, 1>{}([&](auto index) {
using DiLayout = remove_cvref_t<std::tuple_element_t<index.value, DsLayout>>;
if(std::is_same_v<DiLayout, CLayout> == false)
{
DTensorIsValid = false;
}
if constexpr(std::is_same_v<DiLayout, tensor_layout::gemm::RowMajor>)
{
if(kargs.N % TilePartitioner::NPerBlock != 0 && GemmPipeline::kPadN == false)
{
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
{
CK_TILE_ERROR("Can't support N for tensor D that is not a multiple of "
"NPerBlock without padding!");
}
DTensorIsValid = false;
}
if(kargs.N % EpiloguePipeline::GetVectorSizeD(index) != 0)
{
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
{
CK_TILE_ERROR("N is not a multiple of vector load size for D tensor!");
}
DTensorIsValid = false;
}
}
else
{
if(kargs.M % TilePartitioner::MPerBlock != 0 && GemmPipeline::kPadM == false)
{
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
{
CK_TILE_ERROR("Can't support M for tensor D that is not a multiple of "
"MPerBlock without padding!");
}
DTensorIsValid = false;
}
if(kargs.M % EpiloguePipeline::GetVectorSizeD(index) != 0)
{
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
{
CK_TILE_ERROR("M is not a multiple of vector load size for D tensor!");
}
DTensorIsValid = false;
}
}
});
if(!DTensorIsValid)
{
return false;
}
if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
{
// For RowMajor C, M is the row dimension - check M alignment here because
@@ -1585,9 +1690,10 @@ struct QuantGemmKernel
const BDataType* b_ptr,
const AQDataType* aq_ptr,
const BQDataType* bq_ptr,
const std::array<const void*, NumDTensor>& ds_ptr,
CDataType* c_ptr,
void* smem_ptr,
const QuantGemmKernelArgs& kargs,
const KernelArgs& kargs,
const SplitKBatchOffset& splitk_batch_offset,
const index_t block_idx_m,
const index_t block_idx_n)
@@ -1605,6 +1711,7 @@ struct QuantGemmKernel
// the remaining K-groups from the split-K offset position.
const auto& bq_block_window = MakeBQBlockWindow(
bq_ptr, kargs, splitk_batch_offset.bq_group_offset, block_idx_m, block_idx_n);
const auto& ds_block_window = MakeDBlockWindows(ds_ptr, kargs, block_idx_m, block_idx_n);
const index_t num_loop =
amd_wave_read_first_lane(TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k));
@@ -1667,13 +1774,13 @@ struct QuantGemmKernel
kQuantType == QuantType::AQuantGrouped ||
kQuantType == QuantType::BQuantGrouped)
{
EpiloguePipeline{}(c_block_window, c_block_tile, c_block_window, smem_ptr);
EpiloguePipeline{}(c_block_window, c_block_tile, ds_block_window, smem_ptr);
}
else if constexpr(kQuantType == QuantType::RowColQuant)
{
EpiloguePipeline{}(c_block_window,
c_block_tile,
c_block_window,
ds_block_window,
smem_ptr,
aq_block_window,
bq_block_window);
@@ -1683,7 +1790,7 @@ struct QuantGemmKernel
const AccDataType aq_scale = type_convert<AccDataType>(*aq_ptr);
const AccDataType bq_scale = type_convert<AccDataType>(*bq_ptr);
EpiloguePipeline{}(
c_block_window, c_block_tile, c_block_window, smem_ptr, aq_scale, bq_scale);
c_block_window, c_block_tile, ds_block_window, smem_ptr, aq_scale, bq_scale);
}
}
else
@@ -1695,13 +1802,13 @@ struct QuantGemmKernel
kQuantType == QuantType::AQuantGrouped ||
kQuantType == QuantType::BQuantGrouped)
{
EpiloguePipeline{}(c_block_window, c_block_tile, c_block_window, smem_ptr);
EpiloguePipeline{}(c_block_window, c_block_tile, ds_block_window, smem_ptr);
}
else if constexpr(kQuantType == QuantType::RowColQuant)
{
EpiloguePipeline{}(c_block_window,
c_block_tile,
c_block_window,
ds_block_window,
smem_ptr,
aq_block_window,
bq_block_window);
@@ -1711,12 +1818,12 @@ struct QuantGemmKernel
const AccDataType aq_scale = type_convert<AccDataType>(*aq_ptr);
const AccDataType bq_scale = type_convert<AccDataType>(*bq_ptr);
EpiloguePipeline{}(
c_block_window, c_block_tile, c_block_window, smem_ptr, aq_scale, bq_scale);
c_block_window, c_block_tile, ds_block_window, smem_ptr, aq_scale, bq_scale);
}
}
}
CK_TILE_DEVICE void Run_(const QuantGemmKernelArgs& kargs) const
CK_TILE_DEVICE void Run_(const KernelArgs& kargs) const
{
const auto blockId = amd_wave_read_first_lane(blockIdx.x);
const auto [iM, iN] = TilePartitioner{kargs.M, kargs.N}.GetOutputTileIndex(blockId);
@@ -1739,8 +1846,17 @@ struct QuantGemmKernel
// allocate LDS
__shared__ char smem_ptr[GetSmemSize()];
RunGemm(
a_ptr, b_ptr, aq_ptr, bq_ptr, c_ptr, smem_ptr, kargs, splitk_batch_offset, i_m, i_n);
RunGemm(a_ptr,
b_ptr,
aq_ptr,
bq_ptr,
kargs.ds_ptr,
c_ptr,
smem_ptr,
kargs,
splitk_batch_offset,
i_m,
i_n);
}
template <typename T, typename = void>
@@ -1749,7 +1865,7 @@ struct QuantGemmKernel
static constexpr bool kIsAvailableV<T, std::void_t<decltype(T::kIsAvailable)>> =
T::kIsAvailable;
CK_TILE_DEVICE void operator()(const QuantGemmKernelArgs& kargs) const
CK_TILE_DEVICE void operator()(const KernelArgs& kargs) const
{
if constexpr(!kIsAvailableV<GemmPipeline>)
ignore = kargs;
@@ -1758,6 +1874,99 @@ struct QuantGemmKernel
}
};
struct QuantGemmHostArgs : public QuantGemmMultiDHostArgs<0>
{
CK_TILE_HOST QuantGemmHostArgs() = default;
CK_TILE_HOST QuantGemmHostArgs(const void* a_ptr_,
const void* b_ptr_,
void* c_ptr_,
const void* aq_ptr_,
const void* bq_ptr_,
index_t k_batch_,
index_t M_,
index_t N_,
index_t K_,
index_t QK_A_,
index_t QK_B_,
index_t stride_A_,
index_t stride_B_,
index_t stride_C_,
index_t stride_AQ_,
index_t stride_BQ_)
: QuantGemmMultiDHostArgs{a_ptr_,
b_ptr_,
std::array<const void*, 0>{},
c_ptr_,
aq_ptr_,
bq_ptr_,
k_batch_,
M_,
N_,
K_,
QK_A_,
QK_B_,
stride_A_,
stride_B_,
std::array<index_t, 0>{},
stride_C_,
stride_AQ_,
stride_BQ_}
{
}
};
struct QuantGemmKernelArgs : public QuantGemmMultiDKernelArgs<0>
{
CK_TILE_HOST QuantGemmKernelArgs() = default;
CK_TILE_HOST QuantGemmKernelArgs(const void* a_ptr_,
const void* b_ptr_,
const void* aq_ptr_,
const void* bq_ptr_,
void* c_ptr_,
index_t M_,
index_t N_,
index_t K_,
index_t QK_A_,
index_t QK_B_,
index_t stride_A_,
index_t stride_B_,
index_t stride_C_,
index_t stride_AQ_,
index_t stride_BQ_,
index_t k_batch_)
: QuantGemmMultiDKernelArgs<0>{a_ptr_,
b_ptr_,
aq_ptr_,
bq_ptr_,
std::array<const void*, 0>{},
c_ptr_,
M_,
N_,
K_,
QK_A_,
QK_B_,
stride_A_,
stride_B_,
std::array<index_t, 0>{},
stride_C_,
stride_AQ_,
stride_BQ_,
k_batch_}
{
}
};
template <typename TilePartitioner_,
typename GemmPipeline_,
typename EpiloguePipeline_,
QuantType QuantType_,
bool RuntimeSplitKTail_ = false>
using QuantGemmKernel = QuantGemmMultiDKernel<TilePartitioner_,
GemmPipeline_,
EpiloguePipeline_,
QuantType_,
RuntimeSplitKTail_>;
} // namespace ck_tile
#if __clang_major__ >= 23
#pragma clang diagnostic pop

View File

@@ -366,6 +366,7 @@ struct QuantGroupedGemmKernel
{b_ptr},
aq_ptr,
bq_ptr,
{},
c_ptr,
smem_ptr,
kargs,