[rocm-libraries] ROCm/rocm-libraries#4584 (commit 42efd1d)

[CK_TILE][FMHA] Support gfx11

## Motivation

Add support of gfx11 architectures (RDNA3) to FMHA.

## Technical Details

Distributions (matrix elements to lane registers mapping) of gfx11 WMMA
are completely different from distributions of gfx9 MFMA and gfx12 WMMA.
There are two cases in FMHA where this difference matters:
* usage of results (matrix C) of one GEMM as input (matrix A) of another
GEMM.
* random number generation for dropout (implementation for gfx9 MFMA,
gfx12 WMMA and host validation produce the same results).

Both cases are solved by a special remapping implemented using
`__builtin_amdgcn_permlanex16` and `__builtin_amdgcn_perm`.

Additional changes:
* FMHA tests are now build and run only for those types for which
instances exist (gfx11 supports only fp16 and bf16).
* Two fixes for uninitialized values (`mask.sink` and
`do_fp8_static_quant`): they may contain garbage resulting in incorrect
dispatching logic, sometimes tests report that there are no instance
available for current parameters.
* Small fix to remove expcnt(0) from s_waitcnt instruction on gfx11 when
they are not requested (i.e. every time), likely has no effect on
performance but makes disassembly a bit clearer.

## Test Plan

```
ninja test_ck_tile_fmha

bin/test_ck_tile_fmha_fwd_fp16
bin/test_ck_tile_fmha_fwd_bf16
bin/test_ck_tile_fmha_bwd_fp16
bin/test_ck_tile_fmha_bwd_bf16
```

## Test Result

All tests must pass (some tests may be skipped).

## Submission Checklist

- [x] Look over the contributing guidelines at
https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
This commit is contained in:
Anton Gorenko
2026-02-21 01:15:57 +00:00
committed by assistant-librarian[bot]
parent 1915cdfcc2
commit 0d92fffedb
19 changed files with 296 additions and 21 deletions

View File

@@ -457,6 +457,24 @@ class KernelComponentFactoryGfx950(KernelComponentFactoryGfx9):
return results
class KernelComponentFactoryGfx11(KernelComponentFactoryBase):
arch = ArchTrait("gfx11")
@staticmethod
def get_dq_dk_dv_tiles(dtype: str, tr_load: str) -> List[FmhaBwdDQDKDVTileSize]:
if tr_load == "t":
return []
if dtype in ["fp16", "bf16"]:
return [
# bm0, bn0, bk0, bk1, bk2, bk3, bk4, bhdq, bhdv,
FmhaBwdDQDKDVTileSize( 32, 64, 32, 32, 32, 32, 64, 32, 32, 1, 4, 1, 4, 1, 1, 2, 2, 1, 16, 16, 16, 16, 16, 16, -1),
FmhaBwdDQDKDVTileSize( 32, 64, 64, 32, 64, 32, 32, 64, 64, 1, 4, 1, 4, 1, 1, 2, 2, 1, 16, 16, 16, 16, 16, 16, -1),
FmhaBwdDQDKDVTileSize( 16, 64, 128, 16, 128, 16, 32, 128, 128, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 16, 16, 16, 16, -1),
FmhaBwdDQDKDVTileSize( 16, 64, 256, 16, 256, 16, 32, 256, 256, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 16, 16, 16, 16, -1),
] # fmt: skip
return []
class KernelComponentFactoryGfx12(KernelComponentFactoryBase):
arch = ArchTrait("gfx12")
@@ -483,6 +501,8 @@ def get_factory(target: str):
if target.startswith("gfx9"):
return KernelComponentFactoryGfx9
if target.startswith("gfx11"):
return KernelComponentFactoryGfx11
if target.startswith("gfx12"):
return KernelComponentFactoryGfx12

View File

