From 4a6638adcfef6993f0c84713f9c8c9db61e4d557 Mon Sep 17 00:00:00 2001 From: Anton Gorenko Date: Wed, 4 Mar 2026 09:50:05 +0500 Subject: [PATCH] [CK_TILE][FMHA] Extend pipelines with pssk for gfx11/12 (#4957) ## Motivation Build pipelines with seqlen padding only to support vectorized loads in the hdim dimension. The existing pipelines have either all dims padded or all dims not padded. These pipelines can be used in ComfyUI for slightly better performance. ## Technical Details Also a fix included for correct FLOPS calculation in `tile_example_fmha_fwd` when `seqlen_q * seqlen_k` overflows index_t capacity (signed int32). ## Test Plan The existing test cases will use the new pipelines when parameters allow (seqlens - padded, hdims - not padded): ``` ninja test_ck_tile_fmha_fwd bin/test_ck_tile_fmha_fwd_fp16 bin/test_ck_tile_fmha_fwd_bf16 bin/test_ck_tile_fmha_fwd_fp8bf16 # for gfx12 ``` ## Test Result All tests must pass. ## Submission Checklist - [x] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests. --- example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py | 3 +++ example/ck_tile/01_fmha/mask.hpp | 6 +++--- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py index 1d0d04df77..e888cbd383 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py @@ -1134,6 +1134,7 @@ class KernelComponentFactoryGfx11(CompatibilityRuleFactory): ["t", "f"], ): pipelines.append(FmhaFwdPipeline("qr", "row", "f", "f", "f", "f", logits, bias, lse, dropout, qscale, mask, skip, "f", sink)) # fmt: skip + pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "f", "f", logits, bias, lse, dropout, qscale, mask, skip, "f", sink)) # fmt: skip pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "t", "t", logits, bias, lse, dropout, qscale, mask, skip, "f", sink)) # fmt: skip return pipelines @@ -1192,6 +1193,7 @@ class KernelComponentFactoryGfx12(CompatibilityRuleFactory): ["t", "f"], ): pipelines.append(FmhaFwdPipeline("qr", "row", "f", "f", "f", "f", logits, bias, lse, dropout, qscale, mask, skip, "f", sink)) # fmt: skip + pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "f", "f", logits, bias, lse, dropout, qscale, mask, skip, "f", sink)) # fmt: skip pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "t", "t", logits, bias, lse, dropout, qscale, mask, skip, "f", sink)) # fmt: skip elif dtype in cls._DT_FP8_FP8BF16 or dtype in cls._DT_FP8FP32: # no need lse/dropout kernels @@ -1202,6 +1204,7 @@ class KernelComponentFactoryGfx12(CompatibilityRuleFactory): ["no"], ): pipelines.append(FmhaFwdPipeline("qr", "row", "f", "f", "f", "f", logits, bias, "f", "f", qscale, mask, "f", "f", "f")) # fmt: skip + pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "f", "f", logits, bias, "f", "f", qscale, mask, "f", "f", "f")) # fmt: skip pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "t", "t", logits, bias, "f", "f", qscale, mask, "f", "f", "f")) # fmt: skip return pipelines diff --git a/example/ck_tile/01_fmha/mask.hpp b/example/ck_tile/01_fmha/mask.hpp index bcaa7d596d..03e1537c5d 100644 --- a/example/ck_tile/01_fmha/mask.hpp +++ b/example/ck_tile/01_fmha/mask.hpp @@ -178,11 +178,11 @@ struct mask_info return tmp; } - ck_tile::index_t get_unmaskarea() const + std::size_t get_unmaskarea() const { if(type == mask_enum::no_mask) - return seqlen_q * seqlen_k; - ck_tile::index_t area = 0; + return static_cast(seqlen_q) * seqlen_k; + std::size_t area = 0; for(ck_tile::index_t i_y = 0; i_y < seqlen_q; ++i_y) { ck_tile::index_t x_start = std::max(-y + i_y + 1, static_cast(0));