mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 22:39:03 +00:00
[rocm-libraries] ROCm/rocm-libraries#5018 (commit b32e7e6)
[CK_TILE] Add LLC-aware FMHA head grouping and head-major scheduling on RDNA (#5018) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Motivation Long-sequence FMHA can become memory-bound when K/V working sets exceed Infinity Cache (LLC), causing repeated DRAM traffic across heads. This PR introduces LLC-aware launch ordering improvements for FMHA forward, and it is currently enabled only on gfx11 and gfx12. The approach is inspired by [`Dao-AILab/flash-attention#2217`](https://github.com/Dao-AILab/flash-attention/pull/2217), adapted to CK’s kernel/runner structure and layout handling. In this context, `bshd` is the layout used in Flash-Attention, while `bhsd` is the default layout used by the CK Tile FMHA example. ## Technical Details This PR adds two complementary strategies: - For `bshd` input layout (`i_perm/o_perm=0`), enable explicit LLC-aware head grouping: - Estimate LLC size (env override, KFD sysfs, or arch default). - Compute group size from K/V bytes per head vs LLC target. - Launch FMHA forward repeatedly per head-group by slicing Q/K/V/O (and related tensors). - For `bhsd` input layout (`i_perm/o_perm=1`), apply implicit launch-order adjustment: - Keep a single kernel launch. - Reinterpret block linearization in `GetTileIndex` to make execution head-major, improving temporal locality of per-head K/V reuse. Additional integration updates: - Propagate `num_head_q_total` and `head_start` through FMHA args/kargs. - Use global head indexing for dropout RNG stream mapping so grouped launches keep deterministic/consistent dropout behavior. - Keep fallback behavior unchanged when grouping is not beneficial or disabled. ## Test Plan - `test_ck_tile_fmha` - `tile_example_fmha_fwd` ## Test Result - `test_ck_tile_fmha`: all tests passed. - `tile_example_fmha_fwd`: tested this on gfx1100, gfx1151, and gfx1201, and all of them show higher performance compared to the baseline. The improvement is consistent, and performance is well maintained even at long sequence lengths. ./build/bin/tile_example_fmha_fwd -prec=bf16 -mode=0 -b=1 -h=24 -d=128 -s={seqlen} -s_k={seqlen} -lse=0 -iperm={0/1} -operm={0/1} - TFLOPs by sequence length target: gfx1100 layout: bhsd SeqLen | Before | After | Speedup -- | -- | -- | -- 1024 | 56.27 | 61.48 | 1.09x 4096 | 67.10 | 72.27 | 1.08x 8192 | 65.99 | 71.64 | 1.09x 12288 | 61.60 | 76.61 | 1.24x 16384 | 58.99 | 75.74 | 1.28x 20480 | 57.32 | 74.42 | 1.30x 24576 | 56.89 | 74.25 | 1.31x 27280 | 18.93 | 24.48 | 1.29x - TFLOPs by sequence length target: gfx1201 layout: bshd SeqLen | Before | After | Speedup -- | -- | -- | -- 1024 | 66.79 | 65.90 | 0.99x 4096 | 85.90 | 86.80 | 1.01x 8192 | 77.06 | 90.29 | 1.17x 12288 | 58.36 | 88.98 | 1.52x 16384 | 52.12 | 88.88 | 1.71x 20480 | 48.11 | 88.42 | 1.84x 24576 | 47.12 | 89.07 | 1.89x 27280 | 49.05 | 50.31 | 1.03x ## Submission Checklist - [x] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
This commit is contained in:
committed by
assistant-librarian[bot]
parent
9c414d2e59
commit
859acb5ae7
@@ -15,6 +15,15 @@
|
||||
#include <variant>
|
||||
|
||||
#define CK_TILE_FMHA_HANDLE_XOR_LENGTH_FOLD 0
|
||||
|
||||
#if !defined(CK_TILE_FMHA_FORCE_HEAD_MAJOR)
|
||||
#if defined(__HIP_DEVICE_COMPILE__) && (defined(__gfx11__) || defined(__gfx12__))
|
||||
#define CK_TILE_FMHA_FORCE_HEAD_MAJOR 1
|
||||
#else
|
||||
#define CK_TILE_FMHA_FORCE_HEAD_MAJOR 0
|
||||
#endif
|
||||
#endif
|
||||
|
||||
// S[seqlen_q, seqlen_k] = Q[seqlen_q, hdim_q] @ K[seqlen_k, hdim_q]
|
||||
// S'[seqlen_q, seqlen_k] = S[seqlen_q, seqlen_k] * Scale[1]
|
||||
// S''[seqlen_q, seqlen_k] = S'[seqlen_q, seqlen_k] + Bias[seqlen_q, seqlen_k]
|
||||
@@ -111,6 +120,10 @@ struct FmhaFwdKernel
|
||||
ck_tile::index_t nhead_stride_k;
|
||||
ck_tile::index_t nhead_stride_v;
|
||||
ck_tile::index_t nhead_stride_o;
|
||||
|
||||
// Optional global head count and head offset (for grouped launches & RNG correctness)
|
||||
ck_tile::index_t num_head_q_total = 0;
|
||||
ck_tile::index_t head_start = 0;
|
||||
};
|
||||
|
||||
struct FmhaFwdLogitsSoftCapKargs
|
||||
@@ -410,9 +423,11 @@ struct FmhaFwdKernel
|
||||
drop_seed_offset,
|
||||
ck_tile::index_t block_scale_size_q,
|
||||
ck_tile::index_t block_scale_size_kv,
|
||||
const void* cu_seqlen_q_ptr = nullptr,
|
||||
const void* cu_seqlen_k_ptr = nullptr,
|
||||
const void* sink_ptr = nullptr)
|
||||
const void* cu_seqlen_q_ptr = nullptr,
|
||||
const void* cu_seqlen_k_ptr = nullptr,
|
||||
const void* sink_ptr = nullptr,
|
||||
ck_tile::index_t num_head_q_total = 0,
|
||||
ck_tile::index_t head_start = 0)
|
||||
{
|
||||
Kargs kargs{{q_ptr,
|
||||
k_ptr,
|
||||
@@ -448,6 +463,8 @@ struct FmhaFwdKernel
|
||||
batch_stride_k,
|
||||
batch_stride_v,
|
||||
batch_stride_o};
|
||||
kargs.num_head_q_total = num_head_q_total;
|
||||
kargs.head_start = head_start;
|
||||
|
||||
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
|
||||
{
|
||||
@@ -605,9 +622,11 @@ struct FmhaFwdKernel
|
||||
const std::tuple<uint64_t, uint64_t>& drop_seed_offset,
|
||||
ck_tile::index_t block_scale_size_q,
|
||||
ck_tile::index_t block_scale_size_kv,
|
||||
const void* cu_seqlen_q_ptr = nullptr,
|
||||
const void* cu_seqlen_k_ptr = nullptr,
|
||||
const void* sink_ptr = nullptr)
|
||||
const void* cu_seqlen_q_ptr = nullptr,
|
||||
const void* cu_seqlen_k_ptr = nullptr,
|
||||
const void* sink_ptr = nullptr,
|
||||
ck_tile::index_t num_head_q_total = 0,
|
||||
ck_tile::index_t head_start = 0)
|
||||
{
|
||||
return MakeKargsImpl(
|
||||
q_ptr,
|
||||
@@ -668,7 +687,9 @@ struct FmhaFwdKernel
|
||||
block_scale_size_kv,
|
||||
cu_seqlen_q_ptr,
|
||||
cu_seqlen_k_ptr,
|
||||
sink_ptr);
|
||||
sink_ptr,
|
||||
num_head_q_total,
|
||||
head_start);
|
||||
}
|
||||
|
||||
// std::variant<> can't take in a list initializer, overload for backward compatibility
|
||||
@@ -730,9 +751,11 @@ struct FmhaFwdKernel
|
||||
const std::tuple<const void*, const void*>& drop_seed_offset,
|
||||
ck_tile::index_t block_scale_size_q,
|
||||
ck_tile::index_t block_scale_size_kv,
|
||||
const void* cu_seqlen_q_ptr = nullptr,
|
||||
const void* cu_seqlen_k_ptr = nullptr,
|
||||
const void* sink_ptr = nullptr)
|
||||
const void* cu_seqlen_q_ptr = nullptr,
|
||||
const void* cu_seqlen_k_ptr = nullptr,
|
||||
const void* sink_ptr = nullptr,
|
||||
ck_tile::index_t num_head_q_total = 0,
|
||||
ck_tile::index_t head_start = 0)
|
||||
{
|
||||
return MakeKargsImpl(
|
||||
q_ptr,
|
||||
@@ -793,7 +816,9 @@ struct FmhaFwdKernel
|
||||
block_scale_size_kv,
|
||||
cu_seqlen_q_ptr,
|
||||
cu_seqlen_k_ptr,
|
||||
sink_ptr);
|
||||
sink_ptr,
|
||||
num_head_q_total,
|
||||
head_start);
|
||||
}
|
||||
|
||||
template <bool Cond = kIsGroupMode>
|
||||
@@ -851,9 +876,11 @@ struct FmhaFwdKernel
|
||||
drop_seed_offset,
|
||||
ck_tile::index_t block_scale_size_q,
|
||||
ck_tile::index_t block_scale_size_kv,
|
||||
const void* cu_seqlen_q_ptr = nullptr,
|
||||
const void* cu_seqlen_k_ptr = nullptr,
|
||||
const void* sink_ptr = nullptr)
|
||||
const void* cu_seqlen_q_ptr = nullptr,
|
||||
const void* cu_seqlen_k_ptr = nullptr,
|
||||
const void* sink_ptr = nullptr,
|
||||
ck_tile::index_t num_head_q_total = 0,
|
||||
ck_tile::index_t head_start = 0)
|
||||
{
|
||||
Kargs kargs{{q_ptr,
|
||||
k_ptr,
|
||||
@@ -890,6 +917,8 @@ struct FmhaFwdKernel
|
||||
reinterpret_cast<const int32_t*>(seqstart_k_ptr),
|
||||
reinterpret_cast<const int32_t*>(seqlen_q_ptr),
|
||||
reinterpret_cast<const int32_t*>(seqlen_k_ptr)};
|
||||
kargs.num_head_q_total = num_head_q_total;
|
||||
kargs.head_start = head_start;
|
||||
|
||||
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
|
||||
{
|
||||
@@ -1042,9 +1071,11 @@ struct FmhaFwdKernel
|
||||
const std::tuple<uint64_t, uint64_t>& drop_seed_offset,
|
||||
ck_tile::index_t block_scale_size_q,
|
||||
ck_tile::index_t block_scale_size_kv,
|
||||
const void* cu_seqlen_q_ptr = nullptr,
|
||||
const void* cu_seqlen_k_ptr = nullptr,
|
||||
const void* sink_ptr = nullptr)
|
||||
const void* cu_seqlen_q_ptr = nullptr,
|
||||
const void* cu_seqlen_k_ptr = nullptr,
|
||||
const void* sink_ptr = nullptr,
|
||||
ck_tile::index_t num_head_q_total = 0,
|
||||
ck_tile::index_t head_start = 0)
|
||||
{
|
||||
return MakeKargsImpl(
|
||||
q_ptr,
|
||||
@@ -1100,7 +1131,9 @@ struct FmhaFwdKernel
|
||||
block_scale_size_kv,
|
||||
cu_seqlen_q_ptr,
|
||||
cu_seqlen_k_ptr,
|
||||
sink_ptr);
|
||||
sink_ptr,
|
||||
num_head_q_total,
|
||||
head_start);
|
||||
}
|
||||
|
||||
// std::variant<> can't take in a list initializer, overload for backward compatibility
|
||||
@@ -1157,9 +1190,11 @@ struct FmhaFwdKernel
|
||||
const std::tuple<const void*, const void*>& drop_seed_offset,
|
||||
ck_tile::index_t block_scale_size_q,
|
||||
ck_tile::index_t block_scale_size_kv,
|
||||
const void* cu_seqlen_q_ptr = nullptr,
|
||||
const void* cu_seqlen_k_ptr = nullptr,
|
||||
const void* sink_ptr = nullptr)
|
||||
const void* cu_seqlen_q_ptr = nullptr,
|
||||
const void* cu_seqlen_k_ptr = nullptr,
|
||||
const void* sink_ptr = nullptr,
|
||||
ck_tile::index_t num_head_q_total = 0,
|
||||
ck_tile::index_t head_start = 0)
|
||||
{
|
||||
return MakeKargsImpl(
|
||||
q_ptr,
|
||||
@@ -1215,7 +1250,9 @@ struct FmhaFwdKernel
|
||||
block_scale_size_kv,
|
||||
cu_seqlen_q_ptr,
|
||||
cu_seqlen_k_ptr,
|
||||
sink_ptr);
|
||||
sink_ptr,
|
||||
num_head_q_total,
|
||||
head_start);
|
||||
}
|
||||
|
||||
CK_TILE_HOST static constexpr auto GridSize(ck_tile::index_t batch_size_,
|
||||
@@ -1250,6 +1287,54 @@ struct FmhaFwdKernel
|
||||
if constexpr(kIsGroupMode)
|
||||
has_padded_seqlen_k = (kargs.seqlen_k_ptr != nullptr);
|
||||
|
||||
#if CK_TILE_FMHA_FORCE_HEAD_MAJOR
|
||||
// compiler-workaround gate (ROCm 7.1 + gfx12).
|
||||
// Keep head-major enabled for all unaffected kernels.
|
||||
#if defined(__gfx12__) && (HIP_VERSION_MAJOR == 7) && (HIP_VERSION_MINOR == 1)
|
||||
constexpr bool kSkipHeadMajor = kIsGroupMode && kHasMask && !kHasDropout &&
|
||||
(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) &&
|
||||
kPadHeadDimQ && kPadHeadDimV &&
|
||||
(FmhaPipeline::kN1 == 256) &&
|
||||
std::is_same_v<QDataType, ck_tile::fp16_t> &&
|
||||
std::is_same_v<KDataType, ck_tile::fp16_t> &&
|
||||
std::is_same_v<VDataType, ck_tile::fp16_t>;
|
||||
#else
|
||||
constexpr bool kSkipHeadMajor = false;
|
||||
#endif
|
||||
if constexpr(!kSkipHeadMajor)
|
||||
{
|
||||
// bhsd should satisfy stride_q == hdim_q and nhead_stride_q > hdim_q
|
||||
// The extra nhead_stride_q guard prevents bshd false-positive when nhead == 1
|
||||
const bool is_bhsd_layout =
|
||||
(kargs.stride_q == kargs.hdim_q) && (kargs.nhead_stride_q > kargs.hdim_q);
|
||||
if(is_bhsd_layout)
|
||||
{
|
||||
const index_t num_tile_n1 =
|
||||
ck_tile::integer_divide_ceil(kargs.hdim_v, FmhaPipeline::kN1);
|
||||
const index_t num_tile_total = has_padded_seqlen_k ? gridDim.z : gridDim.y;
|
||||
const index_t num_head = gridDim.x;
|
||||
const index_t blocks_per_batch = num_head * num_tile_total;
|
||||
const index_t linear_id =
|
||||
blockIdx.x + gridDim.x * (blockIdx.y + gridDim.y * blockIdx.z);
|
||||
|
||||
const index_t i_batch = linear_id / blocks_per_batch;
|
||||
const index_t rem0 = linear_id - i_batch * blocks_per_batch;
|
||||
const index_t i_nhead = rem0 / num_tile_total;
|
||||
const index_t i_block = rem0 - i_nhead * num_tile_total;
|
||||
|
||||
index_t i_tile_m = i_block / num_tile_n1;
|
||||
index_t i_tile_n = i_block - i_tile_m * num_tile_n1;
|
||||
|
||||
if constexpr(kHasMask)
|
||||
{
|
||||
const index_t num_tile_m = num_tile_total / num_tile_n1;
|
||||
i_tile_m = num_tile_m - 1 - i_tile_m;
|
||||
}
|
||||
return ck_tile::make_tuple(i_tile_m, i_tile_n, i_nhead, i_batch);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
if(has_padded_seqlen_k)
|
||||
{
|
||||
// const index_t num_tile_m0 = seqlen_q / kM0;
|
||||
@@ -1271,7 +1356,8 @@ struct FmhaFwdKernel
|
||||
if constexpr(kHasMask)
|
||||
{
|
||||
// assume that num_tile_n1 is always 1
|
||||
return ck_tile::make_tuple(gridDim.z - 1 - i_tile_m, i_tile_n, i_nhead, i_batch);
|
||||
return ck_tile::make_tuple(
|
||||
static_cast<index_t>(gridDim.z) - 1 - i_tile_m, i_tile_n, i_nhead, i_batch);
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -1299,7 +1385,8 @@ struct FmhaFwdKernel
|
||||
if constexpr(kHasMask)
|
||||
{
|
||||
// assume that num_tile_n1 is always 1
|
||||
return ck_tile::make_tuple(gridDim.y - 1 - i_tile_m, i_tile_n, i_nhead, i_batch);
|
||||
return ck_tile::make_tuple(
|
||||
static_cast<index_t>(gridDim.y) - 1 - i_tile_m, i_tile_n, i_nhead, i_batch);
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -1677,9 +1764,12 @@ struct FmhaFwdKernel
|
||||
auto dropout = [&, i_nhead_ = i_nhead, i_batch_ = i_batch]() {
|
||||
if constexpr(kHasDropout)
|
||||
{
|
||||
const auto num_head_q_total =
|
||||
(kargs.num_head_q_total > 0 ? kargs.num_head_q_total : kargs.num_head_q);
|
||||
const auto i_head_global = kargs.head_start + i_nhead_;
|
||||
return BlockDropout{i_batch_,
|
||||
i_nhead_,
|
||||
kargs.num_head_q,
|
||||
i_head_global,
|
||||
num_head_q_total,
|
||||
kargs.is_drop_seed_offset_from_host ? kargs.drop_seed.val
|
||||
: *kargs.drop_seed.ptr,
|
||||
kargs.is_drop_seed_offset_from_host
|
||||
|
||||
Reference in New Issue
Block a user