add mask support in hdim=192/128 (#1999)

This commit is contained in:
carlushuang
2025-03-21 18:28:43 +08:00
committed by GitHub
parent 5a0d693b86
commit 6c08c5c46d

View File

@@ -492,7 +492,7 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> Tuple[Fm
continue
if hdim == 192 and tile.F_bn1 == 128:
# NOTE: this is used to speedup deepseek prefill case, we don't gen training
if pipeline.F_bias != 'no' or pipeline.F_lse == 't' or pipeline.F_dropout == 't' or (pipeline.F_mask not in ['no', 's_no']):
if pipeline.F_bias != 'no' or pipeline.F_lse == 't' or pipeline.F_dropout == 't':
continue
k = FmhaFwdKernel(F_idx=0,
F_hdim=hdim,