[CK_TILE] FMHA BWD Optimization For GFX950 (#2628)

* simplify fmha_bwd_kernel MakeKargs & dq_dram_window

* simply duplicate

* trload pipeline

* Try two-stage

* add prefetch

* optimize & iglp
This commit is contained in:
Yi DING
2025-08-12 11:11:55 +08:00
committed by GitHub
parent a7badc6ec5
commit 4fde1646e5
16 changed files with 2216 additions and 586 deletions

View File

@@ -83,6 +83,7 @@ using fmha_bwd_pipeline_problem_{F_idx} = ck_tile::BlockFmhaBwdPipelineProblem<
{F_deterministic},
fmha_mask_{F_idx},
fmha_dropout_{F_idx},
{F_trload},
fmha_bwd_trait_{F_idx}>;
using fmha_bwd_pipeline_{F_idx} = ck_tile::BlockFmhaBwdDQDKDVPipeline<fmha_bwd_pipeline_problem_{F_idx}>;
@@ -113,7 +114,8 @@ using dq_dk_dv_trait_{F_idx} = fmha_bwd_dq_dk_dv_traits_<{F_hdim},
{F_dbias},
{F_dpad},
{F_dvpad},
{F_deterministic}>;
{F_deterministic},
{F_trload}>;
#include <iostream>
@@ -168,29 +170,35 @@ float fmha_bwd_(const ck_tile::stream_config& s, fmha_bwd_args a)
template <>
float fmha_bwd<2>(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config& s){{
const bool has_load_tr = ck_tile::is_load_tr_supported();
float r = -1;
{F_dispatch}
return r;
}}
"""
FMHA_BWD_API_PER_DTYPE=""" {F_if}(t.data_type.compare(\"{F_dtype}\") == 0){{
{F_hdim_case}
FMHA_BWD_API_PER_TRLOAD=""" {F_if}({F_trload_cond}){{
{F_body}
}}
"""
FMHA_BWD_API_PER_HDIM_CASE=""" {F_if} (t.hdim_q <= {F_hdim} && t.hdim_v <= {F_hdim}) {{
{F_inner_dispatch}
}}
FMHA_BWD_API_PER_DTYPE=""" {F_if}(t.data_type.compare(\"{F_dtype}\") == 0){{
{F_body}
}}
"""
FMHA_BWD_API_PER_HDIM_CASE=""" {F_if} (t.hdim_q <= {F_hdim} && t.hdim_v <= {F_hdim}) {{
{F_body}
}}
"""
FMHA_BWD_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_dbias == {F_dbias}) && ({F_dropout_check}) &&
({F_scheck}) && ({F_dcheck}) && ({F_dvcheck}) && (t.is_deterministic == {F_deterministic})) {{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_spad1d}, {F_dvpad}>;
using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_mask}, {F_dropout}, {F_bias}, {F_dbias}, {F_dpad}, {F_dvpad}, {F_deterministic}>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_spad1d}, {F_dpad}, {F_deterministic}>;
r = fmha_bwd_<dot_do_o_trait_, dq_dk_dv_trait_, convert_dq_trait_>(s, a);
return r;
}}
FMHA_BWD_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_dbias == {F_dbias}) && ({F_dropout_check}) &&
({F_scheck}) && ({F_dcheck}) && ({F_dvcheck}) && (t.is_deterministic == {F_deterministic})) {{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_spad1d}, {F_dvpad}>;
using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_mask}, {F_dropout}, {F_bias}, {F_dbias}, {F_dpad}, {F_dvpad}, {F_deterministic}, {F_trload}>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_spad1d}, {F_dpad}, {F_deterministic}>;
r = fmha_bwd_<dot_do_o_trait_, dq_dk_dv_trait_, convert_dq_trait_>(s, a);
return r;
}}
"""
# M0 size for 1d kernels (dot/convert)
@@ -250,6 +258,7 @@ class FmhaBwdDQDKDVKernel:
F_mode : str # value from MODE_MAP
F_deterministic : str #
mask_impl : str #
F_trload : str #
@property
def template(self) -> str:
@@ -291,6 +300,7 @@ class FmhaBwdDQDKDVKernel:
F_mask = get_mask_map(self.mask_impl)[self.F_mask],
F_mode = MODE_MAP[self.F_mode],
F_deterministic = BOOL_MAP[self.F_deterministic],
F_trload = BOOL_MAP[self.F_trload],
)
@property
@@ -324,6 +334,9 @@ class FmhaBwdDQDKDVKernel:
if self.F_deterministic == 't' : n += '_deterministic'
else: n += '_ndeterministic'
if self.F_trload == 't' : n += '_trload'
else: n += '_ntrload'
return n
@property
@@ -332,8 +345,8 @@ class FmhaBwdDQDKDVKernel:
# TODO: design a more practical way to do it
# this is current supported tile size.
def get_fmha_bwd_dq_dk_dv_tile_ppl_dict_from_dtype(dtype : str) -> Optional[dict]:
if dtype == 'fp16' or dtype == 'bf16':
def get_fmha_bwd_dq_dk_dv_tile_ppl_dict_from_dtype(dtype : str, tr_load: str) -> Optional[dict]:
if (dtype == 'fp16' or dtype == 'bf16') and tr_load == 'f':
return {
'32' : FmhaBwdDQDKDVTileSize( 32, 128, 32, 32, 32, 32, 64, 32, 32, 1, 4, 1, 4, 1, 1, 2, 2, 1, 16, 16, 32, 16, 16, 16, 1),
'64' : FmhaBwdDQDKDVTileSize( 32, 128, 64, 32, 64, 32, 32, 64, 64, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 16, 1),
@@ -341,6 +354,10 @@ def get_fmha_bwd_dq_dk_dv_tile_ppl_dict_from_dtype(dtype : str) -> Optional[dict
# '160' : FmhaBwdDQDKDVTileSize( 32, 64, 160, 32, 160, 32, 32, 160, 160, 1, 4, 1, 4, 1, 1, 2, 2, 1, 16, 16, 32, 16, 16, 16, 1),
'256' : FmhaBwdDQDKDVTileSize( 16, 64, 256, 16, 256, 16, 32, 256, 256, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 16, 1),
}
elif (dtype == 'fp16' or dtype == 'bf16') and tr_load == 't':
return {
'128' : FmhaBwdDQDKDVTileSize( 32, 128, 128, 32, 128, 32, 32, 128, 128, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 32, 1),
}
else:
return None
@@ -573,6 +590,7 @@ class FmhaBwdApiTrait:
dvpad : str
deterministic : str
mask_impl : str
tr_load : bool
@property
def bm0(self) -> int:
@@ -620,7 +638,7 @@ class FmhaBwdApiTrait:
def dq_dk_dv_kernel(self) -> FmhaBwdDQDKDVKernel:
return FmhaBwdDQDKDVKernel(F_idx=self.idx, F_hdim=self.hdim, F_dtype=self.dtype, F_tile=self.tile,
F_dpad=self.dpad, F_dvpad=self.dvpad, F_bias=self.bias, F_dbias=self.dbias, F_dropout=self.dropout,
F_mask=self.mask, F_mode=self.mode, F_deterministic=self.deterministic, mask_impl=self.mask_impl)
F_mask=self.mask, F_mode=self.mode, F_deterministic=self.deterministic, mask_impl=self.mask_impl, F_trload=self.tr_load)
@property
def convert_dq_kernel(self) -> FmhaBwdConvertQGradKernel:
@@ -636,12 +654,13 @@ class FmhaBwdApiTrait:
class FmhaBwdApiPool:
def __init__(self, mask_impl):
self.dq_dk_dv_pool = defaultdict(lambda: defaultdict(list))
self.dq_dk_dv_pool = defaultdict(lambda: defaultdict(lambda: defaultdict(list)))
self.mask_impl = mask_impl
def register_dq_dk_dv_traits(self, trait : FmhaBwdApiTrait) -> None:
# TODO: do we need to check duplication?
self.dq_dk_dv_pool[trait.dtype][trait.hdim].append(copy.copy(trait))
self.dq_dk_dv_pool[trait.tr_load][trait.dtype][trait.hdim].append(copy.copy(trait))
@staticmethod
def if_(i: int) -> str:
@@ -656,24 +675,31 @@ class FmhaBwdApiPool:
F_bias=BIAS_MAP[trait.bias], F_dbias=BOOL_MAP[trait.dbias], F_dropout_check=DROPOUT_CHECK_MAP[trait.dropout], F_dropout=DROPOUT_MAP[trait.dropout],
F_scheck=trait.scheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck, F_hdim=trait.hdim, F_dtype=BWD_DTYPE_MAP[trait.dtype],
F_spad1d=BOOL_MAP[trait.spad1d], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad],
F_deterministic=BOOL_MAP[trait.deterministic])
F_deterministic=BOOL_MAP[trait.deterministic], F_trload=BOOL_MAP[trait.tr_load])
i += 1
return inners
@property
def api(self) -> str:
per_dtypes=str()
for i, dtype in enumerate(self.dq_dk_dv_pool):
per_hdim_case=str()
for j, hdim in enumerate(self.dq_dk_dv_pool[dtype]):
traits=self.dq_dk_dv_pool[dtype][hdim]
inners = self._api_innders(traits)
per_hdim_case = per_hdim_case + FMHA_BWD_API_PER_HDIM_CASE.format(F_if=self.if_(j), F_hdim=hdim, F_inner_dispatch=inners)
per_dtypes += FMHA_BWD_API_PER_DTYPE.format(F_if=self.if_(i), F_dtype=dtype, F_hdim_case=per_hdim_case)
if not per_dtypes:
tr_load_cond_map = {
"t": "has_load_tr",
"f": "true"
}
per_tr_load = ''
for tr_load in ["t", "f"]:
per_dtypes = ''
for j, dtype in enumerate(self.dq_dk_dv_pool[tr_load]):
per_hdim_case = ''
for k, hdim in enumerate(self.dq_dk_dv_pool[tr_load][dtype]):
traits = self.dq_dk_dv_pool[tr_load][dtype][hdim]
inners = self._api_innders(traits)
per_hdim_case = per_hdim_case + FMHA_BWD_API_PER_HDIM_CASE.format(F_if=self.if_(k), F_hdim=hdim, F_body=inners)
per_dtypes += FMHA_BWD_API_PER_DTYPE.format(F_if=self.if_(j), F_dtype=dtype, F_body=per_hdim_case)
per_tr_load += FMHA_BWD_API_PER_TRLOAD.format(F_if='if', F_trload_cond=tr_load_cond_map[tr_load], F_body=per_dtypes)
if not per_tr_load:
# empty string we add some ignore to suppress warning in api
per_dtypes += ' (void)t ; (void)s ; (void)a;'
return FMHA_BWD_KERNEL_HEADER + FMHA_BWD_API.format(F_dispatch = per_dtypes)
per_tr_load += ' (void)t ; (void)s ; (void)a;'
return FMHA_BWD_KERNEL_HEADER + FMHA_BWD_API.format(F_dispatch = per_tr_load)
def get_bwd_blobs(filter_list: str, receipt, mask_impl, optdim_list) -> Tuple[FmhaBwdApiPool, List[FmhaBwdOGradDotOKernel], List[FmhaBwdDQDKDVKernel], List[FmhaBwdConvertQGradKernel]]:
if filter_list == '':
@@ -690,8 +716,8 @@ def get_bwd_blobs(filter_list: str, receipt, mask_impl, optdim_list) -> Tuple[Fm
gen_convert_dq: Dict[FmhaBwdConvertQGradKernel, Literal[True]] = {}
api_pool = FmhaBwdApiPool(mask_impl)
for dtype in BWD_DTYPE_MAP.keys():
d = get_fmha_bwd_dq_dk_dv_tile_ppl_dict_from_dtype(dtype)
for dtype, tr_load in itertools.product(BWD_DTYPE_MAP.keys(), ["t", "f"]):
d = get_fmha_bwd_dq_dk_dv_tile_ppl_dict_from_dtype(dtype, tr_load)
if d is None:
continue
for hdim_str, mode, mask, bias, dbias, dropout, spad1d, dpad, dvpad, deterministic in itertools.product(d.keys(), MODE_MAP.keys(), get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"], DROPOUT_MAP.keys(), *([["t", "f"]] * 4)):
@@ -703,7 +729,9 @@ def get_bwd_blobs(filter_list: str, receipt, mask_impl, optdim_list) -> Tuple[Fm
continue
if ("wg32" in dropout):
continue
t = FmhaBwdApiTrait(idx=0, hdim=hdim, dtype=dtype, mode=mode,tile=tile,mask=mask, bias=bias, dbias=dbias, dropout=dropout, spad1d=spad1d, dpad=dpad, dvpad=dvpad, deterministic=deterministic, mask_impl=mask_impl)
if tr_load == "t" and (dpad == "t" or dvpad == "t"):
continue # tr_load cannot work with dpad or dvpad
t = FmhaBwdApiTrait(idx=0, hdim=hdim, dtype=dtype, mode=mode,tile=tile,mask=mask, bias=bias, dbias=dbias, dropout=dropout, spad1d=spad1d, dpad=dpad, dvpad=dvpad, deterministic=deterministic, mask_impl=mask_impl, tr_load=tr_load)
if not fnmatch.fnmatch(t.dot_do_o_kernel.name, filter_dot_do_o):
continue