diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py index 1d0d04df77..e888cbd383 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py @@ -1134,6 +1134,7 @@ class KernelComponentFactoryGfx11(CompatibilityRuleFactory): ["t", "f"], ): pipelines.append(FmhaFwdPipeline("qr", "row", "f", "f", "f", "f", logits, bias, lse, dropout, qscale, mask, skip, "f", sink)) # fmt: skip + pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "f", "f", logits, bias, lse, dropout, qscale, mask, skip, "f", sink)) # fmt: skip pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "t", "t", logits, bias, lse, dropout, qscale, mask, skip, "f", sink)) # fmt: skip return pipelines @@ -1192,6 +1193,7 @@ class KernelComponentFactoryGfx12(CompatibilityRuleFactory): ["t", "f"], ): pipelines.append(FmhaFwdPipeline("qr", "row", "f", "f", "f", "f", logits, bias, lse, dropout, qscale, mask, skip, "f", sink)) # fmt: skip + pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "f", "f", logits, bias, lse, dropout, qscale, mask, skip, "f", sink)) # fmt: skip pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "t", "t", logits, bias, lse, dropout, qscale, mask, skip, "f", sink)) # fmt: skip elif dtype in cls._DT_FP8_FP8BF16 or dtype in cls._DT_FP8FP32: # no need lse/dropout kernels @@ -1202,6 +1204,7 @@ class KernelComponentFactoryGfx12(CompatibilityRuleFactory): ["no"], ): pipelines.append(FmhaFwdPipeline("qr", "row", "f", "f", "f", "f", logits, bias, "f", "f", qscale, mask, "f", "f", "f")) # fmt: skip + pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "f", "f", logits, bias, "f", "f", qscale, mask, "f", "f", "f")) # fmt: skip pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "t", "t", logits, bias, "f", "f", qscale, mask, "f", "f", "f")) # fmt: skip return pipelines diff --git a/example/ck_tile/01_fmha/mask.hpp b/example/ck_tile/01_fmha/mask.hpp index bcaa7d596d..03e1537c5d 100644 --- a/example/ck_tile/01_fmha/mask.hpp +++ b/example/ck_tile/01_fmha/mask.hpp @@ -178,11 +178,11 @@ struct mask_info return tmp; } - ck_tile::index_t get_unmaskarea() const + std::size_t get_unmaskarea() const { if(type == mask_enum::no_mask) - return seqlen_q * seqlen_k; - ck_tile::index_t area = 0; + return static_cast(seqlen_q) * seqlen_k; + std::size_t area = 0; for(ck_tile::index_t i_y = 0; i_y < seqlen_q; ++i_y) { ck_tile::index_t x_start = std::max(-y + i_y + 1, static_cast(0));