diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py index 18e0022cf5..1849068161 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py @@ -808,7 +808,7 @@ class CompatibilityRuleFactory: kernel_ctx.pipeline.F_bias != "no" or kernel_ctx.pipeline.F_dropout == "t" ): - False + return False return True def check_feature( diff --git a/example/ck_tile/01_fmha/fmha_fwd.hpp b/example/ck_tile/01_fmha/fmha_fwd.hpp index 4adb159b31..521f1e4738 100644 --- a/example/ck_tile/01_fmha/fmha_fwd.hpp +++ b/example/ck_tile/01_fmha/fmha_fwd.hpp @@ -299,6 +299,8 @@ struct fmha_fwd_args ck_tile::index_t hdim_v; ck_tile::index_t nhead_q; ck_tile::index_t nhead_k; + ck_tile::index_t num_head_q_total = 0; + ck_tile::index_t head_start = 0; float scale_s; float logits_soft_cap; @@ -733,7 +735,9 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args) args.block_scale_size_kv, args.cu_seqlen_q_ptr, args.cu_seqlen_k_ptr, - args.sink_ptr); + args.sink_ptr, + args.num_head_q_total, + args.head_start); } else { // create batch mode kernel arguments @@ -795,7 +799,9 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args) args.block_scale_size_kv, args.cu_seqlen_q_ptr, args.cu_seqlen_k_ptr, - args.sink_ptr); + args.sink_ptr, + args.num_head_q_total, + args.head_start); } }(); diff --git a/example/ck_tile/01_fmha/fmha_fwd_head_grouping.hpp b/example/ck_tile/01_fmha/fmha_fwd_head_grouping.hpp new file mode 100644 index 0000000000..9cd1fb9cdc --- /dev/null +++ b/example/ck_tile/01_fmha/fmha_fwd_head_grouping.hpp @@ -0,0 +1,418 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "ck_tile/host.hpp" + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#ifndef CK_TILE_FMHA_ENABLE_HEAD_GROUPING +#define CK_TILE_FMHA_ENABLE_HEAD_GROUPING 1 +#endif + +#if CK_TILE_FMHA_ENABLE_HEAD_GROUPING +CK_TILE_DECLARE_ENV_VAR_BOOL(CK_TILE_FMHA_HEAD_GROUP_LOG) +CK_TILE_DECLARE_ENV_VAR_BOOL(CK_TILE_FMHA_DISABLE_HEAD_GROUPING) +CK_TILE_DECLARE_ENV_VAR_UINT64(CK_TILE_FMHA_LLC_CACHE_MB) + +namespace fmha_fwd_head_grouping { + +inline bool log_enabled() +{ + return ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_FMHA_HEAD_GROUP_LOG)); +} + +inline bool disabled_by_env() +{ + return ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_FMHA_DISABLE_HEAD_GROUPING)); +} + +inline bool is_decimal_string(const std::string& s) +{ + if(s.empty()) + return false; + return std::all_of(s.begin(), s.end(), [](unsigned char c) { return std::isdigit(c) != 0; }); +} + +inline std::optional read_property_value(const std::string& filepath, + const std::string& key) +{ + std::ifstream fs(filepath); + if(!fs.is_open()) + return std::nullopt; + + std::string k, v; + while(fs >> k >> v) + { + if(k == key) + { + try + { + return std::stoll(v, nullptr, 0); + } + catch(...) + { + return std::nullopt; + } + } + std::string rest; + std::getline(fs, rest); + } + return std::nullopt; +} + +struct kfd_device_location +{ + int domain = 0; + int location_id = 0; +}; + +inline std::optional get_current_kfd_location() +{ + int device = 0; + if(hipGetDevice(&device) != hipSuccess) + return std::nullopt; + + char bdf[64] = {}; + if(hipDeviceGetPCIBusId(bdf, sizeof(bdf), device) == hipSuccess) + { + unsigned int domain = 0, bus = 0, dev = 0, fn = 0; + if(std::sscanf(bdf, "%x:%x:%x.%x", &domain, &bus, &dev, &fn) == 4) + { + return kfd_device_location{ + static_cast(domain), + static_cast(((bus & 0xff) << 8) | ((dev & 0x1f) << 3) | (fn & 0x7))}; + } + } + + hipDeviceProp_t props{}; + if(hipGetDeviceProperties(&props, device) != hipSuccess) + return std::nullopt; + + return kfd_device_location{props.pciDomainID, + ((props.pciBusID & 0xff) << 8) | ((props.pciDeviceID & 0x1f) << 3)}; +} + +inline std::optional find_matching_kfd_node(const kfd_device_location& loc) +{ + constexpr const char* kKfdNodesDir = "/sys/class/kfd/kfd/topology/nodes"; + DIR* dir = opendir(kKfdNodesDir); + if(dir == nullptr) + return std::nullopt; + + std::optional matched; + while(auto* ent = readdir(dir)) + { + const std::string node_name(ent->d_name); + if(!is_decimal_string(node_name)) + continue; + + const std::string prop_path = std::string(kKfdNodesDir) + "/" + node_name + "/properties"; + const auto location_val = read_property_value(prop_path, "location_id"); + if(!location_val.has_value() || static_cast(*location_val) != loc.location_id) + continue; + + const auto domain_val = read_property_value(prop_path, "domain"); + if(domain_val.has_value() && static_cast(*domain_val) != loc.domain) + continue; + + matched = node_name; + break; + } + + closedir(dir); + return matched; +} + +inline size_t read_kfd_node_l3_bytes(const std::string& node_name) +{ + const std::string caches_dir = "/sys/class/kfd/kfd/topology/nodes/" + node_name + "/caches"; + DIR* dir = opendir(caches_dir.c_str()); + if(dir == nullptr) + return 0; + + size_t l3_kb = 0; + while(auto* ent = readdir(dir)) + { + const std::string cache_name(ent->d_name); + if(!is_decimal_string(cache_name)) + continue; + + const std::string prop_path = caches_dir + "/" + cache_name + "/properties"; + const auto level_val = read_property_value(prop_path, "level"); + if(!level_val.has_value() || *level_val != 3) + continue; + + const auto size_val = read_property_value(prop_path, "size"); + if(!size_val.has_value() || *size_val <= 0) + continue; + + l3_kb = std::max(l3_kb, static_cast(*size_val)); + } + + closedir(dir); + return l3_kb * 1024ull; +} + +inline size_t get_kfd_sysfs_llc_cache_bytes() +{ + const auto loc = get_current_kfd_location(); + if(!loc.has_value()) + return 0; + + const auto node = find_matching_kfd_node(*loc); + if(!node.has_value()) + return 0; + + return read_kfd_node_l3_bytes(*node); +} + +inline size_t get_default_llc_cache_bytes_for_arch(const std::string& arch); + +inline size_t resolve_llc_cache_bytes_uncached(const std::string& arch) +{ + // If parsed LLC looks invalidly tiny, ignore it and fallback. + constexpr size_t kMinValidKfdLlcBytes = 32ull * 1024ull; + + const size_t kfd_llc_bytes = get_kfd_sysfs_llc_cache_bytes(); + if(kfd_llc_bytes >= kMinValidKfdLlcBytes) + return kfd_llc_bytes; + + const size_t default_cache_bytes = get_default_llc_cache_bytes_for_arch(arch); + if(default_cache_bytes > 0) + return default_cache_bytes; + + // No default configured -> no grouping. + return 0; +} + +inline bool ck_tile_is_rdna_arch(const std::string& arch) +{ + return arch.rfind("gfx11", 0) == 0 || arch.rfind("gfx12", 0) == 0; +} + +inline size_t get_default_llc_cache_bytes_for_arch(const std::string& arch) +{ + if(arch == "gfx1100") + return 96ull * 1024ull * 1024ull; + if(arch == "gfx1101") + return 64ull * 1024ull * 1024ull; + if(arch == "gfx1102") + return 32ull * 1024ull * 1024ull; + if(arch == "gfx1151") + return 32ull * 1024ull * 1024ull; + if(arch == "gfx1200") + return 32ull * 1024ull * 1024ull; + if(arch == "gfx1201") + return 64ull * 1024ull * 1024ull; + return 0; +} + +inline size_t get_llc_cache_bytes(const std::string& arch) +{ + // resolve once and reuse. + static const size_t resolved_llc_bytes = [&]() -> size_t { + const uint64_t llc_mb = ck_tile::EnvValue(CK_TILE_ENV(CK_TILE_FMHA_LLC_CACHE_MB)); + if(llc_mb > 0) + { + constexpr uint64_t kBytesPerMb = 1024ull * 1024ull; + const uint64_t max_mb_for_size_t = static_cast( + std::numeric_limits::max() / static_cast(kBytesPerMb)); + + if(llc_mb <= max_mb_for_size_t) + return static_cast(llc_mb * kBytesPerMb); + } + + return resolve_llc_cache_bytes_uncached(arch); + }(); + + return resolved_llc_bytes; +} + +inline std::optional get_head_group_size(ck_tile::index_t nhead_q, + ck_tile::index_t nhead_k, + ck_tile::index_t batch, + ck_tile::index_t seqlen_k, + ck_tile::index_t hdim_q, + ck_tile::index_t hdim_v, + size_t elem_bytes_k, + size_t elem_bytes_v) +{ + if(disabled_by_env()) + return std::nullopt; + + const std::string arch = ck_tile::get_device_name(); + if(arch.empty() || !ck_tile_is_rdna_arch(arch)) + return std::nullopt; + + const size_t llc_bytes = get_llc_cache_bytes(arch); + if(llc_bytes == 0) + return std::nullopt; + + if(nhead_k <= 0 || nhead_q <= 0 || (nhead_q % nhead_k) != 0) + return std::nullopt; + if(seqlen_k <= 0 || hdim_q <= 0 || hdim_v <= 0 || batch <= 0) + return std::nullopt; + + const size_t kv_bytes_per_head = + static_cast(seqlen_k) * + (static_cast(hdim_q) * elem_bytes_k + static_cast(hdim_v) * elem_bytes_v); + if(kv_bytes_per_head == 0) + return std::nullopt; + + // large LLC GPUs (>= 64MB): slightly more cache-resident grouping + constexpr size_t kLargeLlcThresholdBytes = 64ull * 1024ull * 1024ull; + const bool is_large_llc = llc_bytes >= kLargeLlcThresholdBytes; + const long double llc_utilization = is_large_llc ? 0.85L : 1.0L; + const long double threshold_ratio = is_large_llc ? 1.3L : 1.5L; + const size_t target_llc_bytes = + static_cast(static_cast(llc_bytes) * llc_utilization); + if(target_llc_bytes == 0) + return std::nullopt; + + const size_t total_kv_bytes = static_cast(nhead_q) * kv_bytes_per_head; + if(static_cast(total_kv_bytes) < + static_cast(target_llc_bytes) * threshold_ratio) + return std::nullopt; + + ck_tile::index_t group = static_cast(target_llc_bytes / kv_bytes_per_head); + if(group < 1) + group = 1; + + const ck_tile::index_t min_group_size = std::max(1, nhead_q / 16); + if(group < min_group_size) + group = min_group_size; + + // Cap the number of groups to avoid excessive launch overhead. + constexpr ck_tile::index_t kMaxGroups = 8; + const ck_tile::index_t min_group_for_max_groups = + ck_tile::integer_divide_ceil(nhead_q, kMaxGroups); + if(group < min_group_for_max_groups) + group = min_group_for_max_groups; + + const ck_tile::index_t gqa_ratio = nhead_q / nhead_k; + if(gqa_ratio > 1) + { + group = ((group + gqa_ratio - 1) / gqa_ratio) * gqa_ratio; + } + + group = std::min(group, nhead_q); + if(group >= nhead_q) + return std::nullopt; + + return group; +} + +template +inline const void* ptr_offset(const void* base, ck_tile::index_t offset_elems) +{ + if(base == nullptr) + return nullptr; + return static_cast(reinterpret_cast(base) + offset_elems); +} + +template +inline void* ptr_offset(void* base, ck_tile::index_t offset_elems) +{ + if(base == nullptr) + return nullptr; + return static_cast(reinterpret_cast(base) + offset_elems); +} + +template +float run_fwd_head_grouped(const ck_tile::stream_config& sc, + const FmhaFwdTraits& fmha_traits, + const FmhaFwdArgs& base_args_in, + ck_tile::index_t nhead, + ck_tile::index_t nhead_k, + ck_tile::index_t group_size_q, + bool use_blockscale_qscale, + RunKernelFn&& run_kernel_fn) +{ + auto base_args = base_args_in; + base_args.num_head_q_total = nhead; + const ck_tile::index_t gqa_ratio = (nhead_k > 0 ? (nhead / nhead_k) : 1); + const ck_tile::index_t group_sz = std::min(group_size_q, nhead); + const ck_tile::index_t n_groups = ck_tile::integer_divide_ceil(nhead, group_sz); + + float total_time = 0.0f; + for(ck_tile::index_t head_start = 0; head_start < nhead; head_start += group_sz) + { + const ck_tile::index_t q_heads = std::min(group_sz, nhead - head_start); + const ck_tile::index_t k_head_start = + (gqa_ratio >= 1 ? head_start / gqa_ratio : head_start); + const ck_tile::index_t k_heads = (gqa_ratio >= 1 ? q_heads / gqa_ratio : q_heads); + + auto args = base_args; + args.nhead_q = q_heads; + args.nhead_k = k_heads; + args.head_start = head_start; + + args.q_ptr = ptr_offset(base_args.q_ptr, head_start * base_args.nhead_stride_q); + args.k_ptr = + ptr_offset(base_args.k_ptr, k_head_start * base_args.nhead_stride_k); + args.v_ptr = + ptr_offset(base_args.v_ptr, k_head_start * base_args.nhead_stride_v); + args.o_ptr = ptr_offset(base_args.o_ptr, head_start * base_args.nhead_stride_o); + + args.bias_ptr = + ptr_offset(base_args.bias_ptr, head_start * base_args.nhead_stride_bias); + args.lse_ptr = + ptr_offset(base_args.lse_ptr, head_start * base_args.nhead_stride_lse); + args.rand_val_ptr = ptr_offset( + base_args.rand_val_ptr, head_start * base_args.nhead_stride_randval); + + if(use_blockscale_qscale) + { + args.q_descale_ptr = ptr_offset(base_args.q_descale_ptr, + head_start * base_args.nhead_stride_q_descale); + args.k_descale_ptr = ptr_offset(base_args.k_descale_ptr, + k_head_start * base_args.nhead_stride_k_descale); + args.v_descale_ptr = ptr_offset(base_args.v_descale_ptr, + k_head_start * base_args.nhead_stride_v_descale); + } + else + { + args.q_descale_ptr = base_args.q_descale_ptr; + args.k_descale_ptr = base_args.k_descale_ptr; + args.v_descale_ptr = base_args.v_descale_ptr; + } + + args.sink_ptr = ptr_offset(base_args.sink_ptr, head_start); + + if(log_enabled()) + { + const ck_tile::index_t head_end = head_start + q_heads; + std::cout << "[LLC Head Grouping] group " << (head_start / group_sz) << "/" << n_groups + << " heads_q=[" << head_start << ", " << head_end << ") heads_k=[" + << k_head_start << ", " << (k_head_start + k_heads) << ")" << std::endl; + } + + const float t = run_kernel_fn(fmha_traits, args, sc); + if(t < 0.0f) + return t; + total_time += t; + } + return total_time; +} + +} // namespace fmha_fwd_head_grouping +#endif // CK_TILE_FMHA_ENABLE_HEAD_GROUPING diff --git a/example/ck_tile/01_fmha/fmha_fwd_runner.hpp b/example/ck_tile/01_fmha/fmha_fwd_runner.hpp index 17d53a4e6d..40b8006381 100644 --- a/example/ck_tile/01_fmha/fmha_fwd_runner.hpp +++ b/example/ck_tile/01_fmha/fmha_fwd_runner.hpp @@ -6,14 +6,17 @@ #include "ck_tile/host.hpp" #include "ck_tile/ref/naive_attention.hpp" #include "fmha_fwd.hpp" +#include "fmha_fwd_head_grouping.hpp" #include "utils.hpp" #include "ck_tile/utility/json_dump.hpp" #include +#include #include #include #include #include +#include #include #include #include @@ -1238,6 +1241,11 @@ fwd_result fmha_fwd_run(mode_enum mode, args.hdim_v = hdim_v; args.nhead_q = nhead; args.nhead_k = nhead_k; + if constexpr(std::is_same_v>) + { + args.num_head_q_total = nhead; + args.head_start = 0; + } args.stride_q = stride_q; args.stride_k = stride_k; @@ -1555,7 +1563,87 @@ fwd_result fmha_fwd_run(mode_enum mode, return fmha_fwd(fmha_traits, fmha_args, sc); }; - const float fwd_ave_time = run_fwd(stream_config); + + float fwd_ave_time = -1.0f; +#if CK_TILE_FMHA_ENABLE_HEAD_GROUPING + const bool allow_head_grouping = !i_perm && !use_kvcache && (num_splits <= 1) && + !need_append_kvcache && + (mode == mode_enum::batch || mode == mode_enum::group); + + if(allow_head_grouping) + { + if(fmha_fwd_head_grouping::disabled_by_env()) + { + if(fmha_fwd_head_grouping::log_enabled()) + std::cout << "[LLC Head Grouping] disabled by env" << std::endl; + } + else + { + const auto group_size_opt = + fmha_fwd_head_grouping::get_head_group_size(nhead, + nhead_k, + batch, + max_seqlen_k, + hdim_q, + hdim_v, + sizeof(KDataType), + sizeof(VDataType)); + + if(group_size_opt.has_value() && group_size_opt.value() < nhead) + { + if(fmha_fwd_head_grouping::log_enabled()) + { + const std::string arch = ck_tile::get_device_name(); + const size_t llc_bytes = fmha_fwd_head_grouping::get_llc_cache_bytes(arch); + const ck_tile::index_t gqa_ratio = (nhead_k > 0 ? (nhead / nhead_k) : 1); + const ck_tile::index_t group_sz = group_size_opt.value(); + const ck_tile::index_t n_groups = ck_tile::integer_divide_ceil(nhead, group_sz); + std::cout << "[LLC Head Grouping] enabled" << std::endl; + std::cout << "[LLC Head Grouping] arch=" << (arch.empty() ? "unknown" : arch) + << " llc_mb=" << (llc_bytes / (1024ull * 1024ull)) + << " nhead_q=" << nhead << " nhead_k=" << nhead_k + << " gqa_ratio=" << gqa_ratio << " group_size=" << group_sz + << " groups=" << n_groups << std::endl; + } + fmha_fwd_traits fmha_traits; + init_traits(fmha_traits); + + fmha_fwd_args fmha_args; + init_args(fmha_args); + + fwd_ave_time = fmha_fwd_head_grouping::run_fwd_head_grouped( + stream_config, + fmha_traits, + fmha_args, + nhead, + nhead_k, + group_size_opt.value(), + qscale.type == quant_scale_enum::blockscale, + [&](const auto& traits, auto& args, const auto& sc) { + return fmha_fwd(traits, args, sc); + }); + } + else if(fmha_fwd_head_grouping::log_enabled()) + { + std::cout << "[LLC Head Grouping] skipped (group_size not set or >= nhead)" + << std::endl; + } + } + } + else if(fmha_fwd_head_grouping::log_enabled()) + { + std::cout << "[LLC Head Grouping] disabled by conditions/layout" << std::endl; + } +#endif + + if(fwd_ave_time < 0.0f) + fwd_ave_time = run_fwd(stream_config); if(fwd_ave_time < 0.0f) { std::cout << ", not supported yet" << std::flush << std::endl; diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp index bd09453dbb..16f5b00bb1 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp @@ -15,6 +15,15 @@ #include #define CK_TILE_FMHA_HANDLE_XOR_LENGTH_FOLD 0 + +#if !defined(CK_TILE_FMHA_FORCE_HEAD_MAJOR) +#if defined(__HIP_DEVICE_COMPILE__) && (defined(__gfx11__) || defined(__gfx12__)) +#define CK_TILE_FMHA_FORCE_HEAD_MAJOR 1 +#else +#define CK_TILE_FMHA_FORCE_HEAD_MAJOR 0 +#endif +#endif + // S[seqlen_q, seqlen_k] = Q[seqlen_q, hdim_q] @ K[seqlen_k, hdim_q] // S'[seqlen_q, seqlen_k] = S[seqlen_q, seqlen_k] * Scale[1] // S''[seqlen_q, seqlen_k] = S'[seqlen_q, seqlen_k] + Bias[seqlen_q, seqlen_k] @@ -111,6 +120,10 @@ struct FmhaFwdKernel ck_tile::index_t nhead_stride_k; ck_tile::index_t nhead_stride_v; ck_tile::index_t nhead_stride_o; + + // Optional global head count and head offset (for grouped launches & RNG correctness) + ck_tile::index_t num_head_q_total = 0; + ck_tile::index_t head_start = 0; }; struct FmhaFwdLogitsSoftCapKargs @@ -410,9 +423,11 @@ struct FmhaFwdKernel drop_seed_offset, ck_tile::index_t block_scale_size_q, ck_tile::index_t block_scale_size_kv, - const void* cu_seqlen_q_ptr = nullptr, - const void* cu_seqlen_k_ptr = nullptr, - const void* sink_ptr = nullptr) + const void* cu_seqlen_q_ptr = nullptr, + const void* cu_seqlen_k_ptr = nullptr, + const void* sink_ptr = nullptr, + ck_tile::index_t num_head_q_total = 0, + ck_tile::index_t head_start = 0) { Kargs kargs{{q_ptr, k_ptr, @@ -448,6 +463,8 @@ struct FmhaFwdKernel batch_stride_k, batch_stride_v, batch_stride_o}; + kargs.num_head_q_total = num_head_q_total; + kargs.head_start = head_start; if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) { @@ -605,9 +622,11 @@ struct FmhaFwdKernel const std::tuple& drop_seed_offset, ck_tile::index_t block_scale_size_q, ck_tile::index_t block_scale_size_kv, - const void* cu_seqlen_q_ptr = nullptr, - const void* cu_seqlen_k_ptr = nullptr, - const void* sink_ptr = nullptr) + const void* cu_seqlen_q_ptr = nullptr, + const void* cu_seqlen_k_ptr = nullptr, + const void* sink_ptr = nullptr, + ck_tile::index_t num_head_q_total = 0, + ck_tile::index_t head_start = 0) { return MakeKargsImpl( q_ptr, @@ -668,7 +687,9 @@ struct FmhaFwdKernel block_scale_size_kv, cu_seqlen_q_ptr, cu_seqlen_k_ptr, - sink_ptr); + sink_ptr, + num_head_q_total, + head_start); } // std::variant<> can't take in a list initializer, overload for backward compatibility @@ -730,9 +751,11 @@ struct FmhaFwdKernel const std::tuple& drop_seed_offset, ck_tile::index_t block_scale_size_q, ck_tile::index_t block_scale_size_kv, - const void* cu_seqlen_q_ptr = nullptr, - const void* cu_seqlen_k_ptr = nullptr, - const void* sink_ptr = nullptr) + const void* cu_seqlen_q_ptr = nullptr, + const void* cu_seqlen_k_ptr = nullptr, + const void* sink_ptr = nullptr, + ck_tile::index_t num_head_q_total = 0, + ck_tile::index_t head_start = 0) { return MakeKargsImpl( q_ptr, @@ -793,7 +816,9 @@ struct FmhaFwdKernel block_scale_size_kv, cu_seqlen_q_ptr, cu_seqlen_k_ptr, - sink_ptr); + sink_ptr, + num_head_q_total, + head_start); } template @@ -851,9 +876,11 @@ struct FmhaFwdKernel drop_seed_offset, ck_tile::index_t block_scale_size_q, ck_tile::index_t block_scale_size_kv, - const void* cu_seqlen_q_ptr = nullptr, - const void* cu_seqlen_k_ptr = nullptr, - const void* sink_ptr = nullptr) + const void* cu_seqlen_q_ptr = nullptr, + const void* cu_seqlen_k_ptr = nullptr, + const void* sink_ptr = nullptr, + ck_tile::index_t num_head_q_total = 0, + ck_tile::index_t head_start = 0) { Kargs kargs{{q_ptr, k_ptr, @@ -890,6 +917,8 @@ struct FmhaFwdKernel reinterpret_cast(seqstart_k_ptr), reinterpret_cast(seqlen_q_ptr), reinterpret_cast(seqlen_k_ptr)}; + kargs.num_head_q_total = num_head_q_total; + kargs.head_start = head_start; if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) { @@ -1042,9 +1071,11 @@ struct FmhaFwdKernel const std::tuple& drop_seed_offset, ck_tile::index_t block_scale_size_q, ck_tile::index_t block_scale_size_kv, - const void* cu_seqlen_q_ptr = nullptr, - const void* cu_seqlen_k_ptr = nullptr, - const void* sink_ptr = nullptr) + const void* cu_seqlen_q_ptr = nullptr, + const void* cu_seqlen_k_ptr = nullptr, + const void* sink_ptr = nullptr, + ck_tile::index_t num_head_q_total = 0, + ck_tile::index_t head_start = 0) { return MakeKargsImpl( q_ptr, @@ -1100,7 +1131,9 @@ struct FmhaFwdKernel block_scale_size_kv, cu_seqlen_q_ptr, cu_seqlen_k_ptr, - sink_ptr); + sink_ptr, + num_head_q_total, + head_start); } // std::variant<> can't take in a list initializer, overload for backward compatibility @@ -1157,9 +1190,11 @@ struct FmhaFwdKernel const std::tuple& drop_seed_offset, ck_tile::index_t block_scale_size_q, ck_tile::index_t block_scale_size_kv, - const void* cu_seqlen_q_ptr = nullptr, - const void* cu_seqlen_k_ptr = nullptr, - const void* sink_ptr = nullptr) + const void* cu_seqlen_q_ptr = nullptr, + const void* cu_seqlen_k_ptr = nullptr, + const void* sink_ptr = nullptr, + ck_tile::index_t num_head_q_total = 0, + ck_tile::index_t head_start = 0) { return MakeKargsImpl( q_ptr, @@ -1215,7 +1250,9 @@ struct FmhaFwdKernel block_scale_size_kv, cu_seqlen_q_ptr, cu_seqlen_k_ptr, - sink_ptr); + sink_ptr, + num_head_q_total, + head_start); } CK_TILE_HOST static constexpr auto GridSize(ck_tile::index_t batch_size_, @@ -1250,6 +1287,54 @@ struct FmhaFwdKernel if constexpr(kIsGroupMode) has_padded_seqlen_k = (kargs.seqlen_k_ptr != nullptr); +#if CK_TILE_FMHA_FORCE_HEAD_MAJOR + // compiler-workaround gate (ROCm 7.1 + gfx12). + // Keep head-major enabled for all unaffected kernels. +#if defined(__gfx12__) && (HIP_VERSION_MAJOR == 7) && (HIP_VERSION_MINOR == 1) + constexpr bool kSkipHeadMajor = kIsGroupMode && kHasMask && !kHasDropout && + (BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) && + kPadHeadDimQ && kPadHeadDimV && + (FmhaPipeline::kN1 == 256) && + std::is_same_v && + std::is_same_v && + std::is_same_v; +#else + constexpr bool kSkipHeadMajor = false; +#endif + if constexpr(!kSkipHeadMajor) + { + // bhsd should satisfy stride_q == hdim_q and nhead_stride_q > hdim_q + // The extra nhead_stride_q guard prevents bshd false-positive when nhead == 1 + const bool is_bhsd_layout = + (kargs.stride_q == kargs.hdim_q) && (kargs.nhead_stride_q > kargs.hdim_q); + if(is_bhsd_layout) + { + const index_t num_tile_n1 = + ck_tile::integer_divide_ceil(kargs.hdim_v, FmhaPipeline::kN1); + const index_t num_tile_total = has_padded_seqlen_k ? gridDim.z : gridDim.y; + const index_t num_head = gridDim.x; + const index_t blocks_per_batch = num_head * num_tile_total; + const index_t linear_id = + blockIdx.x + gridDim.x * (blockIdx.y + gridDim.y * blockIdx.z); + + const index_t i_batch = linear_id / blocks_per_batch; + const index_t rem0 = linear_id - i_batch * blocks_per_batch; + const index_t i_nhead = rem0 / num_tile_total; + const index_t i_block = rem0 - i_nhead * num_tile_total; + + index_t i_tile_m = i_block / num_tile_n1; + index_t i_tile_n = i_block - i_tile_m * num_tile_n1; + + if constexpr(kHasMask) + { + const index_t num_tile_m = num_tile_total / num_tile_n1; + i_tile_m = num_tile_m - 1 - i_tile_m; + } + return ck_tile::make_tuple(i_tile_m, i_tile_n, i_nhead, i_batch); + } + } +#endif + if(has_padded_seqlen_k) { // const index_t num_tile_m0 = seqlen_q / kM0; @@ -1271,7 +1356,8 @@ struct FmhaFwdKernel if constexpr(kHasMask) { // assume that num_tile_n1 is always 1 - return ck_tile::make_tuple(gridDim.z - 1 - i_tile_m, i_tile_n, i_nhead, i_batch); + return ck_tile::make_tuple( + static_cast(gridDim.z) - 1 - i_tile_m, i_tile_n, i_nhead, i_batch); } else { @@ -1299,7 +1385,8 @@ struct FmhaFwdKernel if constexpr(kHasMask) { // assume that num_tile_n1 is always 1 - return ck_tile::make_tuple(gridDim.y - 1 - i_tile_m, i_tile_n, i_nhead, i_batch); + return ck_tile::make_tuple( + static_cast(gridDim.y) - 1 - i_tile_m, i_tile_n, i_nhead, i_batch); } else { @@ -1677,9 +1764,12 @@ struct FmhaFwdKernel auto dropout = [&, i_nhead_ = i_nhead, i_batch_ = i_batch]() { if constexpr(kHasDropout) { + const auto num_head_q_total = + (kargs.num_head_q_total > 0 ? kargs.num_head_q_total : kargs.num_head_q); + const auto i_head_global = kargs.head_start + i_nhead_; return BlockDropout{i_batch_, - i_nhead_, - kargs.num_head_q, + i_head_global, + num_head_q_total, kargs.is_drop_seed_offset_from_host ? kargs.drop_seed.val : *kargs.drop_seed.ptr, kargs.is_drop_seed_offset_from_host