mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-12 17:26:00 +00:00
[CK] Add FP8 per-tensor quantization support for FMHA V3 pipeline (#6051) ## Motivation The existing FMHA V3 pipeline only supports fp16/bf16 data types. This PR extends V3 to handle FP8 inputs with per-tensor descaling on gfx950, enabling higher throughput for FP8 inference workloads using the assembly-optimized V3 code path. ## Technical Details **Warp GEMM:** - Add FP8 32x32x32 warp gemm with C-transposed distribution (`WarpGemmMfma_f32_32x32x32_fp8_fp8_CTransposed`) and dispatcher entries **V3 Kernel (`fmha_fwd_v3_kernel.hpp`):** - Add per-tensor descale support for Q, K, V tensors, passing descale pointers through to pipeline kargs **V3 Pipeline (`block_fmha_fwd_v3_pipeline.hpp`):** - Add FP8 data path with dtype-aware type selection - Add asm volatile P matrix conversion from f32 to fp8 - Add FP8-aware instruction scheduling in `CoreLoopScheduler` **V3 Pipeline Policy (`block_fmha_fwd_v3_pipeline_default_policy.hpp`):** - Add FP8 QK warp gemm selection (SwizzleB variant for V tile distribution compatibility) **Codegen (`fmha_fwd.py`):** - Add gfx950 FP8BF16 V3 tile size (256x64x128x128x64x128) - Add FP8BF16 V3 pipeline variants (mask: no/causal, qscale: no/pertensor) - Extend `can_dispatch_v3` condition for fp8bf16 + pertensor **Misc:** - Add LLVM scheduler `TRANS` mask to `LLVMSchedGroupMask` enum (`arch.hpp`) - Fix `mask_info` default initialization for `no_mask` case (`mask.hpp`) V3 dispatch for FP8 is disabled by default (`F_is_v3_enabled=false`) pending further validation. ## Performance: fmha_fwd V3 FP8 (avg runs 2-6, stock ROCm 7.1.1, gfx950) | Problem | Regular (TFlops) | Varlen (TFlops) | |---|---:|---:| | batch=1 heads=6/1 seqlen=1024 causal | 48.9 | 47.6 | | batch=1 heads=6/1 seqlen=2048 causal | 119.8 | 117.4 | | batch=1 heads=6/1 seqlen=4096 causal | 263.7 | 259.2 | | batch=1 heads=6/1 seqlen=8192 causal | 548.9 | 543.6 | | batch=1 heads=6/1 seqlen=16384 causal | 1043.0 | 1063.7 | | batch=1 heads=6/1 seqlen=32768 causal | 1237.2 | 1279.6 | | batch=1 heads=6/1 seqlen=65536 causal | 1315.4 | 1382.7 | | batch=1 heads=6/1 seqlen=131072 causal | 1326.3 | 1402.2 | | batch=1 heads=16/1 seqlen=65536 causal | 1298.7 | 1388.4 | | batch=1 heads=40/40 seqlen=37200 non-causal | 1248.9 | 1326.1 | ## Test Plan Tested with aiter's `test_mha_fp8.py` test suite (176 cases) covering batch sizes (1-2), sequence lengths (113-4096), head counts (5/8/32/40), GQA ratios (1:1, 1:8), and causal/non-causal modes. Verified all cases dispatch to the V3 pipeline by enabling `F_is_v3_enabled` and confirming kernel names contain `qr_async_trload_v3`. ## Test Result 176/176 tests passed with V3 enabled. All cases correctly dispatched to V3 pipeline with `pertensor` quantization. ## Submission Checklist - [x] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
9.5 KiB
9.5 KiB
Changelog for Composable Kernel
Documentation for Composable Kernel available at https://rocm.docs.amd.com/projects/composable_kernel/en/latest/.
(Unreleased) Composable Kernel 1.3.0
Added
- Added overload of load_tile_transpose that takes reference to output tensor as output parameter
- Use data type from LDS tensor view when determining tile distribution for transpose in the GEMM pipeline
- Added eightwarps support for abquant mode in blockscale GEMM.
- Added preshuffleB support for abquant mode in blockscale GEMM.
- Added support for explicit GEMM in CK_TILE grouped convolution forward and backward weight.
- Added TF32 convolution support on gfx942 and gfx950 in CK. It could be enabled/disabled via
DTYPESof "tf32". - Added streamingllm sink support for FMHA FWD, include qr_ks_vs, qr_async and splitkv pipelines.
- Added support for microscaling (MX) FP8/FP4 mixed data types to Flatmm pipeline.
- Added support for fp8 dynamic tensor-wise quantization of fp8 fmha fwd kernel.
- Added FP8 KV cache support for FMHA batch prefill.
- Added support for gfx1153 target.
- Added FMHA batch prefill kernel support for several KV cache layouts, flexible page sizes, and different lookup table configurations.
- Added gpt-oss sink support for FMHA FWD, include qr_ks_vs, qr_async, qr_async_trload and splitkv pipelines.
- Added persistent async input scheduler for CK Tile universal GEMM kernels to support asynchronous input streaming.
- Added FP8 block scale quantization for FMHA forward kernel.
- Added gfx11 support for FMHA.
- Added microscaling (MX) FP8/FP4 support on gfx950 for FMHA forward kernel ("qr" pipeline only).
- Added FP8 per-tensor quantization support for FMHA forward V3 pipeline on gfx950.
Changed
Upcoming changes
Composable Kernel 1.2.0 for ROCm 7.2.0
Added
- Added tests for f8 x bf8 on CompV3, and f8 x bf8 with K_BlockSize 32 on CompV4
- Added CK-Tile dispatcher - a unified kernel dispatch, code generation and architecture-based kernel filtering system with with C++ and Python frontends starting with GEMM support.
- Added support for bf16 data type to grouped_gemm and grouped_gemm_preshuffle.
- Added Col-Col-Row-Col layout support for aquant mode in blockscale GEMM.
- Added support for mixed precision fp8 x bf8 universal GEMM and weight preshuffle GEMM.
- Added a compute async pipeline in the CK Tile universal GEMM on gfx950.
- Added support for B Tensor type
pk_int4_tin the CK Tile weight preshuffle GEMM. - Added the new api to load different memory sizes to SGPR.
- Added support for B Tensor preshuffle in CK Tile grouped GEMM.
- Added a basic copy kernel example and supporting documentation for new CK Tile developers.
- Added support for grouped GEMM kernels to perform Multi D elementwise operation.
- Added support for multiple ABD GEMM.
- Added benchmarking support for tile engine GEMM Multi D.
- Added block scaling support in CK Tile GEMM, allowing flexible use of quantization matrices from either A or B operands.
- Added the row-wise column-wise quantization for CK Tile GEMM and CK Tile grouped GEMM.
- Added support for f32 to FMHA (forward and backward).
- Added tensor-wise quantization for CK Tile GEMM.
- Added support for batched contraction kernel.
- Added WMMA (gfx12) support for FMHA.
- Added pooling kernel in CK_TILE
- Added top-k sigmoid kernel in CK_TILE
- Added the blockscale 2D support for CK_TILE GEMM.
- Added Flatmm pipeline for microscaling (MX) FP8/FP4 data types
- Added reduce and multi reduction kernels
Changed
- Removed
BlockSizeinmake_kernelandCShuffleEpilogueProblemto support Wave32 in CK Tile (#2594) - Added an optional template parameter
Arch(gfx9_t,gfx12_tetc.) tomake_kernelto support linking multiple object files that have the same kernel compiled for different architectures. - FMHA examples and tests can be built for multiple architectures (gfx9, gfx950, gfx12) at the same time.
Upcoming changes
- Composable Kernel will be adopting C++20 features in an upcoming ROCm release, updating the minimum compiler requirement to C++20. Ensure that your development environment complies with this requirement to facilitate a seamless transition.
Composable Kernel 1.1.0 for ROCm 7.1.1
Upcoming changes
- Composable Kernel will be adopting C++20 features in an upcoming ROCm release, updating the minimum compiler requirement to C++20. Ensure that your development environment complies with this requirement to facilitate a seamless transition.
Composable Kernel 1.1.0 for ROCm 7.1.0
Added
- Added support for hdim as a multiple of 32 for FMHA (fwd/fwd_splitkv/bwd)
- Added support for elementwise kernel.
Upcoming changes
- Non-grouped convolutions are deprecated. Their functionality is supported by grouped convolution.
Composable Kernel 1.1.0 for ROCm 7.0.0
Added
- Added support for bf16, f32, and f16 for 2D and 3D NGCHW grouped convolution backward data
- Added a fully asynchronous HOST (CPU) arguments copy flow for CK grouped GEMM kernels.
- Added support GKCYX layout for grouped convolution forward (NGCHW/GKCYX/NGKHW, number of instances in instance factory for NGCHW/GKYXC/NGKHW has been reduced).
- Added support for GKCYX layout for grouped convolution forward (NGCHW/GKCYX/NGKHW).
- Added support for GKCYX layout for grouped convolution backward weight (NGCHW/GKCYX/NGKHW).
- Added support for GKCYX layout for grouped convolution backward data (NGCHW/GKCYX/NGKHW).
- Added support for Stream-K version of mixed fp8/bf16 GEMM
- Added support for Multiple D GEMM
- Added GEMM pipeline for microscaling (MX) FP8/FP6/FP4 data types
- Added support for FP16 2:4 structured sparsity to universal GEMM.
- Added support for Split K for grouped convolution backward data.
- Added logit soft-capping support for fMHA forward kernels.
- Added support for hdim as a multiple of 32 for FMHA (fwd/fwd_splitkv)
- Added benchmarking support for tile engine GEMM.
- Added Ping-pong scheduler support for GEMM operation along the K dimension.
- Added rotating buffer feature for CK_Tile GEMM.
- Added int8 support for CK_TILE GEMM.
- Added CK Tile Epilogue Chainer framework for composable epilogue sequences in GEMM operations
Optimized
- Optimize the gemm multiply multiply preshuffle & lds bypass with Pack of KGroup and better instruction layout.
- Added Vectorize Transpose optimization for CK Tile
- Added the asynchronous copy for gfx950
Changed
- Removed support for gfx940 and gfx941 targets (#1944)
- Replaced the raw buffer load/store intrinsics with Clang20 built-ins (#1876)
- DL and DPP kernels are now enabled by default.
- Number of instances in instance factory for grouped convolution forward NGCHW/GKYXC/NGKHW has been reduced.
- Number of instances in instance factory for grouped convolution backward weight NGCHW/GKYXC/NGKHW has been reduced.
- Number of instances in instance factory for grouped convolution backward data NGCHW/GKYXC/NGKHW has been reduced.
Composable Kernel 1.1.0 for ROCm 6.1.0
Additions
- Added generic instances for GEMM XDL operations (#1161)
- Added gamma and beta parameters for the layernorm and groupnorm bwd operations (#1133)
- Introduced wrapper sublibrary (limited functionality). (#1071, #1098, #1108, #1126)
- Added an option to vary the number of warm-up cycles and iterations for ckProfiler (#1124)
Optimizations
- New performance optimizations for GEMM operations on MI200 and MI300 architectures (#1135)
Fixes
- Reduced the build time for most GPU architectures (#1084)
- Fixed some conversion issues for fp8 data type (#1099)
Changes
None
Known issues
None
Composable Kernel 1.1.0 for ROCm 6.0.0
Fixes
- Fixed a hazard associated with inline v_dot (#808)
- Fixed two bugs in grouped convolution backward data without K padding (#848 #876)
Optimizations
None
Additions
- Added an image to a column kernel (#867)
- Added a column to an image kernel (#930)
- Support for 3D grouped convolution on RDNA 3 GPUs (#935, #950, #985)
- Grouped convolution support for small K and C (#822 #879 #897)
- Support for NHWGC (2D and 3D) grouped convolution backward weight (#769 #804)
- Support for bf16/f32/f16 and NHWGC (2D and 3D) grouped convolution backward data (#757 #799)
- Support for Batched GEMM DL (#732)
Changes
- Changed the grouped convolution API to maintain consistency with other convolution kernels (#817)
Composable Kernel 0.2.0 for ROCm 5.7.0
Fixes
- Fixed a bug in 6-dimensional kernels (#555)
- Fixed a test case failure with grouped convolution backward weight (#524)
Optimizations
- Improved the performance of the normalization kernel
Additions
- New CMake flags:
- "DL_KERNELS"-* Must be set to "ON" in order to build the GEMM DL and batched_gemm_multi_d_dl instances
- "DTYPES" -- Can be set to any subset of "fp64;fp32;fp16;fp8;bf16;int8" to build an instance of the specified data types
- "INSTANCES_ONLY" -- Only builds CK library and instances without tests, examples, or profiler
- New feature: if GPU_TARGETS is not set in the CMake command line, CK will be built for all targets supported by the compiler
- Support for MI300A/MI300X
- Support for AMD RDNA 3
- New user tutorial (#563)
- Additional instances for irregular GEMM sizes (#560)
- New inter-wave consumer-producer programming model for GEMM kernels (#310)
- GEMM with support multiple elementwise fusions (multi-D) (#534)
- Multi-embeddings support (#542)
- AMD RDNA 3 blockwise GEMM and real GEMM support (#541)
- AMD RDNA grouped convolution backward weight support (#505)
- MaxPool and AvgPool forward (#815); MaxPool backward (#750)
Changes
None