[rocm-libraries] ROCm/rocm-libraries#5018 (commit b32e7e6)

[CK_TILE] Add LLC-aware FMHA head grouping and head-major
 scheduling on RDNA (#5018)
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

## Motivation
Long-sequence FMHA can become memory-bound when K/V working sets exceed
Infinity Cache (LLC), causing repeated DRAM traffic across heads.

This PR introduces LLC-aware launch ordering improvements for FMHA
forward, and it is currently enabled only on gfx11 and gfx12. The
approach is inspired by
[`Dao-AILab/flash-attention#2217`](https://github.com/Dao-AILab/flash-attention/pull/2217),
adapted to CK’s kernel/runner structure and layout handling.

In this context, `bshd` is the layout used in Flash-Attention, while
`bhsd` is the default layout used by the CK Tile FMHA example.

## Technical Details
This PR adds two complementary strategies:

- For `bshd` input layout (`i_perm/o_perm=0`), enable explicit LLC-aware
head grouping:
  - Estimate LLC size (env override, KFD sysfs, or arch default).
  - Compute group size from K/V bytes per head vs LLC target.
- Launch FMHA forward repeatedly per head-group by slicing Q/K/V/O (and
related tensors).

- For `bhsd` input layout (`i_perm/o_perm=1`), apply implicit
launch-order adjustment:
  - Keep a single kernel launch.
- Reinterpret block linearization in `GetTileIndex` to make execution
head-major,
     improving temporal locality of per-head K/V reuse.

Additional integration updates:
- Propagate `num_head_q_total` and `head_start` through FMHA args/kargs.
- Use global head indexing for dropout RNG stream mapping so grouped
launches keep
    deterministic/consistent dropout behavior.
- Keep fallback behavior unchanged when grouping is not beneficial or
disabled.

## Test Plan
- `test_ck_tile_fmha`
- `tile_example_fmha_fwd`

## Test Result
- `test_ck_tile_fmha`: all tests passed.
- `tile_example_fmha_fwd`: tested this on gfx1100, gfx1151, and gfx1201,
and all of them show higher performance compared to the baseline. The
improvement is consistent, and performance is well maintained even at
long sequence lengths.

./build/bin/tile_example_fmha_fwd -prec=bf16 -mode=0 -b=1 -h=24 -d=128
-s={seqlen} -s_k={seqlen} -lse=0 -iperm={0/1} -operm={0/1}
- TFLOPs by sequence length target: gfx1100 layout: bhsd

SeqLen | Before | After | Speedup
-- | -- | -- | --
1024 | 56.27 | 61.48 | 1.09x
4096 | 67.10 | 72.27 | 1.08x
8192 | 65.99 | 71.64 | 1.09x
12288 | 61.60 | 76.61 | 1.24x
16384 | 58.99 | 75.74 | 1.28x
20480 | 57.32 | 74.42 | 1.30x
24576 | 56.89 | 74.25 | 1.31x
27280 | 18.93 | 24.48 | 1.29x

- TFLOPs by sequence length target: gfx1201 layout: bshd

SeqLen | Before | After | Speedup
-- | -- | -- | --
1024 | 66.79 | 65.90 | 0.99x
4096 | 85.90 | 86.80 | 1.01x
8192 | 77.06 | 90.29 | 1.17x
12288 | 58.36 | 88.98 | 1.52x
16384 | 52.12 | 88.88 | 1.71x
20480 | 48.11 | 88.42 | 1.84x
24576 | 47.12 | 89.07 | 1.89x
27280 | 49.05 | 50.31 | 1.03x

## Submission Checklist

- [x] Look over the contributing guidelines at
https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
This commit is contained in:
Hosang
2026-03-16 21:19:23 +00:00
committed by assistant-librarian[bot]
parent 9c414d2e59
commit 859acb5ae7
5 changed files with 632 additions and 30 deletions

View File

@@ -808,7 +808,7 @@ class CompatibilityRuleFactory:
kernel_ctx.pipeline.F_bias != "no"
or kernel_ctx.pipeline.F_dropout == "t"
):
False
return False
return True
def check_feature(

View File

@@ -299,6 +299,8 @@ struct fmha_fwd_args
ck_tile::index_t hdim_v;
ck_tile::index_t nhead_q;
ck_tile::index_t nhead_k;
ck_tile::index_t num_head_q_total = 0;
ck_tile::index_t head_start = 0;
float scale_s;
float logits_soft_cap;
@@ -733,7 +735,9 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args)
args.block_scale_size_kv,
args.cu_seqlen_q_ptr,
args.cu_seqlen_k_ptr,
args.sink_ptr);
args.sink_ptr,
args.num_head_q_total,
args.head_start);
}
else
{ // create batch mode kernel arguments
@@ -795,7 +799,9 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args)
args.block_scale_size_kv,
args.cu_seqlen_q_ptr,
args.cu_seqlen_k_ptr,
args.sink_ptr);
args.sink_ptr,
args.num_head_q_total,
args.head_start);
}
}();

