mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-18 12:00:07 +00:00
Merge branch 'develop' into users/yiding12/fmha-bwd-workspace
This commit is contained in:
@@ -108,28 +108,35 @@ bool run_grouped_conv_fwd(bool do_verification,
|
||||
|
||||
if(do_verification)
|
||||
{
|
||||
Tensor<AccDataType> c_host(out_g_n_k_wos_desc);
|
||||
|
||||
auto ref_conv = ck::tensor_operation::host::ReferenceConvFwd<NDimSpatial,
|
||||
InDataType,
|
||||
WeiDataType,
|
||||
OutDataType,
|
||||
AccDataType,
|
||||
InElementOp,
|
||||
WeiElementOp,
|
||||
OutElementOp>();
|
||||
PassThrough>();
|
||||
|
||||
auto ref_invoker = ref_conv.MakeInvoker();
|
||||
auto ref_argument = ref_conv.MakeArgument(in,
|
||||
wei,
|
||||
out_host,
|
||||
c_host,
|
||||
conv_param.conv_filter_strides_,
|
||||
conv_param.conv_filter_dilations_,
|
||||
conv_param.input_left_pads_,
|
||||
conv_param.input_right_pads_,
|
||||
in_element_op,
|
||||
wei_element_op,
|
||||
out_element_op);
|
||||
PassThrough{});
|
||||
|
||||
ref_invoker.Run(ref_argument);
|
||||
|
||||
out_host.ForEach([&](auto&, auto idx)
|
||||
{
|
||||
out_element_op(out_host(idx), c_host(idx));
|
||||
});
|
||||
|
||||
out_device_buf.FromDevice(out_device.mData.data());
|
||||
|
||||
pass &=
|
||||
|
||||
@@ -22,8 +22,16 @@ from codegen.cpp_symbol_map import (
|
||||
QSCALE_CHECK_MAP,
|
||||
QSCALE_MAP,
|
||||
)
|
||||
from codegen.arch import ArchTrait
|
||||
from codegen.utils import update_file
|
||||
|
||||
# Architecture trait for kernels requiring global_load_lds (CDNA3+).
|
||||
# Only used for GLOBAL_LOAD_LDS variants; all other kernels are arch-agnostic.
|
||||
CDNA3_PLUS_ARCH = ArchTrait(
|
||||
"cdna3_plus",
|
||||
preprocessor_check="defined(__gfx94__) || defined(__gfx950__)",
|
||||
)
|
||||
|
||||
DTYPE_BITS = {
|
||||
"fp32": 32,
|
||||
"fp16": 16,
|
||||
@@ -34,6 +42,10 @@ DTYPE_BITS = {
|
||||
"bf8": 8,
|
||||
}
|
||||
|
||||
# Element size in bytes per dtype, used by the auto-generated dispatcher to
|
||||
# decide kv_load_mode per-arm (total KV cache bytes vs INT32_MAX).
|
||||
DTYPE_BYTES = {k: v // 8 for k, v in DTYPE_BITS.items()}
|
||||
|
||||
K0_MAX_SUBMAX_MAP = {32: 32, 64: 64, 96: 128, 128: 128, 256: 256}
|
||||
|
||||
SUPPORTED_PAGE_SIZE = [1, 16, 1024]
|
||||
@@ -47,6 +59,10 @@ KV_LOOKUP_TABLE_ENUM_MAP = {
|
||||
"vllm": "ck_tile::BlockAttentionKVCacheLookupTableEnum::VLLM_BLOCK_TABLE_2D",
|
||||
"sglang": "ck_tile::BlockAttentionKVCacheLookupTableEnum::SGLANG_PAGE_TABLE_1D",
|
||||
}
|
||||
KV_LOAD_MODE_ENUM_MAP = {
|
||||
False: "ck_tile::BlockAttentionKVCacheLoadModeEnum::BUFFER_LOAD",
|
||||
True: "ck_tile::BlockAttentionKVCacheLoadModeEnum::GLOBAL_LOAD_LDS",
|
||||
}
|
||||
|
||||
|
||||
FMHA_BATCH_PREFILL_PIPELINE_MAP = {
|
||||
@@ -61,6 +77,8 @@ FMHA_FWD_KERNEL_HEADER = """// SPDX-License-Identifier: MIT
|
||||
"""
|
||||
|
||||
FMHA_FWD_KERNEL_BODY = """
|
||||
#if !defined(__HIP_DEVICE_COMPILE__) || ({F_arch_check})
|
||||
|
||||
using fmha_dtype_{F_idx} = {F_dtype};
|
||||
|
||||
using fmha_block_tile_{F_idx} = ck_tile::sequence<{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}>;
|
||||
@@ -87,7 +105,8 @@ using fmha_trait_{F_idx} = ck_tile::TileFmhaBatchPrefillTraits<{F_spad},
|
||||
{F_sink},
|
||||
{F_page_size},
|
||||
{F_kv_memory_layout},
|
||||
{F_kv_lookup_table}>;
|
||||
{F_kv_lookup_table},
|
||||
{F_kv_load_mode}>;
|
||||
|
||||
using fmha_variant_{F_idx} = ck_tile::ComposedAttention<{F_logits} * ck_tile::LOGITS_SOFT_CAP, CK_TILE_FMHA_FWD_FAST_EXP2>;
|
||||
|
||||
@@ -125,7 +144,7 @@ using fmha_kernel_{F_idx} =
|
||||
ck_tile::FmhaBatchPrefillWithPagedKVCacheKernel<fmha_pipeline_{F_idx}, fmha_epilogue_{F_idx}>;
|
||||
|
||||
using trait_{F_idx} = fmha_fwd_batch_prefill_traits_<{F_hdim}, {F_dtype}, {F_mode},{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout},
|
||||
{F_pipeline_enum}, {F_logits}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_dropout}, {F_qscale}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, false, false, {F_sink}, {F_page_size}, {F_kv_memory_layout}, {F_kv_lookup_table}>;
|
||||
{F_pipeline_enum}, {F_logits}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_dropout}, {F_qscale}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, false, false, {F_sink}, {F_page_size}, {F_kv_memory_layout}, {F_kv_lookup_table}, {F_kv_load_mode}>;
|
||||
|
||||
#include <iostream>
|
||||
|
||||
@@ -140,10 +159,13 @@ float fmha_batch_prefill_<trait_{F_idx}>(const ck_tile::stream_config& s, fmha_b
|
||||
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
|
||||
return ck_tile::launch_kernel(s, ck_tile::make_kernel<kBlockPerCu>(k_{{}}, grids, blocks, 0, kargs));
|
||||
}}
|
||||
|
||||
#endif // !defined(__HIP_DEVICE_COMPILE__) || ({F_arch_check})
|
||||
"""
|
||||
|
||||
FMHA_FWD_API_FILENAME = "fmha_batch_prefill_api.cpp"
|
||||
FMHA_FWD_API = """
|
||||
#include <cstdint>
|
||||
#include <cstdio>
|
||||
|
||||
namespace {{
|
||||
@@ -194,6 +216,7 @@ float fmha_batch_prefill(fmha_batch_prefill_traits t, fmha_batch_prefill_args a,
|
||||
"""
|
||||
|
||||
FMHA_FWD_API_PER_DTYPE = """ {F_if}(t.data_type.compare(\"{F_dtype}\") == 0){{
|
||||
constexpr int kElementBytes = {F_element_bytes};
|
||||
{F_hdim_case}
|
||||
}}
|
||||
"""
|
||||
@@ -203,8 +226,8 @@ FMHA_FWD_API_PER_HDIM_CASE = """ {F_if} (t.hdim_q <= {F_hdim} && t.hdim_v
|
||||
"""
|
||||
|
||||
FMHA_FWD_API_INNER_DISPATCH = """ {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && (t.has_logits_soft_cap == {F_logits}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.has_dropout == {F_dropout}) && (t.qscale_type == {F_qscale_check}) && (t.has_sink == {F_sink}) &&
|
||||
({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck}) && ({F_constraint}) && (t.kv_memory_layout == {F_kv_memory_layout}) && (t.kv_lookup_table == {F_kv_lookup_table}) && (t.page_size == {F_page_size})) {{
|
||||
using trait_ = fmha_fwd_batch_prefill_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, {F_lse}, {F_dropout}, {F_qscale}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, false, false, {F_sink}, {F_page_size}, {F_kv_memory_layout}, {F_kv_lookup_table}>;
|
||||
({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck}) && ({F_constraint}) && (t.kv_memory_layout == {F_kv_memory_layout}) && (t.kv_lookup_table == {F_kv_lookup_table}) && (t.page_size == {F_page_size}) && (fmha_batch_prefill_select_kv_load_mode(a.page_block_size, {F_bn0}, a.num_total_pages, a.batch_stride_k, kElementBytes) == {F_kv_load_mode})) {{
|
||||
using trait_ = fmha_fwd_batch_prefill_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, {F_lse}, {F_dropout}, {F_qscale}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, false, false, {F_sink}, {F_page_size}, {F_kv_memory_layout}, {F_kv_lookup_table}, {F_kv_load_mode}>;
|
||||
return fmha_batch_prefill_<trait_>(s, a);
|
||||
}}
|
||||
"""
|
||||
@@ -253,12 +276,14 @@ class FmhaFwdApiTrait:
|
||||
kv_memory_layout: str
|
||||
kv_lookup_table: str
|
||||
page_size: int = 1 # page block size
|
||||
use_global_load: bool = False # use global_load_lds_* for >2GB KV cache
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return (
|
||||
f"{self.hdim}-{self.dtype}-{self.mode}-{self.bm0}-{self.bn0}-{self.bk0}-{self.bn0}-{self.bk1}-{self.bk0max}-"
|
||||
+ f"{self.vlayout}-{self.logits}-{self.mask}-{self.bias}-{self.lse}-{self.dropout}-{self.qscale}-{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}-{self.kv_memory_layout}-{self.kv_lookup_table}-ps{self.page_size}"
|
||||
+ ("-gload" if self.use_global_load else "-bload")
|
||||
)
|
||||
|
||||
@property
|
||||
@@ -481,6 +506,7 @@ class FmhaFwdApiPool:
|
||||
],
|
||||
F_page_size=trait.page_size,
|
||||
F_sink=BOOL_MAP[trait.sink],
|
||||
F_kv_load_mode=KV_LOAD_MODE_ENUM_MAP[trait.use_global_load],
|
||||
)
|
||||
if_j = "if" if j == 0 else "else if"
|
||||
per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format(
|
||||
@@ -488,7 +514,10 @@ class FmhaFwdApiPool:
|
||||
)
|
||||
if_i = "if" if i == 0 else "else if"
|
||||
per_dtypes = per_dtypes + FMHA_FWD_API_PER_DTYPE.format(
|
||||
F_if=if_i, F_dtype=dtype, F_hdim_case=per_hdim_case
|
||||
F_if=if_i,
|
||||
F_dtype=dtype,
|
||||
F_element_bytes=DTYPE_BYTES[dtype],
|
||||
F_hdim_case=per_hdim_case,
|
||||
)
|
||||
if not per_dtypes:
|
||||
# empty string we add some ignore to suppress warning in api
|
||||
@@ -539,6 +568,7 @@ class FmhaFwdKernel:
|
||||
F_pipeline: FmhaFwdPipeline
|
||||
mask_impl: str
|
||||
F_page_size: int = 1 # page block size
|
||||
F_use_global_load: bool = False # use global_load_lds_* for >2GB KV cache
|
||||
|
||||
@property
|
||||
def template(self) -> str:
|
||||
@@ -588,6 +618,10 @@ class FmhaFwdKernel:
|
||||
F_pipeline=FMHA_BATCH_PREFILL_PIPELINE_MAP[self.F_pipeline.tag],
|
||||
F_page_size=self.F_page_size,
|
||||
F_sink=BOOL_MAP[self.F_pipeline.F_sink],
|
||||
F_kv_load_mode=KV_LOAD_MODE_ENUM_MAP[self.F_use_global_load],
|
||||
F_arch_check=CDNA3_PLUS_ARCH.preprocessor_check
|
||||
if self.F_use_global_load
|
||||
else "true",
|
||||
)
|
||||
|
||||
@property
|
||||
@@ -595,6 +629,7 @@ class FmhaFwdKernel:
|
||||
# TODO: we don't encode idx here
|
||||
return (
|
||||
f"fmha_batch_prefill_d{self.F_hdim}_{self.F_dtype}_{self.F_mode}_ps{self.F_page_size}_"
|
||||
+ ("gload_" if self.F_use_global_load else "bload_")
|
||||
+ self.F_tile.name
|
||||
+ "_"
|
||||
+ self.F_pipeline.name
|
||||
@@ -632,6 +667,7 @@ class FmhaFwdKernel:
|
||||
kv_memory_layout=self.F_pipeline.F_kv_memory_layout,
|
||||
kv_lookup_table=self.F_pipeline.F_kv_lookup_table,
|
||||
page_size=self.F_page_size,
|
||||
use_global_load=self.F_use_global_load,
|
||||
)
|
||||
|
||||
|
||||
@@ -714,8 +750,11 @@ class CustomFactory(KernelComponentFactory):
|
||||
|
||||
|
||||
def get_fwd_blobs(
|
||||
kernel_filter: Optional[str], receipt, optdim_list, mask_impl,
|
||||
targets: Optional[List[str]] = None
|
||||
kernel_filter: Optional[str],
|
||||
receipt,
|
||||
optdim_list,
|
||||
mask_impl,
|
||||
targets: Optional[List[str]] = None,
|
||||
) -> Tuple[FmhaFwdApiPool, List[FmhaFwdKernel]]:
|
||||
# batch_prefill pipeline uses gfx9-specific async scatter-gather buffer addressing
|
||||
# (amd_buffer_addressing.hpp raw buffer loads) that is not compatible with
|
||||
@@ -837,6 +876,25 @@ def get_fwd_blobs(
|
||||
api_pool.register_traits(k.api_trait())
|
||||
gen.append(k)
|
||||
|
||||
# For page_size < kN0 (tile.F_bn0), also generate a GLOBAL_LOAD_LDS
|
||||
# variant for >2GB KV cache support. The default (BUFFER_LOAD) uses SRD
|
||||
# buffer_load (fast, <2GB). GLOBAL_LOAD_LDS uses global_load_lds_*
|
||||
# (slower, handles >2GB).
|
||||
if page_size < tile.F_bn0:
|
||||
k_global_load = FmhaFwdKernel(
|
||||
F_idx=0,
|
||||
F_hdim=hdim,
|
||||
F_dtype=dtype,
|
||||
F_mode=mode,
|
||||
F_tile=tile,
|
||||
F_pipeline=pipeline,
|
||||
mask_impl=mask_impl,
|
||||
F_page_size=page_size,
|
||||
F_use_global_load=True,
|
||||
)
|
||||
api_pool.register_traits(k_global_load.api_trait())
|
||||
gen.append(k_global_load)
|
||||
|
||||
return (api_pool, gen)
|
||||
|
||||
|
||||
@@ -856,7 +914,9 @@ def write_blobs(
|
||||
optdim_list,
|
||||
mask_impl,
|
||||
) -> None:
|
||||
api_pool, kernels = get_fwd_blobs(kernel_filter, receipt, optdim_list, mask_impl, targets)
|
||||
api_pool, kernels = get_fwd_blobs(
|
||||
kernel_filter, receipt, optdim_list, mask_impl, targets
|
||||
)
|
||||
for kernel in kernels:
|
||||
write_single_fwd_kernel(kernel, output_dir)
|
||||
write_fwd_api(api_pool, output_dir)
|
||||
@@ -871,7 +931,9 @@ def list_blobs(
|
||||
mask_impl,
|
||||
) -> None:
|
||||
with file_path.open("a") as f:
|
||||
_, kernels = get_fwd_blobs(kernel_filter, receipt, optdim_list, mask_impl, targets)
|
||||
_, kernels = get_fwd_blobs(
|
||||
kernel_filter, receipt, optdim_list, mask_impl, targets
|
||||
)
|
||||
for kernel in kernels:
|
||||
f.write((file_path.parent / GEN_DIR / kernel.filename).as_posix() + "\n")
|
||||
f.write((file_path.parent / GEN_DIR / FMHA_FWD_API_FILENAME).as_posix() + "\n")
|
||||
|
||||
@@ -673,6 +673,33 @@ struct fmha_batch_prefill_args
|
||||
ck_tile::index_t nhead_stride_kv_block_descale = 0; // Stride along num_kv_head dimension
|
||||
};
|
||||
|
||||
// Selects the KV-cache load mode for a batch-prefill dispatch arm.
|
||||
// GLOBAL_LOAD_LDS: required when (a) the page is smaller than one K/V tile
|
||||
// so per-page SRD is impossible, AND (b) the total KV-pool byte size
|
||||
// exceeds INT32_MAX so SRD's 32-bit byte offset cannot address it.
|
||||
// BUFFER_LOAD: every other case — the SGPR-resident SRD path is fastest.
|
||||
// Inputs are taken as plain integers so the helper has no template parameter
|
||||
// and can be called from each codegen-emitted dispatcher arm with the arm's
|
||||
// compile-time kN0 / element_bytes substituted as constants.
|
||||
inline ck_tile::BlockAttentionKVCacheLoadModeEnum
|
||||
fmha_batch_prefill_select_kv_load_mode(ck_tile::index_t page_block_size,
|
||||
ck_tile::index_t kN0,
|
||||
ck_tile::index_t num_total_pages,
|
||||
ck_tile::index_t batch_stride_k,
|
||||
ck_tile::index_t element_bytes)
|
||||
{
|
||||
// Promote every operand to long_index_t so overflow is impossible regardless
|
||||
// of multiplication order. A bare `static_cast<long_index_t>(num_total_pages)
|
||||
// * batch_stride_k * element_bytes` only works because of left-to-right
|
||||
// associativity — a future reorder of the operands would silently truncate.
|
||||
const auto kv_pool_bytes = static_cast<ck_tile::long_index_t>(num_total_pages) *
|
||||
static_cast<ck_tile::long_index_t>(batch_stride_k) *
|
||||
static_cast<ck_tile::long_index_t>(element_bytes);
|
||||
return (page_block_size < kN0 && kv_pool_bytes > INT32_MAX)
|
||||
? ck_tile::BlockAttentionKVCacheLoadModeEnum::GLOBAL_LOAD_LDS
|
||||
: ck_tile::BlockAttentionKVCacheLoadModeEnum::BUFFER_LOAD;
|
||||
}
|
||||
|
||||
template <typename FmhaKernel>
|
||||
auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args)
|
||||
{
|
||||
@@ -1457,7 +1484,9 @@ template <ck_tile::index_t HDim_,
|
||||
ck_tile::BlockAttentionKVCacheMemoryLayoutEnum kKVMemoryLayout_ =
|
||||
ck_tile::BlockAttentionKVCacheMemoryLayoutEnum::VECTORIZED_LAYOUT,
|
||||
ck_tile::BlockAttentionKVCacheLookupTableEnum kKVLookupTable_ =
|
||||
ck_tile::BlockAttentionKVCacheLookupTableEnum::SGLANG_PAGE_TABLE_1D>
|
||||
ck_tile::BlockAttentionKVCacheLookupTableEnum::SGLANG_PAGE_TABLE_1D,
|
||||
ck_tile::BlockAttentionKVCacheLoadModeEnum kKVLoadMode_ =
|
||||
ck_tile::BlockAttentionKVCacheLoadModeEnum::BUFFER_LOAD>
|
||||
struct fmha_fwd_batch_prefill_traits_ : public fmha_fwd_traits_<HDim_,
|
||||
DataType_,
|
||||
kIsGroupMode_,
|
||||
@@ -1486,6 +1515,7 @@ struct fmha_fwd_batch_prefill_traits_ : public fmha_fwd_traits_<HDim_,
|
||||
static constexpr auto kKVMemoryLayout = kKVMemoryLayout_;
|
||||
static constexpr auto kKVLookupTable = kKVLookupTable_;
|
||||
static constexpr ck_tile::index_t kPageBlockSize = kPageBlockSize_;
|
||||
static constexpr auto kKVLoadMode = kKVLoadMode_;
|
||||
static_assert(kIsVLayoutRowMajor_, "Batch prefill only supports row-major V layout");
|
||||
};
|
||||
|
||||
|
||||
Reference in New Issue
Block a user