Fix FMHA fp8 hdim=64 incorrect result in MI200 (#3423)

* Fix incorrect result in hdim=64

* Add change log

[ROCm/composable_kernel commit: 292f87aa03]
This commit is contained in:
rocking
2025-12-18 00:16:54 +08:00
committed by GitHub
parent 2de39368c2
commit 97b2015929
2 changed files with 8 additions and 3 deletions

View File

@@ -17,6 +17,7 @@ Documentation for Composable Kernel available at [https://rocm.docs.amd.com/proj
## Composable Kernel 1.2.0 for ROCm 7.2.0
### Added
* Added support for fp8 dynamic tensor-wise quantization of fp8 fmha fwd kernel.
* Added support for bf16 data type to grouped_gemm and grouped_gemm_preshuffle.
* Added Col-Col-Row-Col layout support for aquant mode in blockscale GEMM.
* Added support for mixed precision fp8 x bf8 universal GEMM and weight preshuffle GEMM
@@ -91,7 +92,7 @@ Documentation for Composable Kernel available at [https://rocm.docs.amd.com/proj
### Optimized
* Optimize the gemm multiply multiply preshuffle & lds bypass with Pack of KGroup and better instruction layout.
* Added Vectorize Transpose optimization for CK Tile
* Added Vectorize Transpose optimization for CK Tile
* Added the asynchronous copy for gfx950
### Changed

View File

@@ -1014,8 +1014,12 @@ class KernelComponentFactoryGfx9(CompatibilityRuleFactoryGfx9):
["no"],
["f", "t"],
):
pipelines.append(FmhaFwdPipeline("qr_async", "row", "t", "f", "t", "t", logits, bias, "f", "f", qscale, mask, "f", "f", sink)) # fmt: skip
pipelines.append(FmhaFwdPipeline("qr_async", "row", "t", "t", "t", "t", logits, bias, "f", "f", qscale, mask, "f", "f", sink)) # fmt: skip
if hdim == 64:
pipelines.append(FmhaFwdPipeline("qr", "row", "t", "f", "t", "t", logits, bias, "f", "f", qscale, mask, "f", "f", sink)) # fmt: skip
pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "t", "t", logits, bias, "f", "f", qscale, mask, "f", "f", sink)) # fmt: skip
else:
pipelines.append(FmhaFwdPipeline("qr_async", "row", "t", "f", "t", "t", logits, bias, "f", "f", qscale, mask, "f", "f", sink)) # fmt: skip
pipelines.append(FmhaFwdPipeline("qr_async", "row", "t", "t", "t", "t", logits, bias, "f", "f", qscale, mask, "f", "f", sink)) # fmt: skip
elif dtype in ["fp8", "fp8fp16", "bf8"]:
# TODO
pass