mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-05 20:55:59 +00:00
Add smem inline assembly based implementation of gms_init/gms_barrier/gms_reset for gfx90a
This commit is contained in:
@@ -23,24 +23,24 @@ __device__ index_t get_warp_local_1d_id() { return threadIdx.x / get_warp_size()
|
||||
// get_wave_id() does the same thing as get_warp_local_1d_id(), except that
|
||||
// it tries to save the result in sgpr
|
||||
#if defined(__gfx90a__)
|
||||
static __device__ inline index_t get_wave_id()
|
||||
__device__ inline index_t get_wave_id()
|
||||
{
|
||||
int thread_id = threadIdx.x;
|
||||
int tmp_int;
|
||||
int wave_id;
|
||||
constexpr index_t shift = get_shift<warpSize>();
|
||||
constexpr int shift = get_shift<warpSize>();
|
||||
|
||||
// clang-format off
|
||||
__asm__ volatile("v_lshrrev_b32 %1, %3, %2 \n\
|
||||
v_readfirstlane_b32 %0, %1"
|
||||
: "=s"(wave_id), "=v"(tmp_int)
|
||||
: "v"(thread_id), "i"(shift));
|
||||
__asm__ volatile("v_readfirstlane_b32 s16, %1 \n\
|
||||
s_lshr_b32 %0, s16, %2"
|
||||
: "=s"(wave_id)
|
||||
: "v"(thread_id), "i"(shift)
|
||||
: "s16");
|
||||
// clang-format on
|
||||
|
||||
return wave_id;
|
||||
};
|
||||
#else
|
||||
static __device__ inline index_t get_wave_id() { return get_warp_local_1d_id(); };
|
||||
__device__ inline index_t get_wave_id() { return get_warp_local_1d_id(); };
|
||||
#endif
|
||||
|
||||
__device__ index_t get_block_1d_id() { return blockIdx.x; }
|
||||
|
||||
@@ -9,8 +9,78 @@ namespace ck {
|
||||
// Initialization flag of Barrier object, can be any value except for zero
|
||||
static constexpr int BarrierInitFlag = 0x7856;
|
||||
|
||||
#if defined(__gfx90a__)
|
||||
|
||||
// only the first thread-block in the synchronizaton group is supposed to call this function
|
||||
static __device__ void gms_init(int NumWarps, int* p_control_bits)
|
||||
__device__ inline void gms_init(int NumWarps, int* p_control_bits)
|
||||
{
|
||||
int wave_id = get_wave_id();
|
||||
|
||||
// clang-format off
|
||||
// regs[0] = BarrierInitFlag, regs[1] = NumWorkgroup, regs[2] = 0, regs[3] = 0
|
||||
// regs[0:3] using s[16:19]
|
||||
__asm__ volatile("s_cmp_lg_i32 %3, 0 \n\
|
||||
s_cbranch_scc1 skip_gms_init%= \n\
|
||||
s_movk_i32 s16, %1 \n\
|
||||
s_mov_b32 s17, %2 \n\
|
||||
s_movk_i32 s18, 0 \n\
|
||||
s_movk_i32 s19, 0 \n\
|
||||
s_atomic_cmpswap_x2 s[16:19], %0, 0 \n\
|
||||
s_waitcnt lgkmcnt(0) \n\
|
||||
skip_gms_init%=:"
|
||||
:
|
||||
: "s"(p_control_bits), "i"(BarrierInitFlag), "s"(NumWarps), "s"(wave_id)
|
||||
: "s16", "s17", "s18", "s19");
|
||||
// clang-format on
|
||||
};
|
||||
|
||||
// all the warps in the synchronization group is supposed to call this function
|
||||
__device__ inline void gms_barrier(int* p_control_bits)
|
||||
{
|
||||
// clang-format off
|
||||
__asm__ volatile("wait_initialized%=: \n\
|
||||
s_load_dword s16, %0, 0 glc \n\
|
||||
s_waitcnt lgkmcnt(0) \n\
|
||||
s_cmp_lg_u32 s16 %1 \n\
|
||||
s_cbranch_scc1 wait_initialized%= \n\
|
||||
s_atomic_sub %3, %0, %2 \n\
|
||||
wait_all_arrive%=: \n\
|
||||
s_load_dword s17, %0, %2 glc \n\
|
||||
s_waitcnt lgkmcnt(0) \n\
|
||||
s_cmp_lg_u32 s17, 0 \n\
|
||||
s_cbranch_scc1 wait_all_arrive%= \n\
|
||||
skip_barrier%=:"
|
||||
:
|
||||
: "s"(p_control_bits), "i"(BarrierInitFlag), "i"(sizeof(int)), "s"(1)
|
||||
: "s16", "s17");
|
||||
// clang-format on
|
||||
};
|
||||
|
||||
// only the first thread-block in the synchronizaton group is supposed to call this function
|
||||
__device__ inline void gms_reset(int* p_control_bits)
|
||||
{
|
||||
int wave_id = get_wave_id();
|
||||
|
||||
// clang-format off
|
||||
// regs[0] = 0, regs[1] = BarrierInitFlag
|
||||
// regs[0:1] using s[16:17]
|
||||
__asm__ volatile("s_cmp_lg_i32 %2, 0 \n\
|
||||
s_cbranch_scc1 skip_gms_reset%= \n\
|
||||
s_movk_i32 s16, 0 \n\
|
||||
s_movk_i32 s17, %1 \n\
|
||||
s_atomic_cmpswap s[16:17], %0, 0 \n\
|
||||
s_waitcnt lgkmcnt(0) \n\
|
||||
skip_gms_reset%=:"
|
||||
:
|
||||
: "s"(p_control_bits), "i"(BarrierInitFlag), "s"(wave_id)
|
||||
: "s16", "s17");
|
||||
// clang-format on
|
||||
};
|
||||
|
||||
#else
|
||||
|
||||
// only the first thread-block in the synchronizaton group is supposed to call this function
|
||||
__device__ inline void gms_init(int NumWarps, int* p_control_bits)
|
||||
{
|
||||
union
|
||||
{
|
||||
@@ -26,7 +96,7 @@ static __device__ void gms_init(int NumWarps, int* p_control_bits)
|
||||
};
|
||||
|
||||
// all the workgroups in the synchronization group is supposed to call this function
|
||||
static __device__ void gms_barrier(int* p_control_bits)
|
||||
__device__ inline void gms_barrier(int* p_control_bits)
|
||||
{
|
||||
constexpr int mask = warpSize - 1;
|
||||
|
||||
@@ -58,11 +128,13 @@ static __device__ void gms_barrier(int* p_control_bits)
|
||||
};
|
||||
|
||||
// only the first thread-block in the synchronizaton group is supposed to call this function
|
||||
static __device__ void gms_reset(int* p_control_bits)
|
||||
__device__ inline void gms_reset(int* p_control_bits)
|
||||
{
|
||||
// reset the barrier object
|
||||
if(threadIdx.x == 0)
|
||||
(void)atomicCAS(&p_control_bits[0], BarrierInitFlag, 0);
|
||||
};
|
||||
|
||||
#endif
|
||||
|
||||
} // namespace ck
|
||||
|
||||
Reference in New Issue
Block a user