[CK_TILE] FMHA BWD Pad HDim to a Multiple of 8 (#2918)

This commit is contained in:
Yi DING
2025-09-26 16:42:59 +08:00
committed by GitHub
parent 518d24e662
commit 32773fe5cb
12 changed files with 110 additions and 88 deletions

View File

@@ -50,16 +50,10 @@ using fmha_bwd_shape_{F_idx} = ck_tile::TileFmhaBwdShape<fmha_block_tile_{F_idx}
fmha_warp_tile2_{F_idx},
{F_maxq}>;
using fmha_bwd_trait_{F_idx} = ck_tile::TileFmhaTraits<false, /* kPadSeqLenQ */
false, /* kPadSeqLenK */
{F_dpad},
using fmha_bwd_trait_{F_idx} = ck_tile::TileFmhaBwdTraits<{F_dpad},
{F_dvpad},
false,
{F_bias},
{F_dbias},
false,
false,
false,
{F_occupancy}>;
using fmha_mask_{F_idx} = {F_mask};
using fmha_dropout_{F_idx} = {F_dropout};
@@ -94,19 +88,19 @@ using fmha_bwd_dk_epilogue_{F_idx} = ck_tile::Default2DEpilogue<
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<{F_dtype}>::AccDataType,
typename FmhaBwdTypeConfig<{F_dtype}>::KGradDataType,
false,
{F_dpad}>>;
({F_dpad} > 0)>>;
using fmha_bwd_dv_epilogue_{F_idx} = ck_tile::Default2DEpilogue<
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<{F_dtype}>::AccDataType,
typename FmhaBwdTypeConfig<{F_dtype}>::VGradDataType,
false,
{F_dvpad}>>;
({F_dvpad} > 0)>>;
using fmha_bwd_dq_epilogue_{F_idx} = ck_tile::Default2DEpilogue<
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<{F_dtype}>::AccDataType,
typename FmhaBwdTypeConfig<{F_dtype}>::QGradDataType,
false,
{F_dpad}>>;
({F_dpad} > 0)>>;
using fmha_bwd_dq_dk_dv_kernel_{F_idx} =
ck_tile::FmhaBwdDQDKDVKernel<fmha_bwd_pipeline_{F_idx},
@@ -220,9 +214,9 @@ def FMHA_BWD_API_COND_STATEMENT(F_cond: str, F_body: str, *, indent=0, if_ = 0)
FMHA_BWD_API_INNER_DISPATCH="""
{F_if}((t.is_group_mode == {F_mode}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_dbias == {F_dbias}) && ({F_dropout_check}) &&
({F_scheck}) && ({F_dcheck}) && ({F_dvcheck}) && (t.is_deterministic == {F_deterministic}){F_cond_extra}) {{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_spad1d}, {F_dvpad}>;
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_spad1d}, ({F_dvpad} > 0)>;
using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_mask}, {F_dropout}, {F_bias}, {F_dbias}, {F_dpad}, {F_dvpad}, {F_deterministic}, {F_trload}, {F_maxq}, {F_bn0}>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_spad1d}, {F_dpad}, {F_deterministic}, {F_convert_dq_bn0}>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_spad1d}, ({F_dpad} > 0), {F_deterministic}, {F_convert_dq_bn0}>;
r = fmha_bwd_<dot_do_o_trait_, dq_dk_dv_trait_, std::conditional_t<{F_convert_dq_enabled}, convert_dq_trait_, void>>(s, a);
return r;
}}
@@ -278,8 +272,8 @@ class FmhaBwdDQDKDVKernel:
F_hdim : int # hdim
F_dtype : str # data type
F_tile : FmhaBwdDQDKDVTileSize
F_dpad : str #
F_dvpad : str #
F_dpad : Literal[0, 8 ,1]
F_dvpad : Literal[0, 8 ,1]
F_bias : str #
F_dbias : str #
F_dropout : str #
@@ -320,8 +314,8 @@ class FmhaBwdDQDKDVKernel:
F_wm1 = self.F_tile.F_wm1,
F_wn1 = self.F_tile.F_wn1,
F_wk1 = self.F_tile.F_wk1,
F_dpad = BOOL_MAP[self.F_dpad],
F_dvpad = BOOL_MAP[self.F_dvpad],
F_dpad = self.F_dpad,
F_dvpad = self.F_dvpad,
F_bias = BIAS_MAP[self.F_bias],
F_dbias = BOOL_MAP[self.F_dbias],
F_dropout = DROPOUT_MAP[self.F_dropout],
@@ -337,8 +331,8 @@ class FmhaBwdDQDKDVKernel:
def name(self) -> str:
def pad_name() -> str:
n = ''
if self.F_dpad == 't' : n += 'd'
if self.F_dvpad == 't' : n += 'dv'
if self.F_dpad : n += f'd{self.F_dpad}'
if self.F_dvpad : n += f'dv{self.F_dvpad}'
if n != '' : n = 'p' + n
return n
pn = pad_name()
@@ -622,8 +616,8 @@ class FmhaBwdApiTrait:
dbias : str
dropout : str
spad1d : str # spad for 1d kernels (dot/convert)
dpad : str
dvpad : str
dpad : Literal[0, 1, 8]
dvpad : Literal[0, 1, 8]
deterministic : str
mask_impl : str
tr_load : str
@@ -652,13 +646,13 @@ class FmhaBwdApiTrait:
@property
def dcheck(self) -> str:
if self.dpad == 't': return f'a.hdim_q % {self.bhdq} != 0'
else : return f'a.hdim_q % {self.bhdq} == 0'
if self.dpad == 0: return f'a.hdim_q % {self.bhdq} == 0'
else: return f'a.hdim_q % {self.dpad} == 0'
@property
def dvcheck(self) -> str:
if self.dvpad == 't': return f'a.hdim_v % {self.bhdv} != 0'
else : return f'a.hdim_v % {self.bhdv} == 0'
if self.dvpad == 0: return f'a.hdim_v % {self.bhdv} == 0'
else: return f'a.hdim_v % {self.dvpad} == 0'
@property
def extra_cond(self) -> str:
@@ -678,8 +672,9 @@ class FmhaBwdApiTrait:
def get_occupancy(dtype, hdim):
return 2
F_dvpad = 't' if self.dvpad else 'f'
return FmhaBwdOGradDotOKernel(F_idx=self.idx, F_hdim=self.hdim, F_dtype=self.dtype, F_spad=self.spad1d,
F_dvpad=self.dvpad, F_mode=self.mode, F_occupancy=get_occupancy(self.dtype, self.hdim))
F_dvpad=F_dvpad, F_mode=self.mode, F_occupancy=get_occupancy(self.dtype, self.hdim))
@property
def dq_dk_dv_kernel(self) -> FmhaBwdDQDKDVKernel:
@@ -694,8 +689,9 @@ class FmhaBwdApiTrait:
def get_occupancy(dtype, hdim):
return 2
F_dpad = 't' if self.dpad else 'f'
return FmhaBwdConvertQGradKernel(F_idx=self.idx, F_hdim=self.hdim, F_dtype=self.dtype,
F_bm0=M0_1D, F_bn0=self.convert_dq_bn0, F_spad=self.spad1d, F_dpad=self.dpad,
F_bm0=M0_1D, F_bn0=self.convert_dq_bn0, F_spad=self.spad1d, F_dpad=F_dpad,
F_mode=self.mode, F_occupancy=get_occupancy(self.dtype, self.hdim),
F_deterministic=self.deterministic, disabled=self.tile.max_seq_q != 0)
@@ -721,7 +717,7 @@ class FmhaBwdApiPool:
F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask], F_mask=get_mask_map(self.mask_impl)[trait.mask], F_bias_check=BIAS_CHECK_MAP[trait.bias],
F_bias=BIAS_MAP[trait.bias], F_dbias=BOOL_MAP[trait.dbias], F_dropout_check=DROPOUT_CHECK_MAP[trait.dropout], F_dropout=DROPOUT_MAP[trait.dropout],
F_scheck=trait.scheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck, F_hdim=trait.hdim, F_dtype=BWD_DTYPE_MAP[trait.dtype],
F_spad1d=BOOL_MAP[trait.spad1d], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad],
F_spad1d=BOOL_MAP[trait.spad1d], F_dpad=trait.dpad, F_dvpad=trait.dvpad,
F_deterministic=BOOL_MAP[trait.deterministic], F_trload=BOOL_MAP[trait.tr_load], F_maxq=trait.tile.max_seq_q,
F_convert_dq_enabled=BOOL_MAP[not trait.convert_dq_kernel.disabled], F_bn0=trait.tile.F_bn0, F_cond_extra=trait.extra_cond,
F_convert_dq_bn0=trait.convert_dq_bn0)
@@ -794,7 +790,10 @@ def get_bwd_blobs(filter_list: str, receipt, mask_impl, optdim_list) -> Tuple[Fm
for dtype, tr_load in itertools.product(BWD_DTYPE_MAP.keys(), ["t", "f"]):
tiles: Any = get_dq_dk_dv_tiles(dtype, tr_load)
for tile, mode, mask, bias, dbias, dropout, spad1d, dpad, dvpad, deterministic in itertools.product(tiles, MODE_MAP.keys(), get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"], DROPOUT_MAP.keys(), *([["t", "f"]] * 4)):
dpad_options = itertools.product(*([[0, 8, 1]] * 2))
tf = ["t", "f"]
for tile, mode, mask, bias, dbias, dropout, spad1d, (dpad, dvpad), deterministic in itertools.product(
tiles, MODE_MAP.keys(), get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), tf, DROPOUT_MAP.keys(), tf, dpad_options, tf):
assert isinstance(tile, FmhaBwdDQDKDVTileSize), "tile must be FmhaBwdDQDKDVTileSize"
hdim = tile.F_bhdq
if (mode == "group") and (spad1d == "f"):
@@ -805,8 +804,12 @@ def get_bwd_blobs(filter_list: str, receipt, mask_impl, optdim_list) -> Tuple[Fm
continue
if ("wg32" in dropout):
continue
if tr_load == "t" and (dpad == "t" or dvpad == "t"):
if tr_load == "t":
continue # tr_load cannot work with dpad or dvpad
else: # tr_load == "f"
# do not generate instance with only 1 of dpad/dvpad being 8
if dpad != dvpad and dpad == 8:
continue
if optdim_list != [-1]:
if hdim not in optdim_list:
continue