View File

@@ -0,0 +1,418 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck_tile/host.hpp"
#include <algorithm>
#include <cctype>
#include <cstdio>
#include <dirent.h>
#include <fstream>
#include <iostream>
#include <limits>
#include <optional>
#include <string>
#ifndef CK_TILE_FMHA_ENABLE_HEAD_GROUPING
#define CK_TILE_FMHA_ENABLE_HEAD_GROUPING 1
#endif
#if CK_TILE_FMHA_ENABLE_HEAD_GROUPING
CK_TILE_DECLARE_ENV_VAR_BOOL(CK_TILE_FMHA_HEAD_GROUP_LOG)
CK_TILE_DECLARE_ENV_VAR_BOOL(CK_TILE_FMHA_DISABLE_HEAD_GROUPING)
CK_TILE_DECLARE_ENV_VAR_UINT64(CK_TILE_FMHA_LLC_CACHE_MB)
namespace fmha_fwd_head_grouping {
inline bool log_enabled()
{
return ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_FMHA_HEAD_GROUP_LOG));
}
inline bool disabled_by_env()
{
return ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_FMHA_DISABLE_HEAD_GROUPING));
}
inline bool is_decimal_string(const std::string& s)
{
if(s.empty())
return false;
return std::all_of(s.begin(), s.end(), [](unsigned char c) { return std::isdigit(c) != 0; });
}
inline std::optional<long long> read_property_value(const std::string& filepath,
const std::string& key)
{
std::ifstream fs(filepath);
if(!fs.is_open())
return std::nullopt;
std::string k, v;
while(fs >> k >> v)
{
if(k == key)
{
try
{
return std::stoll(v, nullptr, 0);
}
catch(...)
{
return std::nullopt;
}
}
std::string rest;
std::getline(fs, rest);
}
return std::nullopt;
}
struct kfd_device_location
{
int domain = 0;
int location_id = 0;
};
inline std::optional<kfd_device_location> get_current_kfd_location()
{
int device = 0;
if(hipGetDevice(&device) != hipSuccess)
return std::nullopt;
char bdf[64] = {};
if(hipDeviceGetPCIBusId(bdf, sizeof(bdf), device) == hipSuccess)
{
unsigned int domain = 0, bus = 0, dev = 0, fn = 0;
if(std::sscanf(bdf, "%x:%x:%x.%x", &domain, &bus, &dev, &fn) == 4)
{
return kfd_device_location{
static_cast<int>(domain),
static_cast<int>(((bus & 0xff) << 8) | ((dev & 0x1f) << 3) | (fn & 0x7))};
}
}
hipDeviceProp_t props{};
if(hipGetDeviceProperties(&props, device) != hipSuccess)
return std::nullopt;
return kfd_device_location{props.pciDomainID,
((props.pciBusID & 0xff) << 8) | ((props.pciDeviceID & 0x1f) << 3)};
}
inline std::optional<std::string> find_matching_kfd_node(const kfd_device_location& loc)
{
constexpr const char* kKfdNodesDir = "/sys/class/kfd/kfd/topology/nodes";
DIR* dir = opendir(kKfdNodesDir);
if(dir == nullptr)
return std::nullopt;
std::optional<std::string> matched;
while(auto* ent = readdir(dir))
{
const std::string node_name(ent->d_name);
if(!is_decimal_string(node_name))
continue;
const std::string prop_path = std::string(kKfdNodesDir) + "/" + node_name + "/properties";
const auto location_val = read_property_value(prop_path, "location_id");
if(!location_val.has_value() || static_cast<int>(*location_val) != loc.location_id)
continue;
const auto domain_val = read_property_value(prop_path, "domain");
if(domain_val.has_value() && static_cast<int>(*domain_val) != loc.domain)
continue;
matched = node_name;
break;
}
closedir(dir);
return matched;
}
inline size_t read_kfd_node_l3_bytes(const std::string& node_name)
{
const std::string caches_dir = "/sys/class/kfd/kfd/topology/nodes/" + node_name + "/caches";
DIR* dir = opendir(caches_dir.c_str());
if(dir == nullptr)
return 0;
size_t l3_kb = 0;
while(auto* ent = readdir(dir))
{
const std::string cache_name(ent->d_name);
if(!is_decimal_string(cache_name))
continue;
const std::string prop_path = caches_dir + "/" + cache_name + "/properties";
const auto level_val = read_property_value(prop_path, "level");
if(!level_val.has_value() || *level_val != 3)
continue;
const auto size_val = read_property_value(prop_path, "size");
if(!size_val.has_value() || *size_val <= 0)
continue;
l3_kb = std::max(l3_kb, static_cast<size_t>(*size_val));
}
closedir(dir);
return l3_kb * 1024ull;
}
inline size_t get_kfd_sysfs_llc_cache_bytes()
{
const auto loc = get_current_kfd_location();
if(!loc.has_value())
return 0;
const auto node = find_matching_kfd_node(*loc);
if(!node.has_value())
return 0;
return read_kfd_node_l3_bytes(*node);
}
inline size_t get_default_llc_cache_bytes_for_arch(const std::string& arch);
inline size_t resolve_llc_cache_bytes_uncached(const std::string& arch)
{
// If parsed LLC looks invalidly tiny, ignore it and fallback.
constexpr size_t kMinValidKfdLlcBytes = 32ull * 1024ull;
const size_t kfd_llc_bytes = get_kfd_sysfs_llc_cache_bytes();
if(kfd_llc_bytes >= kMinValidKfdLlcBytes)
return kfd_llc_bytes;
const size_t default_cache_bytes = get_default_llc_cache_bytes_for_arch(arch);
if(default_cache_bytes > 0)
return default_cache_bytes;
// No default configured -> no grouping.
return 0;
}
inline bool ck_tile_is_rdna_arch(const std::string& arch)
{
return arch.rfind("gfx11", 0) == 0 || arch.rfind("gfx12", 0) == 0;
}
inline size_t get_default_llc_cache_bytes_for_arch(const std::string& arch)
{
if(arch == "gfx1100")
return 96ull * 1024ull * 1024ull;
if(arch == "gfx1101")
return 64ull * 1024ull * 1024ull;
if(arch == "gfx1102")
return 32ull * 1024ull * 1024ull;
if(arch == "gfx1151")
return 32ull * 1024ull * 1024ull;
if(arch == "gfx1200")
return 32ull * 1024ull * 1024ull;
if(arch == "gfx1201")
return 64ull * 1024ull * 1024ull;
return 0;
}
inline size_t get_llc_cache_bytes(const std::string& arch)
{
// resolve once and reuse.
static const size_t resolved_llc_bytes = [&]() -> size_t {
const uint64_t llc_mb = ck_tile::EnvValue(CK_TILE_ENV(CK_TILE_FMHA_LLC_CACHE_MB));
if(llc_mb > 0)
{
constexpr uint64_t kBytesPerMb = 1024ull * 1024ull;
const uint64_t max_mb_for_size_t = static_cast<uint64_t>(
std::numeric_limits<size_t>::max() / static_cast<size_t>(kBytesPerMb));
if(llc_mb <= max_mb_for_size_t)
return static_cast<size_t>(llc_mb * kBytesPerMb);
}
return resolve_llc_cache_bytes_uncached(arch);
}();
return resolved_llc_bytes;
}
inline std::optional<ck_tile::index_t> get_head_group_size(ck_tile::index_t nhead_q,
ck_tile::index_t nhead_k,
ck_tile::index_t batch,
ck_tile::index_t seqlen_k,
ck_tile::index_t hdim_q,
ck_tile::index_t hdim_v,
size_t elem_bytes_k,
size_t elem_bytes_v)
{
if(disabled_by_env())
return std::nullopt;
const std::string arch = ck_tile::get_device_name();
if(arch.empty() || !ck_tile_is_rdna_arch(arch))
return std::nullopt;
const size_t llc_bytes = get_llc_cache_bytes(arch);
if(llc_bytes == 0)
return std::nullopt;
if(nhead_k <= 0 || nhead_q <= 0 || (nhead_q % nhead_k) != 0)
return std::nullopt;
if(seqlen_k <= 0 || hdim_q <= 0 || hdim_v <= 0 || batch <= 0)
return std::nullopt;
const size_t kv_bytes_per_head =
static_cast<size_t>(seqlen_k) *
(static_cast<size_t>(hdim_q) * elem_bytes_k + static_cast<size_t>(hdim_v) * elem_bytes_v);
if(kv_bytes_per_head == 0)
return std::nullopt;
// large LLC GPUs (>= 64MB): slightly more cache-resident grouping
constexpr size_t kLargeLlcThresholdBytes = 64ull * 1024ull * 1024ull;
const bool is_large_llc = llc_bytes >= kLargeLlcThresholdBytes;
const long double llc_utilization = is_large_llc ? 0.85L : 1.0L;
const long double threshold_ratio = is_large_llc ? 1.3L : 1.5L;
const size_t target_llc_bytes =
static_cast<size_t>(static_cast<long double>(llc_bytes) * llc_utilization);
if(target_llc_bytes == 0)
return std::nullopt;
const size_t total_kv_bytes = static_cast<size_t>(nhead_q) * kv_bytes_per_head;
if(static_cast<long double>(total_kv_bytes) <
static_cast<long double>(target_llc_bytes) * threshold_ratio)
return std::nullopt;
ck_tile::index_t group = static_cast<ck_tile::index_t>(target_llc_bytes / kv_bytes_per_head);
if(group < 1)
group = 1;
const ck_tile::index_t min_group_size = std::max<ck_tile::index_t>(1, nhead_q / 16);
if(group < min_group_size)
group = min_group_size;
// Cap the number of groups to avoid excessive launch overhead.
constexpr ck_tile::index_t kMaxGroups = 8;
const ck_tile::index_t min_group_for_max_groups =
ck_tile::integer_divide_ceil(nhead_q, kMaxGroups);
if(group < min_group_for_max_groups)
group = min_group_for_max_groups;
const ck_tile::index_t gqa_ratio = nhead_q / nhead_k;
if(gqa_ratio > 1)
{
group = ((group + gqa_ratio - 1) / gqa_ratio) * gqa_ratio;
}
group = std::min(group, nhead_q);
if(group >= nhead_q)
return std::nullopt;
return group;
}
template <typename T>
inline const void* ptr_offset(const void* base, ck_tile::index_t offset_elems)
{
if(base == nullptr)
return nullptr;
return static_cast<const void*>(reinterpret_cast<const T*>(base) + offset_elems);
}
template <typename T>
inline void* ptr_offset(void* base, ck_tile::index_t offset_elems)
{
if(base == nullptr)
return nullptr;
return static_cast<void*>(reinterpret_cast<T*>(base) + offset_elems);
}
template <typename QDataType,
typename KDataType,
typename VDataType,
typename ODataType,
typename BiasDataType,
typename LSEDataType,
typename RandValOutputDataType,
typename FmhaFwdTraits,
typename FmhaFwdArgs,
typename RunKernelFn>
float run_fwd_head_grouped(const ck_tile::stream_config& sc,
const FmhaFwdTraits& fmha_traits,
const FmhaFwdArgs& base_args_in,
ck_tile::index_t nhead,
ck_tile::index_t nhead_k,
ck_tile::index_t group_size_q,
bool use_blockscale_qscale,
RunKernelFn&& run_kernel_fn)
{
auto base_args = base_args_in;
base_args.num_head_q_total = nhead;
const ck_tile::index_t gqa_ratio = (nhead_k > 0 ? (nhead / nhead_k) : 1);
const ck_tile::index_t group_sz = std::min(group_size_q, nhead);
const ck_tile::index_t n_groups = ck_tile::integer_divide_ceil(nhead, group_sz);
float total_time = 0.0f;
for(ck_tile::index_t head_start = 0; head_start < nhead; head_start += group_sz)
{
const ck_tile::index_t q_heads = std::min(group_sz, nhead - head_start);
const ck_tile::index_t k_head_start =
(gqa_ratio >= 1 ? head_start / gqa_ratio : head_start);
const ck_tile::index_t k_heads = (gqa_ratio >= 1 ? q_heads / gqa_ratio : q_heads);
auto args = base_args;
args.nhead_q = q_heads;
args.nhead_k = k_heads;
args.head_start = head_start;
args.q_ptr = ptr_offset<QDataType>(base_args.q_ptr, head_start * base_args.nhead_stride_q);
args.k_ptr =
ptr_offset<KDataType>(base_args.k_ptr, k_head_start * base_args.nhead_stride_k);
args.v_ptr =
ptr_offset<VDataType>(base_args.v_ptr, k_head_start * base_args.nhead_stride_v);
args.o_ptr = ptr_offset<ODataType>(base_args.o_ptr, head_start * base_args.nhead_stride_o);
args.bias_ptr =
ptr_offset<BiasDataType>(base_args.bias_ptr, head_start * base_args.nhead_stride_bias);
args.lse_ptr =
ptr_offset<LSEDataType>(base_args.lse_ptr, head_start * base_args.nhead_stride_lse);
args.rand_val_ptr = ptr_offset<RandValOutputDataType>(
base_args.rand_val_ptr, head_start * base_args.nhead_stride_randval);
if(use_blockscale_qscale)
{
args.q_descale_ptr = ptr_offset<float>(base_args.q_descale_ptr,
head_start * base_args.nhead_stride_q_descale);
args.k_descale_ptr = ptr_offset<float>(base_args.k_descale_ptr,
k_head_start * base_args.nhead_stride_k_descale);
args.v_descale_ptr = ptr_offset<float>(base_args.v_descale_ptr,
k_head_start * base_args.nhead_stride_v_descale);
}
else
{
args.q_descale_ptr = base_args.q_descale_ptr;
args.k_descale_ptr = base_args.k_descale_ptr;
args.v_descale_ptr = base_args.v_descale_ptr;
}
args.sink_ptr = ptr_offset<float>(base_args.sink_ptr, head_start);
if(log_enabled())
{
const ck_tile::index_t head_end = head_start + q_heads;
std::cout << "[LLC Head Grouping] group " << (head_start / group_sz) << "/" << n_groups
<< " heads_q=[" << head_start << ", " << head_end << ") heads_k=["
<< k_head_start << ", " << (k_head_start + k_heads) << ")" << std::endl;
}
const float t = run_kernel_fn(fmha_traits, args, sc);
if(t < 0.0f)
return t;
total_time += t;
}
return total_time;
}
} // namespace fmha_fwd_head_grouping
#endif // CK_TILE_FMHA_ENABLE_HEAD_GROUPING

