mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 22:39:03 +00:00
[rocm-libraries] ROCm/rocm-libraries#4584 (commit 42efd1d)
[CK_TILE][FMHA] Support gfx11 ## Motivation Add support of gfx11 architectures (RDNA3) to FMHA. ## Technical Details Distributions (matrix elements to lane registers mapping) of gfx11 WMMA are completely different from distributions of gfx9 MFMA and gfx12 WMMA. There are two cases in FMHA where this difference matters: * usage of results (matrix C) of one GEMM as input (matrix A) of another GEMM. * random number generation for dropout (implementation for gfx9 MFMA, gfx12 WMMA and host validation produce the same results). Both cases are solved by a special remapping implemented using `__builtin_amdgcn_permlanex16` and `__builtin_amdgcn_perm`. Additional changes: * FMHA tests are now build and run only for those types for which instances exist (gfx11 supports only fp16 and bf16). * Two fixes for uninitialized values (`mask.sink` and `do_fp8_static_quant`): they may contain garbage resulting in incorrect dispatching logic, sometimes tests report that there are no instance available for current parameters. * Small fix to remove expcnt(0) from s_waitcnt instruction on gfx11 when they are not requested (i.e. every time), likely has no effect on performance but makes disassembly a bit clearer. ## Test Plan ``` ninja test_ck_tile_fmha bin/test_ck_tile_fmha_fwd_fp16 bin/test_ck_tile_fmha_fwd_bf16 bin/test_ck_tile_fmha_bwd_fp16 bin/test_ck_tile_fmha_bwd_bf16 ``` ## Test Result All tests must pass (some tests may be skipped). ## Submission Checklist - [x] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
This commit is contained in:
committed by
assistant-librarian[bot]
parent
1915cdfcc2
commit
0d92fffedb
@@ -3,9 +3,9 @@
|
||||
|
||||
set(INST_TARGETS ${SUPPORTED_GPU_TARGETS})
|
||||
# Currently only gfx9 and gfx12 archs are supported by FMHA
|
||||
list(FILTER INST_TARGETS INCLUDE REGEX "gfx9|gfx12")
|
||||
list(FILTER INST_TARGETS INCLUDE REGEX "gfx9|gfx1[12]")
|
||||
if(NOT INST_TARGETS)
|
||||
message(WARNING "Skipping Tile Engine FMHA compilation: No supported GPU targets (gfx9, gfx12) found in SUPPORTED_GPU_TARGETS: ${SUPPORTED_GPU_TARGETS}")
|
||||
message(WARNING "Skipping Tile Engine FMHA compilation: No supported GPU targets (gfx9, gfx11, gfx12) found in SUPPORTED_GPU_TARGETS: ${SUPPORTED_GPU_TARGETS}")
|
||||
return()
|
||||
endif()
|
||||
|
||||
|
||||
@@ -457,6 +457,24 @@ class KernelComponentFactoryGfx950(KernelComponentFactoryGfx9):
|
||||
return results
|
||||
|
||||
|
||||
class KernelComponentFactoryGfx11(KernelComponentFactoryBase):
|
||||
arch = ArchTrait("gfx11")
|
||||
|
||||
@staticmethod
|
||||
def get_dq_dk_dv_tiles(dtype: str, tr_load: str) -> List[FmhaBwdDQDKDVTileSize]:
|
||||
if tr_load == "t":
|
||||
return []
|
||||
if dtype in ["fp16", "bf16"]:
|
||||
return [
|
||||
# bm0, bn0, bk0, bk1, bk2, bk3, bk4, bhdq, bhdv,
|
||||
FmhaBwdDQDKDVTileSize( 32, 64, 32, 32, 32, 32, 64, 32, 32, 1, 4, 1, 4, 1, 1, 2, 2, 1, 16, 16, 16, 16, 16, 16, -1),
|
||||
FmhaBwdDQDKDVTileSize( 32, 64, 64, 32, 64, 32, 32, 64, 64, 1, 4, 1, 4, 1, 1, 2, 2, 1, 16, 16, 16, 16, 16, 16, -1),
|
||||
FmhaBwdDQDKDVTileSize( 16, 64, 128, 16, 128, 16, 32, 128, 128, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 16, 16, 16, 16, -1),
|
||||
FmhaBwdDQDKDVTileSize( 16, 64, 256, 16, 256, 16, 32, 256, 256, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 16, 16, 16, 16, -1),
|
||||
] # fmt: skip
|
||||
return []
|
||||
|
||||
|
||||
class KernelComponentFactoryGfx12(KernelComponentFactoryBase):
|
||||
arch = ArchTrait("gfx12")
|
||||
|
||||
@@ -483,6 +501,8 @@ def get_factory(target: str):
|
||||
if target.startswith("gfx9"):
|
||||
return KernelComponentFactoryGfx9
|
||||
|
||||
if target.startswith("gfx11"):
|
||||
return KernelComponentFactoryGfx11
|
||||
if target.startswith("gfx12"):
|
||||
return KernelComponentFactoryGfx12
|
||||
|
||||
|
||||
@@ -1094,6 +1094,50 @@ class KernelComponentFactoryGfx950(
|
||||
return pipelines
|
||||
|
||||
|
||||
class KernelComponentFactoryGfx11(CompatibilityRuleFactory):
|
||||
arch = ArchTrait("gfx11")
|
||||
|
||||
_DT_FP16_BF16 = ("fp16", "bf16")
|
||||
|
||||
@classmethod
|
||||
def supported_dtypes(cls) -> Tuple[str]:
|
||||
return cls._DT_FP16_BF16
|
||||
|
||||
@classmethod
|
||||
def get_hdim_tile_size_dict(cls, dtype: str) -> Optional[dict]:
|
||||
if dtype in cls._DT_FP16_BF16:
|
||||
return {
|
||||
# bm0, bn0, bk0, bn1, bk1,
|
||||
( 32, 32) : [FmhaFwdTileSize( 64, 64, 16, 32, 32, 32, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)],
|
||||
( 64, 64) : [FmhaFwdTileSize( 64, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)],
|
||||
(128, 128) : [FmhaFwdTileSize( 64, 64, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)],
|
||||
(192, 128) : [FmhaFwdTileSize( 64, 64, 32, 128, 32, 256, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)],
|
||||
(256, 256) : [FmhaFwdTileSize( 64, 64, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)],
|
||||
} # fmt: skip
|
||||
else:
|
||||
raise ValueError(f"unsupported dtype={dtype}")
|
||||
|
||||
@classmethod
|
||||
def get_pipelines(
|
||||
cls, dtype, hdim, hdim_v, receipt, mask_impl
|
||||
) -> List[FmhaFwdPipeline]:
|
||||
pipelines = []
|
||||
if dtype in cls._DT_FP16_BF16:
|
||||
qscale = "no"
|
||||
for logits, mask, bias, lse, dropout, skip, sink in itertools.product(
|
||||
["t", "f"],
|
||||
get_mask_map(mask_impl).keys(),
|
||||
BIAS_MAP.keys(),
|
||||
["t", "f"],
|
||||
["t", "f"],
|
||||
["t", "f"],
|
||||
["t", "f"],
|
||||
):
|
||||
pipelines.append(FmhaFwdPipeline("qr", "row", "f", "f", "f", "f", logits, bias, lse, dropout, qscale, mask, skip, "f", sink)) # fmt: skip
|
||||
pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "t", "t", logits, bias, lse, dropout, qscale, mask, skip, "f", sink)) # fmt: skip
|
||||
return pipelines
|
||||
|
||||
|
||||
class KernelComponentFactoryGfx12(CompatibilityRuleFactory):
|
||||
arch = ArchTrait("gfx12")
|
||||
|
||||
@@ -1183,6 +1227,8 @@ def get_factory(target: str):
|
||||
if target.startswith("gfx9"):
|
||||
return KernelComponentFactoryGfx9
|
||||
|
||||
if target.startswith("gfx11"):
|
||||
return KernelComponentFactoryGfx11
|
||||
if target.startswith("gfx12"):
|
||||
return KernelComponentFactoryGfx12
|
||||
|
||||
|
||||
@@ -388,6 +388,22 @@ class KernelComponentFactoryGfx9(KernelComponentFactoryBase):
|
||||
arch = ArchTrait("gfx9")
|
||||
|
||||
|
||||
class KernelComponentFactoryGfx11(KernelComponentFactoryBase):
|
||||
arch = ArchTrait("gfx11")
|
||||
|
||||
@staticmethod
|
||||
def get_hdim_tile_size_dict(dtype: str) -> Optional[dict]:
|
||||
if dtype in ["fp16", "bf16"]:
|
||||
return KernelComponentFactoryBase.get_hdim_tile_size_dict(dtype)
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def get_pipelines(dtype, hdim) -> List[FmhaFwdAppendKVPipeline]:
|
||||
if dtype in ["fp16", "bf16"]:
|
||||
return KernelComponentFactoryBase.get_pipelines(dtype, hdim)
|
||||
return []
|
||||
|
||||
|
||||
class KernelComponentFactoryGfx12(KernelComponentFactoryBase):
|
||||
arch = ArchTrait("gfx12")
|
||||
|
||||
@@ -398,6 +414,8 @@ def get_factory(target: str):
|
||||
if target.startswith("gfx9"):
|
||||
return KernelComponentFactoryGfx9
|
||||
|
||||
if target.startswith("gfx11"):
|
||||
return KernelComponentFactoryGfx11
|
||||
if target.startswith("gfx12"):
|
||||
return KernelComponentFactoryGfx12
|
||||
|
||||
|
||||
@@ -836,6 +836,23 @@ class KernelComponentFactoryGfx9(KernelComponentFactoryBase):
|
||||
return None
|
||||
|
||||
|
||||
class KernelComponentFactoryGfx11(KernelComponentFactoryBase):
|
||||
arch = ArchTrait("gfx11")
|
||||
|
||||
@staticmethod
|
||||
def get_hdim_tile_size_dict(dtype: str) -> Optional[dict]:
|
||||
if dtype in ["fp16", "bf16"]:
|
||||
return {
|
||||
# bm0, bn0, bk0, bn1, bk1,
|
||||
"32" : FmhaFwdTileSize( 64, 64, 16, 32, 32, 32, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1),
|
||||
"64" : FmhaFwdTileSize( 64, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1),
|
||||
"128": FmhaFwdTileSize( 64, 64, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1),
|
||||
"256": FmhaFwdTileSize( 64, 64, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1),
|
||||
} # fmt: skip
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
class KernelComponentFactoryGfx12(KernelComponentFactoryBase):
|
||||
arch = ArchTrait("gfx12")
|
||||
|
||||
@@ -865,6 +882,8 @@ def get_factory(target: str):
|
||||
if target.startswith("gfx9"):
|
||||
return KernelComponentFactoryGfx9
|
||||
|
||||
if target.startswith("gfx11"):
|
||||
return KernelComponentFactoryGfx11
|
||||
if target.startswith("gfx12"):
|
||||
return KernelComponentFactoryGfx12
|
||||
|
||||
|
||||
@@ -597,6 +597,24 @@ class KernelComponentFactoryGfx9(KernelComponentFactoryBase):
|
||||
return None
|
||||
|
||||
|
||||
class KernelComponentFactoryGfx11(KernelComponentFactoryBase):
|
||||
arch = ArchTrait("gfx11")
|
||||
|
||||
@staticmethod
|
||||
def get_hdim_tile_size_dict(dtype: str) -> Optional[dict]:
|
||||
if dtype in ["fp16", "bf16"]:
|
||||
return {
|
||||
# bm0, bn0, bk0, bn1, bk1,
|
||||
# "32": FmhaFwdTileSize( 64, 64, 16, 32, 32, 32, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1),
|
||||
# "64": FmhaFwdTileSize( 64, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1),
|
||||
"128": FmhaFwdTileSize( 64, 64, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1),
|
||||
# "192": FmhaFwdTileSize( 64, 64, 32, 128, 32, 256, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1),
|
||||
# "256": FmhaFwdTileSize( 64, 64, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1),
|
||||
} # fmt: skip
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
class KernelComponentFactoryGfx12(KernelComponentFactoryBase):
|
||||
arch = ArchTrait("gfx12")
|
||||
|
||||
@@ -628,6 +646,8 @@ def get_factory(target: str):
|
||||
if target.startswith("gfx9"):
|
||||
return KernelComponentFactoryGfx9
|
||||
|
||||
if target.startswith("gfx11"):
|
||||
return KernelComponentFactoryGfx11
|
||||
if target.startswith("gfx12"):
|
||||
return KernelComponentFactoryGfx12
|
||||
|
||||
|
||||
@@ -1635,8 +1635,8 @@ struct fmha_fwd_splitkv_traits
|
||||
mask_enum mask_type;
|
||||
bias_enum bias_type; // 0:no bias, 1:elementwise bias, 2:alibi. sync with BlockAttentionBiasEnum
|
||||
bool has_lse;
|
||||
bool do_fp8_static_quant;
|
||||
bool has_sink = false;
|
||||
bool do_fp8_static_quant = false;
|
||||
bool has_sink = false;
|
||||
// TODO: padding check is inside this api
|
||||
};
|
||||
float fmha_fwd_splitkv(fmha_fwd_splitkv_traits,
|
||||
|
||||
@@ -73,6 +73,7 @@ struct mask_info
|
||||
tmp.x = r.at(ck_tile::number<1>{});
|
||||
tmp.left = left_size;
|
||||
tmp.right = right_size;
|
||||
tmp.sink = 0;
|
||||
}
|
||||
else if(t == "t" || t == "b" || t == "g")
|
||||
{
|
||||
@@ -147,7 +148,10 @@ struct mask_info
|
||||
}
|
||||
else if(str == "0")
|
||||
{
|
||||
tmp.type = mask_enum::no_mask;
|
||||
tmp.type = mask_enum::no_mask;
|
||||
tmp.left = -1;
|
||||
tmp.right = -1;
|
||||
tmp.sink = 0;
|
||||
}
|
||||
else if(str == "1" || str == "t")
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user