[CK_TILE][FMHA] Integrate FAv2 & FAv3 (WIP) in the single fmha_fwd() API (#3153)

* Let fmha_fwd_v3() compatible with fmha_fwd()

* Decouple get_fwd_blobs() and FmhaFwdKernel

* Decouple compatibility checks from get_fwd_blobs()

* Extract product feature checks out from get_fwd_blobs()

* Remove duplicated code in factories and redundant checks

* Remove FmhaFwdKernel<>::GetName()

* Let FmhaFwdApiPool support pipelines with different mask_impl

* Add tile setting for fmha fwd v3 pipeline

* Add fwd v3 instances to tile_example_fmha_fwd manually

* Remove unused function import

* Undo irrelevant changes

* Remove fwd v3 instances from tile_example_fmha_fwd

* Finish fmha fwd v3 kernel instance codegen

* Fix formatting

* Remove unused F_idx attribute

* Add is_generic_attention_mask<> traits

* Add constraints to the fmha fwd v3 pipeline

* Unify traits & problem used for fmha fwd v3

* Unify kernel launch code for fmha fwd v2 & v3

* Unify kernel template selection logic

* Use same kernel codegen template for both v2 & v3

* Rename api() property as render() method

* Allow specifying filter for fmha fwd api pool

* Allow specifying function name when rendering api pool items

* Separate fmha fwd v3 kernel dispatching logic from v2

* Remove lambda assignment

* Add simple v2/v3 dispatch logic

* Stop generating empty if-clauses

Skip iterating over dictionaries that have no traits, and avoid assigning i_* to them.

* Use "".join() to concatenate fmha fwd api string content

* Add more feature checks for fmha fwd v3 pipeline

* Check features before dispatch to fmha_fwd_v3()

* Add more feature checks for fmha_fwd_v3()

* Add missing filter call

* Use Tuple to reserve the dtype orders

* Fix wrong pipeline matching logic

* Add fmha fwd v3 group mode instances

* Add functor_transform<>

* Add type constraints to make_tile_window()

* Remove fmha fwd v3 example

* Fix wrong product(aiter mha_fwd()) config

* Fix wrong fmha fwd v2/v3 selection logic

* Fix formatting

* Add comment to warning v3 kernel users

* Fix wrong codegen logics

* Remove unnecessary param

* Fix format

---------

Co-authored-by: Illia Silin <98187287+illsilin@users.noreply.github.com>
This commit is contained in:
Po Yen Chen
2025-12-05 10:31:12 +08:00
committed by GitHub
parent d1193e8637
commit 05292b3604
22 changed files with 890 additions and 1449 deletions

View File

@@ -30,16 +30,24 @@ _MASK_MAP = {
}
def get_mask_map(mask: str):
if mask == "generic":
def get_mask_map(mask_impl: str):
if mask_impl == "generic":
return _MASK_MAP
elif mask == "simplified":
elif mask_impl == "simplified":
return _MASK_SIMPLIFIED_MAP
else:
assert False
return None
def get_mask_impl(mask: str) -> str:
return "simplified" if mask.startswith("s_") else "generic"
def get_mask_cpp_type(mask: str) -> str:
return get_mask_map(get_mask_impl(mask))[mask]
_MASK_CHECK_MAP = {
"no": "t.mask_type == mask_enum::no_mask",
"causal": "t.mask_type == mask_enum::mask_top_left || t.mask_type == mask_enum::mask_bottom_right",
@@ -62,6 +70,10 @@ def get_mask_check_map(mask: str):
return None
def get_mask_cpp_check_expr(mask: str) -> str:
return get_mask_check_map(get_mask_impl(mask))[mask]
QSCALE_MAP = {
"no": "ck_tile::BlockAttentionQuantScaleEnum::NO_SCALE",
"pertensor": "ck_tile::BlockAttentionQuantScaleEnum::PERTENSOR",
@@ -122,6 +134,7 @@ PIPELINE_MAP = {
"qr_async": "ck_tile::BlockFmhaPipelineQRKSVSAsync",
"qs": "ck_tile::BlockFmhaPipelineQSKSVS",
"qr_async_trload": "ck_tile::BlockFmhaPipelineQRKSVSAsyncTrload",
"qr_async_trload_v3": "ck_tile::BlockFmhaFwdV3Pipeline",
}
PIPELINE_ENUM_MAP = {
@@ -131,6 +144,7 @@ PIPELINE_ENUM_MAP = {
"qs": "ck_tile::BlockFmhaPipelineEnum::QSKSVS",
"qr_pagedkv": "ck_tile::BlockFmhaPipelineEnum::QRKSVS",
"qr_async_trload": "ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC_TRLOAD",
"qr_async_trload_v3": "ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC_TRLOAD_V3",
}
BOOL_MAP = {

File diff suppressed because it is too large Load Diff