From d16d00919c43f10759e7b4e4d112125221ed9064 Mon Sep 17 00:00:00 2001 From: Qianfeng Zhang Date: Sun, 25 Jun 2023 15:13:30 +0000 Subject: [PATCH] Add smem inline assembly based implementation of gms_init/gms_barrier/gms_reset for gfx90a --- include/ck/utility/get_id.hpp | 16 ++-- .../ck/utility/workgroup_synchronization.hpp | 78 ++++++++++++++++++- 2 files changed, 83 insertions(+), 11 deletions(-) diff --git a/include/ck/utility/get_id.hpp b/include/ck/utility/get_id.hpp index c872a1a0e5..0cbbc13d97 100644 --- a/include/ck/utility/get_id.hpp +++ b/include/ck/utility/get_id.hpp @@ -23,24 +23,24 @@ __device__ index_t get_warp_local_1d_id() { return threadIdx.x / get_warp_size() // get_wave_id() does the same thing as get_warp_local_1d_id(), except that // it tries to save the result in sgpr #if defined(__gfx90a__) -static __device__ inline index_t get_wave_id() +__device__ inline index_t get_wave_id() { int thread_id = threadIdx.x; - int tmp_int; int wave_id; - constexpr index_t shift = get_shift(); + constexpr int shift = get_shift(); // clang-format off - __asm__ volatile("v_lshrrev_b32 %1, %3, %2 \n\ - v_readfirstlane_b32 %0, %1" - : "=s"(wave_id), "=v"(tmp_int) - : "v"(thread_id), "i"(shift)); + __asm__ volatile("v_readfirstlane_b32 s16, %1 \n\ + s_lshr_b32 %0, s16, %2" + : "=s"(wave_id) + : "v"(thread_id), "i"(shift) + : "s16"); // clang-format on return wave_id; }; #else -static __device__ inline index_t get_wave_id() { return get_warp_local_1d_id(); }; +__device__ inline index_t get_wave_id() { return get_warp_local_1d_id(); }; #endif __device__ index_t get_block_1d_id() { return blockIdx.x; } diff --git a/include/ck/utility/workgroup_synchronization.hpp b/include/ck/utility/workgroup_synchronization.hpp index c5caee5a07..da7e344b48 100644 --- a/include/ck/utility/workgroup_synchronization.hpp +++ b/include/ck/utility/workgroup_synchronization.hpp @@ -9,8 +9,78 @@ namespace ck { // Initialization flag of Barrier object, can be any value except for zero static constexpr int BarrierInitFlag = 0x7856; +#if defined(__gfx90a__) + // only the first thread-block in the synchronizaton group is supposed to call this function -static __device__ void gms_init(int NumWarps, int* p_control_bits) +__device__ inline void gms_init(int NumWarps, int* p_control_bits) +{ + int wave_id = get_wave_id(); + + // clang-format off + // regs[0] = BarrierInitFlag, regs[1] = NumWorkgroup, regs[2] = 0, regs[3] = 0 + // regs[0:3] using s[16:19] + __asm__ volatile("s_cmp_lg_i32 %3, 0 \n\ + s_cbranch_scc1 skip_gms_init%= \n\ + s_movk_i32 s16, %1 \n\ + s_mov_b32 s17, %2 \n\ + s_movk_i32 s18, 0 \n\ + s_movk_i32 s19, 0 \n\ + s_atomic_cmpswap_x2 s[16:19], %0, 0 \n\ + s_waitcnt lgkmcnt(0) \n\ + skip_gms_init%=:" + : + : "s"(p_control_bits), "i"(BarrierInitFlag), "s"(NumWarps), "s"(wave_id) + : "s16", "s17", "s18", "s19"); + // clang-format on +}; + +// all the warps in the synchronization group is supposed to call this function +__device__ inline void gms_barrier(int* p_control_bits) +{ + // clang-format off + __asm__ volatile("wait_initialized%=: \n\ + s_load_dword s16, %0, 0 glc \n\ + s_waitcnt lgkmcnt(0) \n\ + s_cmp_lg_u32 s16 %1 \n\ + s_cbranch_scc1 wait_initialized%= \n\ + s_atomic_sub %3, %0, %2 \n\ + wait_all_arrive%=: \n\ + s_load_dword s17, %0, %2 glc \n\ + s_waitcnt lgkmcnt(0) \n\ + s_cmp_lg_u32 s17, 0 \n\ + s_cbranch_scc1 wait_all_arrive%= \n\ + skip_barrier%=:" + : + : "s"(p_control_bits), "i"(BarrierInitFlag), "i"(sizeof(int)), "s"(1) + : "s16", "s17"); + // clang-format on +}; + +// only the first thread-block in the synchronizaton group is supposed to call this function +__device__ inline void gms_reset(int* p_control_bits) +{ + int wave_id = get_wave_id(); + + // clang-format off + // regs[0] = 0, regs[1] = BarrierInitFlag + // regs[0:1] using s[16:17] + __asm__ volatile("s_cmp_lg_i32 %2, 0 \n\ + s_cbranch_scc1 skip_gms_reset%= \n\ + s_movk_i32 s16, 0 \n\ + s_movk_i32 s17, %1 \n\ + s_atomic_cmpswap s[16:17], %0, 0 \n\ + s_waitcnt lgkmcnt(0) \n\ + skip_gms_reset%=:" + : + : "s"(p_control_bits), "i"(BarrierInitFlag), "s"(wave_id) + : "s16", "s17"); + // clang-format on +}; + +#else + +// only the first thread-block in the synchronizaton group is supposed to call this function +__device__ inline void gms_init(int NumWarps, int* p_control_bits) { union { @@ -26,7 +96,7 @@ static __device__ void gms_init(int NumWarps, int* p_control_bits) }; // all the workgroups in the synchronization group is supposed to call this function -static __device__ void gms_barrier(int* p_control_bits) +__device__ inline void gms_barrier(int* p_control_bits) { constexpr int mask = warpSize - 1; @@ -58,11 +128,13 @@ static __device__ void gms_barrier(int* p_control_bits) }; // only the first thread-block in the synchronizaton group is supposed to call this function -static __device__ void gms_reset(int* p_control_bits) +__device__ inline void gms_reset(int* p_control_bits) { // reset the barrier object if(threadIdx.x == 0) (void)atomicCAS(&p_control_bits[0], BarrierInitFlag, 0); }; +#endif + } // namespace ck