From 6d6402d7be1730269ffaa1877b32a5a2869aa134 Mon Sep 17 00:00:00 2001 From: Thomas Ning Date: Fri, 1 Aug 2025 09:32:24 -0700 Subject: [PATCH] Fix the GFX 950 Universal GEMM (#2597) * solve the gfx950 error * clang format * fix a typo error --------- Co-authored-by: ThomasNing [ROCm/composable_kernel commit: 7c44a763fa9719ba1b18d3b6a37b6138c78d97fd] --- include/ck_tile/ops/gemm/warp/warp_gemm.hpp | 24 ++++++++++++------- .../ops/gemm/warp/warp_gemm_dispatcher.hpp | 16 +++++++++---- .../test_gemm_pipeline_universal_run_test.inc | 2 ++ 3 files changed, 30 insertions(+), 12 deletions(-) diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm.hpp index 47b91ccbf7..fb191d565d 100644 --- a/include/ck_tile/ops/gemm/warp/warp_gemm.hpp +++ b/include/ck_tile/ops/gemm/warp/warp_gemm.hpp @@ -265,17 +265,25 @@ using WarpGemmMfma_f32_16x16x64_bf8_bf8 = WarpGemmImpl, 2>>; -using WarpGemmMfma_f32_16x16x128_fp8_fp8 = WarpGemmImpl>>; +template +using WarpGemmMfma_f32_16x16x128_fp8_fp8 = WarpGemmImpl< + WarpGemmAtrributeMfma, + AttrNumAccess>>; -using WarpGemmMfma_f32_16x16x128_fp8_bf8 = WarpGemmImpl>>; +template +using WarpGemmMfma_f32_16x16x128_fp8_bf8 = WarpGemmImpl< + WarpGemmAtrributeMfma, + AttrNumAccess>>; -using WarpGemmMfma_f32_16x16x128_bf8_fp8 = WarpGemmImpl>>; +template +using WarpGemmMfma_f32_16x16x128_bf8_fp8 = WarpGemmImpl< + WarpGemmAtrributeMfma, + AttrNumAccess>>; -using WarpGemmMfma_f32_16x16x128_bf8_bf8 = WarpGemmImpl>>; +template +using WarpGemmMfma_f32_16x16x128_bf8_bf8 = WarpGemmImpl< + WarpGemmAtrributeMfma, + AttrNumAccess>>; template using WarpGemmMfma_f32_32x32x64_fp8_fp8 = WarpGemmImpl< diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp index 4e5d102e35..e91d505c8e 100644 --- a/include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp +++ b/include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp @@ -92,10 +92,10 @@ template<> struct WarpGemmMfmaDispatcher struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfma_f32_16x16x64_bf8_bf8; }; template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfma_f32_32x32x16_bf8_bf8_CTransposed; }; -template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfma_f32_16x16x128_fp8_fp8; }; -template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfma_f32_16x16x128_fp8_bf8; }; -template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfma_f32_16x16x128_bf8_fp8; }; -template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfma_f32_16x16x128_bf8_bf8; }; +template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfma_f32_16x16x128_fp8_fp8<>; }; +template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfma_f32_16x16x128_fp8_bf8<>; }; +template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfma_f32_16x16x128_bf8_fp8<>; }; +template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfma_f32_16x16x128_bf8_bf8<>; }; template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfma_f32_32x32x64_fp8_fp8<>; }; template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfma_f32_32x32x64_fp8_bf8<>; }; @@ -110,6 +110,14 @@ template<> struct WarpGemmMfmaDispatcher struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfma_f32_32x32x64_bf8_bf8; }; +template<> struct WarpGemmMfmaDispatcher { + using Type = WarpGemmMfma_f32_16x16x128_fp8_fp8; }; +template<> struct WarpGemmMfmaDispatcher { + using Type = WarpGemmMfma_f32_16x16x128_fp8_bf8; }; +template<> struct WarpGemmMfmaDispatcher { + using Type = WarpGemmMfma_f32_16x16x128_bf8_fp8; }; +template<> struct WarpGemmMfmaDispatcher { + using Type = WarpGemmMfma_f32_16x16x128_bf8_bf8; }; // int8 // ADataType, BDataType, AccDataType, MPerWave, NPerWave, KPerWave, TransposeC, SwizzleA, UseStructuredSparsity template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfma_i32_32x32x16_i8_i8; }; diff --git a/test/ck_tile/gemm/test_gemm_pipeline_universal_run_test.inc b/test/ck_tile/gemm/test_gemm_pipeline_universal_run_test.inc index 7d89dda684..a22ecf2486 100644 --- a/test/ck_tile/gemm/test_gemm_pipeline_universal_run_test.inc +++ b/test/ck_tile/gemm/test_gemm_pipeline_universal_run_test.inc @@ -375,6 +375,8 @@ int run_gemm_combinations(std::string const& data_type) { is_success = run_gemm_test(ARG_COUNT, argv) && is_success; + is_success = + run_gemm_test(ARG_COUNT, argv) && is_success; } catch(const ArgumentsNotSupportedException& e) {