From 4016b27f5aa8026399c1945e09484c97bf344faf Mon Sep 17 00:00:00 2001 From: rocking Date: Thu, 18 Dec 2025 00:16:54 +0800 Subject: [PATCH] Fix FMHA fp8 hdim=64 incorrect result in MI200 (#3423) * Fix incorrect result in hdim=64 * Add change log [ROCm/composable_kernel commit: 292f87aa03a97be56082be95ab593160c3910629] --- CHANGELOG.md | 3 ++- example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py | 8 ++++++-- 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index a69ce2260e..b502bfaf3b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -17,6 +17,7 @@ Documentation for Composable Kernel available at [https://rocm.docs.amd.com/proj ## Composable Kernel 1.2.0 for ROCm 7.2.0 ### Added +* Added support for fp8 dynamic tensor-wise quantization of fp8 fmha fwd kernel. * Added support for bf16 data type to grouped_gemm and grouped_gemm_preshuffle. * Added Col-Col-Row-Col layout support for aquant mode in blockscale GEMM. * Added support for mixed precision fp8 x bf8 universal GEMM and weight preshuffle GEMM @@ -91,7 +92,7 @@ Documentation for Composable Kernel available at [https://rocm.docs.amd.com/proj ### Optimized * Optimize the gemm multiply multiply preshuffle & lds bypass with Pack of KGroup and better instruction layout. -* Added Vectorize Transpose optimization for CK Tile +* Added Vectorize Transpose optimization for CK Tile * Added the asynchronous copy for gfx950 ### Changed diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py index 4d6900a802..d157a165fc 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py @@ -1014,8 +1014,12 @@ class KernelComponentFactoryGfx9(CompatibilityRuleFactoryGfx9): ["no"], ["f", "t"], ): - pipelines.append(FmhaFwdPipeline("qr_async", "row", "t", "f", "t", "t", logits, bias, "f", "f", qscale, mask, "f", "f", sink)) # fmt: skip - pipelines.append(FmhaFwdPipeline("qr_async", "row", "t", "t", "t", "t", logits, bias, "f", "f", qscale, mask, "f", "f", sink)) # fmt: skip + if hdim == 64: + pipelines.append(FmhaFwdPipeline("qr", "row", "t", "f", "t", "t", logits, bias, "f", "f", qscale, mask, "f", "f", sink)) # fmt: skip + pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "t", "t", logits, bias, "f", "f", qscale, mask, "f", "f", sink)) # fmt: skip + else: + pipelines.append(FmhaFwdPipeline("qr_async", "row", "t", "f", "t", "t", logits, bias, "f", "f", qscale, mask, "f", "f", sink)) # fmt: skip + pipelines.append(FmhaFwdPipeline("qr_async", "row", "t", "t", "t", "t", logits, bias, "f", "f", qscale, mask, "f", "f", sink)) # fmt: skip elif dtype in ["fp8", "fp8fp16", "bf8"]: # TODO pass