@@ -1094,6 +1094,50 @@ class KernelComponentFactoryGfx950(
return pipelines
class KernelComponentFactoryGfx11(CompatibilityRuleFactory):
arch = ArchTrait("gfx11")
_DT_FP16_BF16 = ("fp16", "bf16")
@classmethod
def supported_dtypes(cls) -> Tuple[str]:
return cls._DT_FP16_BF16
@classmethod
def get_hdim_tile_size_dict(cls, dtype: str) -> Optional[dict]:
if dtype in cls._DT_FP16_BF16:
return {
# bm0, bn0, bk0, bn1, bk1,
( 32, 32) : [FmhaFwdTileSize( 64, 64, 16, 32, 32, 32, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)],
( 64, 64) : [FmhaFwdTileSize( 64, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)],
(128, 128) : [FmhaFwdTileSize( 64, 64, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)],
(192, 128) : [FmhaFwdTileSize( 64, 64, 32, 128, 32, 256, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)],
(256, 256) : [FmhaFwdTileSize( 64, 64, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)],
} # fmt: skip
else:
raise ValueError(f"unsupported dtype={dtype}")
@classmethod
def get_pipelines(
cls, dtype, hdim, hdim_v, receipt, mask_impl
) -> List[FmhaFwdPipeline]:
pipelines = []
if dtype in cls._DT_FP16_BF16:
qscale = "no"
for logits, mask, bias, lse, dropout, skip, sink in itertools.product(
["t", "f"],
get_mask_map(mask_impl).keys(),
BIAS_MAP.keys(),
["t", "f"],
["t", "f"],
["t", "f"],
["t", "f"],
):
pipelines.append(FmhaFwdPipeline("qr", "row", "f", "f", "f", "f", logits, bias, lse, dropout, qscale, mask, skip, "f", sink)) # fmt: skip
pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "t", "t", logits, bias, lse, dropout, qscale, mask, skip, "f", sink)) # fmt: skip
return pipelines
class KernelComponentFactoryGfx12(CompatibilityRuleFactory):
arch = ArchTrait("gfx12")
@@ -1183,6 +1227,8 @@ def get_factory(target: str):
if target.startswith("gfx9"):
return KernelComponentFactoryGfx9
if target.startswith("gfx11"):
return KernelComponentFactoryGfx11
if target.startswith("gfx12"):
return KernelComponentFactoryGfx12

View File

@@ -388,6 +388,22 @@ class KernelComponentFactoryGfx9(KernelComponentFactoryBase):
arch = ArchTrait("gfx9")
class KernelComponentFactoryGfx11(KernelComponentFactoryBase):
arch = ArchTrait("gfx11")
@staticmethod
def get_hdim_tile_size_dict(dtype: str) -> Optional[dict]:
if dtype in ["fp16", "bf16"]:
return KernelComponentFactoryBase.get_hdim_tile_size_dict(dtype)
return None
@staticmethod
def get_pipelines(dtype, hdim) -> List[FmhaFwdAppendKVPipeline]:
if dtype in ["fp16", "bf16"]:
return KernelComponentFactoryBase.get_pipelines(dtype, hdim)
return []
class KernelComponentFactoryGfx12(KernelComponentFactoryBase):
arch = ArchTrait("gfx12")
@@ -398,6 +414,8 @@ def get_factory(target: str):
if target.startswith("gfx9"):
return KernelComponentFactoryGfx9
if target.startswith("gfx11"):
return KernelComponentFactoryGfx11
if target.startswith("gfx12"):
return KernelComponentFactoryGfx12

View File

@@ -836,6 +836,23 @@ class KernelComponentFactoryGfx9(KernelComponentFactoryBase):
return None
class KernelComponentFactoryGfx11(KernelComponentFactoryBase):
arch = ArchTrait("gfx11")
@staticmethod
def get_hdim_tile_size_dict(dtype: str) -> Optional[dict]:
if dtype in ["fp16", "bf16"]:
return {
# bm0, bn0, bk0, bn1, bk1,
"32" : FmhaFwdTileSize( 64, 64, 16, 32, 32, 32, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1),
"64" : FmhaFwdTileSize( 64, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1),
"128": FmhaFwdTileSize( 64, 64, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1),
"256": FmhaFwdTileSize( 64, 64, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1),
} # fmt: skip
else:
return None
class KernelComponentFactoryGfx12(KernelComponentFactoryBase):
arch = ArchTrait("gfx12")
@@ -865,6 +882,8 @@ def get_factory(target: str):
if target.startswith("gfx9"):
return KernelComponentFactoryGfx9
if target.startswith("gfx11"):
return KernelComponentFactoryGfx11
if target.startswith("gfx12"):
return KernelComponentFactoryGfx12

View File

@@ -597,6 +597,24 @@ class KernelComponentFactoryGfx9(KernelComponentFactoryBase):
return None
class KernelComponentFactoryGfx11(KernelComponentFactoryBase):
arch = ArchTrait("gfx11")
@staticmethod
def get_hdim_tile_size_dict(dtype: str) -> Optional[dict]:
if dtype in ["fp16", "bf16"]:
return {
# bm0, bn0, bk0, bn1, bk1,
# "32": FmhaFwdTileSize( 64, 64, 16, 32, 32, 32, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1),
# "64": FmhaFwdTileSize( 64, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1),
"128": FmhaFwdTileSize( 64, 64, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1),
# "192": FmhaFwdTileSize( 64, 64, 32, 128, 32, 256, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1),
# "256": FmhaFwdTileSize( 64, 64, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1),
} # fmt: skip
else:
return None
class KernelComponentFactoryGfx12(KernelComponentFactoryBase):
arch = ArchTrait("gfx12")
@@ -628,6 +646,8 @@ def get_factory(target: str):
if target.startswith("gfx9"):
return KernelComponentFactoryGfx9
if target.startswith("gfx11"):
return KernelComponentFactoryGfx11
if target.startswith("gfx12"):
return KernelComponentFactoryGfx12