From 8c42eaf3f8e9c3a4a6943dfa7f0efb455eb33601 Mon Sep 17 00:00:00 2001 From: Thrupti Raj Lakshmana Gowda Date: Thu, 2 Apr 2026 21:55:03 -0500 Subject: [PATCH] [CK Tile] Fix architecture-dependent EightWave assignment in cshuffle_epilogue (#6102) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Motivation Tile engine CI build on the develop branch started failing after a recent change(https://github.com/ROCm/rocm-libraries/pull/5218) in `cshuffle_epilogue.hpp`. The `EightWave` constant was unconditionally computed as `(MWave * NWave == 8)` for all architectures, but this logic is only valid for gfx9*. On other architectures (e.g., gfx1201), `EightWave` must always be `false`, otherwise it leads to incorrect `BlockedXDLN_PerWarp` computation and build failures. ## Technical Details In `cshuffle_epilogue.hpp`, the `EightWave` static constexpr was set as: ```cpp static constexpr bool EightWave = (MWave * NWave == 8); ``` This was applied regardless of the target GPU architecture. The fix uses a preprocessor guard to make this architecture-aware: - **gfx9* (`__gfx9__`):** `EightWave` is evaluated as `(MWave * NWave == 8)` — true or false depending on the wave configuration - **All other architectures:** `EightWave` defaults to `false` ## Test Plan - Tile engine CI build on develop branch ## Test Result - *Pending CI* ## Submission Checklist - [x] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests. --- > **Note:** This PR supersedes ROCm/rocm-libraries#5436, which is blocked pending a review approval from a reviewer currently on PTO. The same changes have been applied to this branch (`users/tlakshma/ck/develop-clone`) to allow merging. Co-authored-by: Claude Sonnet 4.6 Co-authored-by: Max Podkorytov <4273004+tenpercent@users.noreply.github.com> Co-authored-by: Po Yen Chen --- include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp | 8 +++++++- tile_engine/ops/gemm_streamk/CMakeLists.txt | 4 ++-- 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp b/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp index 5d1ac2fd2f..fba831e205 100644 --- a/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp +++ b/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp @@ -129,7 +129,13 @@ struct CShuffleEpilogue static constexpr index_t isCTransposed = Problem::isCTransposed; static constexpr bool FixedVectorSize = Problem::FixedVectorSize; static constexpr bool TiledMMAPermuteN = Problem::TiledMMAPermuteN; - static constexpr bool EightWave = (MWave * NWave == 8); + +#if defined(__gfx9__) + static constexpr bool EightWave = (MWave * NWave == 8); +#else + static constexpr bool EightWave = false; +#endif + static constexpr index_t BlockedXDLN_PerWarp = EightWave ? kNPerBlock / NWave / NPerXdl : Problem::BlockedXDLN_PerWarp; static constexpr bool DoubleSmemBuffer = Problem::DoubleSmemBuffer; diff --git a/tile_engine/ops/gemm_streamk/CMakeLists.txt b/tile_engine/ops/gemm_streamk/CMakeLists.txt index 1a4e18f881..748b025514 100644 --- a/tile_engine/ops/gemm_streamk/CMakeLists.txt +++ b/tile_engine/ops/gemm_streamk/CMakeLists.txt @@ -216,7 +216,7 @@ message(STATUS "SUPPORTED_GPU_TARGETS: ${SUPPORTED_GPU_TARGETS}") # Filter GPU targets to only gfx90a, gfx942 set(GEMM_GPU_TARGETS_INDIVIDUAL "") -set(DESIRED_TARGETS "gfx90a;gfx942;gfx12-generic") # TODO: Add gfx950 when supported +set(DESIRED_TARGETS "gfx90a;gfx942") # TODO: Add gfx950 when supported foreach(target IN LISTS SUPPORTED_GPU_TARGETS) if(target IN_LIST DESIRED_TARGETS) @@ -227,7 +227,7 @@ endforeach() # Skip build if no matching targets found if(NOT GEMM_GPU_TARGETS_INDIVIDUAL) - message(WARNING "Skipping Tile Engine GEMM build: No supported GPU targets (gfx90a, gfx942, gfx12-generic) found in SUPPORTED_GPU_TARGETS: ${SUPPORTED_GPU_TARGETS}") + message(WARNING "Skipping Tile Engine GEMM build: No supported GPU targets (gfx90a, gfx942) found in SUPPORTED_GPU_TARGETS: ${SUPPORTED_GPU_TARGETS}") else() message(STATUS "Building individual GEMM targets for GPU targets: ${GEMM_GPU_TARGETS_INDIVIDUAL}")