mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-03 13:11:25 +00:00
[rocm-libraries] ROCm/rocm-libraries#6051 (commit f0838b2)
[CK] Add FP8 per-tensor quantization support for FMHA V3 pipeline (#6051) ## Motivation The existing FMHA V3 pipeline only supports fp16/bf16 data types. This PR extends V3 to handle FP8 inputs with per-tensor descaling on gfx950, enabling higher throughput for FP8 inference workloads using the assembly-optimized V3 code path. ## Technical Details **Warp GEMM:** - Add FP8 32x32x32 warp gemm with C-transposed distribution (`WarpGemmMfma_f32_32x32x32_fp8_fp8_CTransposed`) and dispatcher entries **V3 Kernel (`fmha_fwd_v3_kernel.hpp`):** - Add per-tensor descale support for Q, K, V tensors, passing descale pointers through to pipeline kargs **V3 Pipeline (`block_fmha_fwd_v3_pipeline.hpp`):** - Add FP8 data path with dtype-aware type selection - Add asm volatile P matrix conversion from f32 to fp8 - Add FP8-aware instruction scheduling in `CoreLoopScheduler` **V3 Pipeline Policy (`block_fmha_fwd_v3_pipeline_default_policy.hpp`):** - Add FP8 QK warp gemm selection (SwizzleB variant for V tile distribution compatibility) **Codegen (`fmha_fwd.py`):** - Add gfx950 FP8BF16 V3 tile size (256x64x128x128x64x128) - Add FP8BF16 V3 pipeline variants (mask: no/causal, qscale: no/pertensor) - Extend `can_dispatch_v3` condition for fp8bf16 + pertensor **Misc:** - Add LLVM scheduler `TRANS` mask to `LLVMSchedGroupMask` enum (`arch.hpp`) - Fix `mask_info` default initialization for `no_mask` case (`mask.hpp`) V3 dispatch for FP8 is disabled by default (`F_is_v3_enabled=false`) pending further validation. ## Performance: fmha_fwd V3 FP8 (avg runs 2-6, stock ROCm 7.1.1, gfx950) | Problem | Regular (TFlops) | Varlen (TFlops) | |---|---:|---:| | batch=1 heads=6/1 seqlen=1024 causal | 48.9 | 47.6 | | batch=1 heads=6/1 seqlen=2048 causal | 119.8 | 117.4 | | batch=1 heads=6/1 seqlen=4096 causal | 263.7 | 259.2 | | batch=1 heads=6/1 seqlen=8192 causal | 548.9 | 543.6 | | batch=1 heads=6/1 seqlen=16384 causal | 1043.0 | 1063.7 | | batch=1 heads=6/1 seqlen=32768 causal | 1237.2 | 1279.6 | | batch=1 heads=6/1 seqlen=65536 causal | 1315.4 | 1382.7 | | batch=1 heads=6/1 seqlen=131072 causal | 1326.3 | 1402.2 | | batch=1 heads=16/1 seqlen=65536 causal | 1298.7 | 1388.4 | | batch=1 heads=40/40 seqlen=37200 non-causal | 1248.9 | 1326.1 | ## Test Plan Tested with aiter's `test_mha_fp8.py` test suite (176 cases) covering batch sizes (1-2), sequence lengths (113-4096), head counts (5/8/32/40), GQA ratios (1:1, 1:8), and causal/non-causal modes. Verified all cases dispatch to the V3 pipeline by enabling `F_is_v3_enabled` and confirming kernel names contain `qr_async_trload_v3`. ## Test Result 176/176 tests passed with V3 enabled. All cases correctly dispatched to V3 pipeline with `pertensor` quantization. ## Submission Checklist - [x] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
This commit is contained in:
committed by
assistant-librarian[bot]
parent
020b6f435e
commit
c2ac7aa7b0
@@ -206,22 +206,14 @@ float {F_func_name}([[maybe_unused]] fmha_fwd_traits t, [[maybe_unused]] fmha_fw
|
||||
"""
|
||||
FMHA_FWD_API_FOOTER_TEMPLATE = """
|
||||
float fmha_fwd(fmha_fwd_traits traits, fmha_fwd_args args, const ck_tile::stream_config& config) {{
|
||||
const std::string device_name = ck_tile::get_device_name();
|
||||
|
||||
const bool is_swa = (traits.mask_type != mask_enum::no_mask) and
|
||||
((0 < args.window_size_left) or (0 < args.window_size_right));
|
||||
const bool can_dispatch_v3 =
|
||||
(device_name.compare(0, 6, "gfx950") == 0) and
|
||||
(traits.data_type.compare("fp16") == 0 or traits.data_type.compare("bf16") == 0) and
|
||||
traits.is_v_rowmajor and (traits.bias_type == bias_enum::no_bias) and
|
||||
(not traits.has_lse) and (not traits.has_dropout) and
|
||||
(traits.qscale_type == quant_scale_enum::no_scale) and (not is_swa) and
|
||||
(args.nhead_q % args.nhead_k == 0) and (args.hdim_q == 128) and (args.hdim_v == 128);
|
||||
if ({F_is_v3_enabled} and can_dispatch_v3) {{
|
||||
return fmha_fwd_v3(traits, args, config);
|
||||
}} else {{
|
||||
return fmha_fwd_v2(traits, args, config);
|
||||
#pragma clang diagnostic push
|
||||
#pragma clang diagnostic ignored "-Wunreachable-code"
|
||||
if ({F_is_v3_enabled}) {{
|
||||
float r = fmha_fwd_v3(traits, args, config);
|
||||
if (r >= 0) return r;
|
||||
}}
|
||||
#pragma clang diagnostic pop
|
||||
return fmha_fwd_v2(traits, args, config);
|
||||
}}
|
||||
"""
|
||||
|
||||
@@ -1059,10 +1051,11 @@ class KernelComponentFactoryGfx950(
|
||||
def get_hdim_tile_size_dict(cls, dtype: str) -> Optional[dict]:
|
||||
result = KernelComponentFactoryGfx9.get_hdim_tile_size_dict(dtype)
|
||||
if dtype in cls._DT_FP16_BF16:
|
||||
# add tile for qr_async_trload_v3
|
||||
if (128, 128) in result.keys():
|
||||
result[(128, 128)].append(
|
||||
FmhaFwdTileSize(256, 32, 128, 128, 32, 128, 8, 1, 1, 8, 1, 1, 32, 32, 16, 32, 32, 16, -1)) # fmt: skip
|
||||
# # add tile for qr_async_trload_v3 (bf16/fp16 V3 not ready)
|
||||
# if (128, 128) in result.keys():
|
||||
# result[(128, 128)].append(
|
||||
# FmhaFwdTileSize(256, 32, 128, 128, 32, 128, 8, 1, 1, 8, 1, 1, 32, 32, 16, 32, 32, 16, -1)) # fmt: skip
|
||||
pass
|
||||
elif dtype in cls._DT_MXFP8:
|
||||
return {
|
||||
# bm0, bn0, bk0, bn1, bk1,
|
||||
@@ -1075,6 +1068,10 @@ class KernelComponentFactoryGfx950(
|
||||
(128, 128) : [FmhaFwdTileSize(128, 128, 64, 128, 64, 128, 4, 1, 1, 4, 1, 1, 32, 32, 64, 32, 32, 64, -1)],
|
||||
(256, 256) : [FmhaFwdTileSize(128, 128, 128, 256, 128, 256, 4, 1, 1, 4, 1, 1, 16, 16, 128, 16, 16, 128, -1)],
|
||||
} # fmt: skip
|
||||
elif dtype in cls._DT_FP8BF16:
|
||||
if (128, 128) in result.keys():
|
||||
result[(128, 128)].append(
|
||||
FmhaFwdTileSize(256, 64, 128, 128, 64, 128, 8, 1, 1, 8, 1, 1, 32, 32, 32, 32, 32, 32, -1)) # fmt: skip
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
@@ -1105,12 +1102,19 @@ class KernelComponentFactoryGfx950(
|
||||
pipelines.append(FmhaFwdPipeline("qr_async_trload", "row", "f", "f", "f", "f", logits, bias, lse, dropout, qscale, mask, skip, "t", sink)) # fmt: skip
|
||||
pipelines.append(FmhaFwdPipeline("qr_async_trload", "row", "f", "f", "t", "t", logits, bias, lse, dropout, qscale, mask, skip, "t", sink)) # fmt: skip
|
||||
|
||||
# qr_async_trload_v3 only supports hdim=hdim_v=128 for now
|
||||
if (hdim, hdim_v) == (128, 128):
|
||||
# qr_async_trload_v3 only supports (generic) causal mask
|
||||
for logits, mask in itertools.product(["t", "f"], ["no", "causal"]):
|
||||
pipelines.append(FmhaFwdPipeline("qr_async_trload_v3", "row", "t", "t", "f", "f",
|
||||
F_logits=logits, F_bias="no", F_lse="f", F_dropout="f", F_qscale=qscale, F_mask=mask, F_skip="f", F_trload="t", F_sink="f")) # fmt: skip
|
||||
# # qr_async_trload_v3 bf16/fp16 not ready
|
||||
# if (hdim, hdim_v) == (128, 128):
|
||||
# for logits, mask in itertools.product(["t", "f"], ["no", "causal"]):
|
||||
# pipelines.append(FmhaFwdPipeline("qr_async_trload_v3", "row", "t", "t", "f", "f",
|
||||
# F_logits=logits, F_bias="no", F_lse="f", F_dropout="f", F_qscale=qscale, F_mask=mask, F_skip="f", F_trload="t", F_sink="f")) # fmt: skip
|
||||
elif dtype in cls._DT_FP8BF16:
|
||||
# qr_async_trload_v3 only supports (generic) causal mask
|
||||
for logits, qscale, mask in itertools.product(
|
||||
["t", "f"],
|
||||
["no", "pertensor"],
|
||||
["no", "causal"],
|
||||
):
|
||||
pipelines.append(FmhaFwdPipeline("qr_async_trload_v3", "row", "t", "t", "f", "f", F_logits=logits, F_bias="no", F_lse="f", F_dropout="f", F_qscale=qscale, F_mask=mask, F_skip="f", F_trload="t", F_sink="f")) # fmt: skip
|
||||
|
||||
elif dtype in cls._DT_MXFP8 or dtype in cls._DT_MXFP4:
|
||||
# no need dropout kernels
|
||||
@@ -1494,8 +1498,8 @@ def write_fwd_api(
|
||||
FMHA_FWD_API_FOOTER_TEMPLATE.format(
|
||||
F_is_v3_enabled=BOOL_MAP[
|
||||
# NOTE: enable v3 pipelines when ready
|
||||
# 0 < api_pool.get_num_traits(filter_fn=accept_only_v3)
|
||||
False
|
||||
0 < api_pool.get_num_traits(filter_fn=accept_only_v3)
|
||||
# False
|
||||
]
|
||||
),
|
||||
]
|
||||
|
||||
Reference in New Issue
Block a user