explicit show no feature in kernel name (#1920)

This commit is contained in:
rocking
2025-02-28 14:23:30 +08:00
committed by GitHub
parent a9bcd3c98d
commit faa2235dad
3 changed files with 48 additions and 33 deletions

View File

@@ -413,20 +413,26 @@ class FmhaBwdDQDKDVKernel:
pn = pad_name()
n = f"fmha_bwd_d{self.F_hdim}_{self.F_dtype}_{self.F_mode}_" + self.F_tile.name + f'_{self.F_pipeline}'
if pn != '' : n += f'_{pn}'
if self.F_bias != 'no' :
n += f'_{self.F_bias}'
else:
n += '_nbias'
else: n += '_npad'
if self.F_bias != 'no' : n += f'_{self.F_bias}'
else: n += '_nbias'
if self.F_dbias == 't' : n += '_dbias'
else: n += '_ndbias'
if self.F_mask[0:2] == 's_':
if self.F_mask == 's_mask': n += f'_mask'
else: n += '_nmask'
else:
if self.F_mask != 'no' : n += f'_m{self.F_mask[0]}'
if self.F_dropout != 'no' :
n += f'_{self.F_dropout}'
else:
n += '_ndropout'
else: n += '_nmask'
if self.F_dropout != 'no' : n += f'_{self.F_dropout}'
else: n += '_ndropout'
if self.F_deterministic == 't' : n += '_deterministic'
else: n += '_ndeterministic'
return n
@property
@@ -635,6 +641,7 @@ class FmhaBwdOGradDotOKernel:
pn = pad_name()
n = f"fmha_bwd_dot_do_o_d{self.F_hdim}_{self.F_dtype}_{self.F_mode}_o{self.F_occupancy}"
if pn != '' : n += f'_{pn}'
else: n += '_npad'
return n
@property
@@ -784,7 +791,9 @@ class FmhaBwdConvertQGradKernel:
pn = pad_name()
n = f"fmha_bwd_convert_dq_d{self.F_hdim}_{self.F_dtype}_b{self.F_bm0}x{self.F_bn0}_{self.F_mode}_o{self.F_occupancy}"
if pn != '' : n += f'_{pn}'
if self.F_deterministic == 't' : n += f'_deterministic'
else: n += '_npad'
if self.F_deterministic == 't' : n += '_deterministic'
else: n += '_ndeterministic'
return n
@property

View File

@@ -233,23 +233,26 @@ class FmhaFwdPipeline:
pn = pad_name()
n = f'{self.tag}_v{self.F_vlayout[0]}'
if pn != '' : n += f'_{pn}'
if self.F_bias != 'no' :
n += f'_{self.F_bias}'
else:
n += '_nbias'
else: n += '_npad'
if self.F_bias != 'no' : n += f'_{self.F_bias}'
else: n += '_nbias'
if self.F_mask[0:2] == 's_':
if self.F_mask == 's_mask': n += f'_mask'
else: n += '_nmask'
else:
if self.F_mask != 'no' : n += f'_m{self.F_mask[0]}'
if self.F_lse == 't' :
n += '_lse'
else:
n += '_nlse'
if self.F_dropout == 't' :
n += '_dropout'
else:
n += '_ndropout'
else: n += '_nmask'
if self.F_lse == 't' : n += '_lse'
else: n += '_nlse'
if self.F_dropout == 't' : n += '_dropout'
else: n += '_ndropout'
if self.F_squant == 't' : n += '_squant'
else: n += '_nsquant'
return n
class FmhaFwdApiPool:

View File

@@ -397,23 +397,26 @@ class FmhaFwdSplitKVPipeline:
pn = pad_name()
n = f'{self.tag}_v{self.F_vlayout[0]}'
if pn != '' : n += f'_{pn}'
if self.F_bias != 'no' :
n += f'_{self.F_bias}'
else:
n += '_nbias'
else: n += '_npad'
if self.F_bias != 'no' : n += f'_{self.F_bias}'
else: n += '_nbias'
if self.F_mask[0:2] == 's_':
if self.F_mask == 's_mask': n += f'_mask'
else: n += '_nmask'
else:
if self.F_mask != 'no' : n += f'_m{self.F_mask[0]}'
if self.F_lse == 't' :
n += '_lse'
else:
n += '_nlse'
else: n += '_nmask'
if self.F_lse == 't' : n += '_lse'
else: n += '_nlse'
if self.F_squant == 't' : n += '_squant'
if self.F_pagedkv == 't' :
n += '_pagedkv'
else:
n += '_npagedkv'
else: n += '_nsquant'
if self.F_pagedkv == 't' : n += '_pagedkv'
else: n += '_npagedkv'
return n
@dataclass