[CK_TILE] Fix the Wrong Output Generated by Gemm Examples on GFX11/12 (#2713)

* Introduce macro CK_TILE_USE_WMMA

Signed-off-by: Tianyuan Wu <Tianyuan.Wu@amd.com>

* Make CK_TILE_USE_WMMA global for all examples

Signed-off-by: Tianyuan Wu <Tianyuan.Wu@amd.com>

* Remove CK_TILE_USE_WMMA from config.hpp

Signed-off-by: Tianyuan Wu <Tianyuan.Wu@amd.com>

---------

Signed-off-by: Tianyuan Wu <Tianyuan.Wu@amd.com>

[ROCm/composable_kernel commit: e9605ed36d]
This commit is contained in:
Tianyuan Wu
2025-08-26 03:55:35 +08:00
committed by GitHub
parent 5bac0f8933
commit 17c71940ca
4 changed files with 18 additions and 11 deletions

View File

@@ -225,6 +225,8 @@ if (SUPPORTED_GPU_TARGETS MATCHES "gfx11" OR SUPPORTED_GPU_TARGETS MATCHES "gfx1
message(STATUS "Enabling WMMA instances")
add_definitions(-DCK_USE_WMMA)
set(CK_USE_WMMA "ON")
add_definitions(-DCK_TILE_USE_WMMA)
set(CK_TILE_USE_WMMA "ON")
endif()
if (SUPPORTED_GPU_TARGETS MATCHES "gfx12")
message(STATUS "Enabling WMMA FP8 gemms on native architectures")
@@ -324,23 +326,12 @@ if(USE_BITINT_EXTENSION_INT4)
message(STATUS "CK compiled with USE_BITINT_EXTENSION_INT4 set to ${USE_BITINT_EXTENSION_INT4}")
endif()
if(USE_OPT_GFX11)
add_compile_options(-mcumode)
add_compile_options(-mno-wavefrontsize64)
message(STATUS "CK compiled with USE_OPT_GFX11 set to ${USE_OPT_GFX11}")
endif()
if(ENABLE_ASM_DUMP)
add_compile_options(--save-temps)
add_compile_options(-Wno-gnu-line-marker)
message("CK compiled with ENABLE_ASM_DUMP set to ${ENABLE_ASM_DUMP}")
endif()
if(USE_OPT_GFX12 AND (SUPPORTED_GPU_TARGETS MATCHES "gfx12"))
add_compile_options(-mno-wavefrontsize64)
message(STATUS "CK compiled with USE_OPT_GFX12 set to ${USE_OPT_GFX12}")
endif()
## Threads
set(THREADS_PREFER_PTHREAD_FLAG ON)
find_package(Threads REQUIRED)

View File

@@ -26,6 +26,15 @@ float gemm(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s)
constexpr ck_tile::index_t N_Tile = 256;
constexpr ck_tile::index_t K_Tile = 64;
#if CK_TILE_USE_WMMA
constexpr ck_tile::index_t M_Warp = 4;
constexpr ck_tile::index_t N_Warp = 2;
constexpr ck_tile::index_t K_Warp = 1;
constexpr ck_tile::index_t M_Warp_Tile = 16;
constexpr ck_tile::index_t N_Warp_Tile = 16;
constexpr ck_tile::index_t K_Warp_Tile = 16;
#else
constexpr ck_tile::index_t M_Warp = 2;
constexpr ck_tile::index_t N_Warp = 2;
constexpr ck_tile::index_t K_Warp = 1;
@@ -33,6 +42,7 @@ float gemm(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s)
constexpr ck_tile::index_t M_Warp_Tile = 32;
constexpr ck_tile::index_t N_Warp_Tile = 32;
constexpr ck_tile::index_t K_Warp_Tile = 16;
#endif
using CodegenGemmShape =
ck_tile::TileGemmShape<ck_tile::sequence<M_Tile, N_Tile, K_Tile>,

2
example/ck_tile/03_gemm/gemm_utils.hpp Executable file → Normal file
View File

@@ -172,6 +172,7 @@ struct GemmConfigComputeV3_2 : public GemmConfigBase
static constexpr int kBlockPerCu = 2;
};
#if CK_TILE_USE_WMMA
template <typename PrecType>
struct GemmConfigComputeV3_WMMA : public GemmConfigBase
{
@@ -192,6 +193,7 @@ struct GemmConfigComputeV3_WMMA : public GemmConfigBase
static constexpr int kBlockPerCu = 2;
};
#endif
template <typename PrecType>
struct GemmConfigComputeV4 : public GemmConfigBase

View File

@@ -335,7 +335,11 @@ int main(int argc, char* argv[])
try
{
#if CK_TILE_USE_WMMA
return !run_gemm_example<GemmConfigComputeV3_WMMA>(arg_parser);
#else
return !run_gemm_example<GemmConfigComputeV3>(arg_parser);
#endif
}
catch(const std::runtime_error& e)
{