From f9e76244d8cd6f2b200411f5f93d4c21f1b4a3b6 Mon Sep 17 00:00:00 2001 From: joyeamd Date: Tue, 19 Aug 2025 16:20:43 +0800 Subject: [PATCH] fix grouped gemm example when wave32 enabled (#2707) 1, delete some unused variables 2, fix BlockSize when wave32 enabled [ROCm/composable_kernel commit: a1589a9667517ddc73048c05c6f3c859db99851d] --- example/ck_tile/17_grouped_gemm/grouped_gemm.cpp | 7 ------- .../ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp | 12 +++++++++++- 2 files changed, 11 insertions(+), 8 deletions(-) diff --git a/example/ck_tile/17_grouped_gemm/grouped_gemm.cpp b/example/ck_tile/17_grouped_gemm/grouped_gemm.cpp index 1e6844261f..527ef1e466 100644 --- a/example/ck_tile/17_grouped_gemm/grouped_gemm.cpp +++ b/example/ck_tile/17_grouped_gemm/grouped_gemm.cpp @@ -29,10 +29,6 @@ float grouped_gemm_tileloop(const ck_tile::stream_config& s, void* kargs_ptr, bool splitk) { - constexpr bool kPadM = false; - constexpr bool kPadN = false; - constexpr bool kPadK = false; - constexpr ck_tile::index_t TileParitionerGroupNum = 8; constexpr ck_tile::index_t TileParitionerM01 = 4; @@ -44,7 +40,6 @@ float grouped_gemm_tileloop(const ck_tile::stream_config& s, using TilePartitioner = ck_tile:: GemmSpatiallyLocalTilePartitioner; - using Traits = ck_tile::TileGemmTraits; using GemmUniversalTraits = ck_tile::PersistentTileGemmUniversalTraits; - using GemmPipelineProblem = - ck_tile::GemmPipelineProblem; float ave_time{0}; diff --git a/include/ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp b/include/ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp index c35435ee5e..eac7f547c1 100644 --- a/include/ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp +++ b/include/ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp @@ -155,7 +155,17 @@ struct GroupedGemmKernel return group_count * sizeof(GemmTransKernelArg); } - CK_TILE_HOST static constexpr auto BlockSize() -> dim3 { return dim3(kBlockSize); } + CK_TILE_HOST static auto BlockSize() -> dim3 + { + if(is_wave32()) + { + return dim3(kBlockSize / 2); + } + else + { + return dim3(kBlockSize); + } + } /** * @brief Get the maximum occupancy grid size for the persistent kernel on the current device.