mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-03 05:01:25 +00:00
Integrate Multi D GEMMs into Grouped GEMMs along with unit tests (#2923)
* feat(grouped_gemm_multi_d): add new example that integrates grouped_gemm and multi_d_gemm feature * feat: generalized grouped_gemm_kernel.hpp * feat: generalized grouped_gemm_kernel.hpp even further by removing hardcoded 0 * refactor: grouped_gemm_multi_d relies on grouped_gemm_kernel * tests(grouped_gemm): grouped_gemm test suite passes with minor adjustments * fix: segfault fix by passing correct parameters for d tensors * docs: add multi d info and trim down outdated content * tests: add unit tests for grouped_gemm_multi_d and minor changes in grouped_gemm related test for compatibility * style: clang format * fix: incorrect validation method and Dtensor layout in test suite
This commit is contained in:
@@ -23,10 +23,13 @@ namespace ck_tile {
|
||||
/// arguments object. It contain all necessary information required to build proper kernel
|
||||
/// argument and launch kernel on GPU. This structure defines the GEMM problem configuration by
|
||||
/// stating all required information like M,N,K sizes and respective strides.
|
||||
|
||||
template <index_t NumDTensor = 0>
|
||||
struct GroupedGemmHostArgs
|
||||
{
|
||||
CK_TILE_HOST GroupedGemmHostArgs(const void* a_ptr_,
|
||||
const void* b_ptr_,
|
||||
const std::array<const void*, NumDTensor>& ds_ptr_,
|
||||
void* e_ptr_,
|
||||
index_t k_batch_,
|
||||
index_t M_,
|
||||
@@ -34,15 +37,18 @@ struct GroupedGemmHostArgs
|
||||
index_t K_,
|
||||
index_t stride_A_,
|
||||
index_t stride_B_,
|
||||
const std::array<index_t, NumDTensor>& stride_Ds_,
|
||||
index_t stride_E_)
|
||||
: a_ptr(a_ptr_),
|
||||
b_ptr(b_ptr_),
|
||||
ds_ptr(ds_ptr_),
|
||||
e_ptr(e_ptr_),
|
||||
M(M_),
|
||||
N(N_),
|
||||
K(K_),
|
||||
stride_A(stride_A_),
|
||||
stride_B(stride_B_),
|
||||
stride_Ds(stride_Ds_),
|
||||
stride_E(stride_E_),
|
||||
k_batch(k_batch_)
|
||||
{
|
||||
@@ -50,6 +56,7 @@ struct GroupedGemmHostArgs
|
||||
|
||||
const void* a_ptr;
|
||||
const void* b_ptr;
|
||||
const std::array<const void*, NumDTensor> ds_ptr;
|
||||
union
|
||||
{
|
||||
void* e_ptr;
|
||||
@@ -61,7 +68,7 @@ struct GroupedGemmHostArgs
|
||||
index_t K;
|
||||
index_t stride_A;
|
||||
index_t stride_B;
|
||||
|
||||
const std::array<index_t, NumDTensor> stride_Ds;
|
||||
union
|
||||
{
|
||||
index_t stride_E;
|
||||
@@ -71,20 +78,23 @@ struct GroupedGemmHostArgs
|
||||
index_t k_batch;
|
||||
};
|
||||
|
||||
template <index_t NumDTensor = 0>
|
||||
struct GemmTransKernelArg
|
||||
{
|
||||
UniversalGemmKernelArgs<> group_karg;
|
||||
UniversalGemmKernelArgs<1, 1, NumDTensor> group_karg;
|
||||
ck_tile::index_t block_start;
|
||||
ck_tile::index_t block_end;
|
||||
|
||||
GemmTransKernelArg() = delete;
|
||||
GemmTransKernelArg(UniversalGemmKernelArgs<>&& karg, index_t bl_start, index_t bl_end)
|
||||
: group_karg{karg}, block_start{bl_start}, block_end{bl_end}
|
||||
GemmTransKernelArg(UniversalGemmKernelArgs<1, 1, NumDTensor>&& karg,
|
||||
index_t bl_start,
|
||||
index_t bl_end)
|
||||
: group_karg{std::move(karg)}, block_start{bl_start}, block_end{bl_end}
|
||||
{
|
||||
}
|
||||
|
||||
GemmTransKernelArg(UniversalGemmKernelArgs<>&& karg)
|
||||
: group_karg{karg}, block_start{0}, block_end{0}
|
||||
GemmTransKernelArg(UniversalGemmKernelArgs<1, 1, NumDTensor>&& karg)
|
||||
: group_karg{std::move(karg)}, block_start{0}, block_end{0}
|
||||
{
|
||||
}
|
||||
};
|
||||
@@ -106,9 +116,12 @@ struct GroupedGemmKernel
|
||||
using CLayout = remove_cvref_t<typename GemmPipeline::CLayout>;
|
||||
|
||||
/// @brief Specify the data type configurations for A, B, C/E
|
||||
using ADataType = remove_cvref_t<typename GemmPipeline::ADataType>;
|
||||
using BDataType = remove_cvref_t<typename GemmPipeline::BDataType>;
|
||||
using CDataType = remove_cvref_t<typename EpiloguePipeline::ODataType>;
|
||||
using ADataType = remove_cvref_t<typename GemmPipeline::ADataType>;
|
||||
using BDataType = remove_cvref_t<typename GemmPipeline::BDataType>;
|
||||
using CDataType = remove_cvref_t<typename EpiloguePipeline::ODataType>;
|
||||
using DsDataType = remove_cvref_t<typename EpiloguePipeline::DsDataType>;
|
||||
|
||||
static constexpr index_t NumDTensor_ = DsDataType::size();
|
||||
|
||||
/// @brief ALayout and ADataType are expected to be scalars, not a tuple.
|
||||
static_assert(
|
||||
@@ -140,19 +153,21 @@ struct GroupedGemmKernel
|
||||
concat('x', P_::MPerBlock, P_::NPerBlock, P_::KPerBlock),
|
||||
concat('x', P_::GetVectorSizeA(), P_::GetVectorSizeB(), P_::GetVectorSizeC()),
|
||||
concat('x', P_::kPadM, P_::kPadN, P_::kPadK),
|
||||
(UsePersistentKernel ? "Persistent" : "NonPersistent"));
|
||||
(UsePersistentKernel ? "Persistent" : "NonPersistent"),
|
||||
(NumDTensor_ == 2 ? "MultiD" : "NoMultiD"),
|
||||
(GemmPipeline::DoubleSmemBuffer ? "DoubleSmemBuffer" : "SingleSmemBuffer"));
|
||||
// clang-format on
|
||||
}
|
||||
|
||||
CK_TILE_HOST static auto
|
||||
GetWorkSpaceSize(const std::vector<GroupedGemmHostArgs>& gemm_descs) -> std::size_t
|
||||
GetWorkSpaceSize(const std::vector<GroupedGemmHostArgs<>>& gemm_descs) -> std::size_t
|
||||
{
|
||||
return gemm_descs.size() * sizeof(GemmTransKernelArg);
|
||||
return gemm_descs.size() * sizeof(GemmTransKernelArg<NumDTensor_>);
|
||||
}
|
||||
|
||||
CK_TILE_HOST static auto GetWorkSpaceSize(index_t group_count) -> std::size_t
|
||||
{
|
||||
return group_count * sizeof(GemmTransKernelArg);
|
||||
return group_count * sizeof(GemmTransKernelArg<NumDTensor_>);
|
||||
}
|
||||
|
||||
CK_TILE_HOST static auto BlockSize() -> dim3
|
||||
@@ -184,7 +199,8 @@ struct GroupedGemmKernel
|
||||
return dim3(grid_size, 1, 1);
|
||||
}
|
||||
|
||||
CK_TILE_HOST static auto GridSize(const std::vector<GroupedGemmHostArgs>& gemm_descs)
|
||||
CK_TILE_HOST static auto
|
||||
GridSize(const std::vector<GroupedGemmHostArgs<NumDTensor_>>& gemm_descs)
|
||||
{
|
||||
index_t grid_size = 0;
|
||||
for(const auto& it_desc : gemm_descs)
|
||||
@@ -196,9 +212,10 @@ struct GroupedGemmKernel
|
||||
}
|
||||
|
||||
CK_TILE_HOST static auto
|
||||
MakeKargs(const std::vector<GroupedGemmHostArgs>& gemm_descs) -> std::vector<GemmTransKernelArg>
|
||||
MakeKargs(const std::vector<GroupedGemmHostArgs<NumDTensor_>>& gemm_descs)
|
||||
-> std::vector<GemmTransKernelArg<NumDTensor_>>
|
||||
{
|
||||
std::vector<GemmTransKernelArg> gemm_kernel_args_;
|
||||
std::vector<GemmTransKernelArg<NumDTensor_>> gemm_kernel_args_;
|
||||
index_t group_count = ck_tile::type_convert<ck_tile::index_t>(gemm_descs.size());
|
||||
index_t grid_size = 0;
|
||||
gemm_kernel_args_.reserve(group_count);
|
||||
@@ -217,6 +234,7 @@ struct GroupedGemmKernel
|
||||
const index_t stride_a = gemm_descs[i].stride_A;
|
||||
const index_t stride_b = gemm_descs[i].stride_B;
|
||||
const index_t stride_e = gemm_descs[i].stride_E;
|
||||
auto stride_ds = gemm_descs[i].stride_Ds;
|
||||
|
||||
const index_t grid_size_grp = TilePartitioner::GridSize(M, N) * gemm_descs[i].k_batch;
|
||||
|
||||
@@ -225,19 +243,19 @@ struct GroupedGemmKernel
|
||||
|
||||
grid_size += grid_size_grp;
|
||||
|
||||
auto karg =
|
||||
UniversalGemmKernelArgs<>{{type_convert<const ADataType*>(gemm_descs[i].a_ptr)},
|
||||
{type_convert<const BDataType*>(gemm_descs[i].b_ptr)},
|
||||
{/*ds_ptr*/},
|
||||
type_convert<CDataType*>(gemm_descs[i].e_ptr),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
{stride_a},
|
||||
{stride_b},
|
||||
{/*stride_ds*/},
|
||||
stride_e,
|
||||
gemm_descs[i].k_batch};
|
||||
auto karg = UniversalGemmKernelArgs<1, 1, NumDTensor_>{
|
||||
{type_convert<const ADataType*>(gemm_descs[i].a_ptr)},
|
||||
{type_convert<const BDataType*>(gemm_descs[i].b_ptr)},
|
||||
{gemm_descs[i].ds_ptr},
|
||||
type_convert<CDataType*>(gemm_descs[i].e_ptr),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
{stride_a},
|
||||
{stride_b},
|
||||
stride_ds,
|
||||
stride_e,
|
||||
gemm_descs[i].k_batch};
|
||||
|
||||
gemm_kernel_args_.emplace_back(std::move(karg), block_start, block_end);
|
||||
}
|
||||
@@ -245,7 +263,8 @@ struct GroupedGemmKernel
|
||||
return gemm_kernel_args_;
|
||||
}
|
||||
|
||||
CK_TILE_HOST static bool IsSupportedArgument(const std::vector<GemmTransKernelArg>& kargs)
|
||||
CK_TILE_HOST static bool
|
||||
IsSupportedArgument(const std::vector<GemmTransKernelArg<NumDTensor_>>& kargs)
|
||||
{
|
||||
for(const auto& karg : kargs)
|
||||
{
|
||||
@@ -262,7 +281,7 @@ struct GroupedGemmKernel
|
||||
return max(GemmPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize());
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE void Run(const UniversalGemmKernelArgs<>& kargs,
|
||||
CK_TILE_DEVICE void Run(const UniversalGemmKernelArgs<1, 1, NumDTensor_>& kargs,
|
||||
const tuple<index_t, index_t>& block_idx_2d,
|
||||
const index_t block_idx_z) const
|
||||
{
|
||||
@@ -292,8 +311,16 @@ struct GroupedGemmKernel
|
||||
{
|
||||
|
||||
__shared__ char smem_ptr_1[GetSmemSize()];
|
||||
RunGemmWithPipelineSelection2LDS(
|
||||
a_ptr, b_ptr, c_ptr, smem_ptr_0, smem_ptr_1, kargs, splitk_batch_offset, i_m, i_n);
|
||||
RunGemmWithPipelineSelection2LDS(a_ptr,
|
||||
b_ptr,
|
||||
c_ptr,
|
||||
kargs.ds_ptr,
|
||||
smem_ptr_0,
|
||||
smem_ptr_1,
|
||||
kargs,
|
||||
splitk_batch_offset,
|
||||
i_m,
|
||||
i_n);
|
||||
}
|
||||
else // SingleSmemBuffer
|
||||
{
|
||||
@@ -306,7 +333,7 @@ struct GroupedGemmKernel
|
||||
{
|
||||
Base::RunGemm({a_ptr},
|
||||
{b_ptr},
|
||||
{/*ds_ptr*/},
|
||||
kargs.ds_ptr,
|
||||
c_ptr,
|
||||
smem_ptr_0,
|
||||
kargs,
|
||||
@@ -340,7 +367,7 @@ struct GroupedGemmKernel
|
||||
const BDataType* b_ptr,
|
||||
CDataType* c_ptr,
|
||||
void* smem_ptr_0,
|
||||
const UniversalGemmKernelArgs<>& kargs,
|
||||
const UniversalGemmKernelArgs<1, 1, NumDTensor_>& kargs,
|
||||
const typename Base::SplitKBatchOffset& splitk_batch_offset,
|
||||
const index_t block_idx_m,
|
||||
const index_t block_idx_n)
|
||||
@@ -396,9 +423,10 @@ struct GroupedGemmKernel
|
||||
RunGemmWithPipelineSelection2LDS(const ADataType* a_ptr,
|
||||
const BDataType* b_ptr,
|
||||
CDataType* c_ptr,
|
||||
const std::array<const void*, NumDTensor_>& ds_ptr,
|
||||
void* __restrict__ smem_ptr_0,
|
||||
void* __restrict__ smem_ptr_1,
|
||||
const UniversalGemmKernelArgs<>& kargs,
|
||||
const UniversalGemmKernelArgs<1, 1, NumDTensor_>& kargs,
|
||||
const typename Base::SplitKBatchOffset& splitk_batch_offset,
|
||||
const index_t block_idx_m,
|
||||
const index_t block_idx_n)
|
||||
@@ -406,7 +434,7 @@ struct GroupedGemmKernel
|
||||
// Create Gemm tensor views, pad views and tile windows
|
||||
const auto& gemm_tensor_views_tuple =
|
||||
Base::template MakeGemmTensorViews<EpiloguePipeline::MemoryOperation>(
|
||||
{a_ptr}, {b_ptr}, {/*ds_ptr*/}, c_ptr, kargs, splitk_batch_offset.splitted_k);
|
||||
{a_ptr}, {b_ptr}, ds_ptr, c_ptr, kargs, splitk_batch_offset.splitted_k);
|
||||
|
||||
const auto& gemm_pad_views = Base::MakeGemmPadViews(gemm_tensor_views_tuple);
|
||||
auto gemm_tile_windows =
|
||||
@@ -453,7 +481,7 @@ struct GroupedGemmKernel
|
||||
c_block_window, c_block_tile, d_block_window, smem_ptr_0);
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE index_t FindGroupId(const GemmTransKernelArg* gemm_desc_ptr,
|
||||
CK_TILE_DEVICE index_t FindGroupId(const GemmTransKernelArg<NumDTensor_>* gemm_desc_ptr,
|
||||
index_t block_id,
|
||||
index_t group_count) const
|
||||
{
|
||||
@@ -485,7 +513,7 @@ struct GroupedGemmKernel
|
||||
index_t group_count) const
|
||||
{
|
||||
const index_t block_id = ck_tile::get_block_1d_id();
|
||||
const auto gemm_desc_ptr = reinterpret_cast<const GemmTransKernelArg*>(
|
||||
const auto gemm_desc_ptr = reinterpret_cast<const GemmTransKernelArg<NumDTensor_>*>(
|
||||
cast_pointer_to_generic_address_space(gemm_descs_const));
|
||||
|
||||
const index_t group_id = FindGroupId(gemm_desc_ptr, block_id, group_count);
|
||||
@@ -508,7 +536,7 @@ struct GroupedGemmKernel
|
||||
const index_t group_count) const
|
||||
{
|
||||
const index_t grid_size = ck_tile::get_grid_size();
|
||||
const auto gemm_desc_ptr = reinterpret_cast<const GemmTransKernelArg*>(
|
||||
const auto gemm_desc_ptr = reinterpret_cast<const GemmTransKernelArg<NumDTensor_>*>(
|
||||
cast_pointer_to_generic_address_space(gemm_descs_const));
|
||||
index_t block_id = ck_tile::get_block_1d_id(); // initial block_id
|
||||
index_t cum_grid_size = 0;
|
||||
|
||||
Reference in New Issue
Block a user