mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 06:49:15 +00:00
[rocm-libraries] ROCm/rocm-libraries#4584 (commit 42efd1d)
[CK_TILE][FMHA] Support gfx11 ## Motivation Add support of gfx11 architectures (RDNA3) to FMHA. ## Technical Details Distributions (matrix elements to lane registers mapping) of gfx11 WMMA are completely different from distributions of gfx9 MFMA and gfx12 WMMA. There are two cases in FMHA where this difference matters: * usage of results (matrix C) of one GEMM as input (matrix A) of another GEMM. * random number generation for dropout (implementation for gfx9 MFMA, gfx12 WMMA and host validation produce the same results). Both cases are solved by a special remapping implemented using `__builtin_amdgcn_permlanex16` and `__builtin_amdgcn_perm`. Additional changes: * FMHA tests are now build and run only for those types for which instances exist (gfx11 supports only fp16 and bf16). * Two fixes for uninitialized values (`mask.sink` and `do_fp8_static_quant`): they may contain garbage resulting in incorrect dispatching logic, sometimes tests report that there are no instance available for current parameters. * Small fix to remove expcnt(0) from s_waitcnt instruction on gfx11 when they are not requested (i.e. every time), likely has no effect on performance but makes disassembly a bit clearer. ## Test Plan ``` ninja test_ck_tile_fmha bin/test_ck_tile_fmha_fwd_fp16 bin/test_ck_tile_fmha_fwd_bf16 bin/test_ck_tile_fmha_bwd_fp16 bin/test_ck_tile_fmha_bwd_bf16 ``` ## Test Result All tests must pass (some tests may be skipped). ## Submission Checklist - [x] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
This commit is contained in:
committed by
assistant-librarian[bot]
parent
1915cdfcc2
commit
0d92fffedb
@@ -911,14 +911,15 @@ struct WaitcntLayoutGfx12
|
||||
};
|
||||
|
||||
struct WaitcntLayoutGfx11
|
||||
{ // vm[15:10] (6), lgkm[9:4] (6), exp unused
|
||||
{ // vm[15:10] (6), lgkm[9:4] (6), exp [2:0] (3)
|
||||
CK_TILE_DEVICE static constexpr index_t VM_MASK = 0x3F;
|
||||
CK_TILE_DEVICE static constexpr index_t LGKM_MASK = 0x3F;
|
||||
CK_TILE_DEVICE static constexpr bool HAS_EXP = false;
|
||||
CK_TILE_DEVICE static constexpr index_t EXP_MASK = 0x07;
|
||||
CK_TILE_DEVICE static constexpr bool HAS_EXP = true;
|
||||
|
||||
CK_TILE_DEVICE static constexpr index_t pack_vm(index_t c) { return ((c & VM_MASK) << 10); }
|
||||
CK_TILE_DEVICE static constexpr index_t pack_lgkm(index_t c) { return ((c & LGKM_MASK) << 4); }
|
||||
CK_TILE_DEVICE static constexpr index_t pack_exp(index_t) { return 0; }
|
||||
CK_TILE_DEVICE static constexpr index_t pack_exp(index_t c) { return (c & EXP_MASK); }
|
||||
};
|
||||
|
||||
struct WaitcntLayoutLegacy
|
||||
@@ -952,10 +953,14 @@ using Waitcnt = WaitcntLayoutLegacy;
|
||||
struct waitcnt_arg
|
||||
{
|
||||
// kMax* exposed for callers; match field widths per-arch
|
||||
#if defined(__gfx12__) || defined(__gfx11__)
|
||||
#if defined(__gfx12__)
|
||||
CK_TILE_DEVICE static constexpr index_t kMaxVmCnt = 0x3F; // 6 bits
|
||||
CK_TILE_DEVICE static constexpr index_t kMaxLgkmCnt = 0x3F; // 6 bits
|
||||
CK_TILE_DEVICE static constexpr index_t kMaxExpCnt = 0x0; // none
|
||||
#elif defined(__gfx11__)
|
||||
CK_TILE_DEVICE static constexpr index_t kMaxVmCnt = 0x3F; // 6 bits
|
||||
CK_TILE_DEVICE static constexpr index_t kMaxLgkmCnt = 0x3F; // 6 bits
|
||||
CK_TILE_DEVICE static constexpr index_t kMaxExpCnt = 0x07; // 3 bits
|
||||
#else
|
||||
CK_TILE_DEVICE static constexpr index_t kMaxVmCnt = 0x3F; // 6 bits (split)
|
||||
CK_TILE_DEVICE static constexpr index_t kMaxLgkmCnt = 0x0F; // 4 bits
|
||||
@@ -981,8 +986,8 @@ struct waitcnt_arg
|
||||
{
|
||||
if constexpr(Waitcnt::HAS_EXP)
|
||||
{
|
||||
// EXP_MASK only exists on legacy
|
||||
#if !defined(__gfx12__) && !defined(__gfx11__)
|
||||
// EXP_MASK only exists on pre-gfx12
|
||||
#if !defined(__gfx12__)
|
||||
static_assert((cnt & ~Waitcnt::EXP_MASK) == 0, "expcnt out of range");
|
||||
return Waitcnt::pack_exp(cnt);
|
||||
#else
|
||||
|
||||
Reference in New Issue
Block a user