Files
composable_kernel/CHANGELOG.md
Po Yen Chen 341fb33386 [CK] Add FP8 per-tensor quantization support for FMHA V3 pipeline (#6051)
## Motivation

The existing FMHA V3 pipeline only supports fp16/bf16 data types. This
PR extends V3 to handle FP8 inputs with per-tensor descaling on gfx950,
enabling higher throughput for
  FP8 inference workloads using the assembly-optimized V3 code path.

  ## Technical Details

  **Warp GEMM:**
- Add FP8 32x32x32 warp gemm with C-transposed distribution
(`WarpGemmMfma_f32_32x32x32_fp8_fp8_CTransposed`) and dispatcher entries

  **V3 Kernel (`fmha_fwd_v3_kernel.hpp`):**
- Add per-tensor descale support for Q, K, V tensors, passing descale
pointers through to pipeline kargs

  **V3 Pipeline (`block_fmha_fwd_v3_pipeline.hpp`):**
  - Add FP8 data path with dtype-aware type selection
  - Add asm volatile P matrix conversion from f32 to fp8
  - Add FP8-aware instruction scheduling in `CoreLoopScheduler`

**V3 Pipeline Policy
(`block_fmha_fwd_v3_pipeline_default_policy.hpp`):**
- Add FP8 QK warp gemm selection (SwizzleB variant for V tile
distribution compatibility)

  **Codegen (`fmha_fwd.py`):**
  - Add gfx950 FP8BF16 V3 tile size (256x64x128x128x64x128)
- Add FP8BF16 V3 pipeline variants (mask: no/causal, qscale:
no/pertensor)
  - Extend `can_dispatch_v3` condition for fp8bf16 + pertensor

  **Misc:**
- Add LLVM scheduler `TRANS` mask to `LLVMSchedGroupMask` enum
(`arch.hpp`)
- Fix `mask_info` default initialization for `no_mask` case (`mask.hpp`)

V3 dispatch for FP8 is disabled by default (`F_is_v3_enabled=false`)
pending further validation.

## Performance: fmha_fwd V3 FP8 (avg runs 2-6, stock ROCm 7.1.1, gfx950)

  | Problem | Regular (TFlops) | Varlen (TFlops) |
  |---|---:|---:|
  | batch=1 heads=6/1 seqlen=1024 causal | 48.9 | 47.6 |
  | batch=1 heads=6/1 seqlen=2048 causal | 119.8 | 117.4 |
  | batch=1 heads=6/1 seqlen=4096 causal | 263.7 | 259.2 |
  | batch=1 heads=6/1 seqlen=8192 causal | 548.9 | 543.6 |
  | batch=1 heads=6/1 seqlen=16384 causal | 1043.0 | 1063.7 |
  | batch=1 heads=6/1 seqlen=32768 causal | 1237.2 | 1279.6 |
  | batch=1 heads=6/1 seqlen=65536 causal | 1315.4 | 1382.7 |
  | batch=1 heads=6/1 seqlen=131072 causal | 1326.3 | 1402.2 |
  | batch=1 heads=16/1 seqlen=65536 causal | 1298.7 | 1388.4 |
  | batch=1 heads=40/40 seqlen=37200 non-causal | 1248.9 | 1326.1 |

## Test Plan

Tested with aiter's `test_mha_fp8.py` test suite (176 cases) covering
batch sizes (1-2), sequence lengths (113-4096), head counts (5/8/32/40),
GQA ratios (1:1, 1:8), and
causal/non-causal modes. Verified all cases dispatch to the V3 pipeline
by enabling `F_is_v3_enabled` and confirming kernel names contain
`qr_async_trload_v3`.

  ## Test Result

176/176 tests passed with V3 enabled. All cases correctly dispatched to
V3 pipeline with `pertensor` quantization.

  ## Submission Checklist

- [x] Look over the contributing guidelines at
https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
2026-04-07 22:19:28 +08:00

207 lines
9.5 KiB
Markdown

# Changelog for Composable Kernel
Documentation for Composable Kernel available at [https://rocm.docs.amd.com/projects/composable_kernel/en/latest/](https://rocm.docs.amd.com/projects/composable_kernel/en/latest/).
## (Unreleased) Composable Kernel 1.3.0
### Added
* Added overload of load_tile_transpose that takes reference to output tensor as output parameter
* Use data type from LDS tensor view when determining tile distribution for transpose in the GEMM pipeline
* Added eightwarps support for abquant mode in blockscale GEMM.
* Added preshuffleB support for abquant mode in blockscale GEMM.
* Added support for explicit GEMM in CK_TILE grouped convolution forward and backward weight.
* Added TF32 convolution support on gfx942 and gfx950 in CK. It could be enabled/disabled via `DTYPES` of "tf32".
* Added streamingllm sink support for FMHA FWD, include qr_ks_vs, qr_async and splitkv pipelines.
* Added support for microscaling (MX) FP8/FP4 mixed data types to Flatmm pipeline.
* Added support for fp8 dynamic tensor-wise quantization of fp8 fmha fwd kernel.
* Added FP8 KV cache support for FMHA batch prefill.
* Added support for gfx1153 target.
* Added FMHA batch prefill kernel support for several KV cache layouts, flexible page sizes, and different lookup table configurations.
* Added gpt-oss sink support for FMHA FWD, include qr_ks_vs, qr_async, qr_async_trload and splitkv pipelines.
* Added persistent async input scheduler for CK Tile universal GEMM kernels to support asynchronous input streaming.
* Added FP8 block scale quantization for FMHA forward kernel.
* Added gfx11 support for FMHA.
* Added microscaling (MX) FP8/FP4 support on gfx950 for FMHA forward kernel ("qr" pipeline only).
* Added FP8 per-tensor quantization support for FMHA forward V3 pipeline on gfx950.
### Changed
### Upcoming changes
## Composable Kernel 1.2.0 for ROCm 7.2.0
### Added
* Added tests for f8 x bf8 on CompV3, and f8 x bf8 with K_BlockSize 32 on CompV4
* Added CK-Tile dispatcher - a unified kernel dispatch, code generation and architecture-based kernel filtering system with with C++ and Python frontends starting with GEMM support.
* Added support for bf16 data type to grouped_gemm and grouped_gemm_preshuffle.
* Added Col-Col-Row-Col layout support for aquant mode in blockscale GEMM.
* Added support for mixed precision fp8 x bf8 universal GEMM and weight preshuffle GEMM.
* Added a compute async pipeline in the CK Tile universal GEMM on gfx950.
* Added support for B Tensor type `pk_int4_t` in the CK Tile weight preshuffle GEMM.
* Added the new api to load different memory sizes to SGPR.
* Added support for B Tensor preshuffle in CK Tile grouped GEMM.
* Added a basic copy kernel example and supporting documentation for new CK Tile developers.
* Added support for grouped GEMM kernels to perform Multi D elementwise operation.
* Added support for multiple ABD GEMM.
* Added benchmarking support for tile engine GEMM Multi D.
* Added block scaling support in CK Tile GEMM, allowing flexible use of quantization matrices from either A or B operands.
* Added the row-wise column-wise quantization for CK Tile GEMM and CK Tile grouped GEMM.
* Added support for f32 to FMHA (forward and backward).
* Added tensor-wise quantization for CK Tile GEMM.
* Added support for batched contraction kernel.
* Added WMMA (gfx12) support for FMHA.
* Added pooling kernel in CK_TILE
* Added top-k sigmoid kernel in CK_TILE
* Added the blockscale 2D support for CK_TILE GEMM.
* Added Flatmm pipeline for microscaling (MX) FP8/FP4 data types
* Added reduce and multi reduction kernels
### Changed
* Removed `BlockSize` in `make_kernel` and `CShuffleEpilogueProblem` to support Wave32 in CK Tile (#2594)
* Added an optional template parameter `Arch` (`gfx9_t`, `gfx12_t` etc.) to `make_kernel` to support linking multiple object files that have the same kernel compiled for different architectures.
* FMHA examples and tests can be built for multiple architectures (gfx9, gfx950, gfx12) at the same time.
### Upcoming changes
* Composable Kernel will be adopting C++20 features in an upcoming ROCm release, updating the minimum compiler requirement to C++20. Ensure that your development environment complies with this requirement to facilitate a seamless transition.
## Composable Kernel 1.1.0 for ROCm 7.1.1
### Upcoming changes
* Composable Kernel will be adopting C++20 features in an upcoming ROCm release, updating the minimum compiler requirement to C++20. Ensure that your development environment complies with this requirement to facilitate a seamless transition.
## Composable Kernel 1.1.0 for ROCm 7.1.0
### Added
* Added support for hdim as a multiple of 32 for FMHA (fwd/fwd_splitkv/bwd)
* Added support for elementwise kernel.
### Upcoming changes
* Non-grouped convolutions are deprecated. Their functionality is supported by grouped convolution.
## Composable Kernel 1.1.0 for ROCm 7.0.0
### Added
* Added support for bf16, f32, and f16 for 2D and 3D NGCHW grouped convolution backward data
* Added a fully asynchronous HOST (CPU) arguments copy flow for CK grouped GEMM kernels.
* Added support GKCYX layout for grouped convolution forward (NGCHW/GKCYX/NGKHW, number of instances in instance factory for NGCHW/GKYXC/NGKHW has been reduced).
* Added support for GKCYX layout for grouped convolution forward (NGCHW/GKCYX/NGKHW).
* Added support for GKCYX layout for grouped convolution backward weight (NGCHW/GKCYX/NGKHW).
* Added support for GKCYX layout for grouped convolution backward data (NGCHW/GKCYX/NGKHW).
* Added support for Stream-K version of mixed fp8/bf16 GEMM
* Added support for Multiple D GEMM
* Added GEMM pipeline for microscaling (MX) FP8/FP6/FP4 data types
* Added support for FP16 2:4 structured sparsity to universal GEMM.
* Added support for Split K for grouped convolution backward data.
* Added logit soft-capping support for fMHA forward kernels.
* Added support for hdim as a multiple of 32 for FMHA (fwd/fwd_splitkv)
* Added benchmarking support for tile engine GEMM.
* Added Ping-pong scheduler support for GEMM operation along the K dimension.
* Added rotating buffer feature for CK_Tile GEMM.
* Added int8 support for CK_TILE GEMM.
* Added CK Tile Epilogue Chainer framework for composable epilogue sequences in GEMM operations
### Optimized
* Optimize the gemm multiply multiply preshuffle & lds bypass with Pack of KGroup and better instruction layout.
* Added Vectorize Transpose optimization for CK Tile
* Added the asynchronous copy for gfx950
### Changed
* Removed support for gfx940 and gfx941 targets (#1944)
* Replaced the raw buffer load/store intrinsics with Clang20 built-ins (#1876)
* DL and DPP kernels are now enabled by default.
* Number of instances in instance factory for grouped convolution forward NGCHW/GKYXC/NGKHW has been reduced.
* Number of instances in instance factory for grouped convolution backward weight NGCHW/GKYXC/NGKHW has been reduced.
* Number of instances in instance factory for grouped convolution backward data NGCHW/GKYXC/NGKHW has been reduced.
## Composable Kernel 1.1.0 for ROCm 6.1.0
### Additions
* Added generic instances for GEMM XDL operations (#1161)
* Added gamma and beta parameters for the layernorm and groupnorm bwd operations (#1133)
* Introduced wrapper sublibrary (limited functionality). (#1071, #1098, #1108, #1126)
* Added an option to vary the number of warm-up cycles and iterations for ckProfiler (#1124)
### Optimizations
* New performance optimizations for GEMM operations on MI200 and MI300 architectures (#1135)
### Fixes
* Reduced the build time for most GPU architectures (#1084)
* Fixed some conversion issues for fp8 data type (#1099)
### Changes
None
### Known issues
None
## Composable Kernel 1.1.0 for ROCm 6.0.0
### Fixes
* Fixed a hazard associated with inline v_dot (#808)
* Fixed two bugs in grouped convolution backward data without K padding (#848 #876)
### Optimizations
None
### Additions
* Added an image to a column kernel (#867)
* Added a column to an image kernel (#930)
* Support for 3D grouped convolution on RDNA 3 GPUs (#935, #950, #985)
* Grouped convolution support for small K and C (#822 #879 #897)
* Support for NHWGC (2D and 3D) grouped convolution backward weight (#769 #804)
* Support for bf16/f32/f16 and NHWGC (2D and 3D) grouped convolution backward data (#757 #799)
* Support for Batched GEMM DL (#732)
### Changes
* Changed the grouped convolution API to maintain consistency with other convolution kernels (#817)
## Composable Kernel 0.2.0 for ROCm 5.7.0
### Fixes
* Fixed a bug in 6-dimensional kernels (#555)
* Fixed a test case failure with grouped convolution backward weight (#524)
### Optimizations
* Improved the performance of the normalization kernel
### Additions
* New CMake flags:
* "DL_KERNELS"-* Must be set to "ON" in order to build the GEMM DL and batched_gemm_multi_d_dl instances
* "DTYPES" -- Can be set to any subset of "fp64;fp32;fp16;fp8;bf16;int8" to build an instance of the specified data types
* "INSTANCES_ONLY" -- Only builds CK library and instances without tests, examples, or profiler
* New feature: if GPU_TARGETS is not set in the CMake command line, CK will be built for all targets supported by the compiler
* Support for MI300A/MI300X
* Support for AMD RDNA 3
* New user tutorial (#563)
* Additional instances for irregular GEMM sizes (#560)
* New inter-wave consumer-producer programming model for GEMM kernels (#310)
* GEMM with support multiple elementwise fusions (multi-D) (#534)
* Multi-embeddings support (#542)
* AMD RDNA 3 blockwise GEMM and real GEMM support (#541)
* AMD RDNA grouped convolution backward weight support (#505)
* MaxPool and AvgPool forward (#815); MaxPool backward (#750)
### Changes
None