[CK_TILE][FMHA] Add logits soft-capping support for FAv3 (WIP) (#3355)

* Let fmha_fwd_v3() compatible with fmha_fwd()

* Decouple get_fwd_blobs() and FmhaFwdKernel

* Decouple compatibility checks from get_fwd_blobs()

* Extract product feature checks out from get_fwd_blobs()

* Remove duplicated code in factories and redundant checks

* Remove FmhaFwdKernel<>::GetName()

* Let FmhaFwdApiPool support pipelines with different mask_impl

* Add tile setting for fmha fwd v3 pipeline

* Add fwd v3 instances to tile_example_fmha_fwd manually

* Remove unused function import

* Undo irrelevant changes

* Remove fwd v3 instances from tile_example_fmha_fwd

* Finish fmha fwd v3 kernel instance codegen

* Fix formatting

* Remove unused F_idx attribute

* Add is_generic_attention_mask<> traits

* Add constraints to the fmha fwd v3 pipeline

* Unify traits & problem used for fmha fwd v3

* Unify kernel launch code for fmha fwd v2 & v3

* Unify kernel template selection logic

* Use same kernel codegen template for both v2 & v3

* Rename api() property as render() method

* Allow specifying filter for fmha fwd api pool

* Allow specifying function name when rendering api pool items

* Separate fmha fwd v3 kernel dispatching logic from v2

* Remove lambda assignment

* Add simple v2/v3 dispatch logic

* Stop generating empty if-clauses

Skip iterating over dictionaries that have no traits, and avoid assigning i_* to them.

* Use "".join() to concatenate fmha fwd api string content

* Add more feature checks for fmha fwd v3 pipeline

* Check features before dispatch to fmha_fwd_v3()

* Add more feature checks for fmha_fwd_v3()

* Add missing filter call

* Use Tuple to reserve the dtype orders

* Fix wrong pipeline matching logic

* Add fmha fwd v3 group mode instances

* Add functor_transform<>

* Add type constraints to make_tile_window()

* Remove fmha fwd v3 example

* Fix wrong product(aiter mha_fwd()) config

* Fix wrong fmha fwd v2/v3 selection logic

* Fix formatting

* Add comment to warning v3 kernel users

* Fix wrong codegen logics

* Remove unnecessary param

* Fix format

* Add logits soft-capping support for fmha fwd v3 pipeline (WIP)

* Add missing Kargs base type

---------

Co-authored-by: Illia Silin <98187287+illsilin@users.noreply.github.com>
This commit is contained in:
Po Yen Chen
2025-12-18 16:08:45 +08:00
committed by GitHub
parent bb8445dca8
commit bfac64953f
4 changed files with 154 additions and 28 deletions

View File

@@ -211,11 +211,10 @@ float fmha_fwd(fmha_fwd_traits traits, fmha_fwd_args args, const ck_tile::stream
const bool can_dispatch_v3 =
(device_name.compare(0, 6, "gfx950") == 0) and
(traits.data_type.compare("fp16") == 0 or traits.data_type.compare("bf16") == 0) and
traits.is_v_rowmajor and (not traits.has_logits_soft_cap) and
(traits.bias_type == bias_enum::no_bias) and (not traits.has_lse) and
(not traits.has_dropout) and (traits.qscale_type == quant_scale_enum::no_scale) and
(not is_swa) and (args.nhead_q % args.nhead_k == 0) and (args.hdim_q == 128) and
(args.hdim_v == 128);
traits.is_v_rowmajor and (traits.bias_type == bias_enum::no_bias) and
(not traits.has_lse) and (not traits.has_dropout) and
(traits.qscale_type == quant_scale_enum::no_scale) and (not is_swa) and
(args.nhead_q % args.nhead_k == 0) and (args.hdim_q == 128) and (args.hdim_v == 128);
if ({F_is_v3_enabled} and can_dispatch_v3) {{
return fmha_fwd_v3(traits, args, config);
}} else {{
@@ -1082,9 +1081,9 @@ class KernelComponentFactoryGfx950(
# qr_async_trload_v3 only supports hdim=hdim_v=128 for now
if (hdim, hdim_v) == (128, 128):
# qr_async_trload_v3 only supports (generic) causal mask
for mask in ["no", "causal"]:
for logits, mask in itertools.product(["t", "f"], ["no", "causal"]):
pipelines.append(FmhaFwdPipeline("qr_async_trload_v3", "row", "t", "t", "f", "f",
F_logits="f", F_bias="no", F_lse="f", F_dropout="f", F_qscale=qscale, F_mask=mask, F_skip="f", F_trload="t", F_sink="f")) # fmt: skip
F_logits=logits, F_bias="no", F_lse="f", F_dropout="f", F_qscale=qscale, F_mask=mask, F_skip="f", F_trload="t", F_sink="f")) # fmt: skip
return pipelines