mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-03 13:11:25 +00:00
[CK] Add FP8 per-tensor quantization support for FMHA V3 pipeline (#6051) ## Motivation The existing FMHA V3 pipeline only supports fp16/bf16 data types. This PR extends V3 to handle FP8 inputs with per-tensor descaling on gfx950, enabling higher throughput for FP8 inference workloads using the assembly-optimized V3 code path. ## Technical Details **Warp GEMM:** - Add FP8 32x32x32 warp gemm with C-transposed distribution (`WarpGemmMfma_f32_32x32x32_fp8_fp8_CTransposed`) and dispatcher entries **V3 Kernel (`fmha_fwd_v3_kernel.hpp`):** - Add per-tensor descale support for Q, K, V tensors, passing descale pointers through to pipeline kargs **V3 Pipeline (`block_fmha_fwd_v3_pipeline.hpp`):** - Add FP8 data path with dtype-aware type selection - Add asm volatile P matrix conversion from f32 to fp8 - Add FP8-aware instruction scheduling in `CoreLoopScheduler` **V3 Pipeline Policy (`block_fmha_fwd_v3_pipeline_default_policy.hpp`):** - Add FP8 QK warp gemm selection (SwizzleB variant for V tile distribution compatibility) **Codegen (`fmha_fwd.py`):** - Add gfx950 FP8BF16 V3 tile size (256x64x128x128x64x128) - Add FP8BF16 V3 pipeline variants (mask: no/causal, qscale: no/pertensor) - Extend `can_dispatch_v3` condition for fp8bf16 + pertensor **Misc:** - Add LLVM scheduler `TRANS` mask to `LLVMSchedGroupMask` enum (`arch.hpp`) - Fix `mask_info` default initialization for `no_mask` case (`mask.hpp`) V3 dispatch for FP8 is disabled by default (`F_is_v3_enabled=false`) pending further validation. ## Performance: fmha_fwd V3 FP8 (avg runs 2-6, stock ROCm 7.1.1, gfx950) | Problem | Regular (TFlops) | Varlen (TFlops) | |---|---:|---:| | batch=1 heads=6/1 seqlen=1024 causal | 48.9 | 47.6 | | batch=1 heads=6/1 seqlen=2048 causal | 119.8 | 117.4 | | batch=1 heads=6/1 seqlen=4096 causal | 263.7 | 259.2 | | batch=1 heads=6/1 seqlen=8192 causal | 548.9 | 543.6 | | batch=1 heads=6/1 seqlen=16384 causal | 1043.0 | 1063.7 | | batch=1 heads=6/1 seqlen=32768 causal | 1237.2 | 1279.6 | | batch=1 heads=6/1 seqlen=65536 causal | 1315.4 | 1382.7 | | batch=1 heads=6/1 seqlen=131072 causal | 1326.3 | 1402.2 | | batch=1 heads=16/1 seqlen=65536 causal | 1298.7 | 1388.4 | | batch=1 heads=40/40 seqlen=37200 non-causal | 1248.9 | 1326.1 | ## Test Plan Tested with aiter's `test_mha_fp8.py` test suite (176 cases) covering batch sizes (1-2), sequence lengths (113-4096), head counts (5/8/32/40), GQA ratios (1:1, 1:8), and causal/non-causal modes. Verified all cases dispatch to the V3 pipeline by enabling `F_is_v3_enabled` and confirming kernel names contain `qr_async_trload_v3`. ## Test Result 176/176 tests passed with V3 enabled. All cases correctly dispatched to V3 pipeline with `pertensor` quantization. ## Submission Checklist - [x] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.