[rocm-libraries] ROCm/rocm-libraries#5018 (commit b32e7e6)

[CK_TILE] Add LLC-aware FMHA head grouping and head-major
 scheduling on RDNA (#5018)
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

## Motivation
Long-sequence FMHA can become memory-bound when K/V working sets exceed
Infinity Cache (LLC), causing repeated DRAM traffic across heads.

This PR introduces LLC-aware launch ordering improvements for FMHA
forward, and it is currently enabled only on gfx11 and gfx12. The
approach is inspired by
[`Dao-AILab/flash-attention#2217`](https://github.com/Dao-AILab/flash-attention/pull/2217),
adapted to CK’s kernel/runner structure and layout handling.

In this context, `bshd` is the layout used in Flash-Attention, while
`bhsd` is the default layout used by the CK Tile FMHA example.

## Technical Details
This PR adds two complementary strategies:

- For `bshd` input layout (`i_perm/o_perm=0`), enable explicit LLC-aware
head grouping:
  - Estimate LLC size (env override, KFD sysfs, or arch default).
  - Compute group size from K/V bytes per head vs LLC target.
- Launch FMHA forward repeatedly per head-group by slicing Q/K/V/O (and
related tensors).

- For `bhsd` input layout (`i_perm/o_perm=1`), apply implicit
launch-order adjustment:
  - Keep a single kernel launch.
- Reinterpret block linearization in `GetTileIndex` to make execution
head-major,
     improving temporal locality of per-head K/V reuse.

Additional integration updates:
- Propagate `num_head_q_total` and `head_start` through FMHA args/kargs.
- Use global head indexing for dropout RNG stream mapping so grouped
launches keep
    deterministic/consistent dropout behavior.
- Keep fallback behavior unchanged when grouping is not beneficial or
disabled.

## Test Plan
- `test_ck_tile_fmha`
- `tile_example_fmha_fwd`

## Test Result
- `test_ck_tile_fmha`: all tests passed.
- `tile_example_fmha_fwd`: tested this on gfx1100, gfx1151, and gfx1201,
and all of them show higher performance compared to the baseline. The
improvement is consistent, and performance is well maintained even at
long sequence lengths.

./build/bin/tile_example_fmha_fwd -prec=bf16 -mode=0 -b=1 -h=24 -d=128
-s={seqlen} -s_k={seqlen} -lse=0 -iperm={0/1} -operm={0/1}
- TFLOPs by sequence length target: gfx1100 layout: bhsd

SeqLen | Before | After | Speedup
-- | -- | -- | --
1024 | 56.27 | 61.48 | 1.09x
4096 | 67.10 | 72.27 | 1.08x
8192 | 65.99 | 71.64 | 1.09x
12288 | 61.60 | 76.61 | 1.24x
16384 | 58.99 | 75.74 | 1.28x
20480 | 57.32 | 74.42 | 1.30x
24576 | 56.89 | 74.25 | 1.31x
27280 | 18.93 | 24.48 | 1.29x

- TFLOPs by sequence length target: gfx1201 layout: bshd

SeqLen | Before | After | Speedup
-- | -- | -- | --
1024 | 66.79 | 65.90 | 0.99x
4096 | 85.90 | 86.80 | 1.01x
8192 | 77.06 | 90.29 | 1.17x
12288 | 58.36 | 88.98 | 1.52x
16384 | 52.12 | 88.88 | 1.71x
20480 | 48.11 | 88.42 | 1.84x
24576 | 47.12 | 89.07 | 1.89x
27280 | 49.05 | 50.31 | 1.03x

## Submission Checklist

- [x] Look over the contributing guidelines at
https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
This commit is contained in:
Hosang
2026-03-16 21:19:23 +00:00
committed by assistant-librarian[bot]
parent 9c414d2e59
commit 859acb5ae7
5 changed files with 632 additions and 30 deletions

View File

@@ -15,6 +15,15 @@
#include <variant>
#define CK_TILE_FMHA_HANDLE_XOR_LENGTH_FOLD 0
#if !defined(CK_TILE_FMHA_FORCE_HEAD_MAJOR)
#if defined(__HIP_DEVICE_COMPILE__) && (defined(__gfx11__) || defined(__gfx12__))
#define CK_TILE_FMHA_FORCE_HEAD_MAJOR 1
#else
#define CK_TILE_FMHA_FORCE_HEAD_MAJOR 0
#endif
#endif
// S[seqlen_q, seqlen_k] = Q[seqlen_q, hdim_q] @ K[seqlen_k, hdim_q]
// S'[seqlen_q, seqlen_k] = S[seqlen_q, seqlen_k] * Scale[1]
// S''[seqlen_q, seqlen_k] = S'[seqlen_q, seqlen_k] + Bias[seqlen_q, seqlen_k]
@@ -111,6 +120,10 @@ struct FmhaFwdKernel
ck_tile::index_t nhead_stride_k;
ck_tile::index_t nhead_stride_v;
ck_tile::index_t nhead_stride_o;
// Optional global head count and head offset (for grouped launches & RNG correctness)
ck_tile::index_t num_head_q_total = 0;
ck_tile::index_t head_start = 0;
};
struct FmhaFwdLogitsSoftCapKargs
@@ -410,9 +423,11 @@ struct FmhaFwdKernel
drop_seed_offset,
ck_tile::index_t block_scale_size_q,
ck_tile::index_t block_scale_size_kv,
const void* cu_seqlen_q_ptr = nullptr,
const void* cu_seqlen_k_ptr = nullptr,
const void* sink_ptr = nullptr)
const void* cu_seqlen_q_ptr = nullptr,
const void* cu_seqlen_k_ptr = nullptr,
const void* sink_ptr = nullptr,
ck_tile::index_t num_head_q_total = 0,
ck_tile::index_t head_start = 0)
{
Kargs kargs{{q_ptr,
k_ptr,
@@ -448,6 +463,8 @@ struct FmhaFwdKernel
batch_stride_k,
batch_stride_v,
batch_stride_o};
kargs.num_head_q_total = num_head_q_total;
kargs.head_start = head_start;
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
{
@@ -605,9 +622,11 @@ struct FmhaFwdKernel
const std::tuple<uint64_t, uint64_t>& drop_seed_offset,
ck_tile::index_t block_scale_size_q,
ck_tile::index_t block_scale_size_kv,
const void* cu_seqlen_q_ptr = nullptr,
const void* cu_seqlen_k_ptr = nullptr,
const void* sink_ptr = nullptr)
const void* cu_seqlen_q_ptr = nullptr,
const void* cu_seqlen_k_ptr = nullptr,
const void* sink_ptr = nullptr,
ck_tile::index_t num_head_q_total = 0,
ck_tile::index_t head_start = 0)
{
return MakeKargsImpl(
q_ptr,
@@ -668,7 +687,9 @@ struct FmhaFwdKernel
block_scale_size_kv,
cu_seqlen_q_ptr,
cu_seqlen_k_ptr,
sink_ptr);
sink_ptr,
num_head_q_total,
head_start);
}
// std::variant<> can't take in a list initializer, overload for backward compatibility
@@ -730,9 +751,11 @@ struct FmhaFwdKernel
const std::tuple<const void*, const void*>& drop_seed_offset,
ck_tile::index_t block_scale_size_q,
ck_tile::index_t block_scale_size_kv,
const void* cu_seqlen_q_ptr = nullptr,
const void* cu_seqlen_k_ptr = nullptr,
const void* sink_ptr = nullptr)
const void* cu_seqlen_q_ptr = nullptr,
const void* cu_seqlen_k_ptr = nullptr,
const void* sink_ptr = nullptr,
ck_tile::index_t num_head_q_total = 0,
ck_tile::index_t head_start = 0)
{
return MakeKargsImpl(
q_ptr,
@@ -793,7 +816,9 @@ struct FmhaFwdKernel
block_scale_size_kv,
cu_seqlen_q_ptr,
cu_seqlen_k_ptr,
sink_ptr);
sink_ptr,
num_head_q_total,
head_start);
}
template <bool Cond = kIsGroupMode>
@@ -851,9 +876,11 @@ struct FmhaFwdKernel
drop_seed_offset,
ck_tile::index_t block_scale_size_q,
ck_tile::index_t block_scale_size_kv,
const void* cu_seqlen_q_ptr = nullptr,
const void* cu_seqlen_k_ptr = nullptr,
const void* sink_ptr = nullptr)
const void* cu_seqlen_q_ptr = nullptr,
const void* cu_seqlen_k_ptr = nullptr,
const void* sink_ptr = nullptr,
ck_tile::index_t num_head_q_total = 0,
ck_tile::index_t head_start = 0)
{
Kargs kargs{{q_ptr,
k_ptr,
@@ -890,6 +917,8 @@ struct FmhaFwdKernel
reinterpret_cast<const int32_t*>(seqstart_k_ptr),
reinterpret_cast<const int32_t*>(seqlen_q_ptr),
reinterpret_cast<const int32_t*>(seqlen_k_ptr)};
kargs.num_head_q_total = num_head_q_total;
kargs.head_start = head_start;
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
{
@@ -1042,9 +1071,11 @@ struct FmhaFwdKernel
const std::tuple<uint64_t, uint64_t>& drop_seed_offset,
ck_tile::index_t block_scale_size_q,
ck_tile::index_t block_scale_size_kv,
const void* cu_seqlen_q_ptr = nullptr,
const void* cu_seqlen_k_ptr = nullptr,
const void* sink_ptr = nullptr)
const void* cu_seqlen_q_ptr = nullptr,
const void* cu_seqlen_k_ptr = nullptr,
const void* sink_ptr = nullptr,
ck_tile::index_t num_head_q_total = 0,
ck_tile::index_t head_start = 0)
{
return MakeKargsImpl(
q_ptr,
@@ -1100,7 +1131,9 @@ struct FmhaFwdKernel
block_scale_size_kv,
cu_seqlen_q_ptr,
cu_seqlen_k_ptr,
sink_ptr);
sink_ptr,
num_head_q_total,
head_start);
}
// std::variant<> can't take in a list initializer, overload for backward compatibility
@@ -1157,9 +1190,11 @@ struct FmhaFwdKernel
const std::tuple<const void*, const void*>& drop_seed_offset,
ck_tile::index_t block_scale_size_q,
ck_tile::index_t block_scale_size_kv,
const void* cu_seqlen_q_ptr = nullptr,
const void* cu_seqlen_k_ptr = nullptr,
const void* sink_ptr = nullptr)
const void* cu_seqlen_q_ptr = nullptr,
const void* cu_seqlen_k_ptr = nullptr,
const void* sink_ptr = nullptr,
ck_tile::index_t num_head_q_total = 0,
ck_tile::index_t head_start = 0)
{
return MakeKargsImpl(
q_ptr,
@@ -1215,7 +1250,9 @@ struct FmhaFwdKernel
block_scale_size_kv,
cu_seqlen_q_ptr,
cu_seqlen_k_ptr,
sink_ptr);
sink_ptr,
num_head_q_total,
head_start);
}
CK_TILE_HOST static constexpr auto GridSize(ck_tile::index_t batch_size_,
@@ -1250,6 +1287,54 @@ struct FmhaFwdKernel
if constexpr(kIsGroupMode)
has_padded_seqlen_k = (kargs.seqlen_k_ptr != nullptr);
#if CK_TILE_FMHA_FORCE_HEAD_MAJOR
// compiler-workaround gate (ROCm 7.1 + gfx12).
// Keep head-major enabled for all unaffected kernels.
#if defined(__gfx12__) && (HIP_VERSION_MAJOR == 7) && (HIP_VERSION_MINOR == 1)
constexpr bool kSkipHeadMajor = kIsGroupMode && kHasMask && !kHasDropout &&
(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) &&
kPadHeadDimQ && kPadHeadDimV &&
(FmhaPipeline::kN1 == 256) &&
std::is_same_v<QDataType, ck_tile::fp16_t> &&
std::is_same_v<KDataType, ck_tile::fp16_t> &&
std::is_same_v<VDataType, ck_tile::fp16_t>;
#else
constexpr bool kSkipHeadMajor = false;
#endif
if constexpr(!kSkipHeadMajor)
{
// bhsd should satisfy stride_q == hdim_q and nhead_stride_q > hdim_q
// The extra nhead_stride_q guard prevents bshd false-positive when nhead == 1
const bool is_bhsd_layout =
(kargs.stride_q == kargs.hdim_q) && (kargs.nhead_stride_q > kargs.hdim_q);
if(is_bhsd_layout)
{
const index_t num_tile_n1 =
ck_tile::integer_divide_ceil(kargs.hdim_v, FmhaPipeline::kN1);
const index_t num_tile_total = has_padded_seqlen_k ? gridDim.z : gridDim.y;
const index_t num_head = gridDim.x;
const index_t blocks_per_batch = num_head * num_tile_total;
const index_t linear_id =
blockIdx.x + gridDim.x * (blockIdx.y + gridDim.y * blockIdx.z);
const index_t i_batch = linear_id / blocks_per_batch;
const index_t rem0 = linear_id - i_batch * blocks_per_batch;
const index_t i_nhead = rem0 / num_tile_total;
const index_t i_block = rem0 - i_nhead * num_tile_total;
index_t i_tile_m = i_block / num_tile_n1;
index_t i_tile_n = i_block - i_tile_m * num_tile_n1;
if constexpr(kHasMask)
{
const index_t num_tile_m = num_tile_total / num_tile_n1;
i_tile_m = num_tile_m - 1 - i_tile_m;
}
return ck_tile::make_tuple(i_tile_m, i_tile_n, i_nhead, i_batch);
}
}
#endif
if(has_padded_seqlen_k)
{
// const index_t num_tile_m0 = seqlen_q / kM0;
@@ -1271,7 +1356,8 @@ struct FmhaFwdKernel
if constexpr(kHasMask)
{
// assume that num_tile_n1 is always 1
return ck_tile::make_tuple(gridDim.z - 1 - i_tile_m, i_tile_n, i_nhead, i_batch);
return ck_tile::make_tuple(
static_cast<index_t>(gridDim.z) - 1 - i_tile_m, i_tile_n, i_nhead, i_batch);
}
else
{
@@ -1299,7 +1385,8 @@ struct FmhaFwdKernel
if constexpr(kHasMask)
{
// assume that num_tile_n1 is always 1
return ck_tile::make_tuple(gridDim.y - 1 - i_tile_m, i_tile_n, i_nhead, i_batch);
return ck_tile::make_tuple(
static_cast<index_t>(gridDim.y) - 1 - i_tile_m, i_tile_n, i_nhead, i_batch);
}
else
{
@@ -1677,9 +1764,12 @@ struct FmhaFwdKernel
auto dropout = [&, i_nhead_ = i_nhead, i_batch_ = i_batch]() {
if constexpr(kHasDropout)
{
const auto num_head_q_total =
(kargs.num_head_q_total > 0 ? kargs.num_head_q_total : kargs.num_head_q);
const auto i_head_global = kargs.head_start + i_nhead_;
return BlockDropout{i_batch_,
i_nhead_,
kargs.num_head_q,
i_head_global,
num_head_q_total,
kargs.is_drop_seed_offset_from_host ? kargs.drop_seed.val
: *kargs.drop_seed.ptr,
kargs.is_drop_seed_offset_from_host