Merge branch 'develop' into users/yiding12/fmha-bwd-workspace

This commit is contained in:
Yi DING
2026-04-27 15:07:41 +08:00
committed by GitHub
50 changed files with 5216 additions and 1120 deletions

View File

@@ -22,8 +22,16 @@ from codegen.cpp_symbol_map import (
QSCALE_CHECK_MAP,
QSCALE_MAP,
)
from codegen.arch import ArchTrait
from codegen.utils import update_file
# Architecture trait for kernels requiring global_load_lds (CDNA3+).
# Only used for GLOBAL_LOAD_LDS variants; all other kernels are arch-agnostic.
CDNA3_PLUS_ARCH = ArchTrait(
"cdna3_plus",
preprocessor_check="defined(__gfx94__) || defined(__gfx950__)",
)
DTYPE_BITS = {
"fp32": 32,
"fp16": 16,
@@ -34,6 +42,10 @@ DTYPE_BITS = {
"bf8": 8,
}
# Element size in bytes per dtype, used by the auto-generated dispatcher to
# decide kv_load_mode per-arm (total KV cache bytes vs INT32_MAX).
DTYPE_BYTES = {k: v // 8 for k, v in DTYPE_BITS.items()}
K0_MAX_SUBMAX_MAP = {32: 32, 64: 64, 96: 128, 128: 128, 256: 256}
SUPPORTED_PAGE_SIZE = [1, 16, 1024]
@@ -47,6 +59,10 @@ KV_LOOKUP_TABLE_ENUM_MAP = {
"vllm": "ck_tile::BlockAttentionKVCacheLookupTableEnum::VLLM_BLOCK_TABLE_2D",
"sglang": "ck_tile::BlockAttentionKVCacheLookupTableEnum::SGLANG_PAGE_TABLE_1D",
}
KV_LOAD_MODE_ENUM_MAP = {
False: "ck_tile::BlockAttentionKVCacheLoadModeEnum::BUFFER_LOAD",
True: "ck_tile::BlockAttentionKVCacheLoadModeEnum::GLOBAL_LOAD_LDS",
}
FMHA_BATCH_PREFILL_PIPELINE_MAP = {
@@ -61,6 +77,8 @@ FMHA_FWD_KERNEL_HEADER = """// SPDX-License-Identifier: MIT
"""
FMHA_FWD_KERNEL_BODY = """
#if !defined(__HIP_DEVICE_COMPILE__) || ({F_arch_check})
using fmha_dtype_{F_idx} = {F_dtype};
using fmha_block_tile_{F_idx} = ck_tile::sequence<{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}>;
@@ -87,7 +105,8 @@ using fmha_trait_{F_idx} = ck_tile::TileFmhaBatchPrefillTraits<{F_spad},
{F_sink},
{F_page_size},
{F_kv_memory_layout},
{F_kv_lookup_table}>;
{F_kv_lookup_table},
{F_kv_load_mode}>;
using fmha_variant_{F_idx} = ck_tile::ComposedAttention<{F_logits} * ck_tile::LOGITS_SOFT_CAP, CK_TILE_FMHA_FWD_FAST_EXP2>;
@@ -125,7 +144,7 @@ using fmha_kernel_{F_idx} =
ck_tile::FmhaBatchPrefillWithPagedKVCacheKernel<fmha_pipeline_{F_idx}, fmha_epilogue_{F_idx}>;
using trait_{F_idx} = fmha_fwd_batch_prefill_traits_<{F_hdim}, {F_dtype}, {F_mode},{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout},
{F_pipeline_enum}, {F_logits}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_dropout}, {F_qscale}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, false, false, {F_sink}, {F_page_size}, {F_kv_memory_layout}, {F_kv_lookup_table}>;
{F_pipeline_enum}, {F_logits}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_dropout}, {F_qscale}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, false, false, {F_sink}, {F_page_size}, {F_kv_memory_layout}, {F_kv_lookup_table}, {F_kv_load_mode}>;
#include <iostream>
@@ -140,10 +159,13 @@ float fmha_batch_prefill_<trait_{F_idx}>(const ck_tile::stream_config& s, fmha_b
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
return ck_tile::launch_kernel(s, ck_tile::make_kernel<kBlockPerCu>(k_{{}}, grids, blocks, 0, kargs));
}}
#endif // !defined(__HIP_DEVICE_COMPILE__) || ({F_arch_check})
"""
FMHA_FWD_API_FILENAME = "fmha_batch_prefill_api.cpp"
FMHA_FWD_API = """
#include <cstdint>
#include <cstdio>
namespace {{
@@ -194,6 +216,7 @@ float fmha_batch_prefill(fmha_batch_prefill_traits t, fmha_batch_prefill_args a,
"""
FMHA_FWD_API_PER_DTYPE = """ {F_if}(t.data_type.compare(\"{F_dtype}\") == 0){{
constexpr int kElementBytes = {F_element_bytes};
{F_hdim_case}
}}
"""
@@ -203,8 +226,8 @@ FMHA_FWD_API_PER_HDIM_CASE = """ {F_if} (t.hdim_q <= {F_hdim} && t.hdim_v
"""
FMHA_FWD_API_INNER_DISPATCH = """ {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && (t.has_logits_soft_cap == {F_logits}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.has_dropout == {F_dropout}) && (t.qscale_type == {F_qscale_check}) && (t.has_sink == {F_sink}) &&
({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck}) && ({F_constraint}) && (t.kv_memory_layout == {F_kv_memory_layout}) && (t.kv_lookup_table == {F_kv_lookup_table}) && (t.page_size == {F_page_size})) {{
using trait_ = fmha_fwd_batch_prefill_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, {F_lse}, {F_dropout}, {F_qscale}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, false, false, {F_sink}, {F_page_size}, {F_kv_memory_layout}, {F_kv_lookup_table}>;
({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck}) && ({F_constraint}) && (t.kv_memory_layout == {F_kv_memory_layout}) && (t.kv_lookup_table == {F_kv_lookup_table}) && (t.page_size == {F_page_size}) && (fmha_batch_prefill_select_kv_load_mode(a.page_block_size, {F_bn0}, a.num_total_pages, a.batch_stride_k, kElementBytes) == {F_kv_load_mode})) {{
using trait_ = fmha_fwd_batch_prefill_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, {F_lse}, {F_dropout}, {F_qscale}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, false, false, {F_sink}, {F_page_size}, {F_kv_memory_layout}, {F_kv_lookup_table}, {F_kv_load_mode}>;
return fmha_batch_prefill_<trait_>(s, a);
}}
"""
@@ -253,12 +276,14 @@ class FmhaFwdApiTrait:
kv_memory_layout: str
kv_lookup_table: str
page_size: int = 1 # page block size
use_global_load: bool = False # use global_load_lds_* for >2GB KV cache
@property
def name(self) -> str:
return (
f"{self.hdim}-{self.dtype}-{self.mode}-{self.bm0}-{self.bn0}-{self.bk0}-{self.bn0}-{self.bk1}-{self.bk0max}-"
+ f"{self.vlayout}-{self.logits}-{self.mask}-{self.bias}-{self.lse}-{self.dropout}-{self.qscale}-{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}-{self.kv_memory_layout}-{self.kv_lookup_table}-ps{self.page_size}"
+ ("-gload" if self.use_global_load else "-bload")
)
@property
@@ -481,6 +506,7 @@ class FmhaFwdApiPool:
],
F_page_size=trait.page_size,
F_sink=BOOL_MAP[trait.sink],
F_kv_load_mode=KV_LOAD_MODE_ENUM_MAP[trait.use_global_load],
)
if_j = "if" if j == 0 else "else if"
per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format(
@@ -488,7 +514,10 @@ class FmhaFwdApiPool:
)
if_i = "if" if i == 0 else "else if"
per_dtypes = per_dtypes + FMHA_FWD_API_PER_DTYPE.format(
F_if=if_i, F_dtype=dtype, F_hdim_case=per_hdim_case
F_if=if_i,
F_dtype=dtype,
F_element_bytes=DTYPE_BYTES[dtype],
F_hdim_case=per_hdim_case,
)
if not per_dtypes:
# empty string we add some ignore to suppress warning in api
@@ -539,6 +568,7 @@ class FmhaFwdKernel:
F_pipeline: FmhaFwdPipeline
mask_impl: str
F_page_size: int = 1 # page block size
F_use_global_load: bool = False # use global_load_lds_* for >2GB KV cache
@property
def template(self) -> str:
@@ -588,6 +618,10 @@ class FmhaFwdKernel:
F_pipeline=FMHA_BATCH_PREFILL_PIPELINE_MAP[self.F_pipeline.tag],
F_page_size=self.F_page_size,
F_sink=BOOL_MAP[self.F_pipeline.F_sink],
F_kv_load_mode=KV_LOAD_MODE_ENUM_MAP[self.F_use_global_load],
F_arch_check=CDNA3_PLUS_ARCH.preprocessor_check
if self.F_use_global_load
else "true",
)
@property
@@ -595,6 +629,7 @@ class FmhaFwdKernel:
# TODO: we don't encode idx here
return (
f"fmha_batch_prefill_d{self.F_hdim}_{self.F_dtype}_{self.F_mode}_ps{self.F_page_size}_"
+ ("gload_" if self.F_use_global_load else "bload_")
+ self.F_tile.name
+ "_"
+ self.F_pipeline.name
@@ -632,6 +667,7 @@ class FmhaFwdKernel:
kv_memory_layout=self.F_pipeline.F_kv_memory_layout,
kv_lookup_table=self.F_pipeline.F_kv_lookup_table,
page_size=self.F_page_size,
use_global_load=self.F_use_global_load,
)
@@ -714,8 +750,11 @@ class CustomFactory(KernelComponentFactory):
def get_fwd_blobs(
kernel_filter: Optional[str], receipt, optdim_list, mask_impl,
targets: Optional[List[str]] = None
kernel_filter: Optional[str],
receipt,
optdim_list,
mask_impl,
targets: Optional[List[str]] = None,
) -> Tuple[FmhaFwdApiPool, List[FmhaFwdKernel]]:
# batch_prefill pipeline uses gfx9-specific async scatter-gather buffer addressing
# (amd_buffer_addressing.hpp raw buffer loads) that is not compatible with
@@ -837,6 +876,25 @@ def get_fwd_blobs(
api_pool.register_traits(k.api_trait())
gen.append(k)
# For page_size < kN0 (tile.F_bn0), also generate a GLOBAL_LOAD_LDS
# variant for >2GB KV cache support. The default (BUFFER_LOAD) uses SRD
# buffer_load (fast, <2GB). GLOBAL_LOAD_LDS uses global_load_lds_*
# (slower, handles >2GB).
if page_size < tile.F_bn0:
k_global_load = FmhaFwdKernel(
F_idx=0,
F_hdim=hdim,
F_dtype=dtype,
F_mode=mode,
F_tile=tile,
F_pipeline=pipeline,
mask_impl=mask_impl,
F_page_size=page_size,
F_use_global_load=True,
)
api_pool.register_traits(k_global_load.api_trait())
gen.append(k_global_load)
return (api_pool, gen)
@@ -856,7 +914,9 @@ def write_blobs(
optdim_list,
mask_impl,
) -> None:
api_pool, kernels = get_fwd_blobs(kernel_filter, receipt, optdim_list, mask_impl, targets)
api_pool, kernels = get_fwd_blobs(
kernel_filter, receipt, optdim_list, mask_impl, targets
)
for kernel in kernels:
write_single_fwd_kernel(kernel, output_dir)
write_fwd_api(api_pool, output_dir)
@@ -871,7 +931,9 @@ def list_blobs(
mask_impl,
) -> None:
with file_path.open("a") as f:
_, kernels = get_fwd_blobs(kernel_filter, receipt, optdim_list, mask_impl, targets)
_, kernels = get_fwd_blobs(
kernel_filter, receipt, optdim_list, mask_impl, targets
)
for kernel in kernels:
f.write((file_path.parent / GEN_DIR / kernel.filename).as_posix() + "\n")
f.write((file_path.parent / GEN_DIR / FMHA_FWD_API_FILENAME).as_posix() + "\n")

View File

@@ -673,6 +673,33 @@ struct fmha_batch_prefill_args
ck_tile::index_t nhead_stride_kv_block_descale = 0; // Stride along num_kv_head dimension
};
// Selects the KV-cache load mode for a batch-prefill dispatch arm.
// GLOBAL_LOAD_LDS: required when (a) the page is smaller than one K/V tile
// so per-page SRD is impossible, AND (b) the total KV-pool byte size
// exceeds INT32_MAX so SRD's 32-bit byte offset cannot address it.
// BUFFER_LOAD: every other case — the SGPR-resident SRD path is fastest.
// Inputs are taken as plain integers so the helper has no template parameter
// and can be called from each codegen-emitted dispatcher arm with the arm's
// compile-time kN0 / element_bytes substituted as constants.
inline ck_tile::BlockAttentionKVCacheLoadModeEnum
fmha_batch_prefill_select_kv_load_mode(ck_tile::index_t page_block_size,
ck_tile::index_t kN0,
ck_tile::index_t num_total_pages,
ck_tile::index_t batch_stride_k,
ck_tile::index_t element_bytes)
{
// Promote every operand to long_index_t so overflow is impossible regardless
// of multiplication order. A bare `static_cast<long_index_t>(num_total_pages)
// * batch_stride_k * element_bytes` only works because of left-to-right
// associativity — a future reorder of the operands would silently truncate.
const auto kv_pool_bytes = static_cast<ck_tile::long_index_t>(num_total_pages) *
static_cast<ck_tile::long_index_t>(batch_stride_k) *
static_cast<ck_tile::long_index_t>(element_bytes);
return (page_block_size < kN0 && kv_pool_bytes > INT32_MAX)
? ck_tile::BlockAttentionKVCacheLoadModeEnum::GLOBAL_LOAD_LDS
: ck_tile::BlockAttentionKVCacheLoadModeEnum::BUFFER_LOAD;
}
template <typename FmhaKernel>
auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args)
{
@@ -1457,7 +1484,9 @@ template <ck_tile::index_t HDim_,
ck_tile::BlockAttentionKVCacheMemoryLayoutEnum kKVMemoryLayout_ =
ck_tile::BlockAttentionKVCacheMemoryLayoutEnum::VECTORIZED_LAYOUT,
ck_tile::BlockAttentionKVCacheLookupTableEnum kKVLookupTable_ =
ck_tile::BlockAttentionKVCacheLookupTableEnum::SGLANG_PAGE_TABLE_1D>
ck_tile::BlockAttentionKVCacheLookupTableEnum::SGLANG_PAGE_TABLE_1D,
ck_tile::BlockAttentionKVCacheLoadModeEnum kKVLoadMode_ =
ck_tile::BlockAttentionKVCacheLoadModeEnum::BUFFER_LOAD>
struct fmha_fwd_batch_prefill_traits_ : public fmha_fwd_traits_<HDim_,
DataType_,
kIsGroupMode_,
@@ -1486,6 +1515,7 @@ struct fmha_fwd_batch_prefill_traits_ : public fmha_fwd_traits_<HDim_,
static constexpr auto kKVMemoryLayout = kKVMemoryLayout_;
static constexpr auto kKVLookupTable = kKVLookupTable_;
static constexpr ck_tile::index_t kPageBlockSize = kPageBlockSize_;
static constexpr auto kKVLoadMode = kKVLoadMode_;
static_assert(kIsVLayoutRowMajor_, "Batch prefill only supports row-major V layout");
};