From 802a5e7373f03d281e37758bb8af83cf87b629d8 Mon Sep 17 00:00:00 2001 From: Tianyuan Wu Date: Tue, 26 Aug 2025 03:55:35 +0800 Subject: [PATCH] [CK_TILE] Fix the Wrong Output Generated by Gemm Examples on GFX11/12 (#2713) * Introduce macro CK_TILE_USE_WMMA Signed-off-by: Tianyuan Wu * Make CK_TILE_USE_WMMA global for all examples Signed-off-by: Tianyuan Wu * Remove CK_TILE_USE_WMMA from config.hpp Signed-off-by: Tianyuan Wu --------- Signed-off-by: Tianyuan Wu [ROCm/composable_kernel commit: e9605ed36db7948491d21911267127823351991d] --- CMakeLists.txt | 13 ++----------- example/ck_tile/03_gemm/gemm_basic.cpp | 10 ++++++++++ example/ck_tile/03_gemm/gemm_utils.hpp | 2 ++ example/ck_tile/03_gemm/universal_gemm.cpp | 4 ++++ 4 files changed, 18 insertions(+), 11 deletions(-) mode change 100755 => 100644 example/ck_tile/03_gemm/gemm_utils.hpp diff --git a/CMakeLists.txt b/CMakeLists.txt index f77a41371f..f148f31d25 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -225,6 +225,8 @@ if (SUPPORTED_GPU_TARGETS MATCHES "gfx11" OR SUPPORTED_GPU_TARGETS MATCHES "gfx1 message(STATUS "Enabling WMMA instances") add_definitions(-DCK_USE_WMMA) set(CK_USE_WMMA "ON") + add_definitions(-DCK_TILE_USE_WMMA) + set(CK_TILE_USE_WMMA "ON") endif() if (SUPPORTED_GPU_TARGETS MATCHES "gfx12") message(STATUS "Enabling WMMA FP8 gemms on native architectures") @@ -324,23 +326,12 @@ if(USE_BITINT_EXTENSION_INT4) message(STATUS "CK compiled with USE_BITINT_EXTENSION_INT4 set to ${USE_BITINT_EXTENSION_INT4}") endif() -if(USE_OPT_GFX11) - add_compile_options(-mcumode) - add_compile_options(-mno-wavefrontsize64) - message(STATUS "CK compiled with USE_OPT_GFX11 set to ${USE_OPT_GFX11}") -endif() - if(ENABLE_ASM_DUMP) add_compile_options(--save-temps) add_compile_options(-Wno-gnu-line-marker) message("CK compiled with ENABLE_ASM_DUMP set to ${ENABLE_ASM_DUMP}") endif() -if(USE_OPT_GFX12 AND (SUPPORTED_GPU_TARGETS MATCHES "gfx12")) - add_compile_options(-mno-wavefrontsize64) - message(STATUS "CK compiled with USE_OPT_GFX12 set to ${USE_OPT_GFX12}") -endif() - ## Threads set(THREADS_PREFER_PTHREAD_FLAG ON) find_package(Threads REQUIRED) diff --git a/example/ck_tile/03_gemm/gemm_basic.cpp b/example/ck_tile/03_gemm/gemm_basic.cpp index 8cdbe39e86..99c943a7f1 100644 --- a/example/ck_tile/03_gemm/gemm_basic.cpp +++ b/example/ck_tile/03_gemm/gemm_basic.cpp @@ -26,6 +26,15 @@ float gemm(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s) constexpr ck_tile::index_t N_Tile = 256; constexpr ck_tile::index_t K_Tile = 64; +#if CK_TILE_USE_WMMA + constexpr ck_tile::index_t M_Warp = 4; + constexpr ck_tile::index_t N_Warp = 2; + constexpr ck_tile::index_t K_Warp = 1; + + constexpr ck_tile::index_t M_Warp_Tile = 16; + constexpr ck_tile::index_t N_Warp_Tile = 16; + constexpr ck_tile::index_t K_Warp_Tile = 16; +#else constexpr ck_tile::index_t M_Warp = 2; constexpr ck_tile::index_t N_Warp = 2; constexpr ck_tile::index_t K_Warp = 1; @@ -33,6 +42,7 @@ float gemm(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s) constexpr ck_tile::index_t M_Warp_Tile = 32; constexpr ck_tile::index_t N_Warp_Tile = 32; constexpr ck_tile::index_t K_Warp_Tile = 16; +#endif using CodegenGemmShape = ck_tile::TileGemmShape, diff --git a/example/ck_tile/03_gemm/gemm_utils.hpp b/example/ck_tile/03_gemm/gemm_utils.hpp old mode 100755 new mode 100644 index eb0a6de8aa..ed2006d4b9 --- a/example/ck_tile/03_gemm/gemm_utils.hpp +++ b/example/ck_tile/03_gemm/gemm_utils.hpp @@ -172,6 +172,7 @@ struct GemmConfigComputeV3_2 : public GemmConfigBase static constexpr int kBlockPerCu = 2; }; +#if CK_TILE_USE_WMMA template struct GemmConfigComputeV3_WMMA : public GemmConfigBase { @@ -192,6 +193,7 @@ struct GemmConfigComputeV3_WMMA : public GemmConfigBase static constexpr int kBlockPerCu = 2; }; +#endif template struct GemmConfigComputeV4 : public GemmConfigBase diff --git a/example/ck_tile/03_gemm/universal_gemm.cpp b/example/ck_tile/03_gemm/universal_gemm.cpp index 4e01710b4d..b80d9991d4 100644 --- a/example/ck_tile/03_gemm/universal_gemm.cpp +++ b/example/ck_tile/03_gemm/universal_gemm.cpp @@ -335,7 +335,11 @@ int main(int argc, char* argv[]) try { +#if CK_TILE_USE_WMMA + return !run_gemm_example(arg_parser); +#else return !run_gemm_example(arg_parser); +#endif } catch(const std::runtime_error& e) {