mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-03 13:11:25 +00:00
Integrate Multi D GEMMs into Grouped GEMMs along with unit tests (#2923)
* feat(grouped_gemm_multi_d): add new example that integrates grouped_gemm and multi_d_gemm feature * feat: generalized grouped_gemm_kernel.hpp * feat: generalized grouped_gemm_kernel.hpp even further by removing hardcoded 0 * refactor: grouped_gemm_multi_d relies on grouped_gemm_kernel * tests(grouped_gemm): grouped_gemm test suite passes with minor adjustments * fix: segfault fix by passing correct parameters for d tensors * docs: add multi d info and trim down outdated content * tests: add unit tests for grouped_gemm_multi_d and minor changes in grouped_gemm related test for compatibility * style: clang format * fix: incorrect validation method and Dtensor layout in test suite
This commit is contained in:
@@ -62,10 +62,10 @@ class TestCkTileGroupedGemm : public ::testing::Test
|
||||
static const ck_tile::index_t K_Warp_Tile = 16;
|
||||
};
|
||||
|
||||
using grouped_gemm_kargs = ck_tile::GroupedGemmHostArgs;
|
||||
using grouped_gemm_kargs = ck_tile::GroupedGemmHostArgs<>;
|
||||
std::size_t get_workspace_size(const std::vector<grouped_gemm_kargs>& gemm_descs)
|
||||
{
|
||||
return gemm_descs.size() * sizeof(ck_tile::GemmTransKernelArg);
|
||||
return gemm_descs.size() * sizeof(ck_tile::GemmTransKernelArg<>);
|
||||
}
|
||||
|
||||
template <typename GroupedGemKernelParam, typename ALayout, typename BLayout, typename CLayout>
|
||||
@@ -436,8 +436,18 @@ class TestCkTileGroupedGemm : public ::testing::Test
|
||||
const void* p_b = b_k_n_dev_buf[i]->GetDeviceBuffer();
|
||||
void* p_c = c_m_n_dev_buf[i]->GetDeviceBuffer();
|
||||
|
||||
gemm_descs.push_back(
|
||||
{p_a, p_b, p_c, kbatch, M, N, K, stride_As[i], stride_Bs[i], stride_Cs[i]});
|
||||
gemm_descs.push_back({p_a,
|
||||
p_b,
|
||||
{/*ds_ptr*/},
|
||||
p_c,
|
||||
kbatch,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
stride_As[i],
|
||||
stride_Bs[i],
|
||||
{/*stride_Ds*/},
|
||||
stride_Cs[i]});
|
||||
}
|
||||
|
||||
ck_tile::DeviceMem gemm_workspace;
|
||||
@@ -446,7 +456,7 @@ class TestCkTileGroupedGemm : public ::testing::Test
|
||||
if constexpr(Persistent)
|
||||
{
|
||||
// Generate kernel arguments
|
||||
std::vector<ck_tile::GemmTransKernelArg> kargs;
|
||||
std::vector<ck_tile::GemmTransKernelArg<>> kargs;
|
||||
void* kargs_ptr = gemm_workspace.GetDeviceBuffer();
|
||||
const bool splitk = gemm_descs[0].k_batch > 1;
|
||||
for(const auto& arg : gemm_descs)
|
||||
@@ -468,7 +478,7 @@ class TestCkTileGroupedGemm : public ::testing::Test
|
||||
ck_tile::hip_check_error(
|
||||
hipMemcpyWithStream(kargs_ptr,
|
||||
kargs.data(),
|
||||
kargs.size() * sizeof(ck_tile::GemmTransKernelArg),
|
||||
kargs.size() * sizeof(ck_tile::GemmTransKernelArg<>),
|
||||
hipMemcpyHostToDevice,
|
||||
stream.stream_id_));
|
||||
#if CK_TILE_USE_WMMA
|
||||
|
||||
Reference in New Issue
Block a user