From b6916590c56205154a5963d30cc8f866aebca7f7 Mon Sep 17 00:00:00 2001 From: rocking Date: Thu, 20 Feb 2025 04:27:01 +0800 Subject: [PATCH] only output the deterministic bwd kernel for aiter (#1903) * only output the deterministic kernel * Add comment [ROCm/composable_kernel commit: e4358c01d96f53af94713a1c488dbecb4bcbc4d4] --- example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py | 8 ++++++-- example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py | 4 ++++ example/ck_tile/01_fmha/codegen/ops/fmha_fwd_appendkv.py | 2 ++ example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py | 3 +++ 4 files changed, 15 insertions(+), 2 deletions(-) diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py b/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py index c56399b346..4c23250d05 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py @@ -492,6 +492,7 @@ def get_bwd_dq_dk_dv_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> if kernel_filter != None: if not fnmatch.fnmatch(k.name, kernel_filter): continue + # Flash attention integration if receipt == 2: cond = dtype in ['fp16', 'bf16'] cond &= bias in ['no', 'alibi'] @@ -506,6 +507,7 @@ def get_bwd_dq_dk_dv_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> cond &= deterministic == "f" if not cond: continue + # PyTorch integration elif receipt == 4: cond = dtype in ['fp16', 'bf16'] cond &= bias in ['no', 'bias'] @@ -514,22 +516,24 @@ def get_bwd_dq_dk_dv_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> cond &= deterministic == "f" if not cond: continue + # Aiter (mha_bwd) integration elif receipt == 10: cond = dtype in ['fp16', 'bf16'] cond &= mode == "batch" cond &= bias in ['no', 'alibi'] cond &= dropout in ['no', 'dropout_wg32', 'dropout_wg16'] cond &= dpad == dvpad - cond &= deterministic == "f" + cond &= deterministic == "t" if not cond: continue + # Aiter (mha_varlen_bwd) integration elif receipt == 11: cond = dtype in ['fp16', 'bf16'] cond &= mode == "group" cond &= bias in ['no', 'alibi'] cond &= dropout in ['no', 'dropout_wg32', 'dropout_wg16'] cond &= dpad == dvpad - cond &= deterministic == "f" + cond &= deterministic == "t" if not cond: continue api_pool.register_dq_dk_dv_traits(k.api_trait()) diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py index c1d8f9a309..b72627ed5d 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py @@ -487,6 +487,7 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> Tuple[Fm if kernel_filter != None: if not fnmatch.fnmatch(k.name, kernel_filter): continue + # 2 - Flash attention integration if receipt in (2, 3): cond = dtype in ['fp16', 'bf16'] cond &= pipeline.F_vlayout == 'row' @@ -494,6 +495,7 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> Tuple[Fm cond &= pipeline.F_squant == 'f' if not cond: continue + # PyTorch integration elif receipt == 4: cond = dtype in ['fp16', 'bf16'] cond &= pipeline.F_vlayout == 'row' @@ -501,6 +503,7 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> Tuple[Fm cond &= pipeline.F_squant == 'f' if not cond: continue + # Aiter(mha_fwd) integration elif receipt == 10: cond = dtype in ['fp16', 'bf16'] cond &= mode == "batch" @@ -509,6 +512,7 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> Tuple[Fm cond &= pipeline.F_squant == 'f' if not cond: continue + # Aiter(mha_varlen_fwd) integration elif receipt == 11: cond = dtype in ['fp16', 'bf16'] cond &= mode == "group" diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_appendkv.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_appendkv.py index 405140d160..f8a89448ba 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_appendkv.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_appendkv.py @@ -326,6 +326,8 @@ def get_fwd_appendkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> if kernel_filter != None: if not fnmatch.fnmatch(k.name, kernel_filter): continue + # 2 - Flash attention integration + # 12 - Aiter(mha_fwd_kvcache) integration if receipt in (2, 12): cond = dtype in ['fp16', 'bf16'] cond &= pipeline.F_vlayout == 'row' diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py index cac75f302b..c0ca666b11 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py @@ -705,6 +705,7 @@ def get_fwd_splitkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> if kernel_filter != None: if not fnmatch.fnmatch(k.name, kernel_filter): continue + # Flash attention integration if receipt == 2: cond = dtype in ['fp16', 'bf16'] cond &= pipeline.F_vlayout == 'row' @@ -712,6 +713,7 @@ def get_fwd_splitkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> cond &= pipeline.F_squant == 'f' if not cond: continue + # Aiter(mha_varlen_fwd) integration elif receipt == 11: cond = dtype in ['fp16', 'bf16'] cond &= mode == "group" @@ -720,6 +722,7 @@ def get_fwd_splitkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> cond &= pipeline.F_squant == 'f' if not cond: continue + # Aiter(mha_fwd_kvcache) integration elif receipt == 12: cond = dtype in ['fp16', 'bf16'] cond &= mode == "batch"