diff --git a/example/ck_tile/01_fmha/generate.py b/example/ck_tile/01_fmha/generate.py index 9aac846ff9..3cd4d2759e 100644 --- a/example/ck_tile/01_fmha/generate.py +++ b/example/ck_tile/01_fmha/generate.py @@ -815,7 +815,9 @@ class FmhaBwdDQDKDVTileSize: F_occupancy : int # occupancy @property def name(self) -> str: - return f"b{self.F_bm0}x{self.F_bn0}" + return f"b{self.F_bm0}x{self.F_bn0}x{self.F_bk0}x{self.F_bk1}x{self.F_bk2}x{self.F_bk3}x{self.F_bk4}x{self.F_bhdq}x{self.F_bhdv}" +\ + f"_r{self.F_rm0}x{self.F_rn0}x{self.F_rk0}_r{self.F_rm1}x{self.F_rn1}x{self.F_rk1}_r{self.F_rm2}x{self.F_rn2}x{self.F_rk2}" +\ + f"_w{self.F_wm}x{self.F_wn}x{self.F_wk}_o{self.F_occupancy}" @dataclass class FmhaBwdDQDKDVKernel: @@ -879,19 +881,24 @@ class FmhaBwdDQDKDVKernel: @property def name(self) -> str: - def mask_name() -> str: + def pad_name() -> str: n = '' - if self.F_mask[0:2] == 's_': - if self.F_mask == 's_mask': n += f'_mask' - else: - if self.F_mask != 'no' : n += f'_m{self.F_mask[0]}' + if self.F_spad == 't': n += 's' + if self.F_skpad == 't' : n += 'sk' + if self.F_dpad == 't' : n += 'd' + if self.F_dvpad == 't' : n += 'dv' + if n != '' : n = 'p' + n return n - # TODO: we don't encode idx here - mn = mask_name() - n = f"fmha_{self.direction}_d{self.F_hdim}_{self.F_dtype}_{self.F_mode}_" + self.F_tile.name +\ - f"_p{BOOL_MAP[self.F_spad][0]}{BOOL_MAP[self.F_skpad][0]}{BOOL_MAP[self.F_dpad][0]}{BOOL_MAP[self.F_dvpad][0]}" +\ - (f'_{self.F_bias}' if self.F_bias != 'no' else '') + f"_db{BOOL_MAP[self.F_dbias][0]}_dp{BOOL_MAP[self.F_dropout][0]}" - if mn != '' : n += f'{mn}' + pn = pad_name() + n = f"fmha_{self.direction}_d{self.F_hdim}_{self.F_dtype}_{self.F_mode}_" + self.F_tile.name + if pn != '' : n += f'_{pn}' + if self.F_bias != 'no' : n += f'_{self.F_bias}' + if self.F_dbias == 't' : n += '_dbias' + if self.F_mask[0:2] == 's_': + if self.F_mask == 's_mask': n += f'_mask' + else: + if self.F_mask != 'no' : n += f'_m{self.F_mask[0]}' + if self.F_dropout == 't' : n += '_dropout' return n @property @@ -1025,10 +1032,16 @@ class FmhaBwdOGradDotOKernel: @property def name(self) -> str: - # TODO: we don't encode idx here - return f"fmha_{self.direction}_d{self.F_hdim}_{self.F_dtype}_{self.F_mode}" +\ - f"_p{BOOL_MAP[self.F_spad][0]}{BOOL_MAP[self.F_dvpad][0]}" +\ - f"_o{self.F_occupancy}" + def pad_name() -> str: + n = '' + if self.F_spad == 't': n += 's' + if self.F_dvpad == 't' : n += 'dv' + if n != '' : n = 'p' + n + return n + pn = pad_name() + n = f"fmha_{self.direction}_d{self.F_hdim}_{self.F_dtype}_{self.F_mode}_o{self.F_occupancy}" + if pn != '' : n += f'_{pn}' + return n @property def filename(self) -> str: