diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_batch_prefill.py b/example/ck_tile/01_fmha/codegen/ops/fmha_batch_prefill.py index 7c3efb9c18..8c006c09db 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_batch_prefill.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_batch_prefill.py @@ -22,8 +22,16 @@ from codegen.cpp_symbol_map import ( QSCALE_CHECK_MAP, QSCALE_MAP, ) +from codegen.arch import ArchTrait from codegen.utils import update_file +# Architecture trait for kernels requiring global_load_lds (CDNA3+). +# Only used for GLOBAL_LOAD_LDS variants; all other kernels are arch-agnostic. +CDNA3_PLUS_ARCH = ArchTrait( + "cdna3_plus", + preprocessor_check="defined(__gfx94__) || defined(__gfx950__)", +) + DTYPE_BITS = { "fp32": 32, "fp16": 16, @@ -34,6 +42,10 @@ DTYPE_BITS = { "bf8": 8, } +# Element size in bytes per dtype, used by the auto-generated dispatcher to +# decide kv_load_mode per-arm (total KV cache bytes vs INT32_MAX). +DTYPE_BYTES = {k: v // 8 for k, v in DTYPE_BITS.items()} + K0_MAX_SUBMAX_MAP = {32: 32, 64: 64, 96: 128, 128: 128, 256: 256} SUPPORTED_PAGE_SIZE = [1, 16, 1024] @@ -47,6 +59,10 @@ KV_LOOKUP_TABLE_ENUM_MAP = { "vllm": "ck_tile::BlockAttentionKVCacheLookupTableEnum::VLLM_BLOCK_TABLE_2D", "sglang": "ck_tile::BlockAttentionKVCacheLookupTableEnum::SGLANG_PAGE_TABLE_1D", } +KV_LOAD_MODE_ENUM_MAP = { + False: "ck_tile::BlockAttentionKVCacheLoadModeEnum::BUFFER_LOAD", + True: "ck_tile::BlockAttentionKVCacheLoadModeEnum::GLOBAL_LOAD_LDS", +} FMHA_BATCH_PREFILL_PIPELINE_MAP = { @@ -61,6 +77,8 @@ FMHA_FWD_KERNEL_HEADER = """// SPDX-License-Identifier: MIT """ FMHA_FWD_KERNEL_BODY = """ +#if !defined(__HIP_DEVICE_COMPILE__) || ({F_arch_check}) + using fmha_dtype_{F_idx} = {F_dtype}; using fmha_block_tile_{F_idx} = ck_tile::sequence<{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}>; @@ -87,7 +105,8 @@ using fmha_trait_{F_idx} = ck_tile::TileFmhaBatchPrefillTraits<{F_spad}, {F_sink}, {F_page_size}, {F_kv_memory_layout}, - {F_kv_lookup_table}>; + {F_kv_lookup_table}, + {F_kv_load_mode}>; using fmha_variant_{F_idx} = ck_tile::ComposedAttention<{F_logits} * ck_tile::LOGITS_SOFT_CAP, CK_TILE_FMHA_FWD_FAST_EXP2>; @@ -125,7 +144,7 @@ using fmha_kernel_{F_idx} = ck_tile::FmhaBatchPrefillWithPagedKVCacheKernel; using trait_{F_idx} = fmha_fwd_batch_prefill_traits_<{F_hdim}, {F_dtype}, {F_mode},{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, - {F_pipeline_enum}, {F_logits}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_dropout}, {F_qscale}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, false, false, {F_sink}, {F_page_size}, {F_kv_memory_layout}, {F_kv_lookup_table}>; + {F_pipeline_enum}, {F_logits}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_dropout}, {F_qscale}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, false, false, {F_sink}, {F_page_size}, {F_kv_memory_layout}, {F_kv_lookup_table}, {F_kv_load_mode}>; #include @@ -140,10 +159,13 @@ float fmha_batch_prefill_(const ck_tile::stream_config& s, fmha_b constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{{}}, grids, blocks, 0, kargs)); }} + +#endif // !defined(__HIP_DEVICE_COMPILE__) || ({F_arch_check}) """ FMHA_FWD_API_FILENAME = "fmha_batch_prefill_api.cpp" FMHA_FWD_API = """ +#include #include namespace {{ @@ -194,6 +216,7 @@ float fmha_batch_prefill(fmha_batch_prefill_traits t, fmha_batch_prefill_args a, """ FMHA_FWD_API_PER_DTYPE = """ {F_if}(t.data_type.compare(\"{F_dtype}\") == 0){{ + constexpr int kElementBytes = {F_element_bytes}; {F_hdim_case} }} """ @@ -203,8 +226,8 @@ FMHA_FWD_API_PER_HDIM_CASE = """ {F_if} (t.hdim_q <= {F_hdim} && t.hdim_v """ FMHA_FWD_API_INNER_DISPATCH = """ {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && (t.has_logits_soft_cap == {F_logits}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.has_dropout == {F_dropout}) && (t.qscale_type == {F_qscale_check}) && (t.has_sink == {F_sink}) && - ({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck}) && ({F_constraint}) && (t.kv_memory_layout == {F_kv_memory_layout}) && (t.kv_lookup_table == {F_kv_lookup_table}) && (t.page_size == {F_page_size})) {{ - using trait_ = fmha_fwd_batch_prefill_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, {F_lse}, {F_dropout}, {F_qscale}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, false, false, {F_sink}, {F_page_size}, {F_kv_memory_layout}, {F_kv_lookup_table}>; + ({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck}) && ({F_constraint}) && (t.kv_memory_layout == {F_kv_memory_layout}) && (t.kv_lookup_table == {F_kv_lookup_table}) && (t.page_size == {F_page_size}) && (fmha_batch_prefill_select_kv_load_mode(a.page_block_size, {F_bn0}, a.num_total_pages, a.batch_stride_k, kElementBytes) == {F_kv_load_mode})) {{ + using trait_ = fmha_fwd_batch_prefill_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, {F_lse}, {F_dropout}, {F_qscale}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, false, false, {F_sink}, {F_page_size}, {F_kv_memory_layout}, {F_kv_lookup_table}, {F_kv_load_mode}>; return fmha_batch_prefill_(s, a); }} """ @@ -253,12 +276,14 @@ class FmhaFwdApiTrait: kv_memory_layout: str kv_lookup_table: str page_size: int = 1 # page block size + use_global_load: bool = False # use global_load_lds_* for >2GB KV cache @property def name(self) -> str: return ( f"{self.hdim}-{self.dtype}-{self.mode}-{self.bm0}-{self.bn0}-{self.bk0}-{self.bn0}-{self.bk1}-{self.bk0max}-" + f"{self.vlayout}-{self.logits}-{self.mask}-{self.bias}-{self.lse}-{self.dropout}-{self.qscale}-{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}-{self.kv_memory_layout}-{self.kv_lookup_table}-ps{self.page_size}" + + ("-gload" if self.use_global_load else "-bload") ) @property @@ -481,6 +506,7 @@ class FmhaFwdApiPool: ], F_page_size=trait.page_size, F_sink=BOOL_MAP[trait.sink], + F_kv_load_mode=KV_LOAD_MODE_ENUM_MAP[trait.use_global_load], ) if_j = "if" if j == 0 else "else if" per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format( @@ -488,7 +514,10 @@ class FmhaFwdApiPool: ) if_i = "if" if i == 0 else "else if" per_dtypes = per_dtypes + FMHA_FWD_API_PER_DTYPE.format( - F_if=if_i, F_dtype=dtype, F_hdim_case=per_hdim_case + F_if=if_i, + F_dtype=dtype, + F_element_bytes=DTYPE_BYTES[dtype], + F_hdim_case=per_hdim_case, ) if not per_dtypes: # empty string we add some ignore to suppress warning in api @@ -539,6 +568,7 @@ class FmhaFwdKernel: F_pipeline: FmhaFwdPipeline mask_impl: str F_page_size: int = 1 # page block size + F_use_global_load: bool = False # use global_load_lds_* for >2GB KV cache @property def template(self) -> str: @@ -588,6 +618,10 @@ class FmhaFwdKernel: F_pipeline=FMHA_BATCH_PREFILL_PIPELINE_MAP[self.F_pipeline.tag], F_page_size=self.F_page_size, F_sink=BOOL_MAP[self.F_pipeline.F_sink], + F_kv_load_mode=KV_LOAD_MODE_ENUM_MAP[self.F_use_global_load], + F_arch_check=CDNA3_PLUS_ARCH.preprocessor_check + if self.F_use_global_load + else "true", ) @property @@ -595,6 +629,7 @@ class FmhaFwdKernel: # TODO: we don't encode idx here return ( f"fmha_batch_prefill_d{self.F_hdim}_{self.F_dtype}_{self.F_mode}_ps{self.F_page_size}_" + + ("gload_" if self.F_use_global_load else "bload_") + self.F_tile.name + "_" + self.F_pipeline.name @@ -632,6 +667,7 @@ class FmhaFwdKernel: kv_memory_layout=self.F_pipeline.F_kv_memory_layout, kv_lookup_table=self.F_pipeline.F_kv_lookup_table, page_size=self.F_page_size, + use_global_load=self.F_use_global_load, ) @@ -714,8 +750,11 @@ class CustomFactory(KernelComponentFactory): def get_fwd_blobs( - kernel_filter: Optional[str], receipt, optdim_list, mask_impl, - targets: Optional[List[str]] = None + kernel_filter: Optional[str], + receipt, + optdim_list, + mask_impl, + targets: Optional[List[str]] = None, ) -> Tuple[FmhaFwdApiPool, List[FmhaFwdKernel]]: # batch_prefill pipeline uses gfx9-specific async scatter-gather buffer addressing # (amd_buffer_addressing.hpp raw buffer loads) that is not compatible with @@ -837,6 +876,25 @@ def get_fwd_blobs( api_pool.register_traits(k.api_trait()) gen.append(k) + # For page_size < kN0 (tile.F_bn0), also generate a GLOBAL_LOAD_LDS + # variant for >2GB KV cache support. The default (BUFFER_LOAD) uses SRD + # buffer_load (fast, <2GB). GLOBAL_LOAD_LDS uses global_load_lds_* + # (slower, handles >2GB). + if page_size < tile.F_bn0: + k_global_load = FmhaFwdKernel( + F_idx=0, + F_hdim=hdim, + F_dtype=dtype, + F_mode=mode, + F_tile=tile, + F_pipeline=pipeline, + mask_impl=mask_impl, + F_page_size=page_size, + F_use_global_load=True, + ) + api_pool.register_traits(k_global_load.api_trait()) + gen.append(k_global_load) + return (api_pool, gen) @@ -856,7 +914,9 @@ def write_blobs( optdim_list, mask_impl, ) -> None: - api_pool, kernels = get_fwd_blobs(kernel_filter, receipt, optdim_list, mask_impl, targets) + api_pool, kernels = get_fwd_blobs( + kernel_filter, receipt, optdim_list, mask_impl, targets + ) for kernel in kernels: write_single_fwd_kernel(kernel, output_dir) write_fwd_api(api_pool, output_dir) @@ -871,7 +931,9 @@ def list_blobs( mask_impl, ) -> None: with file_path.open("a") as f: - _, kernels = get_fwd_blobs(kernel_filter, receipt, optdim_list, mask_impl, targets) + _, kernels = get_fwd_blobs( + kernel_filter, receipt, optdim_list, mask_impl, targets + ) for kernel in kernels: f.write((file_path.parent / GEN_DIR / kernel.filename).as_posix() + "\n") f.write((file_path.parent / GEN_DIR / FMHA_FWD_API_FILENAME).as_posix() + "\n") diff --git a/example/ck_tile/01_fmha/fmha_fwd.hpp b/example/ck_tile/01_fmha/fmha_fwd.hpp index 6c842def58..98e2df2e1e 100644 --- a/example/ck_tile/01_fmha/fmha_fwd.hpp +++ b/example/ck_tile/01_fmha/fmha_fwd.hpp @@ -673,6 +673,33 @@ struct fmha_batch_prefill_args ck_tile::index_t nhead_stride_kv_block_descale = 0; // Stride along num_kv_head dimension }; +// Selects the KV-cache load mode for a batch-prefill dispatch arm. +// GLOBAL_LOAD_LDS: required when (a) the page is smaller than one K/V tile +// so per-page SRD is impossible, AND (b) the total KV-pool byte size +// exceeds INT32_MAX so SRD's 32-bit byte offset cannot address it. +// BUFFER_LOAD: every other case — the SGPR-resident SRD path is fastest. +// Inputs are taken as plain integers so the helper has no template parameter +// and can be called from each codegen-emitted dispatcher arm with the arm's +// compile-time kN0 / element_bytes substituted as constants. +inline ck_tile::BlockAttentionKVCacheLoadModeEnum +fmha_batch_prefill_select_kv_load_mode(ck_tile::index_t page_block_size, + ck_tile::index_t kN0, + ck_tile::index_t num_total_pages, + ck_tile::index_t batch_stride_k, + ck_tile::index_t element_bytes) +{ + // Promote every operand to long_index_t so overflow is impossible regardless + // of multiplication order. A bare `static_cast(num_total_pages) + // * batch_stride_k * element_bytes` only works because of left-to-right + // associativity — a future reorder of the operands would silently truncate. + const auto kv_pool_bytes = static_cast(num_total_pages) * + static_cast(batch_stride_k) * + static_cast(element_bytes); + return (page_block_size < kN0 && kv_pool_bytes > INT32_MAX) + ? ck_tile::BlockAttentionKVCacheLoadModeEnum::GLOBAL_LOAD_LDS + : ck_tile::BlockAttentionKVCacheLoadModeEnum::BUFFER_LOAD; +} + template auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args) { @@ -1457,7 +1484,9 @@ template + ck_tile::BlockAttentionKVCacheLookupTableEnum::SGLANG_PAGE_TABLE_1D, + ck_tile::BlockAttentionKVCacheLoadModeEnum kKVLoadMode_ = + ck_tile::BlockAttentionKVCacheLoadModeEnum::BUFFER_LOAD> struct fmha_fwd_batch_prefill_traits_ : public fmha_fwd_traits_ +CK_TILE_DEVICE void +async_global_load_lds_dwordxn(void* smem, const void* global_addr, bool_constant = {}) +{ +#if !defined(__gfx94__) && !defined(__gfx950__) + static_assert(always_false_v>, + "global_load_lds requires CDNA3+ (gfx940/gfx950). " + "Ensure kKVLoadMode is BUFFER_LOAD on this architecture."); +#endif + + static_assert(num_dwords == 1 || num_dwords == 4, + "global_load_lds supports num_dwords == 1 or 4 only " + "(2 dwords does not exist on any supported arch; " + "3 dwords only on CDNA4 and unused in FMHA pipeline)"); + +// Inline asm: only the global address is an explicit operand. The LDS +// destination is implicit via M0 (see contract above). `"=r"(smem)` is a +// SSA scheduling anchor only — `smem` is NOT written by this asm; the +// load goes to LDS at `M0[17:2]*4 + offset:0 + ThreadID*4`. +#define CK_TILE_GLOBAL_LOAD_LDS_INSTR(instr) \ + if constexpr(pre_nop) \ + asm volatile("s_nop 4\n" instr " %1, off offset:0" \ + : "=r"(smem) /*scheduling anchor; real LDS dest is M0*/ \ + : "v"(global_addr) \ + : "memory" /*prevents reorder across m0_{set,inc}*/); \ + else \ + asm volatile(instr " %1, off offset:0" \ + : "=r"(smem) /*scheduling anchor; real LDS dest is M0*/ \ + : "v"(global_addr) \ + : "memory" /*prevents reorder across m0_{set,inc}*/); + + if constexpr(num_dwords == 1) + { + CK_TILE_GLOBAL_LOAD_LDS_INSTR("global_load_lds_dword"); + } + else if constexpr(num_dwords == 4) + { + CK_TILE_GLOBAL_LOAD_LDS_INSTR("global_load_lds_dwordx4"); + } +#undef CK_TILE_GLOBAL_LOAD_LDS_INSTR +} + template CK_TILE_DEVICE thread_buffer diff --git a/include/ck_tile/core/tensor/tile_scatter_gather.hpp b/include/ck_tile/core/tensor/tile_scatter_gather.hpp index aa29345892..45131abb97 100644 --- a/include/ck_tile/core/tensor/tile_scatter_gather.hpp +++ b/include/ck_tile/core/tensor/tile_scatter_gather.hpp @@ -45,9 +45,29 @@ template > + typename YsGatherDims = sequence<0>, + bool kUseGlobalLoad_ = false> struct tile_scatter_gather { + static constexpr bool kUseGlobalLoad = kUseGlobalLoad_; + +#if !defined(__gfx94__) && !defined(__gfx950__) + // global_load_lds instruction is only available on CDNA3+ (gfx940/gfx950). + // On other architectures, kUseGlobalLoad must be false. + static_assert(!kUseGlobalLoad_, + "kUseGlobalLoad requires global_load_lds (CDNA3+: gfx940/gfx950). " + "This kernel should not be instantiated on this architecture."); +#endif + + // Empty placeholder used by the SRD instantiation so physical_pages_ and + // page_stride_elements_ occupy zero bytes there (combined with + // [[no_unique_address]] on the member declarations). Access sites are all + // inside `if constexpr(kUseGlobalLoad_)` arms, which compile out in SRD + // mode, so no caller needs to change. + struct gl_field_empty_t + { + }; + using BottomTensorView = remove_reference_t; using WindowLengths = remove_cvref_t; using TileDstr = remove_cvref_t; @@ -233,15 +253,22 @@ struct tile_scatter_gather const BottomTensorIndex& window_origin, const TileDstr& tile_distribution, const PageIdxArray& page_idx, - const ValidArray& valids) + const ValidArray& valids, + index_t page_stride_elements = 0) : bottom_tensor_view_{bottom_tensor_view}, window_lengths_{window_lengths}, window_origin_{window_origin}, tile_dstr_{tile_distribution}, page_idx_{page_idx}, + physical_pages_{}, + page_stride_elements_{}, valids_{valids}, pre_computed_coords_{} { + if constexpr(kUseGlobalLoad_) + { + page_stride_elements_ = page_stride_elements; + } #if 0 // debug // TODO: this use more register for FA, but less register for GEMM // need investigation @@ -357,6 +384,34 @@ struct tile_scatter_gather bottom_tensor_view_.buf_.p_data_ = data; } + // Override buffer size (input in RAW elements, NOT pre-divided by PackedSize) for + // SRD num_records control. Use to set max range when SRD is rebased per-tile + // (page_size >= kN0 path): each rebased SRD only needs to cover one page; without + // this the SRD claims validity for memory beyond the allocated buffer, which can + // fault on gfx950 page-table validation. + // + // Matches buffer_view ctor convention (buffer_view.hpp:245): input is raw element + // count and is divided by PackedSize before being stored. For PackedSize=1 + // (fp16/bf16/fp8) the division is a no-op; for PackedSize=2 (FP4 / packed int4) + // skipping it would over-report num_records by 2x and silently mask OOB on SRD + // reads. batch_prefill currently does not exercise the packed-type path, but this + // setter is generic infrastructure (lives in tile_scatter_gather.hpp) so it must + // honor the same invariant the ctor enforces. + CK_TILE_DEVICE constexpr void set_bottom_tensor_view_buffer_size(index_t size) + { + // Hint the optimizer that size is positive without inserting a runtime + // branch. Using assert() here corrupted gfx950 batch_prefill + // output: the __assert_fail handler's SGPR pressure forced the K-SRD + // register window to be reused as scratch and scattered the SRD writes + // across two conditional branches, which gfx950's packed + // buffer_load_dwordx4 issue window doesn't tolerate (gfx942 absorbs it + // via per-tile single-dword loads). __builtin_assume is hint-only — + // no branch, no scratch SGPRs, no codegen impact. + __builtin_assume(size > 0); + using BufType = remove_cvref_t; + bottom_tensor_view_.buf_.buffer_size_ = size / BufType::PackedSize; + } + // move thread's window adaptor coordinate and bottom tensor coordinate // [p0, p1, ..., y0, y1, ...] ==> [x0, x1, ...] ==> [x0', x1', ...] ==> [offset] template @@ -458,7 +513,21 @@ struct tile_scatter_gather // read from bottom tensor const vector_t vec_value = [&]() { - if constexpr(std::is_same_v) + if constexpr(kUseGlobalLoad_) + { + // Global load mode: 64-bit typed pointer arithmetic + const auto* base_ptr = get_bottom_tensor_view().buf_.p_data_; + const auto physical_page = physical_pages_[idx_gather]; + const auto coord_offset = bottom_tensor_thread_coord.get_offset(); + const long_index_t total_offset = + static_cast(physical_page) * page_stride_elements_ + + coord_offset + page_offset; + const auto* addr = base_ptr + total_offset; + vector_t v; + __builtin_memcpy(&v, addr, sizeof(vector_t)); + return v; + } + else if constexpr(std::is_same_v) { return get_bottom_tensor_view().template get_vectorized_elements( bottom_tensor_thread_coord, @@ -680,7 +749,23 @@ struct tile_scatter_gather const auto page_offset = page_idx_[idx_gather]; // read from bottom tensor - if constexpr(std::is_same_v) + if constexpr(kUseGlobalLoad_) + { + // Global load mode: global_load_lds with 64-bit address + constexpr index_t vector_size = + sizeof(vector_t) / sizeof(uint32_t); // dwords per vector + const auto* base_ptr = get_bottom_tensor_view().buf_.p_data_; + const auto physical_page = physical_pages_[idx_gather]; + const auto coord_offset = bottom_tensor_thread_coord.get_offset(); + const long_index_t total_offset = + static_cast(physical_page) * page_stride_elements_ + + coord_offset + page_offset; + const auto* addr = base_ptr + total_offset; + // global_load_lds takes a byte address; addr (const DataType*) + // converts implicitly to const void*, no explicit cast needed. + async_global_load_lds_dwordxn(smem, addr, pre_nop_); + } + else if constexpr(std::is_same_v) { get_bottom_tensor_view().template async_get_vectorized_elements_raw( smem, bottom_tensor_thread_coord, page_offset, 0, pre_nop_); @@ -1046,6 +1131,13 @@ struct tile_scatter_gather CK_TILE_DEVICE void update_page_idx(const PageIdxArray& new_idx) { page_idx_ = new_idx; } + CK_TILE_DEVICE void update_physical_pages(const PageIdxArray& pages) + { + static_assert(kUseGlobalLoad_, + "global-load mode only; physical_pages_ is unused in SRD mode."); + physical_pages_ = pages; + } + CK_TILE_DEVICE void update_valids(const ValidArray& new_valids) { if constexpr(std::is_same_v == false) @@ -1139,7 +1231,29 @@ struct tile_scatter_gather // 2. thread descriptor for thread tensor in register: [y0, y1, ...] ==> [d] TileDstr tile_dstr_; + // Scatter/gather offsets for each element, set by update_page_idx(). + // SRD mode (kUseGlobalLoad=false): buffer_load(SRD, page_idx_[i] + coord). + // page_idx_[i] = within-page offset when kPageBlockSize >= kN0 (SRD rebased to page base) + // page_idx_[i] = page_base + within-page offset when kPageBlockSize < kN0 (full voffset) + // Global load mode (kUseGlobalLoad=true): page_idx_[i] = within-page offset only. + // Full address = base + physical_pages_[i] * page_stride_elements_ + page_idx_[i] + coord PageIdxArray page_idx_; + + // Physical page indices for global load mode (kUseGlobalLoad=true only). + // Maps each gather element to its physical page in a paged memory pool. + // Updated via update_physical_pages() before each load call. + // SRD mode: collapsed to gl_field_empty_t so the storage disappears. + [[no_unique_address]] std::conditional_t + physical_pages_; + + // Page stride in elements for global load mode (kUseGlobalLoad=true only). + // physical_pages_[i] * page_stride_elements_ gives the page base offset in elements. + // Set at construction time via the make_tile_scatter_gather overload that + // takes bool_constant; immutable thereafter. + // SRD mode: collapsed to gl_field_empty_t so the storage disappears. + [[no_unique_address]] std::conditional_t + page_stride_elements_; + ValidArray valids_; // this contains: @@ -1178,7 +1292,8 @@ template + index_t... YsGatherDims, + bool UseGlobalLoad = false> CK_TILE_DEVICE constexpr auto make_tile_scatter_gather(const TensorView_& tensor_view, const WindowLengths_& window_lengths, @@ -1187,7 +1302,9 @@ make_tile_scatter_gather(const TensorView_& tensor_view, const StaticPageIndexArray_& page_idx, number, number, - sequence) + sequence, + bool_constant = {}, + index_t page_stride_elements = 0) { return tile_scatter_gather, remove_cvref_t, @@ -1196,11 +1313,17 @@ make_tile_scatter_gather(const TensorView_& tensor_view, std::nullptr_t, HsGatherDim, NumCoord, - sequence>{ - tensor_view, window_lengths, origin, tile_distribution, page_idx, nullptr}; + sequence, + UseGlobalLoad>{tensor_view, + window_lengths, + origin, + tile_distribution, + page_idx, + nullptr, + page_stride_elements}; } -// Legacy overload (compatible with original API) +// Legacy overload (compatible with original API, kUseGlobalLoad=false) template +CK_TILE_DEVICE constexpr auto +make_tile_scatter_gather(const TensorView_& tensor_view, + const WindowLengths_& window_lengths, + const multi_index& origin, + const StaticTileDistribution_& tile_distribution, + const StaticPageIndexArray_& page_idx, + bool_constant, + index_t page_stride_elements = 0) +{ + return tile_scatter_gather, + remove_cvref_t, + remove_cvref_t, + remove_cvref_t, + std::nullptr_t, + 0, + 1, + sequence<0>, + UseGlobalLoad>{tensor_view, + window_lengths, + origin, + tile_distribution, + page_idx, + nullptr, + page_stride_elements}; +} + template ` — a value-template that is always `false` but whose +// evaluation is deferred until template instantiation. The canonical use is +// inside the `else` arm of an `if constexpr` chain or under an arch-gated +// `#if` to fire a `static_assert` ONLY when the offending instantiation is +// actually requested, e.g.: +// +// if constexpr (...) { ... } +// else { static_assert(always_false_v, "unsupported T"); } +// +// A bare `static_assert(false, ...)` would fire at template-definition +// parse time on conforming compilers, breaking the whole TU. +template +inline constexpr bool always_false_v = false; + // remove_cvref_t template using remove_reference_t = typename std::remove_reference::type; diff --git a/include/ck_tile/ops/fmha.hpp b/include/ck_tile/ops/fmha.hpp index 8a5d77bf46..59e868f678 100644 --- a/include/ck_tile/ops/fmha.hpp +++ b/include/ck_tile/ops/fmha.hpp @@ -3,6 +3,7 @@ #pragma once #include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp" +#include "ck_tile/ops/fmha/block/block_attention_kv_load_mode_enum.hpp" #include "ck_tile/ops/fmha/block/block_attention_kvcache_layout_enum.hpp" #include "ck_tile/ops/fmha/block/block_attention_quant_scale_enum.hpp" #include "ck_tile/ops/fmha/block/block_dropout.hpp" diff --git a/include/ck_tile/ops/fmha/block/block_attention_kv_load_mode_enum.hpp b/include/ck_tile/ops/fmha/block/block_attention_kv_load_mode_enum.hpp new file mode 100644 index 0000000000..826cd106f1 --- /dev/null +++ b/include/ck_tile/ops/fmha/block/block_attention_kv_load_mode_enum.hpp @@ -0,0 +1,17 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +namespace ck_tile { + +// KV cache load addressing mode selector for batch_prefill / paged-attention pipelines. +// - BUFFER_LOAD: SGPR-based SRD via buffer_load_* (default; 32-bit byte addressing, <2GB pool) +// - GLOBAL_LOAD_LDS: direct global_load_lds_* (64-bit addressing, required for >2GB KV cache) +enum class BlockAttentionKVCacheLoadModeEnum +{ + BUFFER_LOAD = 0, + GLOBAL_LOAD_LDS = 1, +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp index 4f2d3d58c2..8aa6d17dc3 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp @@ -6,6 +6,7 @@ #include "ck_tile/core.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" #include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp" +#include "ck_tile/ops/fmha/block/block_attention_kv_load_mode_enum.hpp" #include "ck_tile/ops/fmha/block/block_attention_kvcache_layout_enum.hpp" #include "ck_tile/ops/fmha/block/block_attention_quant_scale_enum.hpp" #include "ck_tile/ops/fmha/block/block_dropout.hpp" @@ -134,7 +135,8 @@ template + index_t kVectorSize, + bool kUseGlobalLoad_ = false> CK_TILE_HOST_DEVICE void kv_offset_array_transform(const IndexArrayType& physical_pages, const index_t& stride_token, const index_t& stride_page_block, @@ -156,81 +158,65 @@ CK_TILE_HOST_DEVICE void kv_offset_array_transform(const IndexArrayType& physica const index_t& thread_coord_start = coord_vec[kCoordAxis]; constexpr index_t kInPageOffsetMask = (1 << kLog2PageSize) - 1; - if constexpr(kIsKcache) - { - // K cache: per-token lookup - // Each token may be on a different page, so we use physical_pages[k0] for each. - static_for<0, kLoopCount, 1>{}([&](auto k0) { - const index_t global_token_idx = - global_seq_offset + thread_coord_start + kLoopStart + kLoopStride * k0.value; - const index_t token_idx_in_page = global_token_idx & kInPageOffsetMask; + // Addressing strategy — four cases controlled by (kPageBlockSize vs kN0, kUseGlobalLoad_): + // + // Case 1: kPageBlockSize >= kN0 + // SRD is rebased per-tile to the page base (rebase_{k,v}_window in caller). + // Page base is absorbed into the SRD's 48-bit base pointer (SGPR-resident). + // This function writes within-page offset only. + // + // Case 2: kPageBlockSize < kN0 && kUseGlobalLoad_ + // SRD cannot be rebased (multi-page wave). Loads use global_load_lds_*; the full + // 64-bit address is computed by tile_scatter_gather::load() in + // include/ck_tile/core/tensor/tile_scatter_gather.hpp from physical_pages_ + + // page_stride_elements_. This function writes within-page offset only. + // + // Case 3: kPageBlockSize < kN0 && !kUseGlobalLoad_ (kNeedFullOffset == true) + // SRD base is the entire KV buffer; the only place to encode page identity + // is the voffset itself. This function writes the FULL offset: + // page * stride_page_block + within_page + // Limited to <2GB total KV bytes by 32-bit voffset hardware width. + // + // Case 4: kPageBlockSize >= kN0 && kUseGlobalLoad_ + // Not emitted by codegen. Backstop static_assert in + // BlockFmhaBatchPrefillPipelineQRKSVSAsync. + constexpr bool kNeedFullOffset = (kPageBlockSize < kN0) && !kUseGlobalLoad_; - if constexpr(kPageBlockSize >= kN0) + static_for<0, kLoopCount, 1>{}([&](auto k0) { + const index_t global_token_idx = + global_seq_offset + thread_coord_start + kLoopStart + kLoopStride * k0.value; + const index_t token_idx_in_page = global_token_idx & kInPageOffsetMask; + + // Within-page offset (layout-dependent for V cache with VECTORIZED_LAYOUT) + const index_t within_page = [&]() { + if constexpr(!kIsKcache && kKVMemoryLayout == + BlockAttentionKVCacheMemoryLayoutEnum::VECTORIZED_LAYOUT) { - // SRD rebasing mode: within-page offset only. - // The full page base is handled by rebasing the SRD pointer. - kv_offset_vec[k0] = token_idx_in_page * stride_token; + return (token_idx_in_page / kVectorSize) * (stride_token * kVectorSize) + + (token_idx_in_page % kVectorSize); } else { - // Full global offset (original code path for ps1, ps16, etc.) - const index_t physical_page = physical_pages[k0]; - kv_offset_vec[k0] = - physical_page * stride_page_block + token_idx_in_page * stride_token; + return token_idx_in_page * stride_token; } - }); - } - else // V cache - { - // V cache: use physical_pages[k0] for each token - // physical_pages was already populated correctly by load_physical_pages(), handling: - // - page_size=1: page_idx maps token_idx -> physical_page directly - // - V tile crosses pages: per-token page lookup - // - V tile in single page: lane0 lookup with broadcast to all lanes - static_for<0, kLoopCount, 1>{}([&](auto k0) { - const index_t global_token_idx = - global_seq_offset + thread_coord_start + kLoopStart + kLoopStride * k0.value; - const index_t token_idx_in_page = global_token_idx & kInPageOffsetMask; + }(); - if constexpr(kPageBlockSize >= kN0) - { - // SRD rebasing mode: within-page offset only. - // The full page base is handled by rebasing the SRD pointer. - if constexpr(kKVMemoryLayout == - BlockAttentionKVCacheMemoryLayoutEnum::VECTORIZED_LAYOUT) - { - const index_t token_offset = - (token_idx_in_page / kVectorSize) * (stride_token * kVectorSize) + - (token_idx_in_page % kVectorSize); - kv_offset_vec[k0] = token_offset; - } - else - { - kv_offset_vec[k0] = token_idx_in_page * stride_token; - } - } - else - { - // Full global offset (original code path for ps1, ps16, etc.) - const index_t physical_page = physical_pages[k0]; - const long_index_t page_base_offset = - static_cast(physical_page) * stride_page_block; - - if constexpr(kKVMemoryLayout == - BlockAttentionKVCacheMemoryLayoutEnum::VECTORIZED_LAYOUT) - { - const index_t token_offset = - (token_idx_in_page / kVectorSize) * (stride_token * kVectorSize) + - (token_idx_in_page % kVectorSize); - kv_offset_vec[k0] = page_base_offset + token_offset; - } - else - { - kv_offset_vec[k0] = page_base_offset + token_idx_in_page * stride_token; - } - } - }); - } + // SRD + page_size < kN0: add page base to form complete voffset for buffer_load. + // + // 32-bit by hardware: SRD buffer_load voffset is fundamentally 32-bit (CDNA3 MUBUF + // microcode format), so this branch is only reachable when total KV bytes fit in + // INT32_MAX. The kUseGlobalLoad_ template path handles the >2GB case via 64-bit + // global_load_lds_*; widening kv_offset_vec here would not lift the 2GB ceiling + // because the hardware truncates voffset regardless. + if constexpr(kNeedFullOffset) + { + kv_offset_vec[k0] = physical_pages[k0] * stride_page_block + within_page; + } + else + { + kv_offset_vec[k0] = within_page; + } + }); } // a variation of qr/ks/vs, where we use async copy to load k (potentially v in the future) @@ -270,10 +256,21 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync static constexpr index_t kSubQKHeaddim = BlockFmhaShape::kSubQKHeaddim; static constexpr index_t kPageBlockSize = Problem::kPageBlockSize; static constexpr index_t kVectorSize = Problem::kVectorSize; - static constexpr auto I0 = number<0>{}; - static constexpr auto I1 = number<1>{}; - static constexpr auto I2 = number<2>{}; - static constexpr auto I3 = number<3>{}; + // Single load-mode selector for the whole pipeline. GLOBAL_LOAD_LDS routes K/V + // tiles through global_load_lds_* (handles >2GB KV cache); BUFFER_LOAD uses SRD + // buffer_load_*. The enum is named at the trait/Problem level; internally we + // derive a bool helper to keep `if constexpr` sites narrow. Codegen only emits + // GLOBAL_LOAD_LDS arms when page_size < kN0; the static_assert is a backstop. + static constexpr auto kKVLoadMode = Problem::kKVLoadMode; + static constexpr bool kUseGlobalLoad = + (kKVLoadMode == BlockAttentionKVCacheLoadModeEnum::GLOBAL_LOAD_LDS); + static_assert(!kUseGlobalLoad || (kPageBlockSize < kN0), + "GLOBAL_LOAD_LDS load mode is only valid when kPageBlockSize < kN0; " + "codegen should not emit this instantiation otherwise."); + static constexpr auto I0 = number<0>{}; + static constexpr auto I1 = number<1>{}; + static constexpr auto I2 = number<2>{}; + static constexpr auto I3 = number<3>{}; static_assert(kSubQKHeaddim <= 256, "hdim bigger than 256 is not suitable for this pipeline!"); static constexpr bool kIsGroupMode = Problem::kIsGroupMode; @@ -626,19 +623,26 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync kKVMemoryLayout, true, kN0, - kVectorSize>( + kVectorSize, + kUseGlobalLoad>( k_physical_pages, stride_k, page_stride_k, k_coord, k_offsets, current_seq_k); auto k_dram_window = make_tile_scatter_gather(k_dram_block_window.get_bottom_tensor_view(), k_dram_block_window.get_window_lengths(), k_dram_block_window.get_window_origin(), k_dist, - k_offsets); // K DRAM tile window for + k_offsets, + bool_constant{}, + page_stride_k); + if constexpr(kUseGlobalLoad) + { + k_dram_window.update_physical_pages(k_physical_pages); + } k_dram_window.init_raw(); - // SRD rebasing: move the buffer descriptor base pointer to each page's start address - // using 48-bit pointer arithmetic, so voffset only needs the small within-page offset. - // Only applies when kPageBlockSize >= kN0 (all threads in a wave access the same page). + // SRD rebasing for K: only for page_size >= kN0 (all threads on same page). + // For page_size < kN0: either flat loads (kUseGlobalLoad) or full offsets handle + // addressing. auto rebase_k_window = [&](auto& window, index_t physical_page) { if constexpr(kPageBlockSize >= kN0) { @@ -649,24 +653,36 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync const auto* page_ptr = base_ptr + static_cast(physical_page) * page_stride_k; window.set_bottom_tensor_view_data_ptr(page_ptr); + // Limit SRD num_records to one page worth of elements. + // Without this, the SRD claims validity for [page_ptr, page_ptr + + // full_buffer_size), which extends far beyond the allocated buffer when rebased to + // high pages. On gfx950, the hardware may validate the full SRD range against page + // table permissions, causing faults on freed/protected memory beyond the buffer. + window.set_bottom_tensor_view_buffer_size(page_stride_k); window.init_raw(); } }; + // SRD rebasing for V: only for page_size >= kN0 (all threads on same page). + // For page_size < kN0: either flat loads (kUseGlobalLoad) or full offsets handle + // addressing. auto rebase_v_window = [&](auto& window, index_t physical_page) { if constexpr(kPageBlockSize >= kN0) { + // readfirstlane: make physical_page provably wave-uniform so the + // resulting SRD lands in SGPRs (required by buffer load instructions). physical_page = __builtin_amdgcn_readfirstlane(physical_page); const auto* base_ptr = v_dram_block_window_tmp.get_bottom_tensor_view().buf_.p_data_; const auto* page_ptr = base_ptr + static_cast(physical_page) * page_stride_v; window.set_bottom_tensor_view_data_ptr(page_ptr); + window.set_bottom_tensor_view_buffer_size(page_stride_v); window.init_raw(); } }; - // Initial K SRD rebase + // Initial K SRD rebase (no-op for page_size < kN0, uses flat loads instead) rebase_k_window(k_dram_window, k_physical_pages[number<0>{}]); constexpr auto k_oob_ck = bool_constant{}; @@ -874,12 +890,13 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync kKVMemoryLayout, false, kN0, - kVectorSize>(v_physical_pages_k2, - stride_v, - page_stride_v, - v_coord, - v_offsets_k2, - current_seq_k); + kVectorSize, + kUseGlobalLoad>(v_physical_pages_k2, + stride_v, + page_stride_v, + v_coord, + v_offsets_k2, + current_seq_k); static_for<0, V_KIterInner, 1>{}([&](auto k1) { constexpr auto idx = number{}; @@ -899,9 +916,20 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync kKVMemoryLayout, false, kN0, - kVectorSize>( + kVectorSize, + kUseGlobalLoad>( v_physical_pages, stride_v, page_stride_v, v_coord, v_offsets, current_seq_k); } + + // v_offsets semantics — see the four-case addressing-strategy block above + // kNeedFullOffset in kv_offset_array_transform. Three cases reach this lambda: + // Case 1 (kPageBlockSize >= kN0): within-page offset; page base in SRD. + // Case 2 (page_size < kN0, kUseGlobalLoad): within-page offset; page base computed + // by tile_scatter_gather::load() from + // physical_pages_. + // Case 3 (page_size < kN0, !kUseGlobalLoad, == kNeedFullOffset): + // FULL offset (page * stride + within), + // carried in the 32-bit voffset (<2GB cap). }; // Prefetch V physical pages early to hide buffer load latency @@ -915,11 +943,32 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync v_offsets, number<1>{}, // HsGatherDim number<1>{}, // NumCoord - VPageIndexYDims); + VPageIndexYDims, + bool_constant{}, + page_stride_v); + if constexpr(kUseGlobalLoad) + { + v_dram_window.update_physical_pages(v_physical_pages); + } - // Initial V SRD rebase + // Initial V SRD rebase. Single source of truth: rebase_v_window's own + // `if constexpr(kPageBlockSize >= kN0)` makes this a no-op for case 2/3. + // Do not re-add an outer guard here — it would duplicate the inner check + // and drift if the lambda's gating condition ever changes. rebase_v_window(v_dram_window, v_physical_pages[number<0>{}]); + // Save the *current* tile's V physical pages into v_dram_window before + // prefetch_v_physical_pages overwrites the v_physical_pages buffer with the + // *next* tile's pages. Case-2 only (kUseGlobalLoad); case-1/3 don't read + // physical_pages_ from the window. Encapsulating the save+prefetch pair + // here makes the ordering invariant unmissable when a fourth prefetch site + // is added later. + auto save_and_prefetch_v_pages = [&](auto k_loop_start) { + if constexpr(kUseGlobalLoad) + v_dram_window.update_physical_pages(v_physical_pages); + prefetch_v_physical_pages(k_loop_start); + }; + // prefetch K tile async_load_tile_raw( k_lds_store(LdsSeq.at(number<0>{})), k_dram_window, number<-1>{}, k_oob_ck, k_pre_np); @@ -972,7 +1021,7 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync } // Prefetch V physical pages early - overlaps with GEMM0 computation - prefetch_v_physical_pages(number{}); + save_and_prefetch_v_pages(number{}); // STAGE 1, QK gemm clear_tile(s_acc); // initialize C @@ -1166,7 +1215,7 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync // Prefetch V physical pages early - overlaps with softmax computation if constexpr(k1_loops > 1) { - prefetch_v_physical_pages(number<2 * kK1>{}); + save_and_prefetch_v_pages(number<2 * kK1>{}); } auto m_local = block_tile_reduce( @@ -1220,8 +1269,7 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync v_dram_window, {0, kK1}); // will have scratch if move this right after load_tile(v_dram)... - v_buf = load_tile( - v_dram_window, number<-1>{}, bool_constant{}); // load next v_buf + v_buf = load_tile(v_dram_window, number<-1>{}, bool_constant{}); update_v_offsets(number<2 * kK1>{}); v_dram_window.update_page_idx(v_offsets); rebase_v_window(v_dram_window, v_physical_pages[number<0>{}]); @@ -1390,8 +1438,7 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync static_for<0, k1_loops - 1, 1>{}([&](auto i_k1) { if constexpr(i_k1 != 0 && i_k1 < k1_loops - 1) { - v_buf = load_tile( - v_dram_window, number<-1>{}, bool_constant{}); // load next v_buf + v_buf = load_tile(v_dram_window, number<-1>{}, bool_constant{}); // Update V offsets using previously prefetched physical pages update_v_offsets(number<(2 + i_k1.value) * kK1>{}); v_dram_window.update_page_idx(v_offsets); @@ -1401,7 +1448,7 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync // Prefetch V physical pages for NEXT iteration - overlaps with GEMM1 if constexpr(i_k1 + 1 < k1_loops - 1) { - prefetch_v_physical_pages(number<(2 + i_k1.value + 1) * kK1>{}); + save_and_prefetch_v_pages(number<(2 + i_k1.value + 1) * kK1>{}); } block_sync_lds(); @@ -1481,9 +1528,12 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync kKVMemoryLayout, true, kN0, - kVectorSize>( + kVectorSize, + kUseGlobalLoad>( k_physical_pages, stride_k, page_stride_k, k_coord, k_offsets, current_seq_k); k_dram_window.update_page_idx(k_offsets); + if constexpr(kUseGlobalLoad) + k_dram_window.update_physical_pages(k_physical_pages); rebase_k_window(k_dram_window, k_physical_pages[number<0>{}]); // After sink→window transition (i_total_loops == num_sink_loop), V window diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp index 87db7b85b9..c441f57c86 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp @@ -117,6 +117,12 @@ struct BlockFmhaBatchPrefillPipelineProblem static_assert((kPageBlockSize & (kPageBlockSize - 1)) == 0, "kPageBlockSize must be power of two"); + // KV cache load addressing mode. GLOBAL_LOAD_LDS handles >2GB pools via + // 64-bit addressing; BUFFER_LOAD (default) uses SRD buffer_load for the + // <2GB fast path. The 2GB bound = INT32_MAX byte offset, matching CK's + // existing TwoGB convention. + static constexpr auto kKVLoadMode = Traits_::kKVLoadMode; + static constexpr index_t kVectorSize = 16 / sizeof(KDataType_); // Dwordx4 static constexpr auto kKVMemoryLayout = Traits_::kKVMemoryLayout; static constexpr auto kKVLookupTable = Traits_::kKVLookupTable; diff --git a/include/ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp b/include/ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp index 7df39c3d11..e7370cdb65 100644 --- a/include/ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp +++ b/include/ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp @@ -5,6 +5,7 @@ #include "ck_tile/core.hpp" #include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp" +#include "ck_tile/ops/fmha/block/block_attention_kv_load_mode_enum.hpp" #include "ck_tile/ops/fmha/block/block_attention_kvcache_layout_enum.hpp" #include "ck_tile/ops/fmha/block/block_attention_quant_scale_enum.hpp" #include "ck_tile/ops/fmha/block/block_rotary_embedding.hpp" @@ -58,7 +59,9 @@ template + BlockAttentionKVCacheLookupTableEnum::SGLANG_PAGE_TABLE_1D, + BlockAttentionKVCacheLoadModeEnum kKVLoadMode_ = + BlockAttentionKVCacheLoadModeEnum::BUFFER_LOAD> struct TileFmhaBatchPrefillTraits : public TileFmhaTraits