[rocm-libraries] ROCm/rocm-libraries#7500 (commit f5cd4fd)

[CK_TILE][FMHA] Optimize long-context decoding on gfx11/12
 (#7500)

## Motivation

Relevant issue: ROCM-22065

FMHA has less-than-optimal performance of long-context decoding (i.e.
when seqlen_q = 1) on gfx11/12.
This PR optimizes the splitkv pipeline and configs for such scenarios.

## Technical Details

Optimizations applied in this PR:
1. use tiles with smaller M0 (16 vs 64), these tiles are used when
seqlen_q <= 16
2. adapt qr_nwarp_sshuffle pipeline for gfx11, it allows to use more
warps even for M0 = 16 (the qr pipeline parallelizes work between warps
in M dim so with M0 = 16 it allows to use only 1 warp)
3. enable kMergeNumHeadGroupsSeqLenQ (an optimization that merges one
group of heads in GQA) for all hdim values, not only 128
4. increase the number of splits (multiply by the number of head groups)
if (3) is used
5. increase the number of splits for RDNAs (`multiProcessorCount` is the
number of WGPs on RDNAs, not CUs, so it should be doubled to have
meaning similar to CDNAs)

Performance on gfx1151:

| Case | develop (GB/s) | This PR (GB/s) |
|:-------|-------:|-------:|
| [fp16\|group\|bshd] b:1, h:32/32, s:1/45056, d:64/64 | 127.58 | 183.11
|
| [fp16\|group\|bhsd] b:1, h:32/32, s:1/45056, d:64/64 | 153.64 | 215.02
|
| [fp16\|group\|bshd] b:1, h:16/8, s:1/77184, d:128/128 | 120.51 |
225.76 |
| [fp16\|group\|bhsd] b:1, h:16/8, s:1/77184, d:128/128 | 130.62 |
223.84 |
| [fp16\|group\|bshd] b:1, h:32/32, s:1/9600, d:128/128 | 82.65 | 138.44
|
| [fp16\|group\|bhsd] b:1, h:32/32, s:1/9600, d:128/128 | 105.75 |
220.45 |
| [fp16\|group\|bshd] b:1, h:8/1, s:1/401024, d:256/256 | 16.27 | 187.89
|
| [fp16\|group\|bhsd] b:1, h:8/1, s:1/401024, d:256/256 | 16.28 | 188.19
|

## Test Plan

An additional test case is added to the exiting test. It uses seqlen_q =
1, GQA, no mask to trigger the changes
```
ninja test_ck_tile_fmha_fwd_fp16 && bin/test_ck_tile_fmha_fwd_fp16 --gtest_filter="*SplitKV*
ninja test_ck_tile_fmha_fwd_bf16 && bin/test_ck_tile_fmha_fwd_bf16 --gtest_filter="*SplitKV*
```

Manual testing can be done with these commands:
```
bin/tile_example_fmha_fwd -prec=fp16 -mode=1 -page_block_size=128 -b=1 -h=32 -h_k=32 -d=64  -s=1 -s_k=$((352 * 128))  -lse=1 -mask=0 -num_splits=0 -kname=1 -v=1
bin/tile_example_fmha_fwd -prec=fp16 -mode=1 -page_block_size=128 -b=1 -h=16 -h_k=8  -d=128 -s=1 -s_k=$((603 * 128))  -lse=1 -mask=0 -num_splits=0 -kname=1 -v=1
bin/tile_example_fmha_fwd -prec=fp16 -mode=1 -page_block_size=128 -b=1 -h=32 -h_k=32 -d=128 -s=1 -s_k=$((75 * 128))   -lse=1 -mask=0 -num_splits=0 -kname=1 -v=1
bin/tile_example_fmha_fwd -prec=fp16 -mode=1 -page_block_size=128 -b=1 -h=8  -h_k=1  -d=256 -s=1 -s_k=$((3133 * 128)) -lse=1 -mask=0 -num_splits=0 -kname=1 -v=1
```

## Test Result

All the tests must pass.

## Submission Checklist

- [x] Look over the contributing guidelines at
https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
This commit is contained in:
Anton Gorenko
2026-06-03 06:16:10 +00:00
committed by assistant-librarian[bot]
parent 01bd52bdb5
commit 7ecbf82708
5 changed files with 84 additions and 21 deletions

View File

@@ -128,7 +128,7 @@ using trait_{F_idx} = fmha_fwd_splitkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F
namespace {{
template <bool kHasUnevenSplits>
void run_instance(const ck_tile::stream_config& s, fmha_fwd_splitkv_args a) {{
if constexpr ({F_hdim} == 128 && {F_bias} == ck_tile::BlockAttentionBiasEnum::NO_BIAS
if constexpr ({F_bias} == ck_tile::BlockAttentionBiasEnum::NO_BIAS
&& (std::is_same_v<{F_mask}, ck_tile::SimplifiedGenericAttentionMask<false>>
|| std::is_same_v<{F_mask}, FmhaMasks::NoMask>)) {{
if (a.max_seqlen_q == 1 && a.nhead_k < a.nhead_q) {{
@@ -283,7 +283,7 @@ float fmha_fwd_splitkv(fmha_fwd_splitkv_traits t, fmha_fwd_splitkv_args a, const
"""
FMHA_FWD_SPLITKV_API_INNER_DISPATCH = """{F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && (t.has_logits_soft_cap == {F_logits}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.do_fp8_static_quant == {F_squant}) &&
((a.block_table_ptr != nullptr) == {F_pagedkv}) && (t.has_sink == {F_sink}) && ({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck})) {{
((a.block_table_ptr != nullptr) == {F_pagedkv}) && (t.has_sink == {F_sink}) && ({F_scheck}) && ({F_seqtune}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck})) {{
using traits_ = fmha_fwd_splitkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, true, {F_squant}, {F_pagedkv},{F_sink}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>;
// get combine kernel tile sizes
@@ -364,6 +364,14 @@ class FmhaFwdSplitKVApiTrait:
else:
assert False
def seqtune(self, max_bm0: int) -> str:
if self.bm0 == max_bm0:
return "true/*fall back to largest tile*/"
else:
if self.mode == "group":
return f"a.max_seqlen_q <= {self.bm0}"
return f"a.seqlen_q <= {self.bm0}"
@property
def skcheck(self) -> str:
if self.mode == "group":
@@ -561,6 +569,7 @@ class FmhaFwdSplitKVApiPool:
for i_dtype, (dtype, pool_by_dtype) in enumerate(pool_by_arch.items()):
per_hdim_case = str()
for i_hdim, (hdim, pool_by_hdim) in enumerate(pool_by_dtype.items()):
max_bm0 = max((t.bm0 for t in pool_by_hdim), default=0)
inners = str()
for i_trait, trait in enumerate(pool_by_hdim):
inners += FMHA_FWD_SPLITKV_API_INNER_DISPATCH.format(
@@ -579,6 +588,7 @@ class FmhaFwdSplitKVApiPool:
F_pagedkv=BOOL_MAP[trait.pagedkv],
F_sink=BOOL_MAP[trait.sink],
F_scheck=trait.scheck,
F_seqtune=trait.seqtune(max_bm0),
F_skcheck=trait.skcheck,
F_dcheck=trait.dcheck,
F_dvcheck=trait.dvcheck,
@@ -763,6 +773,7 @@ class KernelComponentFactoryBase:
pipelines.append(Pipeline("qr", "row", "t", "f", "f", "f", logits, bias, "t", squant, pagedkv, sink, mask)) # fmt: skip
pipelines.append(Pipeline("qr", "row", "t", "t", "f", "f", logits, bias, "t", squant, pagedkv, sink, mask)) # fmt: skip
pipelines.append(Pipeline("qr", "row", "t", "t", "t", "t", logits, bias, "t", squant, pagedkv, sink, mask)) # fmt: skip
pipelines.append(Pipeline("qr_nwarp_sshuffle", "row", "t", "t", "f", "f", logits, bias, "t", squant, pagedkv, sink, mask)) # fmt: skip
elif dtype in ["fp8", "bf8"]:
for logits, mask, bias in itertools.product(
["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys()
@@ -846,11 +857,15 @@ class KernelComponentFactoryGfx11(KernelComponentFactoryBase):
def get_hdim_tile_size_dict(dtype: str) -> Optional[dict]:
if dtype in ["fp16", "bf16"]:
return {
# bm0, bn0, bk0, bn1, bk1,
"32" : FmhaFwdTileSize( 64, 64, 16, 32, 32, 32, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1),
"64" : FmhaFwdTileSize( 64, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1),
"128": FmhaFwdTileSize( 64, 64, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1),
"256": FmhaFwdTileSize( 64, 64, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1),
# bm0, bn0, bk0, bn1, bk1,
"32" : [FmhaFwdTileSize( 16, 64, 16, 32, 32, 32, 1, 2, 1, 1, 2, 1, 16, 16, 16, 16, 16, 16, -1),
FmhaFwdTileSize( 64, 64, 16, 32, 32, 32, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)],
"64" : [FmhaFwdTileSize( 16, 64, 32, 64, 32, 64, 1, 4, 1, 1, 4, 1, 16, 16, 16, 16, 16, 16, -1),
FmhaFwdTileSize( 64, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)],
"128": [FmhaFwdTileSize( 16, 64, 32, 128, 32, 128, 1, 4, 1, 1, 4, 1, 16, 16, 16, 16, 16, 16, -1),
FmhaFwdTileSize( 64, 64, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)],
"256": [FmhaFwdTileSize( 16, 64, 32, 256, 32, 256, 1, 4, 1, 1, 4, 1, 16, 16, 16, 16, 16, 16, -1),
FmhaFwdTileSize( 64, 64, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)],
} # fmt: skip
else:
return None
@@ -863,11 +878,15 @@ class KernelComponentFactoryGfx12(KernelComponentFactoryBase):
def get_hdim_tile_size_dict(dtype: str) -> Optional[dict]:
if dtype in ["fp16", "bf16"]:
return {
# bm0, bn0, bk0, bn1, bk1,
"32" : FmhaFwdTileSize( 64, 64, 16, 32, 32, 32, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1),
"64" : FmhaFwdTileSize( 64, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1),
"128": FmhaFwdTileSize( 64, 64, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1),
"256": FmhaFwdTileSize( 64, 64, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1),
# bm0, bn0, bk0, bn1, bk1,
"32" : [FmhaFwdTileSize( 16, 64, 16, 32, 32, 32, 1, 2, 1, 1, 2, 1, 16, 16, 16, 16, 16, 16, -1),
FmhaFwdTileSize( 64, 64, 16, 32, 32, 32, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)],
"64" : [FmhaFwdTileSize( 16, 64, 32, 64, 32, 64, 1, 4, 1, 1, 4, 1, 16, 16, 16, 16, 16, 16, -1),
FmhaFwdTileSize( 64, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)],
"128": [FmhaFwdTileSize( 16, 128, 32, 128, 32, 128, 1, 8, 1, 1, 8, 1, 16, 16, 16, 16, 16, 16, -1),
FmhaFwdTileSize( 64, 64, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)],
"256": [FmhaFwdTileSize( 16, 128, 32, 256, 32, 256, 1, 8, 1, 1, 8, 1, 16, 16, 16, 16, 16, 16, -1),
FmhaFwdTileSize( 64, 64, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)],
} # fmt: skip
elif dtype in ["fp8", "bf8"]:
return {
@@ -930,11 +949,17 @@ def get_fwd_splitkv_blobs(
d = factory.get_hdim_tile_size_dict(dtype)
if d is None:
continue
# for hdim_str, mode, mask, bias, lse in itertools.product(d.keys(), MODE_MAP.keys(), MASK_MAP.keys(), ["t", "f"], ["t", "f"]):
for hdim_str, mode in itertools.product(d.keys(), MODE_MAP.keys()):
tile = d[hdim_str]
tiles = d[hdim_str]
if not isinstance(tiles, list):
tiles = [tiles]
hdim = int(hdim_str)
for pipeline in factory.get_pipelines(dtype, hdim, mask_impl):
for tile, pipeline in itertools.product(
tiles, factory.get_pipelines(dtype, hdim, mask_impl)
):
# Use qr_nwarp_sshuffle with multiple N warps and qr otherwise
if (tile.F_rn0 != 1) != (pipeline.tag == "qr_nwarp_sshuffle"):
continue
if mode == "group":
if pipeline.F_spad != "t" or pipeline.F_skpad != "t":
# in group mode, spad/skpad must be true, since we can't predict if seqlen of current batch need pad or not

View File

@@ -165,8 +165,10 @@ int override_num_splits_if_necessary(
if(num_splits < 1 && p_drop == 0.0f)
{
// props.multiProcessorCount for >=gfx10 is the number of WGPs (each has 2 CUs)
const int num_blocks_per_SM = props.warpSize == 32 ? 4 : 2;
return num_splits_heuristic(
batch * nhead * num_m_blocks, props.multiProcessorCount * 2, 128);
batch * nhead * num_m_blocks, props.multiProcessorCount * num_blocks_per_SM, 128);
}
return num_splits;
@@ -648,8 +650,18 @@ fwd_result fmha_fwd_run(mode_enum mode,
// legalize num_splits according to other options
if(num_splits < 1)
{
int nhead_merged = nhead;
int max_seqlen_q_merged = max_seqlen_q;
// When max_seqlen_q == 1 and multiple head groups are merged (kMergeNumHeadGroupsSeqLenQ)
// then more splits are required
if(bias.type == bias_enum::no_bias && mask.type == mask_enum::no_mask &&
max_seqlen_q == 1 && nhead_k < nhead)
{
nhead_merged = nhead_k;
max_seqlen_q_merged = max_seqlen_q * (nhead / nhead_k);
}
num_splits = override_num_splits_if_necessary(
batch, nhead, max_seqlen_q, hdim_v, p_drop, num_splits);
batch, nhead_merged, max_seqlen_q_merged, hdim_v, p_drop, num_splits);
}
if(128 < num_splits)
{

View File

@@ -6,6 +6,7 @@
#include "ck_tile/core.hpp"
#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_nwarp_sshuffle_qr_ks_vs_default_policy.hpp"
#include "ck_tile/ops/gemm/warp/warp_wmma_gemm_gfx11_utils.hpp"
#include "ck_tile/ops/reduce/block/block_reduce.hpp"
namespace ck_tile {
@@ -257,7 +258,7 @@ struct BlockFmhaFwdSplitKVPipelineNWarpSShuffleQRKSVS
clear_tile(o_acc);
if((__builtin_isinf_sign(sink_v) >= 0) && i_split == 0)
{
set_tile(m, SMPLComputeDataType{sink_v * C_LOG2E});
set_tile(m, SMPLComputeDataType{sink_v * static_cast<float>(C_LOG2E)});
set_tile(l, SMPLComputeDataType{1.0f});
}
else
@@ -698,8 +699,15 @@ struct BlockFmhaFwdSplitKVPipelineNWarpSShuffleQRKSVS
block_tile_reduce_sync(rowsum_p, f_sum, bool_constant<false>{});
#if defined(__gfx11__)
auto p = make_static_distributed_tensor<PDataType>(
decltype(gemm_1)::template MakeABlockTileDistribution<kM0, kN0>());
PermuteWarpGemmCToA(
p, cast_tile<PDataType>(tile_elementwise_in(p_compute_element_func, p_compute)));
#else
const auto p =
cast_tile<PDataType>(tile_elementwise_in(p_compute_element_func, p_compute));
#endif
// l{j}, Oacc{j}
constexpr auto o_spans = decltype(o_acc)::get_distributed_spans();

View File

@@ -5,8 +5,6 @@
#include "ck_tile/core.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_custom_policy.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1.hpp"
namespace ck_tile {
@@ -163,6 +161,25 @@ struct BlockFmhaFwdSplitKVPipelineNWarpSShuffleQRKSVSDefaultPolicy
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1;
constexpr index_t kTileK = Problem::BlockFmhaShape::kN0;
#if defined(__gfx11__)
// Keep C distribution and replicate data for NWarp to prevent doubling registers
// PermuteWarpGemmCToA will convert C distribution to A for matrix P later
constexpr index_t K1 = kKPerBlock / WG::kM;
constexpr index_t K0 = kTileK / kKPerBlock;
constexpr index_t M1 = MWarp;
constexpr index_t M0 = kMPerBlock / WG::kN;
constexpr auto s2_block_outer_dstr_encoding =
tile_distribution_encoding<sequence<NWarp>,
tuple<sequence<M0, M1>, sequence<K0, K1>>,
tuple<sequence<1, 0>>,
tuple<sequence<1, 0>>,
sequence<1, 2, 2>,
sequence<0, 0, 1>>{};
constexpr auto s2_block_dstr_encoding = detail::make_embed_tile_distribution_encoding(
s2_block_outer_dstr_encoding, typename WG::CWarpDstrEncoding{});
#else
// K2 is equal to Impl::kABKPerLane * kKIterPerWarpGemm
constexpr index_t K3 = WG::kK / WG::WarpGemmAttribute::Impl::kABKLane;
constexpr index_t K2 = WG::WarpGemmAttribute::Impl::kABKLane;
@@ -179,7 +196,7 @@ struct BlockFmhaFwdSplitKVPipelineNWarpSShuffleQRKSVSDefaultPolicy
tuple<sequence<1, 0>, sequence<2, 2>>,
sequence<1, 2, 2, 2>,
sequence<0, 0, 1, 3>>{};
#endif
constexpr auto s2_block_dstr = make_static_tile_distribution(s2_block_dstr_encoding);
return s2_block_dstr;

View File

@@ -735,6 +735,7 @@ INSTANTIATE_TEST_SUITE_P(TestCkTileFmhaFwd,
Values(3, 4),
Values(std::tuple{4, 3, 1, 200, 1024, "0"},
std::tuple{2, 2, -1, 512, 2000, "0"},
std::tuple{2, 8, 2, 1, 1024, "0"},
std::tuple{3, 2, -1, 230, 899, "t:128,128"})));
TEST_P(SplitKV, DataTypeConfig)