mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-10 16:28:38 +00:00
[CK_TILE] Use gfx11 float buffer atomics in FMHA Bwd ## Motivation FlashAttention CK backward on gfx11 can hit out-of-bounds/tail writes in the dQ accumulator atomic-add path when sequence rows are padded at the tile level but not marked invalid in the DQDKDV main tensor view. With the generic global atomic fallback, an incorrectly-valid tail element can issue an actual pointer-based `atomicAdd`. With the buffer atomic path, the write is issued through a buffer resource with bounds information and follows the same backend already used by gfx9/gfx12. This fixes the gfx11 FMHA BWD failure without changing the gfx11 default for unrelated CK Tile kernels. ## Technical Details This PR enables the existing CK Tile AMD buffer float atomic-add path only for generated FMHA BWD gfx11 translation units. gfx11 normally uses the generic global atomic fallback for floating-point `buffer_view::atomic_add`. That fallback performs the atomic through a raw computed pointer and depends on the software validity predicate to avoid invalid elements. In FMHA BWD dQ accumulation, padded tail rows can reach this path, so using the buffer atomic backend is safer: it uses a buffer resource with base pointer, bounds information, and an element offset, matching the backend already used by gfx9/gfx12. Enabling `CK_TILE_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT` globally for gfx11 is too broad and can break unrelated gfx11 CK builds such as GEMM. Instead, `config.hpp` now preserves an explicitly pre-defined `CK_TILE_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT`, while keeping the existing default disabled for gfx11. ## Test Plan Validated the change with the FlashAttention CK full test suite with backward pass enabled on gfx11. pytest -q -s tests/test_flash_attn_ck.py ## Test Result FlashAttention CK gfx11 test result: 260680 passed, 152076 skipped ## Submission Checklist - [ ] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests. Co-authored-by: Po Yen Chen <PoYen.Chen@amd.com>
CK Tile Example Suite
This directory contains a comprehensive suite of examples demonstrating the CK Tile programming model for high-performance GPU kernels. Each example illustrates a key deep learning or HPC operation, implemented using tile-based parallelism, modular pipelines, and data movement policy.
What is CK Tile?
CK Tile is a composable GPU programming API that expresses kernels as a composition of "tiles"—rectangular blocks of computation and data movement. The pipeline & policy orchestrates data movement (global <-> LDS <-> registers), computation, and synchronization, enabling high efficiency and flexibility.
Example Index
| Example | Operation | Description |
|---|---|---|
| 01_fmha | Fused Multi-Head Attention | Tile-based FMHA with masking, quantization, and epilogue fusion |
| 02_layernorm2d | LayerNorm2D | Blockwise layer normalization with fusion and quantization |
| 03_gemm | GEMM | Matrix multiplication with tilewise parallelism |
| 04_img2col | im2col | Image-to-column transformation for GEMM-based convolution |
| 05_reduce | Reduction | Tilewise sum, max, mean reductions |
| 06_permute | Permute | Generic tensor permutation (up to rank-8) |
| 09_topk_softmax | TopK-Softmax | Rowwise softmax and top-k selection for MoE gating |
| 10_rmsnorm2d | RMSNorm2D | Root mean square normalization for LLMs |
| 11_add_rmsnorm2d_rdquant | Add + RMSNorm2D + RDQuant | Fused add, RMSNorm, and rowwise dynamic quantization |
| 12_smoothquant | SmoothQuant | Per-channel scaling and quantization for int8 inference |
| 13_moe_sorting | MoE Sorting | Token-to-expert rearrangement for MoE dispatch |
| 14_moe_smoothquant | MoE-SmoothQuant | Expert-dependent quantization fused with top-k selection |
| 15_fused_moe | Fused MoE | End-to-end fused MoE block: sorting, group-GEMM, activation, weighting |
| 16_batched_gemm | Batched GEMM | Parallel computation of multiple GEMMs |
| 17_grouped_gemm | Grouped GEMM | Multiple independent GEMMs with different shapes |
| 18_flatmm | FLATMM | Flattened matrix multiplication for packed layouts |
| 19_gemm_multi_d | Multi-D GEMM | GEMM with multiple side inputs (bias, residual, etc.) |
| 35_batched_transpose | Batched Transpose | NCHW <-> NHWC and other layout conversions |
| 36_copy | Copy | Minimal example for tile-based memory movement |
| 37_transpose | Block Transpose | High-performance tiled transpose for large tensors |
Technical Highlights
- Tile Distribution: See
include/ck_tile/tile_program/tile_distribution/for mapping tiles to thread blocks. - Block Tile Pipelines: See
include/ck_tile/tile_program/block_tile_pipeline/for memory/computation pipelines. - Policies and Utilities: Many examples use custom policies for tile/block size and memory access.
How to Build & Run
mkdir build && cd build
sh ../script/cmake-ck-dev.sh ../ <arch>
make -j
Each example produces its own executable in build/bin/.
Learning and Extending
- Start Simple: Try 03_gemm or 36_copy to learn tile basics.
- Explore Fusion: See 11_add_rmsnorm2d_rdquant, 15_fused_moe, or 14_moe_smoothquant for advanced fusion.
- Experiment: Modify tile sizes, layouts, or pipelines to explore performance and flexibility.