mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 02:02:46 +00:00
Fix FMHA fp8 hdim=64 incorrect result in MI200 (#3423)
* Fix incorrect result in hdim=64
* Add change log
[ROCm/composable_kernel commit: 292f87aa03]
This commit is contained in:
@@ -17,6 +17,7 @@ Documentation for Composable Kernel available at [https://rocm.docs.amd.com/proj
|
||||
## Composable Kernel 1.2.0 for ROCm 7.2.0
|
||||
|
||||
### Added
|
||||
* Added support for fp8 dynamic tensor-wise quantization of fp8 fmha fwd kernel.
|
||||
* Added support for bf16 data type to grouped_gemm and grouped_gemm_preshuffle.
|
||||
* Added Col-Col-Row-Col layout support for aquant mode in blockscale GEMM.
|
||||
* Added support for mixed precision fp8 x bf8 universal GEMM and weight preshuffle GEMM
|
||||
@@ -91,7 +92,7 @@ Documentation for Composable Kernel available at [https://rocm.docs.amd.com/proj
|
||||
### Optimized
|
||||
|
||||
* Optimize the gemm multiply multiply preshuffle & lds bypass with Pack of KGroup and better instruction layout.
|
||||
* Added Vectorize Transpose optimization for CK Tile
|
||||
* Added Vectorize Transpose optimization for CK Tile
|
||||
* Added the asynchronous copy for gfx950
|
||||
|
||||
### Changed
|
||||
|
||||
@@ -1014,8 +1014,12 @@ class KernelComponentFactoryGfx9(CompatibilityRuleFactoryGfx9):
|
||||
["no"],
|
||||
["f", "t"],
|
||||
):
|
||||
pipelines.append(FmhaFwdPipeline("qr_async", "row", "t", "f", "t", "t", logits, bias, "f", "f", qscale, mask, "f", "f", sink)) # fmt: skip
|
||||
pipelines.append(FmhaFwdPipeline("qr_async", "row", "t", "t", "t", "t", logits, bias, "f", "f", qscale, mask, "f", "f", sink)) # fmt: skip
|
||||
if hdim == 64:
|
||||
pipelines.append(FmhaFwdPipeline("qr", "row", "t", "f", "t", "t", logits, bias, "f", "f", qscale, mask, "f", "f", sink)) # fmt: skip
|
||||
pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "t", "t", logits, bias, "f", "f", qscale, mask, "f", "f", sink)) # fmt: skip
|
||||
else:
|
||||
pipelines.append(FmhaFwdPipeline("qr_async", "row", "t", "f", "t", "t", logits, bias, "f", "f", qscale, mask, "f", "f", sink)) # fmt: skip
|
||||
pipelines.append(FmhaFwdPipeline("qr_async", "row", "t", "t", "t", "t", logits, bias, "f", "f", qscale, mask, "f", "f", sink)) # fmt: skip
|
||||
elif dtype in ["fp8", "fp8fp16", "bf8"]:
|
||||
# TODO
|
||||
pass
|
||||
|
||||
Reference in New Issue
Block a user