diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py index d6f158116e..a5fffb5159 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py @@ -1304,7 +1304,23 @@ class Product: def get_product(receipt: int) -> Product: # Flash attention integration - if receipt in (2, 3): + if receipt == 2: + + def fit(problem_ctx: ProblemContext, kernel_ctx: KernelContext) -> bool: + cond = problem_ctx.dtype in ["fp16", "bf16"] + cond &= kernel_ctx.pipeline.F_vlayout == "row" + cond &= kernel_ctx.pipeline.F_bias in ["no", "alibi"] + cond &= kernel_ctx.pipeline.F_qscale == "no" + cond &= kernel_ctx.pipeline.F_skip == "f" + cond &= kernel_ctx.pipeline.F_sink == "f" + # FlashAttention direct fwd wrappers always use softcap disabled and LSE enabled. + cond &= kernel_ctx.pipeline.F_logits == "f" + cond &= kernel_ctx.pipeline.F_lse == "t" + return cond + + return Product(name="Flash attention integration", rule=fit) + # Receipt 3 forward coverage used by CK library / smoke tests + elif receipt == 3: def fit(problem_ctx: ProblemContext, kernel_ctx: KernelContext) -> bool: cond = problem_ctx.dtype in ["fp16", "bf16"] diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py index e0ccde8a6b..c9bac50da1 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py @@ -939,6 +939,8 @@ def get_fwd_splitkv_blobs( cond = dtype in ["fp16", "bf16"] cond &= pipeline.F_vlayout == "row" cond &= pipeline.F_bias in ["no", "alibi"] + # FlashAttention splitkv paths use softcap-disabled kernels only. + cond &= pipeline.F_logits == "f" cond &= pipeline.F_squant == "f" cond &= pipeline.F_sink == "f" if not cond: @@ -1142,4 +1144,7 @@ def list_blobs( ) for kernel in kernels: f.write((file_path.parent / GEN_DIR / kernel.filename).as_posix() + "\n") - f.write((file_path.parent / GEN_DIR / FMHA_FWD_SPLITKV_API_FILENAME).as_posix() + "\n") + f.write( + (file_path.parent / GEN_DIR / FMHA_FWD_SPLITKV_API_FILENAME).as_posix() + + "\n" + )