diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py b/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py index c251460a9a..89fbcff40c 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py @@ -527,6 +527,7 @@ def get_bwd_dq_dk_dv_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> cond &= bias in ['no', 'bias'] cond &= dropout in ['no', 'dropout_wg32', 'dropout_wg16'] cond &= dpad == dvpad + cond &= mode == 'batch' cond &= deterministic == "f" if not cond: continue diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py index ee74cb8fb2..06a012d277 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py @@ -629,7 +629,9 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, optdim_list, mask_impl cond &= pipeline.F_vlayout == 'row' cond &= pipeline.F_bias in ['no', 'bias'] cond &= pipeline.F_squant == 'f' + cond &= mode == 'batch' cond &= pipeline.F_skip == 'f' + cond &= pipeline.F_logits == 'f' if not cond: continue # Aiter(mha_fwd) integration diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_appendkv.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_appendkv.py index dc7ef712e2..517e84f380 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_appendkv.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_appendkv.py @@ -332,6 +332,12 @@ def get_fwd_appendkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> cond &= pipeline.F_vlayout == 'row' if not cond: continue + # PyTorch integration + elif receipt == 4: + cond = dtype in ['fp16, bf16'] + cond &= pipeline.F_vlayout == 'row' + if not cond: + continue api_pool.register_traits(k.api_trait()) gen.append(k) diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py index 2d2d71555d..edc1532a05 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py @@ -754,6 +754,15 @@ def get_fwd_splitkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> cond &= pipeline.F_squant == 'f' if not cond: continue + # PyTorch integration + elif receipt == 4: + cond = dtype in ['fp16, bf16'] + cond &= pipeline.F_vlayout == 'row' + cond &= pipeline.F_bias in ['no', 'bias'] + cond &= pipeline.F_squant == 'f' + cond &= mode == 'batch' + if not cond: + continue # Aiter(mha_varlen_fwd) integration elif receipt == 200: cond = dtype in ['fp16', 'bf16']