mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-06 07:51:52 +00:00
update generated filenames
This commit is contained in:
@@ -815,7 +815,9 @@ class FmhaBwdDQDKDVTileSize:
|
||||
F_occupancy : int # occupancy
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return f"b{self.F_bm0}x{self.F_bn0}"
|
||||
return f"b{self.F_bm0}x{self.F_bn0}x{self.F_bk0}x{self.F_bk1}x{self.F_bk2}x{self.F_bk3}x{self.F_bk4}x{self.F_bhdq}x{self.F_bhdv}" +\
|
||||
f"_r{self.F_rm0}x{self.F_rn0}x{self.F_rk0}_r{self.F_rm1}x{self.F_rn1}x{self.F_rk1}_r{self.F_rm2}x{self.F_rn2}x{self.F_rk2}" +\
|
||||
f"_w{self.F_wm}x{self.F_wn}x{self.F_wk}_o{self.F_occupancy}"
|
||||
|
||||
@dataclass
|
||||
class FmhaBwdDQDKDVKernel:
|
||||
@@ -879,19 +881,24 @@ class FmhaBwdDQDKDVKernel:
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
def mask_name() -> str:
|
||||
def pad_name() -> str:
|
||||
n = ''
|
||||
if self.F_mask[0:2] == 's_':
|
||||
if self.F_mask == 's_mask': n += f'_mask'
|
||||
else:
|
||||
if self.F_mask != 'no' : n += f'_m{self.F_mask[0]}'
|
||||
if self.F_spad == 't': n += 's'
|
||||
if self.F_skpad == 't' : n += 'sk'
|
||||
if self.F_dpad == 't' : n += 'd'
|
||||
if self.F_dvpad == 't' : n += 'dv'
|
||||
if n != '' : n = 'p' + n
|
||||
return n
|
||||
# TODO: we don't encode idx here
|
||||
mn = mask_name()
|
||||
n = f"fmha_{self.direction}_d{self.F_hdim}_{self.F_dtype}_{self.F_mode}_" + self.F_tile.name +\
|
||||
f"_p{BOOL_MAP[self.F_spad][0]}{BOOL_MAP[self.F_skpad][0]}{BOOL_MAP[self.F_dpad][0]}{BOOL_MAP[self.F_dvpad][0]}" +\
|
||||
(f'_{self.F_bias}' if self.F_bias != 'no' else '') + f"_db{BOOL_MAP[self.F_dbias][0]}_dp{BOOL_MAP[self.F_dropout][0]}"
|
||||
if mn != '' : n += f'{mn}'
|
||||
pn = pad_name()
|
||||
n = f"fmha_{self.direction}_d{self.F_hdim}_{self.F_dtype}_{self.F_mode}_" + self.F_tile.name
|
||||
if pn != '' : n += f'_{pn}'
|
||||
if self.F_bias != 'no' : n += f'_{self.F_bias}'
|
||||
if self.F_dbias == 't' : n += '_dbias'
|
||||
if self.F_mask[0:2] == 's_':
|
||||
if self.F_mask == 's_mask': n += f'_mask'
|
||||
else:
|
||||
if self.F_mask != 'no' : n += f'_m{self.F_mask[0]}'
|
||||
if self.F_dropout == 't' : n += '_dropout'
|
||||
return n
|
||||
|
||||
@property
|
||||
@@ -1025,10 +1032,16 @@ class FmhaBwdOGradDotOKernel:
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
# TODO: we don't encode idx here
|
||||
return f"fmha_{self.direction}_d{self.F_hdim}_{self.F_dtype}_{self.F_mode}" +\
|
||||
f"_p{BOOL_MAP[self.F_spad][0]}{BOOL_MAP[self.F_dvpad][0]}" +\
|
||||
f"_o{self.F_occupancy}"
|
||||
def pad_name() -> str:
|
||||
n = ''
|
||||
if self.F_spad == 't': n += 's'
|
||||
if self.F_dvpad == 't' : n += 'dv'
|
||||
if n != '' : n = 'p' + n
|
||||
return n
|
||||
pn = pad_name()
|
||||
n = f"fmha_{self.direction}_d{self.F_hdim}_{self.F_dtype}_{self.F_mode}_o{self.F_occupancy}"
|
||||
if pn != '' : n += f'_{pn}'
|
||||
return n
|
||||
|
||||
@property
|
||||
def filename(self) -> str:
|
||||
|
||||
Reference in New Issue
Block a user