[CK_TILE][FMHA] Optimize long-context decoding on gfx11/12 (#7500) ## Motivation Relevant issue: ROCM-22065 FMHA has less-than-optimal performance of long-context decoding (i.e. when seqlen_q = 1) on gfx11/12. This PR optimizes the splitkv pipeline and configs for such scenarios. ## Technical Details Optimizations applied in this PR: 1. use tiles with smaller M0 (16 vs 64), these tiles are used when seqlen_q <= 16 2. adapt qr_nwarp_sshuffle pipeline for gfx11, it allows to use more warps even for M0 = 16 (the qr pipeline parallelizes work between warps in M dim so with M0 = 16 it allows to use only 1 warp) 3. enable kMergeNumHeadGroupsSeqLenQ (an optimization that merges one group of heads in GQA) for all hdim values, not only 128 4. increase the number of splits (multiply by the number of head groups) if (3) is used 5. increase the number of splits for RDNAs (`multiProcessorCount` is the number of WGPs on RDNAs, not CUs, so it should be doubled to have meaning similar to CDNAs) Performance on gfx1151: | Case | develop (GB/s) | This PR (GB/s) | |:-------|-------:|-------:| | [fp16\|group\|bshd] b:1, h:32/32, s:1/45056, d:64/64 | 127.58 | 183.11 | | [fp16\|group\|bhsd] b:1, h:32/32, s:1/45056, d:64/64 | 153.64 | 215.02 | | [fp16\|group\|bshd] b:1, h:16/8, s:1/77184, d:128/128 | 120.51 | 225.76 | | [fp16\|group\|bhsd] b:1, h:16/8, s:1/77184, d:128/128 | 130.62 | 223.84 | | [fp16\|group\|bshd] b:1, h:32/32, s:1/9600, d:128/128 | 82.65 | 138.44 | | [fp16\|group\|bhsd] b:1, h:32/32, s:1/9600, d:128/128 | 105.75 | 220.45 | | [fp16\|group\|bshd] b:1, h:8/1, s:1/401024, d:256/256 | 16.27 | 187.89 | | [fp16\|group\|bhsd] b:1, h:8/1, s:1/401024, d:256/256 | 16.28 | 188.19 | ## Test Plan An additional test case is added to the exiting test. It uses seqlen_q = 1, GQA, no mask to trigger the changes ``` ninja test_ck_tile_fmha_fwd_fp16 && bin/test_ck_tile_fmha_fwd_fp16 --gtest_filter="*SplitKV* ninja test_ck_tile_fmha_fwd_bf16 && bin/test_ck_tile_fmha_fwd_bf16 --gtest_filter="*SplitKV* ``` Manual testing can be done with these commands: ``` bin/tile_example_fmha_fwd -prec=fp16 -mode=1 -page_block_size=128 -b=1 -h=32 -h_k=32 -d=64 -s=1 -s_k=$((352 * 128)) -lse=1 -mask=0 -num_splits=0 -kname=1 -v=1 bin/tile_example_fmha_fwd -prec=fp16 -mode=1 -page_block_size=128 -b=1 -h=16 -h_k=8 -d=128 -s=1 -s_k=$((603 * 128)) -lse=1 -mask=0 -num_splits=0 -kname=1 -v=1 bin/tile_example_fmha_fwd -prec=fp16 -mode=1 -page_block_size=128 -b=1 -h=32 -h_k=32 -d=128 -s=1 -s_k=$((75 * 128)) -lse=1 -mask=0 -num_splits=0 -kname=1 -v=1 bin/tile_example_fmha_fwd -prec=fp16 -mode=1 -page_block_size=128 -b=1 -h=8 -h_k=1 -d=256 -s=1 -s_k=$((3133 * 128)) -lse=1 -mask=0 -num_splits=0 -kname=1 -v=1 ``` ## Test Result All the tests must pass. ## Submission Checklist - [x] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
CK Tile Testing Guide
This document describes the test organization and available test targets for CK Tile operations.
Overview
CK Tile tests are organized with multiple levels of granularity to support different development workflows:
- Global test labels - Run tests across all operations
- Operation-specific umbrella targets - Run all tests for a specific operation
- Individual test executables - Run specific tests
Global Test Labels
These targets run tests across all CK operations (not just CK Tile):
ninja smoke
Run fast smoke tests (tests that complete within ~30 seconds on gfx90a).
ninja smoke
ninja regression
Run slower, more comprehensive regression tests.
ninja regression
ninja check
Run ALL available tests in the entire codebase.
ninja check
Operation-Specific Umbrella Targets
These targets allow you to run all tests for a specific CK Tile operation. This is useful when making changes to a particular operation and wanting to validate all related tests without running the entire test suite.
GEMM Operations
ck_tile_gemm_tests
Run all basic GEMM pipeline tests (memory, compute variants, persistent, etc.)
ninja ck_tile_gemm_tests
Test executables included:
test_ck_tile_gemm_pipeline_memtest_ck_tile_gemm_pipeline_compv3test_ck_tile_gemm_pipeline_compv4test_ck_tile_gemm_pipeline_persistenttest_ck_tile_gemm_pipeline_compv6test_ck_tile_gemm_pipeline_comp_async(gfx95 only)test_ck_tile_gemm_pipeline_*_wmmavariants (gfx11/gfx12 only)
ck_tile_gemm_block_scale_tests
Run all GEMM tests with block-scale quantization (AQuant, BQuant, ABQuant, etc.)
ninja ck_tile_gemm_block_scale_tests
Test executables included: 29 test executables covering:
- AQuant tests (memory pipelines, base layouts, prefill, preshuffle, transpose)
- ABQuant tests (base, padding, preshuffle)
- BQuant tests (1D/2D variants, transpose)
- BQuant with PreshuffleB (decode/prefill, 1D/2D)
- BQuant with PreshuffleQuant (decode/prefill, 1D/2D)
- RowColQuant and TensorQuant tests
ck_tile_gemm_streamk_tests
Run all GEMM StreamK tests (tile partitioner, reduction, smoke, extended)
ninja ck_tile_gemm_streamk_tests
Test executables included:
test_ck_tile_streamk_tile_partitionertest_ck_tile_streamk_reductiontest_ck_tile_streamk_smoketest_ck_tile_streamk_extended
ck_tile_grouped_gemm_quant_tests
Run all grouped GEMM quantization tests
ninja ck_tile_grouped_gemm_quant_tests
Test executables included:
test_ck_tile_grouped_gemm_quant_rowcoltest_ck_tile_grouped_gemm_quant_tensortest_ck_tile_grouped_gemm_quant_aquanttest_ck_tile_grouped_gemm_quant_bquanttest_ck_tile_grouped_gemm_quant_bquant_preshuffleb
Other Operations
ck_tile_fmha_tests
Run all FMHA (Flash Multi-Head Attention) tests
ninja ck_tile_fmha_tests
Test executables included: Forward and backward tests for fp16, bf16, fp8bf16, fp32
ck_tile_reduce_tests
Run all reduce operation tests
ninja ck_tile_reduce_tests
Test executables included:
test_ck_tile_reduce2dtest_ck_tile_multi_reduce2d_threadwisetest_ck_tile_multi_reduce2d_multiblock
Individual Test Executables
You can also build and run individual test executables:
Build a specific test
ninja test_ck_tile_gemm_pipeline_mem
Run a specific test directly
./build/bin/test_ck_tile_gemm_pipeline_mem
Run a specific test through ctest
ctest -R test_ck_tile_gemm_pipeline_mem --output-on-failure