Files
composable_kernel/example/ck_tile
juuso-oskari d77f0bea63 CK-UA: collapse MHA/GQA variants -- one binary per (head_dim, kBlockM)
After moving kBlockQ to runtime in the previous commit, the static
NumQPerKV in `variant_config<V>` and the runtime-vs-static assert in
the kernel became the only things still tying a compiled binary to a
specific num_queries_per_kv. Drop both and the existing instances now
serve every num_qpkv that divides kBlockM evenly.

Concretely:
  * `variant_config<V>` -- remove the NumQPerKV field from every
    specialization.
  * `unified_attention_kernel_traits` -- remove the `num_queries_per_kv`
    / `kBlockQ = kBlockM / num_qpkv` derivation. The BlockTile's 2nd
    entry (the static `kBlockQ` exposed via UnifiedAttentionShape) is
    anchored at kBlockM so it describes the "num_qpkv == 1" fallback;
    the actual kBlockQ is always the runtime value.
  * `unified_attention_kernel_launch` -- recompute kBlockQ at host time
    from `args.num_queries_per_kv` for the total_num_q_blocks math.
  * `unified_attention_kernel.hpp` -- drop the
    `assert(kBlockQ_dyn == kBlockQ)` (it enforced the very coupling we
    just removed).
  * `unified_attention.cpp::select_config` -- collapse the two
    per-num_qpkv code paths into a single (head_dim, avg_rows,
    max_rows) ladder, where avg_rows = avg_q * num_qpkv.

Variant renames (8 variants):
  prefill_d128_mha       -> prefill_d128
  decode_d128_mha_m128   -> decode_d128_m128
  decode_d128_mha_m32    -> decode_d128_m32
  decode_d128_mha_m16    -> decode_d128_m16
  prefill_d64_gqa8       -> prefill_d64
  decode_d64_gqa8_m128   -> decode_d64_m128
  decode_d64_gqa8_m64    -> decode_d64_m64
  decode_d64_gqa8_m16    -> decode_d64_m16

The 16 d=64 instance files lose their `_gqa8` infix to match the
d=128 naming (file count unchanged: 16 dtypes x mask combos per
head_dim).

Validation:
  * Correctness suite: 241/245 (same 4 pre-existing int32-overflow
    failures in the prefill rebased-pointer path).
  * d=128 GQA-8 (a NEW combo we never had a binary for) -- runs
    correctly on the existing decode_d128_m* binaries with num_qpkv=8
    at runtime. max abs diff <= 1e-2 vs the torch reference at ql in
    {1, 4, 16}.
  * d=64 MHA (also a new combo) -- runs correctly on the existing
    decode_d64_m* binaries with num_qpkv=1. Same tolerance.
  * Perf sweep (b=4..256, sk=120000, MI300):
      d=64  GQA-8: speedups 1.28x..1.84x vs Triton (within 0.6%
                   of baseline).
      d=128 MHA:   speedups 0.98x..1.14x vs Triton (within 0.3%
                   of baseline).

Unlocked: adding new (head_dim, num_qpkv) combos no longer requires
new kernel binaries -- just a host-side heuristic update mapping the
combo to the appropriate (kBlockM, BlockWarps) ladder.

Co-authored-by: Cursor <cursoragent@cursor.com>
2026-05-12 12:15:55 +00:00
..

CK Tile Example Suite

This directory contains a comprehensive suite of examples demonstrating the CK Tile programming model for high-performance GPU kernels. Each example illustrates a key deep learning or HPC operation, implemented using tile-based parallelism, modular pipelines, and data movement policy.


What is CK Tile?

CK Tile is a composable GPU programming API that expresses kernels as a composition of "tiles"—rectangular blocks of computation and data movement. The pipeline & policy orchestrates data movement (global <-> LDS <-> registers), computation, and synchronization, enabling high efficiency and flexibility.


Example Index

Example Operation Description
01_fmha Fused Multi-Head Attention Tile-based FMHA with masking, quantization, and epilogue fusion
02_layernorm2d LayerNorm2D Blockwise layer normalization with fusion and quantization
03_gemm GEMM Matrix multiplication with tilewise parallelism
04_img2col im2col Image-to-column transformation for GEMM-based convolution
05_reduce Reduction Tilewise sum, max, mean reductions
06_permute Permute Generic tensor permutation (up to rank-8)
09_topk_softmax TopK-Softmax Rowwise softmax and top-k selection for MoE gating
10_rmsnorm2d RMSNorm2D Root mean square normalization for LLMs
11_add_rmsnorm2d_rdquant Add + RMSNorm2D + RDQuant Fused add, RMSNorm, and rowwise dynamic quantization
12_smoothquant SmoothQuant Per-channel scaling and quantization for int8 inference
13_moe_sorting MoE Sorting Token-to-expert rearrangement for MoE dispatch
14_moe_smoothquant MoE-SmoothQuant Expert-dependent quantization fused with top-k selection
15_fused_moe Fused MoE End-to-end fused MoE block: sorting, group-GEMM, activation, weighting
16_batched_gemm Batched GEMM Parallel computation of multiple GEMMs
17_grouped_gemm Grouped GEMM Multiple independent GEMMs with different shapes
18_flatmm FLATMM Flattened matrix multiplication for packed layouts
19_gemm_multi_d Multi-D GEMM GEMM with multiple side inputs (bias, residual, etc.)
35_batched_transpose Batched Transpose NCHW <-> NHWC and other layout conversions
36_copy Copy Minimal example for tile-based memory movement
37_transpose Block Transpose High-performance tiled transpose for large tensors

Technical Highlights


How to Build & Run

mkdir build && cd build
sh ../script/cmake-ck-dev.sh ../ <arch>
make -j

Each example produces its own executable in build/bin/.


Learning and Extending


References


Back to Composable Kernel Examples