Add smem inline assembly based implementation of gms_init/gms_barrier/gms_reset for gfx90a

This commit is contained in:
Qianfeng Zhang
2023-06-25 15:13:30 +00:00
parent 5046970f6f
commit d16d00919c
2 changed files with 83 additions and 11 deletions

View File

@@ -23,24 +23,24 @@ __device__ index_t get_warp_local_1d_id() { return threadIdx.x / get_warp_size()
// get_wave_id() does the same thing as get_warp_local_1d_id(), except that
// it tries to save the result in sgpr
#if defined(__gfx90a__)
static __device__ inline index_t get_wave_id()
__device__ inline index_t get_wave_id()
{
int thread_id = threadIdx.x;
int tmp_int;
int wave_id;
constexpr index_t shift = get_shift<warpSize>();
constexpr int shift = get_shift<warpSize>();
// clang-format off
__asm__ volatile("v_lshrrev_b32 %1, %3, %2 \n\
v_readfirstlane_b32 %0, %1"
: "=s"(wave_id), "=v"(tmp_int)
: "v"(thread_id), "i"(shift));
__asm__ volatile("v_readfirstlane_b32 s16, %1 \n\
s_lshr_b32 %0, s16, %2"
: "=s"(wave_id)
: "v"(thread_id), "i"(shift)
: "s16");
// clang-format on
return wave_id;
};
#else
static __device__ inline index_t get_wave_id() { return get_warp_local_1d_id(); };
__device__ inline index_t get_wave_id() { return get_warp_local_1d_id(); };
#endif
__device__ index_t get_block_1d_id() { return blockIdx.x; }

View File

@@ -9,8 +9,78 @@ namespace ck {
// Initialization flag of Barrier object, can be any value except for zero
static constexpr int BarrierInitFlag = 0x7856;
#if defined(__gfx90a__)
// only the first thread-block in the synchronizaton group is supposed to call this function
static __device__ void gms_init(int NumWarps, int* p_control_bits)
__device__ inline void gms_init(int NumWarps, int* p_control_bits)
{
int wave_id = get_wave_id();
// clang-format off
// regs[0] = BarrierInitFlag, regs[1] = NumWorkgroup, regs[2] = 0, regs[3] = 0
// regs[0:3] using s[16:19]
__asm__ volatile("s_cmp_lg_i32 %3, 0 \n\
s_cbranch_scc1 skip_gms_init%= \n\
s_movk_i32 s16, %1 \n\
s_mov_b32 s17, %2 \n\
s_movk_i32 s18, 0 \n\
s_movk_i32 s19, 0 \n\
s_atomic_cmpswap_x2 s[16:19], %0, 0 \n\
s_waitcnt lgkmcnt(0) \n\
skip_gms_init%=:"
:
: "s"(p_control_bits), "i"(BarrierInitFlag), "s"(NumWarps), "s"(wave_id)
: "s16", "s17", "s18", "s19");
// clang-format on
};
// all the warps in the synchronization group is supposed to call this function
__device__ inline void gms_barrier(int* p_control_bits)
{
// clang-format off
__asm__ volatile("wait_initialized%=: \n\
s_load_dword s16, %0, 0 glc \n\
s_waitcnt lgkmcnt(0) \n\
s_cmp_lg_u32 s16 %1 \n\
s_cbranch_scc1 wait_initialized%= \n\
s_atomic_sub %3, %0, %2 \n\
wait_all_arrive%=: \n\
s_load_dword s17, %0, %2 glc \n\
s_waitcnt lgkmcnt(0) \n\
s_cmp_lg_u32 s17, 0 \n\
s_cbranch_scc1 wait_all_arrive%= \n\
skip_barrier%=:"
:
: "s"(p_control_bits), "i"(BarrierInitFlag), "i"(sizeof(int)), "s"(1)
: "s16", "s17");
// clang-format on
};
// only the first thread-block in the synchronizaton group is supposed to call this function
__device__ inline void gms_reset(int* p_control_bits)
{
int wave_id = get_wave_id();
// clang-format off
// regs[0] = 0, regs[1] = BarrierInitFlag
// regs[0:1] using s[16:17]
__asm__ volatile("s_cmp_lg_i32 %2, 0 \n\
s_cbranch_scc1 skip_gms_reset%= \n\
s_movk_i32 s16, 0 \n\
s_movk_i32 s17, %1 \n\
s_atomic_cmpswap s[16:17], %0, 0 \n\
s_waitcnt lgkmcnt(0) \n\
skip_gms_reset%=:"
:
: "s"(p_control_bits), "i"(BarrierInitFlag), "s"(wave_id)
: "s16", "s17");
// clang-format on
};
#else
// only the first thread-block in the synchronizaton group is supposed to call this function
__device__ inline void gms_init(int NumWarps, int* p_control_bits)
{
union
{
@@ -26,7 +96,7 @@ static __device__ void gms_init(int NumWarps, int* p_control_bits)
};
// all the workgroups in the synchronization group is supposed to call this function
static __device__ void gms_barrier(int* p_control_bits)
__device__ inline void gms_barrier(int* p_control_bits)
{
constexpr int mask = warpSize - 1;
@@ -58,11 +128,13 @@ static __device__ void gms_barrier(int* p_control_bits)
};
// only the first thread-block in the synchronizaton group is supposed to call this function
static __device__ void gms_reset(int* p_control_bits)
__device__ inline void gms_reset(int* p_control_bits)
{
// reset the barrier object
if(threadIdx.x == 0)
(void)atomicCAS(&p_control_bits[0], BarrierInitFlag, 0);
};
#endif
} // namespace ck