From 5d56dde0e06c3edea74458ad4933fe72ffdda020 Mon Sep 17 00:00:00 2001 From: Po Yen Chen Date: Fri, 22 Aug 2025 10:17:05 +0800 Subject: [PATCH] [CK_TILE] Allow switching between SGPR/VGPR get_warp_id() return values (#2669) * Allow return VGPR get_warp_id() value * Avoid using SALU in async_load_raw() [ROCm/composable_kernel commit: 0db21053e68817a50b0ed0ceea87e88228ab2475] --- include/ck_tile/core/arch/arch.hpp | 13 +++++++++++-- include/ck_tile/core/tensor/tile_window.hpp | 7 +++++-- 2 files changed, 16 insertions(+), 4 deletions(-) diff --git a/include/ck_tile/core/arch/arch.hpp b/include/ck_tile/core/arch/arch.hpp index 234929d6e6..42f2390cde 100644 --- a/include/ck_tile/core/arch/arch.hpp +++ b/include/ck_tile/core/arch/arch.hpp @@ -98,9 +98,18 @@ CK_TILE_DEVICE index_t get_block_1d_id() { return blockIdx.x; } // Use these instead CK_TILE_DEVICE index_t get_lane_id() { return __lane_id(); } -CK_TILE_DEVICE index_t get_warp_id() +template +CK_TILE_DEVICE index_t get_warp_id(bool_constant = {}) { - return __builtin_amdgcn_readfirstlane(threadIdx.x / get_warp_size()); + const index_t warp_id = threadIdx.x / get_warp_size(); + if constexpr(ReturnSgpr) + { + return __builtin_amdgcn_readfirstlane(warp_id); + } + else + { + return warp_id; + } } CK_TILE_DEVICE index_t get_thread_id() { return threadIdx.x; } diff --git a/include/ck_tile/core/tensor/tile_window.hpp b/include/ck_tile/core/tensor/tile_window.hpp index ad5902f16e..f5ddcd278c 100644 --- a/include/ck_tile/core/tensor/tile_window.hpp +++ b/include/ck_tile/core/tensor/tile_window.hpp @@ -288,8 +288,11 @@ struct tile_window_with_static_distribution sizeof(LdsDataType) - size_per_buf; - const index_t m0_init_value = size_per_buf + size_per_wave * get_warp_id(); - m0_set_with_memory(m0_init_value); // This should be wave independent + // Use VALU so the compiler can optimize redundant/repeated computations + const index_t m0_init_value = + size_per_buf + size_per_wave * get_warp_id(/*ReturnSgpr=*/bool_constant{}); + m0_set_with_memory( + __builtin_amdgcn_readfirstlane(m0_init_value)); // This should be wave independent using Traits = typename Base::Traits;