From 85ec3c75848e31d3a3dd032bc86590f3d0ba6e48 Mon Sep 17 00:00:00 2001 From: rocking Date: Wed, 24 Jul 2024 12:38:31 +0000 Subject: [PATCH] Do not store storerandval in bwd for flash attention integration --- example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py b/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py index d4e778877e..43a2c9fcfe 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py @@ -279,8 +279,8 @@ class FmhaBwdApiPool: continue if (spad1 == "t" and trait.spad == "f" and hdim_int <= 64): continue - inners = inners + FMHA_BWD_API_INNER_DISPATCH.format(F_if=if_k, F_mode=MODE_MAP[trait.mode], F_pipeline_enum=BWD_DQDKDV_PIPELINE_ENUM_MAP[trait.pipeline], - F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask], F_mask=get_mask_map(self.mask_impl)[trait.mask], F_bias_check=BIAS_CHECK_MAP[trait.bias], + inners = inners + FMHA_BWD_API_INNER_DISPATCH.format(F_if=if_k, F_mode=MODE_MAP[trait.mode], F_pipeline_enum=BWD_DQDKDV_PIPELINE_ENUM_MAP[trait.pipeline], + F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask], F_mask=get_mask_map(self.mask_impl)[trait.mask], F_bias_check=BIAS_CHECK_MAP[trait.bias], F_bias=BIAS_MAP[trait.bias], F_dbias=BOOL_MAP[trait.dbias], F_dropout_check=DROPOUT_CHECK_MAP[trait.dropout], F_dropout=DROPOUT_MAP[trait.dropout], F_scheck=trait.scheck(spad1=spad1), F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck, F_hdim=hdim, F_dtype=DTYPE_MAP[dtype], F_spad0=BOOL_MAP[trait.spad], F_spad1=BOOL_MAP[spad1], F_skpad=BOOL_MAP[trait.skpad], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad], @@ -490,6 +490,7 @@ def get_bwd_dq_dk_dv_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> if receipt == 2: cond = dtype in ['fp16', 'bf16'] cond &= bias in ['no', 'alibi'] + cond &= dropout in ['no', 'dropout_wg32', 'dropout_wg16'] if not cond: continue api_pool.register_dq_dk_dv_traits(k.api_trait()) @@ -693,7 +694,7 @@ class FmhaBwdConvertQGradKernel: F_spad : str # true/false F_dpad : str # F_mode : str # value from MODE_MAP - F_occupancy : int # + F_occupancy : int # F_deterministic : str # @property