mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-05 14:11:29 +00:00
[CK_TILE] FMHA Support hdim_v to as a Multiple of 32 (#2114)
* 160+192 * Add splitkv d160 * cleanup * fix * Add change log * Fix CHANGELOG * Use static_cast * Update ignored instance --------- Co-authored-by: asleepzzz <hanwen.chang@amd.com>
This commit is contained in:
@@ -282,18 +282,19 @@ class FmhaFwdApiPool:
|
||||
# TODO: do we need to check duplication?
|
||||
if trait.dtype not in self.pool.keys():
|
||||
self.pool[trait.dtype] = dict()
|
||||
if trait.hdim not in self.pool[trait.dtype].keys():
|
||||
self.pool[trait.dtype][trait.hdim] = list()
|
||||
hdim = trait.hdim, trait.bn1
|
||||
if hdim not in self.pool[trait.dtype].keys():
|
||||
self.pool[trait.dtype][hdim] = list()
|
||||
|
||||
self.pool[trait.dtype][trait.hdim].append(copy.copy(trait))
|
||||
self.pool[trait.dtype][hdim].append(copy.copy(trait))
|
||||
|
||||
@property
|
||||
def api(self) -> str:
|
||||
per_dtypes=str()
|
||||
for i, dtype in enumerate(self.pool.keys()):
|
||||
per_hdim_case=str()
|
||||
for j, hdim in enumerate(self.pool[dtype].keys()):
|
||||
traits=self.pool[dtype][hdim]
|
||||
for j, (hdim, hdim_v) in enumerate(self.pool[dtype].keys()):
|
||||
traits=self.pool[dtype][(hdim, hdim_v)]
|
||||
inners=str()
|
||||
for k, trait in enumerate(traits):
|
||||
if_k = 'if' if k == 0 else 'else if'
|
||||
@@ -306,7 +307,7 @@ class FmhaFwdApiPool:
|
||||
F_bm0=trait.bm0, F_bn0=trait.bn0, F_bk0=trait.bk0, F_bn1=trait.bn1, F_bk1=trait.bk1, F_bk0max=trait.bk0max,
|
||||
F_hdim=hdim, F_dtype=FWD_DTYPE_MAP[dtype])
|
||||
if_j = 'if' if j == 0 else 'else if'
|
||||
per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format(F_if=if_j, F_hdim=hdim, F_hdim_v=trait.bn1, F_inner_dispatch=inners)
|
||||
per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format(F_if=if_j, F_hdim=hdim, F_hdim_v=hdim_v, F_inner_dispatch=inners)
|
||||
if_i = 'if' if i == 0 else 'else if'
|
||||
per_dtypes = per_dtypes + FMHA_FWD_API_PER_DTYPE.format(F_if=if_i, F_dtype=dtype, F_hdim_case=per_hdim_case)
|
||||
if not per_dtypes:
|
||||
@@ -435,18 +436,20 @@ class FmhaFwdKernel:
|
||||
def get_fmha_fwd_tile_dict_from_dtype(dtype : str) -> Optional[dict]:
|
||||
if dtype == 'fp16' or dtype == 'bf16':
|
||||
return {
|
||||
'32' : FmhaFwdTileSize(128, 64, 16, 32, 32, 32, 2, 1, 1, 2, 1, 1, 32, 32, 16, 32, 32, 16, -1),
|
||||
'64' : FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1),
|
||||
### '96' : FmhaFwdTileSize(128, 128, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1),
|
||||
'128' : FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1),
|
||||
'192' : FmhaFwdTileSize(128, 128, 32, 128, 32, 192, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1),
|
||||
'256' : FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1),
|
||||
(32, 32) : FmhaFwdTileSize(128, 64, 16, 32, 32, 32, 2, 1, 1, 2, 1, 1, 32, 32, 16, 32, 32, 16, -1),
|
||||
(64, 64) : FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1),
|
||||
### (96, 128) : FmhaFwdTileSize(128, 128, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1),
|
||||
(128,128) : FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1),
|
||||
### (160,160) : FmhaFwdTileSize(128, 128, 32, 160, 32, 160, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, 1),
|
||||
(192,128) : FmhaFwdTileSize(128, 128, 32, 128, 32, 192, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1),
|
||||
### (192,192) : FmhaFwdTileSize(128, 128, 32, 192, 32, 192, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, 1),
|
||||
(256,256) : FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1),
|
||||
}
|
||||
elif dtype == 'fp8' or dtype == 'bf8':
|
||||
return {
|
||||
'64' : FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 2, 1, 1, 2, 1, 1, 32, 32, 32, 32, 32, 32, -1),
|
||||
'128' : FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1),
|
||||
'256' : FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1),
|
||||
(64,64 ) : FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 2, 1, 1, 2, 1, 1, 32, 32, 32, 32, 32, 32, -1),
|
||||
(128,128) : FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1),
|
||||
(256,256) : FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1),
|
||||
}
|
||||
else:
|
||||
return None
|
||||
@@ -454,7 +457,7 @@ def get_fmha_fwd_tile_dict_from_dtype(dtype : str) -> Optional[dict]:
|
||||
def get_fwd_blobs(kernel_filter : Optional[str], receipt, optdim_list, mask_impl) -> Tuple[FmhaFwdApiPool, List[FmhaFwdKernel]]:
|
||||
# TODO: we don't support tuning yet, so pick up one value for vlayout/pipeline/pad
|
||||
# support this in future
|
||||
def get_pipelines(dtype, hdim) -> List[FmhaFwdPipeline]:
|
||||
def get_pipelines(dtype, hdim, hdim_v) -> List[FmhaFwdPipeline]:
|
||||
# this function will populate a list possible pipelines
|
||||
# TODO: the order of List matters! the later in this list will be also be checked later
|
||||
# TODO: currently for qr pipeline, let 't' padding to appear later!!
|
||||
@@ -463,7 +466,7 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, optdim_list, mask_impl
|
||||
pipelines = []
|
||||
if dtype in ['fp16', 'bf16']:
|
||||
for logits, mask, bias, lse, dropout, skip in itertools.product(["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"], ["t", "f"], ["t", "f"]):
|
||||
if hdim == 256:
|
||||
if hdim == 256 and hdim_v == 256:
|
||||
# if True:
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'row', 'f', 'f', 'f', 'f', logits, bias, lse, dropout, squant, mask, skip))
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'col', 'f', 'f', 'f', 'f', logits, bias, lse, dropout, squant, mask, skip))
|
||||
@@ -507,15 +510,13 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, optdim_list, mask_impl
|
||||
if d == None:
|
||||
continue
|
||||
#for hdim_str, mode, mask, bias, lse in itertools.product(d.keys(), MODE_MAP.keys(), MASK_MAP.keys(), ["t", "f"], ["t", "f"]):
|
||||
for hdim_str, mode in itertools.product(d.keys(), MODE_MAP.keys()):
|
||||
tile = d[hdim_str]
|
||||
hdim = int(hdim_str)
|
||||
for pipeline in get_pipelines(dtype, hdim):
|
||||
for ((hdim, hdim_v), tile), mode in itertools.product(d.items(), MODE_MAP.keys()):
|
||||
for pipeline in get_pipelines(dtype, hdim, hdim_v):
|
||||
if mode == "group":
|
||||
if pipeline.F_spad != 't' or pipeline.F_skpad != 't':
|
||||
# in group mode, spad/skpad must be true, since we can't predict if seqlen of current batch need pad or not
|
||||
continue
|
||||
if hdim == 192 and tile.F_bn1 == 128:
|
||||
if (hdim, hdim_v) == (192, 128) or hdim == 160:
|
||||
# NOTE: this is used to speedup deepseek prefill case, we don't gen training
|
||||
if pipeline.F_bias != 'no' or pipeline.F_dropout == 't':
|
||||
continue
|
||||
|
||||
@@ -34,6 +34,7 @@ K0_MAX_SUBMAX_MAP = {
|
||||
64 : 64,
|
||||
96 : 128,
|
||||
128: 128,
|
||||
# 160: 160,
|
||||
256: 256
|
||||
}
|
||||
|
||||
@@ -638,6 +639,7 @@ def get_fmha_fwd_tile_dict_from_dtype(dtype : str) -> Optional[dict]:
|
||||
'64' : FmhaFwdTileSize(64, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1),
|
||||
### '96' : FmhaFwdTileSize(64, 128, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1),
|
||||
'128' : FmhaFwdTileSize(64, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1),
|
||||
### '160' : FmhaFwdTileSize(64, 128, 32, 160, 32, 160, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1),
|
||||
'256' : FmhaFwdTileSize(64, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1),
|
||||
}
|
||||
elif dtype == 'fp8' or dtype == 'bf8':
|
||||
@@ -656,6 +658,7 @@ def get_fmha_fwd_splitkv_combine_tile_dict_from_dtype(dtype : str) -> Optional[d
|
||||
'64' : FmhaFwdSplitKVCombineTileSize(32, -1),
|
||||
### '96' : FmhaFwdSplitKVCombineTileSize(32, -1),
|
||||
'128' : FmhaFwdSplitKVCombineTileSize(32, -1),
|
||||
### '160' : FmhaFwdSplitKVCombineTileSize(32, -1),
|
||||
'256' : FmhaFwdSplitKVCombineTileSize(32, -1),
|
||||
}
|
||||
elif dtype == 'fp8' or dtype == 'bf8':
|
||||
@@ -683,7 +686,7 @@ def get_fwd_splitkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) ->
|
||||
if dtype in ['fp16', 'bf16']:
|
||||
for logits, mask, bias, pagedkv in itertools.product(["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"]):
|
||||
# TODO: use async pipeline when compiler is more stable
|
||||
if hdim == 256 or hdim in [32, 64, 128]: ### [32, 64, 96, 128]:
|
||||
if hdim == 256 or hdim in [32, 64, 128]: ### [32, 64, 96, 128, 160]:
|
||||
# if True:
|
||||
pipelines.append(Pipeline('qr', 'row', 'f', 't', 'f', 'f', logits, bias, 't', squant, pagedkv, mask))
|
||||
pipelines.append(Pipeline('qr', 'col', 'f', 't', 'f', 'f', logits, bias, 't', squant, pagedkv, mask))
|
||||
|
||||
Reference in New Issue
Block a user