From f7c2b4217034b2167a45eddcc2b195d01e4e6810 Mon Sep 17 00:00:00 2001 From: Thrupti Raj Lakshmana Gowda Date: Thu, 26 Feb 2026 10:14:40 -0600 Subject: [PATCH] Tile Engine support for gfx950 (#4592) ## Motivation This PR adds support for the gfx950 GPU architecture to the Tile Engine in Composable Kernel library, focusing on GEMM operations with FP8 and BF8 data types. ## Technical Details Added gfx950-specific MFMA warp GEMM implementations with conditional compilation. Updated default GEMM configuration parameters for tile sizes and warp configurations. Added Jenkins CI pipeline stage for testing TILE_ENGINE_GEMM on gfx950 hardware. ## Test Plan Tile engine itself is a benchmarking utility, so if it passes the CI it will be tested automatically. ## Test Result Tile engine itself is a benchmarking utility, so if it passes the CI it will be tested automatically. ## Submission Checklist - [x] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests. --------- Co-authored-by: Thrupti Raj Lakshmana Gowda Co-authored-by: Thomas Ning --- Jenkinsfile | 29 +++++++++++ include/ck_tile/ops/gemm/warp/warp_gemm.hpp | 50 ++++++++++++++----- .../ops/gemm/warp/warp_gemm_dispatcher.hpp | 14 ++++-- .../configs/user_provided_config.json | 19 ++++--- 4 files changed, 86 insertions(+), 26 deletions(-) diff --git a/Jenkinsfile b/Jenkinsfile index 6c54eaf8d3..4cf4bb1c77 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -1727,6 +1727,35 @@ pipeline { cleanWs() } } + stage("Run TILE_ENGINE_GEMM Tests on gfx950") + { + when { + beforeAgent true + expression { params.RUN_TILE_ENGINE_GEMM_TESTS.toBoolean() } + } + agent{ label rocmnode("gfx950") } + environment{ + setup_args = "NO_CK_BUILD" + execute_args = """ cmake -G Ninja -D CMAKE_PREFIX_PATH=/opt/rocm \ + -D CMAKE_CXX_COMPILER="${params.BUILD_COMPILER}" \ + -D CMAKE_BUILD_TYPE=Release \ + -D GPU_TARGETS="gfx950" \ + -D GEMM_UNIVERSAL_DATATYPE="fp8;fp16" \ + -D GEMM_UNIVERSAL_LAYOUT="rcr;rrr;crr;ccr" \ + -D GEMM_MULTI_D_DATATYPE="fp16" \ + -D GEMM_MULTI_D_LAYOUT="rcrr;rrrr;crrr;ccrr" \ + -D GEMM_PRESHUFFLE_DATATYPE="fp16;fp8;bf16;bf8" \ + -D GEMM_PRESHUFFLE_LAYOUT="rcr" .. && \ + ninja -j${nthreads()} benchmark_gemm_universal_all benchmark_gemm_preshuffle_all benchmark_gemm_multi_d_all && \ + python3 ../tile_engine/ops/gemm/gemm_universal/gemm_benchmark.py . --problem-sizes "1024,1024,1024" --warmup 5 --repeat 5 --verbose --json results.json && \ + python3 ../tile_engine/ops/gemm/gemm_preshuffle/gemm_preshuffle_benchmark.py . --problem-sizes "1024,1024,1024" --warmup 5 --repeat 5 --verbose --json results.json && \ + python3 ../tile_engine/ops/gemm/gemm_multi_d/gemm_multi_d_benchmark.py . --problem-sizes "1024,1024,1024" --warmup 5 --repeat 5 --verbose --json results.json """ + } + steps{ + buildHipClangJobAndReboot(setup_args:setup_args, build_type: 'Release', execute_cmd: execute_args) + cleanWs() + } + } stage("Run TILE_ENGINE_GEMM Tests on gfx1201") { when { diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm.hpp index adf6a515f2..2f25ae9bf5 100644 --- a/include/ck_tile/ops/gemm/warp/warp_gemm.hpp +++ b/include/ck_tile/ops/gemm/warp/warp_gemm.hpp @@ -293,6 +293,44 @@ using WarpGemmMfmaBf16Bf16F32M64N4K16 = WarpGemmImpl>; // fp8 +#if defined(__gfx950__) +template +using WarpGemmMfma_f32_32x32x32_fp8_fp8 = WarpGemmImpl, + 2, + AttrNumAccess>>; +#else +template +using WarpGemmMfma_f32_32x32x32_fp8_fp8 = WarpGemmImpl, + 2>>; +#endif + +#if defined(__gfx950__) +template +using WarpGemmMfma_f32_16x16x64_fp8_fp8 = WarpGemmImpl, + 2, + AttrNumAccess>>; +#else +template +using WarpGemmMfma_f32_16x16x64_fp8_fp8 = WarpGemmImpl, + 2>>; +#endif + +#if defined(__gfx950__) +template +using WarpGemmMfma_f32_32x32x32_bf8_bf8 = WarpGemmImpl, + 2, + AttrNumAccess>>; +#else +template +using WarpGemmMfma_f32_32x32x32_bf8_bf8 = WarpGemmImpl, + 2>>; +#endif using WarpGemmMfma_f32_32x32x16_fp8_fp8 = WarpGemmImpl< WarpGemmAttributeMfma>>; @@ -309,14 +347,6 @@ using WarpGemmMfma_f32_32x32x16_bf8_fp8 = WarpGemmImpl< using WarpGemmMfma_f32_32x32x16_bf8_bf8 = WarpGemmImpl< WarpGemmAttributeMfma>>; -using WarpGemmMfma_f32_32x32x32_fp8_fp8 = WarpGemmImpl, - 2>>; - -using WarpGemmMfma_f32_32x32x32_bf8_bf8 = WarpGemmImpl, - 2>>; - using WarpGemmMfma_f32_32x32x32_fp8_bf8 = WarpGemmImpl, 2>>; @@ -335,10 +365,6 @@ using WarpGemmMfma_f32_16x16x32_bf8_bf8_CTransposed = WarpGemmImpl>>; -using WarpGemmMfma_f32_16x16x64_fp8_fp8 = WarpGemmImpl, - 2>>; - using WarpGemmMfma_f32_16x16x64_bf8_bf8 = WarpGemmImpl, 2>>; diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp index ace9a50d01..43f17c1a56 100644 --- a/include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp +++ b/include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp @@ -105,10 +105,7 @@ template<> struct Dispatcher { u // fp8 // ADataType, BDataType, AccDataType, MPerWave, NPerWave, KPerWave, TransposeC, SwizzleA, UseStructuredSparsity template<> struct Dispatcher { using Type = WarpGemmMfma_f32_32x32x16_fp8_fp8; }; -template<> struct Dispatcher { using Type = WarpGemmMfma_f32_32x32x32_fp8_fp8; }; template<> struct Dispatcher { using Type = WarpGemmMfma_f32_16x16x32_fp8_fp8; }; -template<> struct Dispatcher { using Type = WarpGemmMfma_f32_16x16x64_fp8_fp8; }; -template<> struct Dispatcher { using Type = WarpGemmMfma_f32_16x16x64_fp8_fp8_CTransposed; }; template<> struct Dispatcher { using Type = WarpGemmMfma_f32_32x32x16_fp8_fp8_CTransposed; }; template<> struct Dispatcher { using Type = WarpGemmMfma_f32_16x16x32_fp8_fp8_CTransposed; }; template<> struct Dispatcher { using Type = WarpGemmMfma_f32_32x32x16_fp8_bf8; }; @@ -118,7 +115,6 @@ template<> struct Dispatcher { using Ty template<> struct Dispatcher { using Type = WarpGemmMfma_f32_32x32x16_bf8_fp8; }; template<> struct Dispatcher { using Type = WarpGemmMfma_f32_32x32x16_bf8_fp8_CTransposed; }; template<> struct Dispatcher { using Type = WarpGemmMfma_f32_32x32x16_bf8_bf8; }; -template<> struct Dispatcher { using Type = WarpGemmMfma_f32_32x32x32_bf8_bf8; }; template<> struct Dispatcher { using Type = WarpGemmMfma_f32_16x16x32_bf8_bf8; }; template<> struct Dispatcher { using Type = WarpGemmMfma_f32_16x16x32_bf8_bf8_CTransposed; }; template<> struct Dispatcher { using Type = WarpGemmMfma_f32_16x16x64_bf8_bf8; }; @@ -146,6 +142,16 @@ template<> struct Dispatcher struct Dispatcher { using Type = WarpGemmMfma_f32_32x32x64_bf8_fp8; }; template<> struct Dispatcher { using Type = WarpGemmMfma_f32_32x32x64_bf8_bf8; }; +template<> struct Dispatcher { using Type = WarpGemmMfma_f32_32x32x32_fp8_fp8<>; }; +template<> struct Dispatcher { using Type = WarpGemmMfma_f32_32x32x32_fp8_fp8; }; +template<> struct Dispatcher { using Type = WarpGemmMfma_f32_32x32x32_bf8_bf8<>; }; +template<> struct Dispatcher { using Type = WarpGemmMfma_f32_32x32x32_bf8_bf8; }; + +template<> struct Dispatcher { using Type = WarpGemmMfma_f32_16x16x64_fp8_fp8_CTransposed; }; +template<> struct Dispatcher { using Type = WarpGemmMfma_f32_16x16x64_fp8_fp8<>; }; +template<> struct Dispatcher { using Type = WarpGemmMfma_f32_16x16x64_fp8_fp8; }; + + //WMMA cases template struct Dispatcher { using Type = WarpGemmWmma_f32_16x16x16_f8_f8; }; template struct Dispatcher { using Type = WarpGemmWmma_f32_16x16x16_bf8_bf8; }; diff --git a/tile_engine/ops/gemm/gemm_preshuffle/configs/user_provided_config.json b/tile_engine/ops/gemm/gemm_preshuffle/configs/user_provided_config.json index cf7c79462e..9c1033b751 100644 --- a/tile_engine/ops/gemm/gemm_preshuffle/configs/user_provided_config.json +++ b/tile_engine/ops/gemm/gemm_preshuffle/configs/user_provided_config.json @@ -2,7 +2,7 @@ "tile_config": { "tile_m": { "values": [ - 64 + 128 ] }, "tile_n": { @@ -12,17 +12,17 @@ }, "tile_k": { "values": [ - 192 + 64 ] }, "warp_m": { "values": [ - 2 + 4 ] }, "warp_n": { "values": [ - 2 + 1 ] }, "warp_k": { @@ -32,17 +32,17 @@ }, "warp_tile_m": { "values": [ - 16 + 32 ] }, "warp_tile_n": { "values": [ - 16 + 32 ] }, "warp_tile_k": { "values": [ - 32 + 64 ] } }, @@ -59,8 +59,7 @@ }, "epilogue": { "values": [ - "default", - "cshuffle" + "default" ] }, "pad_m": { @@ -80,7 +79,7 @@ }, "persistent": { "values": [ - true + false ] } },