mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-17 03:19:48 +00:00
[CK_TILE][FMHA] Add logits soft-capping support for FAv3 (WIP) (#3355)
* Let fmha_fwd_v3() compatible with fmha_fwd()
* Decouple get_fwd_blobs() and FmhaFwdKernel
* Decouple compatibility checks from get_fwd_blobs()
* Extract product feature checks out from get_fwd_blobs()
* Remove duplicated code in factories and redundant checks
* Remove FmhaFwdKernel<>::GetName()
* Let FmhaFwdApiPool support pipelines with different mask_impl
* Add tile setting for fmha fwd v3 pipeline
* Add fwd v3 instances to tile_example_fmha_fwd manually
* Remove unused function import
* Undo irrelevant changes
* Remove fwd v3 instances from tile_example_fmha_fwd
* Finish fmha fwd v3 kernel instance codegen
* Fix formatting
* Remove unused F_idx attribute
* Add is_generic_attention_mask<> traits
* Add constraints to the fmha fwd v3 pipeline
* Unify traits & problem used for fmha fwd v3
* Unify kernel launch code for fmha fwd v2 & v3
* Unify kernel template selection logic
* Use same kernel codegen template for both v2 & v3
* Rename api() property as render() method
* Allow specifying filter for fmha fwd api pool
* Allow specifying function name when rendering api pool items
* Separate fmha fwd v3 kernel dispatching logic from v2
* Remove lambda assignment
* Add simple v2/v3 dispatch logic
* Stop generating empty if-clauses
Skip iterating over dictionaries that have no traits, and avoid assigning i_* to them.
* Use "".join() to concatenate fmha fwd api string content
* Add more feature checks for fmha fwd v3 pipeline
* Check features before dispatch to fmha_fwd_v3()
* Add more feature checks for fmha_fwd_v3()
* Add missing filter call
* Use Tuple to reserve the dtype orders
* Fix wrong pipeline matching logic
* Add fmha fwd v3 group mode instances
* Add functor_transform<>
* Add type constraints to make_tile_window()
* Remove fmha fwd v3 example
* Fix wrong product(aiter mha_fwd()) config
* Fix wrong fmha fwd v2/v3 selection logic
* Fix formatting
* Add comment to warning v3 kernel users
* Fix wrong codegen logics
* Remove unnecessary param
* Fix format
* Add logits soft-capping support for fmha fwd v3 pipeline (WIP)
* Add missing Kargs base type
---------
Co-authored-by: Illia Silin <98187287+illsilin@users.noreply.github.com>
[ROCm/composable_kernel commit: bfac64953f]
This commit is contained in:
@@ -211,11 +211,10 @@ float fmha_fwd(fmha_fwd_traits traits, fmha_fwd_args args, const ck_tile::stream
|
||||
const bool can_dispatch_v3 =
|
||||
(device_name.compare(0, 6, "gfx950") == 0) and
|
||||
(traits.data_type.compare("fp16") == 0 or traits.data_type.compare("bf16") == 0) and
|
||||
traits.is_v_rowmajor and (not traits.has_logits_soft_cap) and
|
||||
(traits.bias_type == bias_enum::no_bias) and (not traits.has_lse) and
|
||||
(not traits.has_dropout) and (traits.qscale_type == quant_scale_enum::no_scale) and
|
||||
(not is_swa) and (args.nhead_q % args.nhead_k == 0) and (args.hdim_q == 128) and
|
||||
(args.hdim_v == 128);
|
||||
traits.is_v_rowmajor and (traits.bias_type == bias_enum::no_bias) and
|
||||
(not traits.has_lse) and (not traits.has_dropout) and
|
||||
(traits.qscale_type == quant_scale_enum::no_scale) and (not is_swa) and
|
||||
(args.nhead_q % args.nhead_k == 0) and (args.hdim_q == 128) and (args.hdim_v == 128);
|
||||
if ({F_is_v3_enabled} and can_dispatch_v3) {{
|
||||
return fmha_fwd_v3(traits, args, config);
|
||||
}} else {{
|
||||
@@ -1082,9 +1081,9 @@ class KernelComponentFactoryGfx950(
|
||||
# qr_async_trload_v3 only supports hdim=hdim_v=128 for now
|
||||
if (hdim, hdim_v) == (128, 128):
|
||||
# qr_async_trload_v3 only supports (generic) causal mask
|
||||
for mask in ["no", "causal"]:
|
||||
for logits, mask in itertools.product(["t", "f"], ["no", "causal"]):
|
||||
pipelines.append(FmhaFwdPipeline("qr_async_trload_v3", "row", "t", "t", "f", "f",
|
||||
F_logits="f", F_bias="no", F_lse="f", F_dropout="f", F_qscale=qscale, F_mask=mask, F_skip="f", F_trload="t", F_sink="f")) # fmt: skip
|
||||
F_logits=logits, F_bias="no", F_lse="f", F_dropout="f", F_qscale=qscale, F_mask=mask, F_skip="f", F_trload="t", F_sink="f")) # fmt: skip
|
||||
|
||||
return pipelines
|
||||
|
||||
|
||||
@@ -728,6 +728,7 @@ auto fmha_fwd_v3_create_kargs_and_grids(fmha_fwd_args args)
|
||||
args.nhead_q,
|
||||
args.nhead_q / args.nhead_k,
|
||||
args.scale_s,
|
||||
args.logits_soft_cap,
|
||||
args.stride_q,
|
||||
args.stride_k,
|
||||
args.stride_v,
|
||||
@@ -758,6 +759,7 @@ auto fmha_fwd_v3_create_kargs_and_grids(fmha_fwd_args args)
|
||||
args.nhead_q,
|
||||
args.nhead_q / args.nhead_k,
|
||||
args.scale_s,
|
||||
args.logits_soft_cap,
|
||||
args.stride_q,
|
||||
args.stride_k,
|
||||
args.stride_v,
|
||||
|
||||
Reference in New Issue
Block a user