diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py index d6f158116e..9a528e4a3e 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py @@ -1017,6 +1017,7 @@ class KernelComponentFactoryGfx9(CompatibilityRuleFactoryGfx9): pipelines.append(FmhaFwdPipeline("qr", "row", "f", "f", "f", "f", logits, bias, lse, dropout, qscale, mask, skip, "f", sink)) # fmt: skip pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "t", "t", logits, bias, lse, dropout, qscale, mask, skip, "f", sink)) # fmt: skip else: + pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "t", "t", logits, bias, lse, dropout, qscale, mask, skip, "f", sink)) # fmt: skip pipelines.append(FmhaFwdPipeline("qr_async", "row", "t", "f", "t", "t", logits, bias, lse, dropout, qscale, mask, skip, "f", sink)) # fmt: skip pipelines.append(FmhaFwdPipeline("qr_async", "row", "t", "t", "t", "t", logits, bias, lse, dropout, qscale, mask, skip, "f", sink)) # fmt: skip if receipt == 1 and bias != "bias": diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py index acc0f46fa9..8b580ed921 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py @@ -127,7 +127,7 @@ using trait_{F_idx} = fmha_fwd_splitkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F namespace {{ template void run_instance(const ck_tile::stream_config& s, fmha_fwd_splitkv_args a) {{ - if constexpr ({F_hdim} == 128 && {F_bias} == ck_tile::BlockAttentionBiasEnum::NO_BIAS + if constexpr (({F_hdim} == 64 || {F_hdim} == 128) && {F_bias} == ck_tile::BlockAttentionBiasEnum::NO_BIAS && (std::is_same_v<{F_mask}, ck_tile::SimplifiedGenericAttentionMask> || std::is_same_v<{F_mask}, FmhaMasks::NoMask>)) {{ if (a.max_seqlen_q == 1 && a.nhead_k < a.nhead_q) {{ @@ -280,7 +280,7 @@ float fmha_fwd_splitkv(fmha_fwd_splitkv_traits t, fmha_fwd_splitkv_args a, const """ FMHA_FWD_SPLITKV_API_INNER_DISPATCH = """{F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && (t.has_logits_soft_cap == {F_logits}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.do_fp8_static_quant == {F_squant}) && - ((a.block_table_ptr != nullptr) == {F_pagedkv}) && (t.has_sink == {F_sink}) && ({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck})) {{ + ((a.block_table_ptr != nullptr) == {F_pagedkv}) && (t.has_sink == {F_sink}) && ({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck}) && (a.block_table_ptr == nullptr || a.page_block_size % {F_bn0} == 0)) {{ using traits_ = fmha_fwd_splitkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, true, {F_squant}, {F_pagedkv},{F_sink}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>; // get combine kernel tile sizes @@ -820,17 +820,18 @@ class KernelComponentFactoryGfx9(KernelComponentFactoryBase): def get_hdim_tile_size_dict(dtype: str) -> Optional[dict]: if dtype in ["fp16", "bf16"]: return { - "32" : FmhaFwdTileSize( 32, 64, 16, 32, 32, 32, 2, 1, 1, 2, 1, 1, 16, 16, 16, 16, 16, 16, -1), - "64" : FmhaFwdTileSize( 64, 32, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1), - "96" : FmhaFwdTileSize( 64, 128, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1), - "128": FmhaFwdTileSize( 64, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1), - # "160" : FmhaFwdTileSize(64, 128, 32, 160, 32, 160, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1), - "256": FmhaFwdTileSize( 64, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1), + "32" : [FmhaFwdTileSize( 32, 64, 16, 32, 32, 32, 2, 1, 1, 2, 1, 1, 16, 16, 16, 16, 16, 16, -1)], + "64" : [FmhaFwdTileSize( 64, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1), + FmhaFwdTileSize( 64, 32, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)], + "96" : [FmhaFwdTileSize( 64, 128, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)], + "128": [FmhaFwdTileSize( 64, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)], + # "160" : [FmhaFwdTileSize(64, 128, 32, 160, 32, 160, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)], + "256": [FmhaFwdTileSize( 64, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)], } # fmt: skip elif dtype in ["fp8", "bf8"]: return { - "64" : FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 2, 1, 1, 2, 1, 1, 32, 32, 32, 32, 32, 32, -1), - "128": FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1), + "64" : [FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 2, 1, 1, 2, 1, 1, 32, 32, 32, 32, 32, 32, -1)], + "128": [FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1)], } # fmt: skip else: return None @@ -861,16 +862,17 @@ class KernelComponentFactoryGfx12(KernelComponentFactoryBase): if dtype in ["fp16", "bf16"]: return { # bm0, bn0, bk0, bn1, bk1, - "32" : FmhaFwdTileSize( 64, 64, 16, 32, 32, 32, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1), - "64" : FmhaFwdTileSize( 64, 32, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1), - "128": FmhaFwdTileSize( 64, 64, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1), - "256": FmhaFwdTileSize( 64, 64, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1), + "32" : [FmhaFwdTileSize( 64, 64, 16, 32, 32, 32, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)], + "64" : [FmhaFwdTileSize( 64, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1), + FmhaFwdTileSize( 64, 32, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)], + "128": [FmhaFwdTileSize( 64, 64, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)], + "256": [FmhaFwdTileSize( 64, 64, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)], } # fmt: skip elif dtype in ["fp8", "bf8"]: return { # bm0, bn0, bk0, bn1, bk1, - "64" : FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1), - "128": FmhaFwdTileSize( 64, 64, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1), + "64" : [FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)], + "128": [FmhaFwdTileSize( 64, 64, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)], } # fmt: skip else: return None @@ -905,77 +907,78 @@ def get_fwd_splitkv_blobs( continue # for hdim_str, mode, mask, bias, lse in itertools.product(d.keys(), MODE_MAP.keys(), MASK_MAP.keys(), ["t", "f"], ["t", "f"]): for hdim_str, mode in itertools.product(d.keys(), MODE_MAP.keys()): - tile = d[hdim_str] + tiles = d[hdim_str] + if not isinstance(tiles, list): + tiles = [tiles] hdim = int(hdim_str) - for pipeline in factory.get_pipelines(dtype, hdim, mask_impl): - if mode == "group": - if pipeline.F_spad != "t" or pipeline.F_skpad != "t": - # in group mode, spad/skpad must be true, since we can't predict if seqlen of current batch need pad or not - continue - # logits_soft_cap is only allowed if no bias - if not ( - (pipeline.F_logits == "t" and pipeline.F_bias == "no") - or pipeline.F_logits == "f" - ): - continue - k = Kernel( - F_arch=factory.arch, - F_idx=0, - F_hdim=hdim, - F_dtype=dtype, - F_mode=mode, - F_tile=tile, - F_pipeline=pipeline, - mask_impl=mask_impl, - ) - if kernel_filter != "": - if not fnmatch.fnmatch(k.name, kernel_filter): - continue - if optdim_list != [-1]: - if hdim not in optdim_list: - continue - # Flash attention integration - if receipt == 2: - cond = dtype in ["fp16", "bf16"] - cond &= pipeline.F_vlayout == "row" - cond &= pipeline.F_bias in ["no", "alibi"] - cond &= pipeline.F_squant == "f" - cond &= pipeline.F_sink == "f" - if not cond: - continue - # PyTorch integration - elif receipt == 4: - cond = dtype in ["fp16, bf16"] - cond &= pipeline.F_vlayout == "row" - cond &= pipeline.F_bias in ["no", "bias"] - cond &= pipeline.F_squant == "f" - cond &= mode == "batch" - cond &= pipeline.F_sink == "f" - if not cond: - continue - # Aiter(mha_varlen_fwd) integration - elif receipt == 200: - cond = dtype in ["fp16", "bf16"] - cond &= mode == "group" - cond &= pipeline.F_vlayout == "row" - cond &= pipeline.F_squant == "f" - if not cond: - continue - # aiter::mha_fwd_splikv C++ api integration - elif receipt == 600: - cond = dtype in ["fp16", "bf16"] - cond &= pipeline.F_vlayout == "row" - cond &= pipeline.F_squant == "f" - if not cond: + for tile in tiles: + for pipeline in factory.get_pipelines(dtype, hdim, mask_impl): + if mode == "group": + if pipeline.F_spad != "t" or pipeline.F_skpad != "t": + continue + if not ( + (pipeline.F_logits == "t" and pipeline.F_bias == "no") + or pipeline.F_logits == "f" + ): continue + k = Kernel( + F_arch=factory.arch, + F_idx=0, + F_hdim=hdim, + F_dtype=dtype, + F_mode=mode, + F_tile=tile, + F_pipeline=pipeline, + mask_impl=mask_impl, + ) + if kernel_filter != "": + if not fnmatch.fnmatch(k.name, kernel_filter): + continue + if optdim_list != [-1]: + if hdim not in optdim_list: + continue + # Flash attention integration + if receipt == 2: + cond = dtype in ["fp16", "bf16"] + cond &= pipeline.F_vlayout == "row" + cond &= pipeline.F_bias in ["no", "alibi"] + cond &= pipeline.F_squant == "f" + cond &= pipeline.F_sink == "f" + if not cond: + continue + # PyTorch integration + elif receipt == 4: + cond = dtype in ["fp16, bf16"] + cond &= pipeline.F_vlayout == "row" + cond &= pipeline.F_bias in ["no", "bias"] + cond &= pipeline.F_squant == "f" + cond &= mode == "batch" + cond &= pipeline.F_sink == "f" + if not cond: + continue + # Aiter(mha_varlen_fwd) integration + elif receipt == 200: + cond = dtype in ["fp16", "bf16"] + cond &= mode == "group" + cond &= pipeline.F_vlayout == "row" + cond &= pipeline.F_squant == "f" + if not cond: + continue + # aiter::mha_fwd_splikv C++ api integration + elif receipt == 600: + cond = dtype in ["fp16", "bf16"] + cond &= pipeline.F_vlayout == "row" + cond &= pipeline.F_squant == "f" + if not cond: + continue - # fp32 only - if receipt == 800 or receipt == 801: - cond = dtype == "fp32" - if not cond: - continue + # fp32 only + if receipt == 800 or receipt == 801: + cond = dtype == "fp32" + if not cond: + continue - gen.append(k) + gen.append(k) return gen