Files
composable_kernel/example/ck_tile
juuso-oskari 9d7cc3ee9e CK-UA: extend FP8 to the 16x16x32 _m16 decode tier via LDS roundtrip
The 32x32x16 tiers (prefill_d{64,128}, decode_d{64,128}_m{32,64,128}) keep
the cheap in-register `ds_bpermute_b32` cross-lane swap that fixes the
QK-C / PV-A per-thread alias for the union'd `sp_compute` / `p`.

The 16x16x32 m16 tiers (decode_d{64,128}_m16) cannot use the swap -- the
MFMA puts the paired-lane bit at a different position and the
sub=0/sub=1 4-fp8 chunks no longer map onto each other. We add a
layout-agnostic LDS roundtrip as the `else` branch, gated by the same
`PVWarpTile` constexpr:

  - Hoist two distribution-bound windows over the existing `p_lds`
    region (one bound to the QK-C output distribution, one to the PV-A
    input distribution). Done once per kernel invocation.
  - In `fmha_alu1`, after the cvt_pk_fp8_f32 packing chain, view the
    union's bytes as a `static_distributed_tensor<fp8>` in the QK-C
    distribution, `store_tile` it through `p_lds` in canonical (M, N)
    order, `s_barrier`, then `load_tile` back with the PV-A
    distribution and copy into `sp(idx).p`.

A/B'd a uniform LDS-roundtrip (no fast-path) vs the split: pure LDS
regressed decode_m128 by ~1.5x end-to-end (CK FP8 dropped from
~0.39x of Triton FP8 to ~0.16x), driven by the extra block-wide
barrier on the 4-warp decode path. Keeping the swap for 32x32x16
preserves the previously-tuned perf.

Dispatcher (`unified_attention.cpp`) now FP8-enables every UA variant
including decode_d{64,128}_m16. Four new instance .cpp files
(`unified_attention_d{64,128}_fp8_{mask,nmask}_decode_t.cpp`)
instantiate the m16 FP8 kernels.

Pytest (`test_unified_attention_ck_correctness.py`):
  - 245 BF16/FP16: pass (no regression from the pipeline edit).
  - 160 FP8: pass (was 112 before m16 enablement).
  - 80 skipped: block_size<32 or query_len>kv_len -- pre-existing.

Single-shape m16 dispatches verified on gfx950:
  b=128 sq=1 hq=hk=8 d=128 fp8 PASS  (CK 0.109 ms / Triton 0.043 ms)
  b=128 sq=1 hq=hk=8 d=64  fp8 PASS  (CK 0.077 ms / Triton 0.039 ms)

Co-authored-by: Cursor <cursoragent@cursor.com>
2026-05-15 20:00:35 +00:00
..

CK Tile Example Suite

This directory contains a comprehensive suite of examples demonstrating the CK Tile programming model for high-performance GPU kernels. Each example illustrates a key deep learning or HPC operation, implemented using tile-based parallelism, modular pipelines, and data movement policy.


What is CK Tile?

CK Tile is a composable GPU programming API that expresses kernels as a composition of "tiles"—rectangular blocks of computation and data movement. The pipeline & policy orchestrates data movement (global <-> LDS <-> registers), computation, and synchronization, enabling high efficiency and flexibility.


Example Index

Example Operation Description
01_fmha Fused Multi-Head Attention Tile-based FMHA with masking, quantization, and epilogue fusion
02_layernorm2d LayerNorm2D Blockwise layer normalization with fusion and quantization
03_gemm GEMM Matrix multiplication with tilewise parallelism
04_img2col im2col Image-to-column transformation for GEMM-based convolution
05_reduce Reduction Tilewise sum, max, mean reductions
06_permute Permute Generic tensor permutation (up to rank-8)
09_topk_softmax TopK-Softmax Rowwise softmax and top-k selection for MoE gating
10_rmsnorm2d RMSNorm2D Root mean square normalization for LLMs
11_add_rmsnorm2d_rdquant Add + RMSNorm2D + RDQuant Fused add, RMSNorm, and rowwise dynamic quantization
12_smoothquant SmoothQuant Per-channel scaling and quantization for int8 inference
13_moe_sorting MoE Sorting Token-to-expert rearrangement for MoE dispatch
14_moe_smoothquant MoE-SmoothQuant Expert-dependent quantization fused with top-k selection
15_fused_moe Fused MoE End-to-end fused MoE block: sorting, group-GEMM, activation, weighting
16_batched_gemm Batched GEMM Parallel computation of multiple GEMMs
17_grouped_gemm Grouped GEMM Multiple independent GEMMs with different shapes
18_flatmm FLATMM Flattened matrix multiplication for packed layouts
19_gemm_multi_d Multi-D GEMM GEMM with multiple side inputs (bias, residual, etc.)
35_batched_transpose Batched Transpose NCHW <-> NHWC and other layout conversions
36_copy Copy Minimal example for tile-based memory movement
37_transpose Block Transpose High-performance tiled transpose for large tensors

Technical Highlights


How to Build & Run

mkdir build && cd build
sh ../script/cmake-ck-dev.sh ../ <arch>
make -j

Each example produces its own executable in build/bin/.


Learning and Extending


References


Back to Composable Kernel Examples