From 4fbab6cbee6a95588b244af76b45324964f39023 Mon Sep 17 00:00:00 2001 From: rocking Date: Fri, 28 Feb 2025 14:23:30 +0800 Subject: [PATCH] explicit show no feature in kernel name (#1920) [ROCm/composable_kernel commit: faa2235dad16a32934fb3290baf997555585da70] --- .../ck_tile/01_fmha/codegen/ops/fmha_bwd.py | 27 ++++++++++++------- .../ck_tile/01_fmha/codegen/ops/fmha_fwd.py | 27 ++++++++++--------- .../01_fmha/codegen/ops/fmha_fwd_splitkv.py | 27 ++++++++++--------- 3 files changed, 48 insertions(+), 33 deletions(-) diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py b/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py index 8082523f1b..6326a97f8e 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py @@ -413,20 +413,26 @@ class FmhaBwdDQDKDVKernel: pn = pad_name() n = f"fmha_bwd_d{self.F_hdim}_{self.F_dtype}_{self.F_mode}_" + self.F_tile.name + f'_{self.F_pipeline}' if pn != '' : n += f'_{pn}' - if self.F_bias != 'no' : - n += f'_{self.F_bias}' - else: - n += '_nbias' + else: n += '_npad' + + if self.F_bias != 'no' : n += f'_{self.F_bias}' + else: n += '_nbias' + if self.F_dbias == 't' : n += '_dbias' + else: n += '_ndbias' + if self.F_mask[0:2] == 's_': if self.F_mask == 's_mask': n += f'_mask' + else: n += '_nmask' else: if self.F_mask != 'no' : n += f'_m{self.F_mask[0]}' - if self.F_dropout != 'no' : - n += f'_{self.F_dropout}' - else: - n += '_ndropout' + else: n += '_nmask' + + if self.F_dropout != 'no' : n += f'_{self.F_dropout}' + else: n += '_ndropout' + if self.F_deterministic == 't' : n += '_deterministic' + else: n += '_ndeterministic' return n @property @@ -635,6 +641,7 @@ class FmhaBwdOGradDotOKernel: pn = pad_name() n = f"fmha_bwd_dot_do_o_d{self.F_hdim}_{self.F_dtype}_{self.F_mode}_o{self.F_occupancy}" if pn != '' : n += f'_{pn}' + else: n += '_npad' return n @property @@ -784,7 +791,9 @@ class FmhaBwdConvertQGradKernel: pn = pad_name() n = f"fmha_bwd_convert_dq_d{self.F_hdim}_{self.F_dtype}_b{self.F_bm0}x{self.F_bn0}_{self.F_mode}_o{self.F_occupancy}" if pn != '' : n += f'_{pn}' - if self.F_deterministic == 't' : n += f'_deterministic' + else: n += '_npad' + if self.F_deterministic == 't' : n += '_deterministic' + else: n += '_ndeterministic' return n @property diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py index 79ace6d2c3..f2d9216696 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py @@ -233,23 +233,26 @@ class FmhaFwdPipeline: pn = pad_name() n = f'{self.tag}_v{self.F_vlayout[0]}' if pn != '' : n += f'_{pn}' - if self.F_bias != 'no' : - n += f'_{self.F_bias}' - else: - n += '_nbias' + else: n += '_npad' + + if self.F_bias != 'no' : n += f'_{self.F_bias}' + else: n += '_nbias' + if self.F_mask[0:2] == 's_': if self.F_mask == 's_mask': n += f'_mask' + else: n += '_nmask' else: if self.F_mask != 'no' : n += f'_m{self.F_mask[0]}' - if self.F_lse == 't' : - n += '_lse' - else: - n += '_nlse' - if self.F_dropout == 't' : - n += '_dropout' - else: - n += '_ndropout' + else: n += '_nmask' + + if self.F_lse == 't' : n += '_lse' + else: n += '_nlse' + + if self.F_dropout == 't' : n += '_dropout' + else: n += '_ndropout' + if self.F_squant == 't' : n += '_squant' + else: n += '_nsquant' return n class FmhaFwdApiPool: diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py index b4eea36e86..ba555df88d 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py @@ -397,23 +397,26 @@ class FmhaFwdSplitKVPipeline: pn = pad_name() n = f'{self.tag}_v{self.F_vlayout[0]}' if pn != '' : n += f'_{pn}' - if self.F_bias != 'no' : - n += f'_{self.F_bias}' - else: - n += '_nbias' + else: n += '_npad' + + if self.F_bias != 'no' : n += f'_{self.F_bias}' + else: n += '_nbias' + if self.F_mask[0:2] == 's_': if self.F_mask == 's_mask': n += f'_mask' + else: n += '_nmask' else: if self.F_mask != 'no' : n += f'_m{self.F_mask[0]}' - if self.F_lse == 't' : - n += '_lse' - else: - n += '_nlse' + else: n += '_nmask' + + if self.F_lse == 't' : n += '_lse' + else: n += '_nlse' + if self.F_squant == 't' : n += '_squant' - if self.F_pagedkv == 't' : - n += '_pagedkv' - else: - n += '_npagedkv' + else: n += '_nsquant' + + if self.F_pagedkv == 't' : n += '_pagedkv' + else: n += '_npagedkv' return n @dataclass