fix async copytest bug (#2509)

* fix async copytest bug

* Add block_sync_lds_direct_load utility

* fix the s_waitcnt_imm calculation

* Improve s_waitcnt_imm calculation

* fix vmcnt shift

* add input validation and bug fix

* remove unnecessary output

* move test_copy into test

* change bit width check

* refactor macros into constexpr functions

which still get inlined

* wrap s_waitcnt api

* parameterize test

* cleanup

* cleanup fp8 stub

* add fp8 test cases; todo which input parameters are valid?

* replace n for fp8 in test cases

* add large shapes; fp8 fails again

* change input init

* test sync/async

* time the test

* clang-format test

* use float instead of bfloat to cover a 4-byte type

* fix logic - arg sections should be 'or'd

* make block_sync_lds_direct_load interface similar to old ck

* fix a few comment typos

* name common shapes

* revert the example to original logic of not waiting lds

* clang-format

---------

Co-authored-by: Max Podkorytov <4273004+tenpercent@users.noreply.github.com>
Co-authored-by: Thomas Ning <Thomas.Ning@amd.com>

[ROCm/composable_kernel commit: a5fdc663c8]
This commit is contained in:
Haocong WANG
2025-07-23 15:14:02 +08:00
committed by GitHub
parent 29e1e00edd
commit 60ff19fb4e
9 changed files with 313 additions and 191 deletions

View File

@@ -10,6 +10,15 @@
#include "ck_tile/core/numeric/integer.hpp"
#include "ck_tile/core/numeric/integral_constant.hpp"
#define CK_TILE_S_CNT_MAX 0b1100'1111'0111'1111
#define CK_TILE_VMCNT(cnt) \
([]() { static_assert(!((cnt) >> 6), "VMCNT only has 6 bits"); }(), \
((cnt)&0b1111) | (((cnt)&0b110000) << 10))
#define CK_TILE_EXPCNT(cnt) \
([]() { static_assert(!((cnt) >> 3), "EXP only has 3 bits"); }(), ((cnt) << 4))
#define CK_TILE_LGKMCNT(cnt) \
([]() { static_assert(!((cnt) >> 4), "LGKM only has 4 bits"); }(), ((cnt) << 8))
namespace ck_tile {
template <typename, bool>
@@ -113,13 +122,72 @@ CK_TILE_DEVICE void block_sync_load_raw(index_t cnt = 0)
#endif
}
// https://llvm.org/docs/AMDGPU/gfx9_waitcnt.html
struct waitcnt_arg
{
// bit numbers (hex) -------------------------> FE'DC'BA98'7'654'3210
// [V]M [E]XP [L]GKM counters and [U]NUSED ---> VV'UU'LLLL'U'EEE'VVVV
CK_TILE_DEVICE static constexpr index_t MAX = 0b11'00'1111'0'111'1111;
CK_TILE_DEVICE static constexpr index_t kMaxVmCnt = 0b111111;
CK_TILE_DEVICE static constexpr index_t kMaxExpCnt = 0b111;
CK_TILE_DEVICE static constexpr index_t kMaxLgkmCnt = 0b1111;
template <index_t cnt>
CK_TILE_DEVICE static constexpr index_t from_vmcnt()
{
static_assert(cnt >= 0 && !(cnt >> 6), "valid range is [0..63]");
return MAX & ((cnt & 0b1111) | ((cnt & 0b110000) << 10));
}
template <index_t cnt>
CK_TILE_DEVICE static constexpr index_t from_expcnt()
{
static_assert(cnt >= 0 && !(cnt >> 3), "valid range is [0..7]");
return MAX & (cnt << 4);
}
template <index_t cnt>
CK_TILE_DEVICE static constexpr index_t from_lgkmcnt()
{
static_assert(cnt >= 0 && !(cnt >> 4), "valid range is [0..15]");
return MAX & (cnt << 8);
}
};
template <index_t vmcnt = waitcnt_arg::kMaxVmCnt,
index_t expcnt = waitcnt_arg::kMaxExpCnt,
index_t lgkmcnt = waitcnt_arg::kMaxLgkmCnt>
CK_TILE_DEVICE void s_waitcnt()
{
__builtin_amdgcn_s_waitcnt(waitcnt_arg::from_vmcnt<vmcnt>() |
waitcnt_arg::from_expcnt<expcnt>() |
waitcnt_arg::from_lgkmcnt<lgkmcnt>());
}
template <index_t vmcnt = waitcnt_arg::kMaxVmCnt,
index_t expcnt = waitcnt_arg::kMaxExpCnt,
index_t lgkmcnt = waitcnt_arg::kMaxLgkmCnt>
CK_TILE_DEVICE void s_waitcnt_barrier()
{
s_waitcnt<vmcnt, expcnt, lgkmcnt>();
__builtin_amdgcn_s_barrier();
}
CK_TILE_DEVICE void block_sync_lds_direct_load()
{
#if 1
// invoke clang builtins which *should* produce the same result as the inline asm below
// difference: inline asm is being compiled to wait vmcnt(0) after the barrier
s_waitcnt_barrier<0, waitcnt_arg::kMaxExpCnt, 0>();
#else
// same content as in old CK (#999)
asm volatile("\
s_waitcnt vmcnt(0) \n \
s_waitcnt lgkmcnt(0) \n \
s_barrier \
" ::);
#endif
}
CK_TILE_DEVICE void s_nop(index_t cnt = 0)