View File

@@ -6,14 +6,17 @@
#include "ck_tile/host.hpp"
#include "ck_tile/ref/naive_attention.hpp"
#include "fmha_fwd.hpp"
#include "fmha_fwd_head_grouping.hpp"
#include "utils.hpp"
#include "ck_tile/utility/json_dump.hpp"
#include <array>
#include <cstdlib>
#include <cstring>
#include <functional>
#include <cmath>
#include <numeric>
#include <optional>
#include <ostream>
#include <string>
#include <tuple>
@@ -1238,6 +1241,11 @@ fwd_result fmha_fwd_run(mode_enum mode,
args.hdim_v = hdim_v;
args.nhead_q = nhead;
args.nhead_k = nhead_k;
if constexpr(std::is_same_v<fmha_fwd_args, std::decay_t<decltype(args)>>)
{
args.num_head_q_total = nhead;
args.head_start = 0;
}
args.stride_q = stride_q;
args.stride_k = stride_k;
@@ -1555,7 +1563,87 @@ fwd_result fmha_fwd_run(mode_enum mode,
return fmha_fwd(fmha_traits, fmha_args, sc);
};
const float fwd_ave_time = run_fwd(stream_config);
float fwd_ave_time = -1.0f;
#if CK_TILE_FMHA_ENABLE_HEAD_GROUPING
const bool allow_head_grouping = !i_perm && !use_kvcache && (num_splits <= 1) &&
!need_append_kvcache &&
(mode == mode_enum::batch || mode == mode_enum::group);
if(allow_head_grouping)
{
if(fmha_fwd_head_grouping::disabled_by_env())
{
if(fmha_fwd_head_grouping::log_enabled())
std::cout << "[LLC Head Grouping] disabled by env" << std::endl;
}
else
{
const auto group_size_opt =
fmha_fwd_head_grouping::get_head_group_size(nhead,
nhead_k,
batch,
max_seqlen_k,
hdim_q,
hdim_v,
sizeof(KDataType),
sizeof(VDataType));
if(group_size_opt.has_value() && group_size_opt.value() < nhead)
{
if(fmha_fwd_head_grouping::log_enabled())
{
const std::string arch = ck_tile::get_device_name();
const size_t llc_bytes = fmha_fwd_head_grouping::get_llc_cache_bytes(arch);
const ck_tile::index_t gqa_ratio = (nhead_k > 0 ? (nhead / nhead_k) : 1);
const ck_tile::index_t group_sz = group_size_opt.value();
const ck_tile::index_t n_groups = ck_tile::integer_divide_ceil(nhead, group_sz);
std::cout << "[LLC Head Grouping] enabled" << std::endl;
std::cout << "[LLC Head Grouping] arch=" << (arch.empty() ? "unknown" : arch)
<< " llc_mb=" << (llc_bytes / (1024ull * 1024ull))
<< " nhead_q=" << nhead << " nhead_k=" << nhead_k
<< " gqa_ratio=" << gqa_ratio << " group_size=" << group_sz
<< " groups=" << n_groups << std::endl;
}
fmha_fwd_traits fmha_traits;
init_traits(fmha_traits);
fmha_fwd_args fmha_args;
init_args(fmha_args);
fwd_ave_time = fmha_fwd_head_grouping::run_fwd_head_grouped<QDataType,
KDataType,
VDataType,
ODataType,
BiasDataType,
LSEDataType,
RandValOutputDataType>(
stream_config,
fmha_traits,
fmha_args,
nhead,
nhead_k,
group_size_opt.value(),
qscale.type == quant_scale_enum::blockscale,
[&](const auto& traits, auto& args, const auto& sc) {
return fmha_fwd(traits, args, sc);
});
}
else if(fmha_fwd_head_grouping::log_enabled())
{
std::cout << "[LLC Head Grouping] skipped (group_size not set or >= nhead)"
<< std::endl;
}
}
}
else if(fmha_fwd_head_grouping::log_enabled())
{
std::cout << "[LLC Head Grouping] disabled by conditions/layout" << std::endl;
}
#endif
if(fwd_ave_time < 0.0f)
fwd_ave_time = run_fwd(stream_config);
if(fwd_ave_time < 0.0f)
{
std::cout << ", not supported yet" << std::flush << std::endl;