diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py b/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py index c05660c8ab..c56399b346 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py @@ -499,14 +499,14 @@ def get_bwd_dq_dk_dv_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> cond &= dpad == dvpad if not cond: continue - if receipt == 3: + elif receipt == 3: cond = dtype in ['fp16', 'bf16'] cond &= bias in ['no', 'alibi'] cond &= dpad == dvpad cond &= deterministic == "f" if not cond: continue - if receipt == 4: + elif receipt == 4: cond = dtype in ['fp16', 'bf16'] cond &= bias in ['no', 'bias'] cond &= dropout in ['no', 'dropout_wg32', 'dropout_wg16'] @@ -514,6 +514,24 @@ def get_bwd_dq_dk_dv_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> cond &= deterministic == "f" if not cond: continue + elif receipt == 10: + cond = dtype in ['fp16', 'bf16'] + cond &= mode == "batch" + cond &= bias in ['no', 'alibi'] + cond &= dropout in ['no', 'dropout_wg32', 'dropout_wg16'] + cond &= dpad == dvpad + cond &= deterministic == "f" + if not cond: + continue + elif receipt == 11: + cond = dtype in ['fp16', 'bf16'] + cond &= mode == "group" + cond &= bias in ['no', 'alibi'] + cond &= dropout in ['no', 'dropout_wg32', 'dropout_wg16'] + cond &= dpad == dvpad + cond &= deterministic == "f" + if not cond: + continue api_pool.register_dq_dk_dv_traits(k.api_trait()) gen.append(k) diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py index ad8daba17e..c1d8f9a309 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py @@ -494,13 +494,29 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> Tuple[Fm cond &= pipeline.F_squant == 'f' if not cond: continue - if receipt == 4: + elif receipt == 4: cond = dtype in ['fp16', 'bf16'] cond &= pipeline.F_vlayout == 'row' cond &= pipeline.F_bias in ['no', 'bias'] cond &= pipeline.F_squant == 'f' if not cond: continue + elif receipt == 10: + cond = dtype in ['fp16', 'bf16'] + cond &= mode == "batch" + cond &= pipeline.F_vlayout == 'row' + cond &= pipeline.F_bias in ['no', 'alibi'] + cond &= pipeline.F_squant == 'f' + if not cond: + continue + elif receipt == 11: + cond = dtype in ['fp16', 'bf16'] + cond &= mode == "group" + cond &= pipeline.F_vlayout == 'row' + cond &= pipeline.F_bias in ['no', 'alibi'] + cond &= pipeline.F_squant == 'f' + if not cond: + continue api_pool.register_traits(k.api_trait()) gen.append(k) diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_appendkv.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_appendkv.py index 2f20819302..405140d160 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_appendkv.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_appendkv.py @@ -326,7 +326,7 @@ def get_fwd_appendkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> if kernel_filter != None: if not fnmatch.fnmatch(k.name, kernel_filter): continue - if receipt == 2: + if receipt in (2, 12): cond = dtype in ['fp16', 'bf16'] cond &= pipeline.F_vlayout == 'row' if not cond: diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py index 37745dd382..cac75f302b 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py @@ -268,7 +268,7 @@ float fmha_fwd_splitkv(fmha_fwd_splitkv_traits t, fmha_fwd_splitkv_args a, const FMHA_FWD_SPLITKV_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.do_fp8_static_quant == {F_squant}) && ((a.block_table_ptr != nullptr) == {F_pagedkv}) && ({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck})) {{ using traits_ = fmha_fwd_splitkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_mask}, {F_bias}, true, {F_squant}, {F_pagedkv}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>; - + // get combine kernel tile sizes using OaccDataType = typename FmhaFwdTypeConfig<{F_dtype}>::OaccDataType; constexpr ck_tile::index_t kM0 = ck_tile::BlockFmhaSplitKVCombinePipelineTileSizes::kM0; @@ -712,6 +712,22 @@ def get_fwd_splitkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> cond &= pipeline.F_squant == 'f' if not cond: continue + elif receipt == 11: + cond = dtype in ['fp16', 'bf16'] + cond &= mode == "group" + cond &= pipeline.F_vlayout == 'row' + cond &= pipeline.F_bias in ['no', 'alibi'] + cond &= pipeline.F_squant == 'f' + if not cond: + continue + elif receipt == 12: + cond = dtype in ['fp16', 'bf16'] + cond &= mode == "batch" + cond &= pipeline.F_vlayout == 'row' + cond &= pipeline.F_bias in ['no', 'alibi'] + cond &= pipeline.F_squant == 'f' + if not cond: + continue api_pool.register_traits(k.api_trait()) gen.append(k) diff --git a/example/ck_tile/01_fmha/generate.py b/example/ck_tile/01_fmha/generate.py index a0fb42aa11..0c2cef1ce7 100644 --- a/example/ck_tile/01_fmha/generate.py +++ b/example/ck_tile/01_fmha/generate.py @@ -17,7 +17,7 @@ class HandlerId(IntEnum): LIST_BLOBS = 0 WRITE_BLOBS = 1 -# inspect all modules under 'codegen.ops' and register API handlers +# inspect all modules under 'codegen.ops' and register API handlers ops = [] for importer, module_name, _ in pkgutil.iter_modules(codegen.ops.__path__): full_module_name = '%s.%s' % (codegen.ops.__name__, module_name) @@ -104,7 +104,11 @@ if __name__ == "__main__": help="codegen receipt. 0: generate only 8xhdim coverage\n" + \ " 1: generate more instance to cover all hdim\n" + \ " 2: Only generate instance for Flash attention integration\n" + \ - " 4: Only generate instance for PyTorch integration" + " 4: Only generate instance for PyTorch integration\n" + \ + " 10: Only generate instance for Aiter(mha_fwd, mha_bwd) integration\n" + \ + " 11: Only generate instance for Aiter(mha_varlen_fwd, mha_varlen_bwd) integration\n" + \ + " 12: Only generate instance for Aiter(mha_fwd_kvcache) integration" + ) args = parser.parse_args()