diff --git a/include/ck_tile/core/arch/arch.hpp b/include/ck_tile/core/arch/arch.hpp index 3b12cf061b..31ba053796 100644 --- a/include/ck_tile/core/arch/arch.hpp +++ b/include/ck_tile/core/arch/arch.hpp @@ -139,6 +139,34 @@ CK_TILE_DEVICE void block_sync_load_raw(index_t cnt = 0) // https://llvm.org/docs/AMDGPU/gfx9_waitcnt.html struct waitcnt_arg { +#if defined(__gfx12__) + // use s_wait_loadcnt_dscnt in this instruction; in this instruction, ds [5:0]; mem [13:8] + CK_TILE_DEVICE static constexpr index_t MAX = 0b00'111111'00'111111; + + CK_TILE_DEVICE static constexpr index_t kMaxVmCnt = 0b111111; + CK_TILE_DEVICE static constexpr index_t kMaxExpCnt = 0b111; + CK_TILE_DEVICE static constexpr index_t kMaxLgkmCnt = 0b111111; + + template + CK_TILE_DEVICE static constexpr index_t from_vmcnt() + { + static_assert(cnt >= 0 && !(cnt >> 6), "valid range is [0..63]"); + return MAX & (cnt << 8); + } + + template + CK_TILE_DEVICE static constexpr index_t from_expcnt() + { + return 0; // no export in MI series + } + + template + CK_TILE_DEVICE static constexpr index_t from_lgkmcnt() + { + static_assert(cnt >= 0 && !(cnt >> 6), "valid range is [0..63]"); + return MAX & cnt; + } +#else // bit numbers (hex) -------------------------> FE'DC'BA98'7'654'3210 // [V]M [E]XP [L]GKM counters and [U]NUSED ---> VV'UU'LLLL'U'EEE'VVVV CK_TILE_DEVICE static constexpr index_t MAX = 0b11'00'1111'0'111'1111; @@ -167,6 +195,7 @@ struct waitcnt_arg static_assert(cnt >= 0 && !(cnt >> 4), "valid range is [0..15]"); return MAX & (cnt << 8); } +#endif }; template CK_TILE_DEVICE void s_waitcnt() { +#if defined(__gfx12__) + // GFX12 do't use __builtin_amdgcn_s_waitcnt + constexpr index_t wait_mask = waitcnt_arg::from_vmcnt() | + waitcnt_arg::from_expcnt() | + waitcnt_arg::from_lgkmcnt(); + + asm volatile("s_wait_loadcnt_dscnt %0" : : "n"(wait_mask) : "memory"); +#else __builtin_amdgcn_s_waitcnt(waitcnt_arg::from_vmcnt() | waitcnt_arg::from_expcnt() | waitcnt_arg::from_lgkmcnt()); +#endif } template CK_TILE_DEVICE void s_waitcnt_barrier() { +#if defined(__gfx12__) + // GFX12 optimization: Manual barrier implementation avoids performance penalty + // from __builtin_amdgcn_s_barrier which inserts extra s_wait_loadcnt_dscnt 0x0 + constexpr index_t wait_mask = waitcnt_arg::from_vmcnt() | + waitcnt_arg::from_expcnt() | + waitcnt_arg::from_lgkmcnt(); + + asm volatile("s_wait_loadcnt_dscnt %0\n" + "s_barrier_signal -1\n" + "s_barrier_wait -1" + : + : "n"(wait_mask) + : "memory"); +#else s_waitcnt(); __builtin_amdgcn_s_barrier(); +#endif } template diff --git a/include/ck_tile/ops/fused_moe/kernel/moe_sorting_kernel.hpp b/include/ck_tile/ops/fused_moe/kernel/moe_sorting_kernel.hpp index 42e2fad236..09c2510d3e 100644 --- a/include/ck_tile/ops/fused_moe/kernel/moe_sorting_kernel.hpp +++ b/include/ck_tile/ops/fused_moe/kernel/moe_sorting_kernel.hpp @@ -797,7 +797,7 @@ struct MoeSortingKernel else smem_tokens(curr_token_id, eid)++; } - __builtin_amdgcn_s_waitcnt(0xc07f); + s_waitcnt(); } __syncthreads(); // make sure different i_token iteration not overlap by different wave } @@ -922,7 +922,7 @@ struct MoeSortingKernel // NOTE: this waitcnt is a must, compiler will not generate waitcnt lgkmcnt() // for above write however __syncthreads will cause barrier with waves other // than 0(which is not we want) - __builtin_amdgcn_s_waitcnt(0xc07f); + s_waitcnt(); } if((lid + i_e_ - get_warp_size()) == (num_experts - 1)) {