mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-03 21:21:22 +00:00
Integrate Multi D GEMMs into Grouped GEMMs along with unit tests (#2923)
* feat(grouped_gemm_multi_d): add new example that integrates grouped_gemm and multi_d_gemm feature * feat: generalized grouped_gemm_kernel.hpp * feat: generalized grouped_gemm_kernel.hpp even further by removing hardcoded 0 * refactor: grouped_gemm_multi_d relies on grouped_gemm_kernel * tests(grouped_gemm): grouped_gemm test suite passes with minor adjustments * fix: segfault fix by passing correct parameters for d tensors * docs: add multi d info and trim down outdated content * tests: add unit tests for grouped_gemm_multi_d and minor changes in grouped_gemm related test for compatibility * style: clang format * fix: incorrect validation method and Dtensor layout in test suite
This commit is contained in:
@@ -116,19 +116,6 @@ class TestCkTileGemmPipeline : public ::testing::Test
|
||||
template <typename GemmConfig, bool PadM, bool PadN, bool PadK, bool Preshuffle>
|
||||
void invoke_gemm(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s)
|
||||
{
|
||||
// TODO: This should be parameterized in tests
|
||||
// constexpr ck_tile::index_t M_Tile = 128;
|
||||
// constexpr ck_tile::index_t N_Tile = 128;
|
||||
// constexpr ck_tile::index_t K_Tile = 128;
|
||||
|
||||
// constexpr ck_tile::index_t M_Warp = 1;
|
||||
// constexpr ck_tile::index_t N_Warp = 4;
|
||||
// constexpr ck_tile::index_t K_Warp = 1;
|
||||
|
||||
// constexpr ck_tile::index_t M_Warp_Tile = 32;
|
||||
// constexpr ck_tile::index_t N_Warp_Tile = 32;
|
||||
// constexpr ck_tile::index_t K_Warp_Tile = sizeof(ADataType) == 2 ? 16 : 32;
|
||||
|
||||
constexpr bool kPadM = PadM;
|
||||
constexpr bool kPadN = PadN;
|
||||
constexpr bool kPadK = PadK;
|
||||
|
||||
Reference in New Issue
Block a user