Files
composable_kernel/test/ck_tile
Chao 320a813d67 [rocm-libraries] ROCm/rocm-libraries#6533 (commit 5dcaa45)
[CK_TILE] Add host-side Pack-GQA optimization for FMHA
 forward (#6533)
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

[CK_TILE] Add host-side Pack-GQA optimization for FMHA forward

## Motivation

Host-side Pack-GQA optimization for CK-Tile FMHA forward. Reshapes Q
tensor
from `[b, nhead_q, seqlen_q, d]` to `[b, nhead_kv, nhead_ratio *
seqlen_q, d]`
by adjusting strides, so grouped Q-heads sharing the same KV data are
processed
in a single tile. Zero kernel changes — runner-only.

Phase 1: non-causal attention with GQA ratio packing.
Phase 2: extends to dropout and split-kv paths, fixes stride edge cases.

## Technical Details

Modified files (2):
- `example/ck_tile/01_fmha/example_fmha_fwd.cpp` — Pack-GQA flag
plumbing
- `example/ck_tile/01_fmha/fmha_fwd_runner.hpp` — Q tensor reshape
logic,
  stride adjustment for GQA ratio packing

New files (1):
- `example/ck_tile/01_fmha/test_pack_gqa_phase2.sh` — 53 test cases
covering
  non-causal, dropout, split-kv, various GQA ratios

## Dependencies

None — this PR is standalone.

## Test Plan

- GPU validation on MI300X (gfx942, ROCm 6.4.1):
- Command: `./build/bin/tile_example_fmha_fwd -b=2 -h=32 -h_k=8 -s=2048
-d=128 -prec=bf16 -mode=group -v=1 -warmup=1 -repeat=3`
- GPU validation on MI350X (gfx950, ROCm 7.0), 53 parameterized test
cases:
- Command (GQA 4:1): `./build/bin/tile_example_fmha_fwd -b=2 -h=32
-h_k=8 -s=2048 -d=128 -prec=bf16 -mode=group -v=1 -warmup=1 -repeat=3`
- Command (GQA 8:1): `./build/bin/tile_example_fmha_fwd -b=2 -h=64
-h_k=8 -s=2048 -d=128 -prec=bf16 -mode=group -v=1 -warmup=1 -repeat=3`
- Command (decode): `./build/bin/tile_example_fmha_fwd -b=64 -h=32
-h_k=8 -s=1 -s_k=4096 -d=128 -prec=bf16 -mode=group -v=1 -warmup=1
-repeat=3`

## Test Result

Benchmark results (MI350X, gfx950, ROCm 7.0):

| Config | Without Pack | With Pack | Improvement |
|--------|-------------|-----------|-------------|
| GQA 4:1 prefill b=2 h=32 hk=8 s=2048 d=128 bf16 | 690.05 TFlops (0.199
ms) | 695.61 TFlops (0.198 ms) | +0.8% |
| GQA 8:1 prefill b=2 h=64 hk=8 s=2048 d=128 bf16 | 706.25 TFlops (0.389
ms) | 729.35 TFlops (0.377 ms) | +3.3% |
| GQA 8:1 decode b=64 h=32 hk=4 s_k=4096 d=128 bf16 | 305.20 GB/s (1.763
ms) | 1813.41 GB/s (0.297 ms) | **+5.9x** |
| LLaMA-70B decode b=32 h=64 hk=8 s_k=4096 d=128 bf16 | 591.70 GB/s
(0.909 ms) | 1820.65 GB/s (0.295 ms) | **+3.1x** |
| MHA ratio=1 b=2 h=8 s=4096 d=128 bf16 | 695.16 TFlops | 702.72 TFlops
| no regression |

Benchmark results (MI300X, gfx942, ROCm 6.4.1):

No regression on MI300X. Pack-GQA is a runner-only optimization (zero
kernel changes), performance impact is within noise on MI300X.

| Config | TFlops / GB/s | Time (ms) | Delta vs baseline |
|--------|-------------|-----------|-------------------|
| MHA bf16 b=2 h=8 s=4096 d=128 | 336.52 TFlops | 0.408 | -1.7% |
| GQA 4:1 bf16 b=2 h=32 hk=8 s=2048 d=128 | 322.52 TFlops | 0.426 |
-0.7% |
| GQA 8:1 bf16 b=2 h=64 hk=8 s=2048 d=128 | 349.85 TFlops | 0.786 |
+0.5% |
| LLaMA-70B prefill b=1 h=64 hk=8 s=4096 d=128 bf16 | 381.29 TFlops |
1.442 | +1.2% |
| Decode b=64 h=32 hk=8 s_k=4096 d=128 bf16 | 697.32 GB/s | 1.541 |
+0.8% |

All validation tests pass (`valid:y`) on both MI300X and MI350X.

Additional validation:
- 53 parameterized test cases pass (23 phase 1 + 30 phase 2)
- GQA ratios tested: 1:1, 2:1, 4:1, 8:1, 32:1
- No regression on MHA (ratio=1) workloads
- fp16 and bf16 validated
2026-06-10 01:56:44 +00:00
..

CK Tile Testing Guide

This document describes the test organization and available test targets for CK Tile operations.

Overview

CK Tile tests are organized with multiple levels of granularity to support different development workflows:

  1. Global test labels - Run tests across all operations
  2. Operation-specific umbrella targets - Run all tests for a specific operation
  3. Individual test executables - Run specific tests

Global Test Labels

These targets run tests across all CK operations (not just CK Tile):

ninja smoke

Run fast smoke tests (tests that complete within ~30 seconds on gfx90a).

ninja smoke

ninja regression

Run slower, more comprehensive regression tests.

ninja regression

ninja check

Run ALL available tests in the entire codebase.

ninja check

Operation-Specific Umbrella Targets

These targets allow you to run all tests for a specific CK Tile operation. This is useful when making changes to a particular operation and wanting to validate all related tests without running the entire test suite.

GEMM Operations

ck_tile_gemm_tests

Run all basic GEMM pipeline tests (memory, compute variants, persistent, etc.)

ninja ck_tile_gemm_tests

Test executables included:

  • test_ck_tile_gemm_pipeline_mem
  • test_ck_tile_gemm_pipeline_compv3
  • test_ck_tile_gemm_pipeline_compv4
  • test_ck_tile_gemm_pipeline_persistent
  • test_ck_tile_gemm_pipeline_compv6
  • test_ck_tile_gemm_pipeline_comp_async (gfx95 only)
  • test_ck_tile_gemm_pipeline_*_wmma variants (gfx11/gfx12 only)

ck_tile_gemm_block_scale_tests

Run all GEMM tests with block-scale quantization (AQuant, BQuant, ABQuant, etc.)

ninja ck_tile_gemm_block_scale_tests

Test executables included: 29 test executables covering:

  • AQuant tests (memory pipelines, base layouts, prefill, preshuffle, transpose)
  • ABQuant tests (base, padding, preshuffle)
  • BQuant tests (1D/2D variants, transpose)
  • BQuant with PreshuffleB (decode/prefill, 1D/2D)
  • BQuant with PreshuffleQuant (decode/prefill, 1D/2D)
  • RowColQuant and TensorQuant tests

ck_tile_gemm_streamk_tests

Run all GEMM StreamK tests (tile partitioner, reduction, smoke, extended)

ninja ck_tile_gemm_streamk_tests

Test executables included:

  • test_ck_tile_streamk_tile_partitioner
  • test_ck_tile_streamk_reduction
  • test_ck_tile_streamk_smoke
  • test_ck_tile_streamk_extended

ck_tile_grouped_gemm_quant_tests

Run all grouped GEMM quantization tests

ninja ck_tile_grouped_gemm_quant_tests

Test executables included:

  • test_ck_tile_grouped_gemm_quant_rowcol
  • test_ck_tile_grouped_gemm_quant_tensor
  • test_ck_tile_grouped_gemm_quant_aquant
  • test_ck_tile_grouped_gemm_quant_bquant
  • test_ck_tile_grouped_gemm_quant_bquant_preshuffleb

Other Operations

ck_tile_fmha_tests

Run all FMHA (Flash Multi-Head Attention) tests

ninja ck_tile_fmha_tests

Test executables included: Forward and backward tests for fp16, bf16, fp8bf16, fp32

ck_tile_reduce_tests

Run all reduce operation tests

ninja ck_tile_reduce_tests

Test executables included:

  • test_ck_tile_reduce2d
  • test_ck_tile_multi_reduce2d_threadwise
  • test_ck_tile_multi_reduce2d_multiblock

Individual Test Executables

You can also build and run individual test executables:

Build a specific test

ninja test_ck_tile_gemm_pipeline_mem

Run a specific test directly

./build/bin/test_ck_tile_gemm_pipeline_mem

Run a specific test through ctest

ctest -R test_ck_tile_gemm_pipeline_mem --output-on-failure