Merge commit 'bfac64953fd4a91d1f37a473d5849e38a9ce6852' into develop

This commit is contained in:
assistant-librarian[bot]
2025-12-18 08:16:15 +00:00
parent ba29aebebd
commit a84d4d52bd
4 changed files with 154 additions and 28 deletions

View File

@@ -211,11 +211,10 @@ float fmha_fwd(fmha_fwd_traits traits, fmha_fwd_args args, const ck_tile::stream
const bool can_dispatch_v3 =
(device_name.compare(0, 6, "gfx950") == 0) and
(traits.data_type.compare("fp16") == 0 or traits.data_type.compare("bf16") == 0) and
traits.is_v_rowmajor and (not traits.has_logits_soft_cap) and
(traits.bias_type == bias_enum::no_bias) and (not traits.has_lse) and
(not traits.has_dropout) and (traits.qscale_type == quant_scale_enum::no_scale) and
(not is_swa) and (args.nhead_q % args.nhead_k == 0) and (args.hdim_q == 128) and
(args.hdim_v == 128);
traits.is_v_rowmajor and (traits.bias_type == bias_enum::no_bias) and
(not traits.has_lse) and (not traits.has_dropout) and
(traits.qscale_type == quant_scale_enum::no_scale) and (not is_swa) and
(args.nhead_q % args.nhead_k == 0) and (args.hdim_q == 128) and (args.hdim_v == 128);
if ({F_is_v3_enabled} and can_dispatch_v3) {{
return fmha_fwd_v3(traits, args, config);
}} else {{
@@ -1082,9 +1081,9 @@ class KernelComponentFactoryGfx950(
# qr_async_trload_v3 only supports hdim=hdim_v=128 for now
if (hdim, hdim_v) == (128, 128):
# qr_async_trload_v3 only supports (generic) causal mask
for mask in ["no", "causal"]:
for logits, mask in itertools.product(["t", "f"], ["no", "causal"]):
pipelines.append(FmhaFwdPipeline("qr_async_trload_v3", "row", "t", "t", "f", "f",
F_logits="f", F_bias="no", F_lse="f", F_dropout="f", F_qscale=qscale, F_mask=mask, F_skip="f", F_trload="t", F_sink="f")) # fmt: skip
F_logits=logits, F_bias="no", F_lse="f", F_dropout="f", F_qscale=qscale, F_mask=mask, F_skip="f", F_trload="t", F_sink="f")) # fmt: skip
return pipelines

View File

@@ -728,6 +728,7 @@ auto fmha_fwd_v3_create_kargs_and_grids(fmha_fwd_args args)
args.nhead_q,
args.nhead_q / args.nhead_k,
args.scale_s,
args.logits_soft_cap,
args.stride_q,
args.stride_k,
args.stride_v,
@@ -758,6 +759,7 @@ auto fmha_fwd_v3_create_kargs_and_grids(fmha_fwd_args args)
args.nhead_q,
args.nhead_q / args.nhead_k,
args.scale_s,
args.logits_soft_cap,
args.stride_q,
args.stride_k,
args.stride_v,