diff --git a/example/ck_tile/17_grouped_gemm/grouped_gemm.cpp b/example/ck_tile/17_grouped_gemm/grouped_gemm.cpp index 1e6844261f..527ef1e466 100644 --- a/example/ck_tile/17_grouped_gemm/grouped_gemm.cpp +++ b/example/ck_tile/17_grouped_gemm/grouped_gemm.cpp @@ -29,10 +29,6 @@ float grouped_gemm_tileloop(const ck_tile::stream_config& s, void* kargs_ptr, bool splitk) { - constexpr bool kPadM = false; - constexpr bool kPadN = false; - constexpr bool kPadK = false; - constexpr ck_tile::index_t TileParitionerGroupNum = 8; constexpr ck_tile::index_t TileParitionerM01 = 4; @@ -44,7 +40,6 @@ float grouped_gemm_tileloop(const ck_tile::stream_config& s, using TilePartitioner = ck_tile:: GemmSpatiallyLocalTilePartitioner; - using Traits = ck_tile::TileGemmTraits; using GemmUniversalTraits = ck_tile::PersistentTileGemmUniversalTraits; - using GemmPipelineProblem = - ck_tile::GemmPipelineProblem; float ave_time{0}; diff --git a/include/ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp b/include/ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp index c35435ee5e..eac7f547c1 100644 --- a/include/ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp +++ b/include/ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp @@ -155,7 +155,17 @@ struct GroupedGemmKernel return group_count * sizeof(GemmTransKernelArg); } - CK_TILE_HOST static constexpr auto BlockSize() -> dim3 { return dim3(kBlockSize); } + CK_TILE_HOST static auto BlockSize() -> dim3 + { + if(is_wave32()) + { + return dim3(kBlockSize / 2); + } + else + { + return dim3(kBlockSize); + } + } /** * @brief Get the maximum occupancy grid size for the persistent kernel on the current device.