mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-13 17:55:48 +00:00
[CK_TILE] fix(fmha): support >2GB KV cache in batch prefill via template dispatch (#6653)
## Motivation
The CK batch prefill kernel previously failed (silent overflow + page
faults) when the KV cache exceeded 2 GB, blocking long-context inference
workloads (e.g., 128K+ token contexts with paged KV).
Two distinct failure modes were addressed:
1. **>4GB SRD overflow (`page_size < kN0`):** The SRD
`buffer_load_dwordx4` path uses a 32-bit `voffset` register; for small
page sizes the rebased SRD spans the full KV pool and the offset wraps
past 2 GB, corrupting K/V loads.
2. **gfx950 page-table fault (`page_size >= kN0`):** On CDNA4 the
hardware validates the **full SRD `num_records` range** against
page-table permissions (CDNA3 only checks per-instruction `voffset`).
After per-tile SRD rebase, an un-trimmed `num_records` field extends
past the live page and faults on freed/protected memory.
## Technical Details
**Two-mode `tile_scatter_gather` selected by the `kUseGlobalLoad`
template parameter:**
| Case | `page_size` | KV cache size | Mode | Load path | Addressing |
|---|---|---|---|---|---|
| 1 | `>= kN0` (large pages) | any | SRD (`kUseGlobalLoad=false`) |
`buffer_load_dwordx4` | 32-bit `voffset`, bounded by per-page rebase |
| 2 | `< kN0` (small pages) | `<= 2 GB` | SRD (`kUseGlobalLoad=false`) |
`buffer_load_dwordx4` | 32-bit `voffset`, fits in INT32 byte range |
| 3 | `< kN0` (small pages) | `> 2 GB` | Global-load
(`kUseGlobalLoad=true`) | `async_load_tile_raw_flat` (K) +
`load_tile_flat` (V) | 64-bit |
**Dispatch:** the auto-gen API layer (`fmha_batch_prefill.py`) selects
the kernel instantiation at launch from `(page_block_size,
num_total_pages * batch_stride_k * kElementBytes)`, so the small-page
penalty is paid only when correctness requires it.
**gfx950 SRD `num_records` trimming:** in the K and V rebase lambdas of
`block_fmha_batch_prefill_pipeline_qr_ks_vs_async`,
`set_bottom_tensor_view_buffer_size(page_stride_k/v)` is called after
each rebase to constrain `num_records` to the live page. Required for
CDNA4 page-table validation; harmless on CDNA3.
**Pipeline sync for the global-load path:**
- V uses synchronous `load_tile_flat`; K uses
`async_load_tile_raw_flat`.
- `v_physical_pages_current` is double-buffered so the V flat load
doesn't race against the next iteration's K rebase computation.
**Arch guards:** `global_load_lds` intrinsics are gated to `__gfx94__` /
`__gfx950__` (CDNA3+). Other architectures hit a `dependent_false`
static_assert with a descriptive message.
**Device-side assertion convention:** SRD setters use
`__builtin_assume(cond)` (hint-only) rather than `<cassert>`'s
`assert()`. The latter introduces an `__assert_fail` call whose register
pressure scatters the K-SRD scalar register window across conditional
branches, corrupting `buffer_load_dwordx4` on gfx950.
## Test Plan
Tested on both MI308 (gfx942) and MI355 (gfx950) via the aiter wrapper
test suite. All coverage lives in **`op_tests/test_batch_prefill.py`**:
- **Functional matrix (96 cases)** — `test_batch_prefill`: `page_size ∈
{1, 16, 1024}` × `kv_layout ∈ {linear, vectorized}` × `dtype ∈ {bf16,
fp8 quant variants}` × `causal` × `soft_cap` × `LSE` × `batch_size ∈ {1,
4}` (parametrized to exercise per-sequence SRD rebase across batch
boundaries).
- **>2 GB coverage** — `test_batch_prefill_large_kvcache`: extended to
allocate a 5 GB+ KV cache pool and exercise both `kUseGlobalLoad=true`
(small-page) and `kUseGlobalLoad=false` (large-page rebase) paths.
Includes both single-batch and multi-batch (`batch_size=4`) cases to
exercise per-sequence SRD rebase across the >2 GB pool.
- Numerical reference: PyTorch SDPA, per-batch loop with `atol` / `rtol`
from the existing batch prefill test harness.
## Test Result
| Arch | `test_batch_prefill` | `test_batch_prefill_large_kvcache` (>2
GB) |
|------|----------------------|---------------------|
| MI308 (gfx942) | All passed | Passed |
| MI355 (gfx950) | All passed | Passed |
**Performance impact (gfx950, hot SRD path):**
- +2.67% kernel-time on `seqlen=1024 / page_sz=1024 / bf16 / sglang /
causal / soft_cap=30`, attributable in full to the two
`set_bottom_tensor_view_buffer_size` calls in the K/V rebase lambdas
(5-run median, signal/noise ≈ 9×).
- This cost is **mandatory for gfx950 correctness** on >2 GB workloads —
removing the setters re-introduces page-faults.
- gfx942: 0 regressions in the same range (all configs ≤ +0.97%).
## Submission Checklist
- [ ] Look over the contributing guidelines at
https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
This commit is contained in:
@@ -22,8 +22,16 @@ from codegen.cpp_symbol_map import (
|
||||
QSCALE_CHECK_MAP,
|
||||
QSCALE_MAP,
|
||||
)
|
||||
from codegen.arch import ArchTrait
|
||||
from codegen.utils import update_file
|
||||
|
||||
# Architecture trait for kernels requiring global_load_lds (CDNA3+).
|
||||
# Only used for GLOBAL_LOAD_LDS variants; all other kernels are arch-agnostic.
|
||||
CDNA3_PLUS_ARCH = ArchTrait(
|
||||
"cdna3_plus",
|
||||
preprocessor_check="defined(__gfx94__) || defined(__gfx950__)",
|
||||
)
|
||||
|
||||
DTYPE_BITS = {
|
||||
"fp32": 32,
|
||||
"fp16": 16,
|
||||
@@ -34,6 +42,10 @@ DTYPE_BITS = {
|
||||
"bf8": 8,
|
||||
}
|
||||
|
||||
# Element size in bytes per dtype, used by the auto-generated dispatcher to
|
||||
# decide kv_load_mode per-arm (total KV cache bytes vs INT32_MAX).
|
||||
DTYPE_BYTES = {k: v // 8 for k, v in DTYPE_BITS.items()}
|
||||
|
||||
K0_MAX_SUBMAX_MAP = {32: 32, 64: 64, 96: 128, 128: 128, 256: 256}
|
||||
|
||||
SUPPORTED_PAGE_SIZE = [1, 16, 1024]
|
||||
@@ -47,6 +59,10 @@ KV_LOOKUP_TABLE_ENUM_MAP = {
|
||||
"vllm": "ck_tile::BlockAttentionKVCacheLookupTableEnum::VLLM_BLOCK_TABLE_2D",
|
||||
"sglang": "ck_tile::BlockAttentionKVCacheLookupTableEnum::SGLANG_PAGE_TABLE_1D",
|
||||
}
|
||||
KV_LOAD_MODE_ENUM_MAP = {
|
||||
False: "ck_tile::BlockAttentionKVCacheLoadModeEnum::BUFFER_LOAD",
|
||||
True: "ck_tile::BlockAttentionKVCacheLoadModeEnum::GLOBAL_LOAD_LDS",
|
||||
}
|
||||
|
||||
|
||||
FMHA_BATCH_PREFILL_PIPELINE_MAP = {
|
||||
@@ -61,6 +77,8 @@ FMHA_FWD_KERNEL_HEADER = """// SPDX-License-Identifier: MIT
|
||||
"""
|
||||
|
||||
FMHA_FWD_KERNEL_BODY = """
|
||||
#if !defined(__HIP_DEVICE_COMPILE__) || ({F_arch_check})
|
||||
|
||||
using fmha_dtype_{F_idx} = {F_dtype};
|
||||
|
||||
using fmha_block_tile_{F_idx} = ck_tile::sequence<{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}>;
|
||||
@@ -87,7 +105,8 @@ using fmha_trait_{F_idx} = ck_tile::TileFmhaBatchPrefillTraits<{F_spad},
|
||||
{F_sink},
|
||||
{F_page_size},
|
||||
{F_kv_memory_layout},
|
||||
{F_kv_lookup_table}>;
|
||||
{F_kv_lookup_table},
|
||||
{F_kv_load_mode}>;
|
||||
|
||||
using fmha_variant_{F_idx} = ck_tile::ComposedAttention<{F_logits} * ck_tile::LOGITS_SOFT_CAP, CK_TILE_FMHA_FWD_FAST_EXP2>;
|
||||
|
||||
@@ -125,7 +144,7 @@ using fmha_kernel_{F_idx} =
|
||||
ck_tile::FmhaBatchPrefillWithPagedKVCacheKernel<fmha_pipeline_{F_idx}, fmha_epilogue_{F_idx}>;
|
||||
|
||||
using trait_{F_idx} = fmha_fwd_batch_prefill_traits_<{F_hdim}, {F_dtype}, {F_mode},{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout},
|
||||
{F_pipeline_enum}, {F_logits}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_dropout}, {F_qscale}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, false, false, {F_sink}, {F_page_size}, {F_kv_memory_layout}, {F_kv_lookup_table}>;
|
||||
{F_pipeline_enum}, {F_logits}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_dropout}, {F_qscale}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, false, false, {F_sink}, {F_page_size}, {F_kv_memory_layout}, {F_kv_lookup_table}, {F_kv_load_mode}>;
|
||||
|
||||
#include <iostream>
|
||||
|
||||
@@ -140,10 +159,13 @@ float fmha_batch_prefill_<trait_{F_idx}>(const ck_tile::stream_config& s, fmha_b
|
||||
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
|
||||
return ck_tile::launch_kernel(s, ck_tile::make_kernel<kBlockPerCu>(k_{{}}, grids, blocks, 0, kargs));
|
||||
}}
|
||||
|
||||
#endif // !defined(__HIP_DEVICE_COMPILE__) || ({F_arch_check})
|
||||
"""
|
||||
|
||||
FMHA_FWD_API_FILENAME = "fmha_batch_prefill_api.cpp"
|
||||
FMHA_FWD_API = """
|
||||
#include <cstdint>
|
||||
#include <cstdio>
|
||||
|
||||
namespace {{
|
||||
@@ -194,6 +216,7 @@ float fmha_batch_prefill(fmha_batch_prefill_traits t, fmha_batch_prefill_args a,
|
||||
"""
|
||||
|
||||
FMHA_FWD_API_PER_DTYPE = """ {F_if}(t.data_type.compare(\"{F_dtype}\") == 0){{
|
||||
constexpr int kElementBytes = {F_element_bytes};
|
||||
{F_hdim_case}
|
||||
}}
|
||||
"""
|
||||
@@ -203,8 +226,8 @@ FMHA_FWD_API_PER_HDIM_CASE = """ {F_if} (t.hdim_q <= {F_hdim} && t.hdim_v
|
||||
"""
|
||||
|
||||
FMHA_FWD_API_INNER_DISPATCH = """ {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && (t.has_logits_soft_cap == {F_logits}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.has_dropout == {F_dropout}) && (t.qscale_type == {F_qscale_check}) && (t.has_sink == {F_sink}) &&
|
||||
({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck}) && ({F_constraint}) && (t.kv_memory_layout == {F_kv_memory_layout}) && (t.kv_lookup_table == {F_kv_lookup_table}) && (t.page_size == {F_page_size})) {{
|
||||
using trait_ = fmha_fwd_batch_prefill_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, {F_lse}, {F_dropout}, {F_qscale}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, false, false, {F_sink}, {F_page_size}, {F_kv_memory_layout}, {F_kv_lookup_table}>;
|
||||
({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck}) && ({F_constraint}) && (t.kv_memory_layout == {F_kv_memory_layout}) && (t.kv_lookup_table == {F_kv_lookup_table}) && (t.page_size == {F_page_size}) && (fmha_batch_prefill_select_kv_load_mode(a.page_block_size, {F_bn0}, a.num_total_pages, a.batch_stride_k, kElementBytes) == {F_kv_load_mode})) {{
|
||||
using trait_ = fmha_fwd_batch_prefill_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, {F_lse}, {F_dropout}, {F_qscale}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, false, false, {F_sink}, {F_page_size}, {F_kv_memory_layout}, {F_kv_lookup_table}, {F_kv_load_mode}>;
|
||||
return fmha_batch_prefill_<trait_>(s, a);
|
||||
}}
|
||||
"""
|
||||
@@ -253,12 +276,14 @@ class FmhaFwdApiTrait:
|
||||
kv_memory_layout: str
|
||||
kv_lookup_table: str
|
||||
page_size: int = 1 # page block size
|
||||
use_global_load: bool = False # use global_load_lds_* for >2GB KV cache
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return (
|
||||
f"{self.hdim}-{self.dtype}-{self.mode}-{self.bm0}-{self.bn0}-{self.bk0}-{self.bn0}-{self.bk1}-{self.bk0max}-"
|
||||
+ f"{self.vlayout}-{self.logits}-{self.mask}-{self.bias}-{self.lse}-{self.dropout}-{self.qscale}-{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}-{self.kv_memory_layout}-{self.kv_lookup_table}-ps{self.page_size}"
|
||||
+ ("-gload" if self.use_global_load else "-bload")
|
||||
)
|
||||
|
||||
@property
|
||||
@@ -481,6 +506,7 @@ class FmhaFwdApiPool:
|
||||
],
|
||||
F_page_size=trait.page_size,
|
||||
F_sink=BOOL_MAP[trait.sink],
|
||||
F_kv_load_mode=KV_LOAD_MODE_ENUM_MAP[trait.use_global_load],
|
||||
)
|
||||
if_j = "if" if j == 0 else "else if"
|
||||
per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format(
|
||||
@@ -488,7 +514,10 @@ class FmhaFwdApiPool:
|
||||
)
|
||||
if_i = "if" if i == 0 else "else if"
|
||||
per_dtypes = per_dtypes + FMHA_FWD_API_PER_DTYPE.format(
|
||||
F_if=if_i, F_dtype=dtype, F_hdim_case=per_hdim_case
|
||||
F_if=if_i,
|
||||
F_dtype=dtype,
|
||||
F_element_bytes=DTYPE_BYTES[dtype],
|
||||
F_hdim_case=per_hdim_case,
|
||||
)
|
||||
if not per_dtypes:
|
||||
# empty string we add some ignore to suppress warning in api
|
||||
@@ -539,6 +568,7 @@ class FmhaFwdKernel:
|
||||
F_pipeline: FmhaFwdPipeline
|
||||
mask_impl: str
|
||||
F_page_size: int = 1 # page block size
|
||||
F_use_global_load: bool = False # use global_load_lds_* for >2GB KV cache
|
||||
|
||||
@property
|
||||
def template(self) -> str:
|
||||
@@ -588,6 +618,10 @@ class FmhaFwdKernel:
|
||||
F_pipeline=FMHA_BATCH_PREFILL_PIPELINE_MAP[self.F_pipeline.tag],
|
||||
F_page_size=self.F_page_size,
|
||||
F_sink=BOOL_MAP[self.F_pipeline.F_sink],
|
||||
F_kv_load_mode=KV_LOAD_MODE_ENUM_MAP[self.F_use_global_load],
|
||||
F_arch_check=CDNA3_PLUS_ARCH.preprocessor_check
|
||||
if self.F_use_global_load
|
||||
else "true",
|
||||
)
|
||||
|
||||
@property
|
||||
@@ -595,6 +629,7 @@ class FmhaFwdKernel:
|
||||
# TODO: we don't encode idx here
|
||||
return (
|
||||
f"fmha_batch_prefill_d{self.F_hdim}_{self.F_dtype}_{self.F_mode}_ps{self.F_page_size}_"
|
||||
+ ("gload_" if self.F_use_global_load else "bload_")
|
||||
+ self.F_tile.name
|
||||
+ "_"
|
||||
+ self.F_pipeline.name
|
||||
@@ -632,6 +667,7 @@ class FmhaFwdKernel:
|
||||
kv_memory_layout=self.F_pipeline.F_kv_memory_layout,
|
||||
kv_lookup_table=self.F_pipeline.F_kv_lookup_table,
|
||||
page_size=self.F_page_size,
|
||||
use_global_load=self.F_use_global_load,
|
||||
)
|
||||
|
||||
|
||||
@@ -714,8 +750,11 @@ class CustomFactory(KernelComponentFactory):
|
||||
|
||||
|
||||
def get_fwd_blobs(
|
||||
kernel_filter: Optional[str], receipt, optdim_list, mask_impl,
|
||||
targets: Optional[List[str]] = None
|
||||
kernel_filter: Optional[str],
|
||||
receipt,
|
||||
optdim_list,
|
||||
mask_impl,
|
||||
targets: Optional[List[str]] = None,
|
||||
) -> Tuple[FmhaFwdApiPool, List[FmhaFwdKernel]]:
|
||||
# batch_prefill pipeline uses gfx9-specific async scatter-gather buffer addressing
|
||||
# (amd_buffer_addressing.hpp raw buffer loads) that is not compatible with
|
||||
@@ -837,6 +876,25 @@ def get_fwd_blobs(
|
||||
api_pool.register_traits(k.api_trait())
|
||||
gen.append(k)
|
||||
|
||||
# For page_size < kN0 (tile.F_bn0), also generate a GLOBAL_LOAD_LDS
|
||||
# variant for >2GB KV cache support. The default (BUFFER_LOAD) uses SRD
|
||||
# buffer_load (fast, <2GB). GLOBAL_LOAD_LDS uses global_load_lds_*
|
||||
# (slower, handles >2GB).
|
||||
if page_size < tile.F_bn0:
|
||||
k_global_load = FmhaFwdKernel(
|
||||
F_idx=0,
|
||||
F_hdim=hdim,
|
||||
F_dtype=dtype,
|
||||
F_mode=mode,
|
||||
F_tile=tile,
|
||||
F_pipeline=pipeline,
|
||||
mask_impl=mask_impl,
|
||||
F_page_size=page_size,
|
||||
F_use_global_load=True,
|
||||
)
|
||||
api_pool.register_traits(k_global_load.api_trait())
|
||||
gen.append(k_global_load)
|
||||
|
||||
return (api_pool, gen)
|
||||
|
||||
|
||||
@@ -856,7 +914,9 @@ def write_blobs(
|
||||
optdim_list,
|
||||
mask_impl,
|
||||
) -> None:
|
||||
api_pool, kernels = get_fwd_blobs(kernel_filter, receipt, optdim_list, mask_impl, targets)
|
||||
api_pool, kernels = get_fwd_blobs(
|
||||
kernel_filter, receipt, optdim_list, mask_impl, targets
|
||||
)
|
||||
for kernel in kernels:
|
||||
write_single_fwd_kernel(kernel, output_dir)
|
||||
write_fwd_api(api_pool, output_dir)
|
||||
@@ -871,7 +931,9 @@ def list_blobs(
|
||||
mask_impl,
|
||||
) -> None:
|
||||
with file_path.open("a") as f:
|
||||
_, kernels = get_fwd_blobs(kernel_filter, receipt, optdim_list, mask_impl, targets)
|
||||
_, kernels = get_fwd_blobs(
|
||||
kernel_filter, receipt, optdim_list, mask_impl, targets
|
||||
)
|
||||
for kernel in kernels:
|
||||
f.write((file_path.parent / GEN_DIR / kernel.filename).as_posix() + "\n")
|
||||
f.write((file_path.parent / GEN_DIR / FMHA_FWD_API_FILENAME).as_posix() + "\n")
|
||||
|
||||
@@ -673,6 +673,33 @@ struct fmha_batch_prefill_args
|
||||
ck_tile::index_t nhead_stride_kv_block_descale = 0; // Stride along num_kv_head dimension
|
||||
};
|
||||
|
||||
// Selects the KV-cache load mode for a batch-prefill dispatch arm.
|
||||
// GLOBAL_LOAD_LDS: required when (a) the page is smaller than one K/V tile
|
||||
// so per-page SRD is impossible, AND (b) the total KV-pool byte size
|
||||
// exceeds INT32_MAX so SRD's 32-bit byte offset cannot address it.
|
||||
// BUFFER_LOAD: every other case — the SGPR-resident SRD path is fastest.
|
||||
// Inputs are taken as plain integers so the helper has no template parameter
|
||||
// and can be called from each codegen-emitted dispatcher arm with the arm's
|
||||
// compile-time kN0 / element_bytes substituted as constants.
|
||||
inline ck_tile::BlockAttentionKVCacheLoadModeEnum
|
||||
fmha_batch_prefill_select_kv_load_mode(ck_tile::index_t page_block_size,
|
||||
ck_tile::index_t kN0,
|
||||
ck_tile::index_t num_total_pages,
|
||||
ck_tile::index_t batch_stride_k,
|
||||
ck_tile::index_t element_bytes)
|
||||
{
|
||||
// Promote every operand to long_index_t so overflow is impossible regardless
|
||||
// of multiplication order. A bare `static_cast<long_index_t>(num_total_pages)
|
||||
// * batch_stride_k * element_bytes` only works because of left-to-right
|
||||
// associativity — a future reorder of the operands would silently truncate.
|
||||
const auto kv_pool_bytes = static_cast<ck_tile::long_index_t>(num_total_pages) *
|
||||
static_cast<ck_tile::long_index_t>(batch_stride_k) *
|
||||
static_cast<ck_tile::long_index_t>(element_bytes);
|
||||
return (page_block_size < kN0 && kv_pool_bytes > INT32_MAX)
|
||||
? ck_tile::BlockAttentionKVCacheLoadModeEnum::GLOBAL_LOAD_LDS
|
||||
: ck_tile::BlockAttentionKVCacheLoadModeEnum::BUFFER_LOAD;
|
||||
}
|
||||
|
||||
template <typename FmhaKernel>
|
||||
auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args)
|
||||
{
|
||||
@@ -1457,7 +1484,9 @@ template <ck_tile::index_t HDim_,
|
||||
ck_tile::BlockAttentionKVCacheMemoryLayoutEnum kKVMemoryLayout_ =
|
||||
ck_tile::BlockAttentionKVCacheMemoryLayoutEnum::VECTORIZED_LAYOUT,
|
||||
ck_tile::BlockAttentionKVCacheLookupTableEnum kKVLookupTable_ =
|
||||
ck_tile::BlockAttentionKVCacheLookupTableEnum::SGLANG_PAGE_TABLE_1D>
|
||||
ck_tile::BlockAttentionKVCacheLookupTableEnum::SGLANG_PAGE_TABLE_1D,
|
||||
ck_tile::BlockAttentionKVCacheLoadModeEnum kKVLoadMode_ =
|
||||
ck_tile::BlockAttentionKVCacheLoadModeEnum::BUFFER_LOAD>
|
||||
struct fmha_fwd_batch_prefill_traits_ : public fmha_fwd_traits_<HDim_,
|
||||
DataType_,
|
||||
kIsGroupMode_,
|
||||
@@ -1486,6 +1515,7 @@ struct fmha_fwd_batch_prefill_traits_ : public fmha_fwd_traits_<HDim_,
|
||||
static constexpr auto kKVMemoryLayout = kKVMemoryLayout_;
|
||||
static constexpr auto kKVLookupTable = kKVLookupTable_;
|
||||
static constexpr ck_tile::index_t kPageBlockSize = kPageBlockSize_;
|
||||
static constexpr auto kKVLoadMode = kKVLoadMode_;
|
||||
static_assert(kIsVLayoutRowMajor_, "Batch prefill only supports row-major V layout");
|
||||
};
|
||||
|
||||
|
||||
@@ -1319,6 +1319,87 @@ CK_TILE_DEVICE void async_buffer_load_fence(index_t cnt = 0)
|
||||
asm volatile("s_waitcnt vmcnt(%0)" : : "n"(cnt) : "memory");
|
||||
}
|
||||
|
||||
// Flat async load from global memory to LDS using 64-bit global addressing.
|
||||
// Bypasses the SRD's 32-bit offset limit; required when the KV cache exceeds
|
||||
// INT32_MAX (2GB) byte offset on the SRD voffset path.
|
||||
//
|
||||
// !!! M0 PRECONDITION — IMPLICIT INPUT NOT VISIBLE IN OPERAND LIST !!!
|
||||
//
|
||||
// The LDS destination address is taken from M0 (per AMD CDNA3 ISA §10.3:
|
||||
// `LDS_ADDR = LDSbase + LDSoffset(M0[17:2] * 4) + INST.OFFSET + ThreadID*4`).
|
||||
// M0 does NOT appear as an operand of these instructions or of the inline
|
||||
// asm below — the compiler cannot see the dependency. Caller must:
|
||||
//
|
||||
// 1. Initialize M0 once before the load loop:
|
||||
// `m0_set_with_memory(amd_wave_read_first_lane(lds_byte_offset));`
|
||||
// M0 is SALU-only — `m0_set_with_memory` uses an "s" constraint to
|
||||
// enforce this. Direct VALU writes to M0 are illegal.
|
||||
//
|
||||
// 2. Advance M0 between successive issues:
|
||||
// `m0_inc_with_memory(size_per_issue);`
|
||||
// `size_per_issue` MUST be a multiple of 4 — GLOBAL/FLAT LDS path
|
||||
// only honors M0[17:2]*4 (dword-aligned), so low 2 bits are silently
|
||||
// dropped (NOTE: this differs from MUBUF buffer_load_lds which uses
|
||||
// M0[15:0] as a raw byte offset).
|
||||
//
|
||||
// 3. Never bundle `m0_inc_with_memory` and the next call to this
|
||||
// function into a single inline asm. The compiler auto-inserts a
|
||||
// hazard NOP between an SALU write to M0 and the consuming
|
||||
// `global_load_lds_*`; bundling bypasses that and may read stale M0.
|
||||
//
|
||||
// The "memory" clobber on this asm is load-bearing: it prevents the
|
||||
// compiler from reordering this load across other M0-touching helpers
|
||||
// (`m0_set_with_memory` / `m0_inc_with_memory`, also "memory"-clobbered).
|
||||
//
|
||||
// Verified instruction emission (HIP 6.4 / clang 19, gfx942 + gfx950):
|
||||
// `global_load_lds_dwordx4` is a single instruction (encoding 0xDDF48000
|
||||
// 0x007F0000), NOT software-expanded into 4× dword. Same encoding on both
|
||||
// arches. The opcode is undocumented in CDNA3 ISA spec §13.6.2 but
|
||||
// supported by the LLVM AMDGPU backend.
|
||||
//
|
||||
// Available on gfx940+ (CDNA3: MI300, MI355, MI350 series).
|
||||
template <unsigned num_dwords, bool pre_nop = false>
|
||||
CK_TILE_DEVICE void
|
||||
async_global_load_lds_dwordxn(void* smem, const void* global_addr, bool_constant<pre_nop> = {})
|
||||
{
|
||||
#if !defined(__gfx94__) && !defined(__gfx950__)
|
||||
static_assert(always_false_v<integral_constant<unsigned, num_dwords>>,
|
||||
"global_load_lds requires CDNA3+ (gfx940/gfx950). "
|
||||
"Ensure kKVLoadMode is BUFFER_LOAD on this architecture.");
|
||||
#endif
|
||||
|
||||
static_assert(num_dwords == 1 || num_dwords == 4,
|
||||
"global_load_lds supports num_dwords == 1 or 4 only "
|
||||
"(2 dwords does not exist on any supported arch; "
|
||||
"3 dwords only on CDNA4 and unused in FMHA pipeline)");
|
||||
|
||||
// Inline asm: only the global address is an explicit operand. The LDS
|
||||
// destination is implicit via M0 (see contract above). `"=r"(smem)` is a
|
||||
// SSA scheduling anchor only — `smem` is NOT written by this asm; the
|
||||
// load goes to LDS at `M0[17:2]*4 + offset:0 + ThreadID*4`.
|
||||
#define CK_TILE_GLOBAL_LOAD_LDS_INSTR(instr) \
|
||||
if constexpr(pre_nop) \
|
||||
asm volatile("s_nop 4\n" instr " %1, off offset:0" \
|
||||
: "=r"(smem) /*scheduling anchor; real LDS dest is M0*/ \
|
||||
: "v"(global_addr) \
|
||||
: "memory" /*prevents reorder across m0_{set,inc}*/); \
|
||||
else \
|
||||
asm volatile(instr " %1, off offset:0" \
|
||||
: "=r"(smem) /*scheduling anchor; real LDS dest is M0*/ \
|
||||
: "v"(global_addr) \
|
||||
: "memory" /*prevents reorder across m0_{set,inc}*/);
|
||||
|
||||
if constexpr(num_dwords == 1)
|
||||
{
|
||||
CK_TILE_GLOBAL_LOAD_LDS_INSTR("global_load_lds_dword");
|
||||
}
|
||||
else if constexpr(num_dwords == 4)
|
||||
{
|
||||
CK_TILE_GLOBAL_LOAD_LDS_INSTR("global_load_lds_dwordx4");
|
||||
}
|
||||
#undef CK_TILE_GLOBAL_LOAD_LDS_INSTR
|
||||
}
|
||||
|
||||
template <index_t N,
|
||||
amd_buffer_coherence_enum coherence = amd_buffer_coherence_enum::coherence_default>
|
||||
CK_TILE_DEVICE thread_buffer<int8_t, N>
|
||||
|
||||
@@ -45,9 +45,29 @@ template <typename BottomTensorView_,
|
||||
typename StaticValidArray_,
|
||||
index_t HsGatherDim = 0,
|
||||
index_t NumCoord = 1,
|
||||
typename YsGatherDims = sequence<0>>
|
||||
typename YsGatherDims = sequence<0>,
|
||||
bool kUseGlobalLoad_ = false>
|
||||
struct tile_scatter_gather
|
||||
{
|
||||
static constexpr bool kUseGlobalLoad = kUseGlobalLoad_;
|
||||
|
||||
#if !defined(__gfx94__) && !defined(__gfx950__)
|
||||
// global_load_lds instruction is only available on CDNA3+ (gfx940/gfx950).
|
||||
// On other architectures, kUseGlobalLoad must be false.
|
||||
static_assert(!kUseGlobalLoad_,
|
||||
"kUseGlobalLoad requires global_load_lds (CDNA3+: gfx940/gfx950). "
|
||||
"This kernel should not be instantiated on this architecture.");
|
||||
#endif
|
||||
|
||||
// Empty placeholder used by the SRD instantiation so physical_pages_ and
|
||||
// page_stride_elements_ occupy zero bytes there (combined with
|
||||
// [[no_unique_address]] on the member declarations). Access sites are all
|
||||
// inside `if constexpr(kUseGlobalLoad_)` arms, which compile out in SRD
|
||||
// mode, so no caller needs to change.
|
||||
struct gl_field_empty_t
|
||||
{
|
||||
};
|
||||
|
||||
using BottomTensorView = remove_reference_t<BottomTensorView_>;
|
||||
using WindowLengths = remove_cvref_t<WindowLengths_>;
|
||||
using TileDstr = remove_cvref_t<StaticTileDistribution_>;
|
||||
@@ -233,15 +253,22 @@ struct tile_scatter_gather
|
||||
const BottomTensorIndex& window_origin,
|
||||
const TileDstr& tile_distribution,
|
||||
const PageIdxArray& page_idx,
|
||||
const ValidArray& valids)
|
||||
const ValidArray& valids,
|
||||
index_t page_stride_elements = 0)
|
||||
: bottom_tensor_view_{bottom_tensor_view},
|
||||
window_lengths_{window_lengths},
|
||||
window_origin_{window_origin},
|
||||
tile_dstr_{tile_distribution},
|
||||
page_idx_{page_idx},
|
||||
physical_pages_{},
|
||||
page_stride_elements_{},
|
||||
valids_{valids},
|
||||
pre_computed_coords_{}
|
||||
{
|
||||
if constexpr(kUseGlobalLoad_)
|
||||
{
|
||||
page_stride_elements_ = page_stride_elements;
|
||||
}
|
||||
#if 0 // debug
|
||||
// TODO: this use more register for FA, but less register for GEMM
|
||||
// need investigation
|
||||
@@ -357,6 +384,34 @@ struct tile_scatter_gather
|
||||
bottom_tensor_view_.buf_.p_data_ = data;
|
||||
}
|
||||
|
||||
// Override buffer size (input in RAW elements, NOT pre-divided by PackedSize) for
|
||||
// SRD num_records control. Use to set max range when SRD is rebased per-tile
|
||||
// (page_size >= kN0 path): each rebased SRD only needs to cover one page; without
|
||||
// this the SRD claims validity for memory beyond the allocated buffer, which can
|
||||
// fault on gfx950 page-table validation.
|
||||
//
|
||||
// Matches buffer_view ctor convention (buffer_view.hpp:245): input is raw element
|
||||
// count and is divided by PackedSize before being stored. For PackedSize=1
|
||||
// (fp16/bf16/fp8) the division is a no-op; for PackedSize=2 (FP4 / packed int4)
|
||||
// skipping it would over-report num_records by 2x and silently mask OOB on SRD
|
||||
// reads. batch_prefill currently does not exercise the packed-type path, but this
|
||||
// setter is generic infrastructure (lives in tile_scatter_gather.hpp) so it must
|
||||
// honor the same invariant the ctor enforces.
|
||||
CK_TILE_DEVICE constexpr void set_bottom_tensor_view_buffer_size(index_t size)
|
||||
{
|
||||
// Hint the optimizer that size is positive without inserting a runtime
|
||||
// branch. Using <cassert> assert() here corrupted gfx950 batch_prefill
|
||||
// output: the __assert_fail handler's SGPR pressure forced the K-SRD
|
||||
// register window to be reused as scratch and scattered the SRD writes
|
||||
// across two conditional branches, which gfx950's packed
|
||||
// buffer_load_dwordx4 issue window doesn't tolerate (gfx942 absorbs it
|
||||
// via per-tile single-dword loads). __builtin_assume is hint-only —
|
||||
// no branch, no scratch SGPRs, no codegen impact.
|
||||
__builtin_assume(size > 0);
|
||||
using BufType = remove_cvref_t<decltype(bottom_tensor_view_.buf_)>;
|
||||
bottom_tensor_view_.buf_.buffer_size_ = size / BufType::PackedSize;
|
||||
}
|
||||
|
||||
// move thread's window adaptor coordinate and bottom tensor coordinate
|
||||
// [p0, p1, ..., y0, y1, ...] ==> [x0, x1, ...] ==> [x0', x1', ...] ==> [offset]
|
||||
template <typename ATopIndex>
|
||||
@@ -458,7 +513,21 @@ struct tile_scatter_gather
|
||||
|
||||
// read from bottom tensor
|
||||
const vector_t vec_value = [&]() {
|
||||
if constexpr(std::is_same_v<ValidArray, std::nullptr_t>)
|
||||
if constexpr(kUseGlobalLoad_)
|
||||
{
|
||||
// Global load mode: 64-bit typed pointer arithmetic
|
||||
const auto* base_ptr = get_bottom_tensor_view().buf_.p_data_;
|
||||
const auto physical_page = physical_pages_[idx_gather];
|
||||
const auto coord_offset = bottom_tensor_thread_coord.get_offset();
|
||||
const long_index_t total_offset =
|
||||
static_cast<long_index_t>(physical_page) * page_stride_elements_ +
|
||||
coord_offset + page_offset;
|
||||
const auto* addr = base_ptr + total_offset;
|
||||
vector_t v;
|
||||
__builtin_memcpy(&v, addr, sizeof(vector_t));
|
||||
return v;
|
||||
}
|
||||
else if constexpr(std::is_same_v<ValidArray, std::nullptr_t>)
|
||||
{
|
||||
return get_bottom_tensor_view().template get_vectorized_elements<vector_t>(
|
||||
bottom_tensor_thread_coord,
|
||||
@@ -680,7 +749,23 @@ struct tile_scatter_gather
|
||||
const auto page_offset = page_idx_[idx_gather];
|
||||
|
||||
// read from bottom tensor
|
||||
if constexpr(std::is_same_v<ValidArray, std::nullptr_t>)
|
||||
if constexpr(kUseGlobalLoad_)
|
||||
{
|
||||
// Global load mode: global_load_lds with 64-bit address
|
||||
constexpr index_t vector_size =
|
||||
sizeof(vector_t) / sizeof(uint32_t); // dwords per vector
|
||||
const auto* base_ptr = get_bottom_tensor_view().buf_.p_data_;
|
||||
const auto physical_page = physical_pages_[idx_gather];
|
||||
const auto coord_offset = bottom_tensor_thread_coord.get_offset();
|
||||
const long_index_t total_offset =
|
||||
static_cast<long_index_t>(physical_page) * page_stride_elements_ +
|
||||
coord_offset + page_offset;
|
||||
const auto* addr = base_ptr + total_offset;
|
||||
// global_load_lds takes a byte address; addr (const DataType*)
|
||||
// converts implicitly to const void*, no explicit cast needed.
|
||||
async_global_load_lds_dwordxn<vector_size>(smem, addr, pre_nop_);
|
||||
}
|
||||
else if constexpr(std::is_same_v<ValidArray, std::nullptr_t>)
|
||||
{
|
||||
get_bottom_tensor_view().template async_get_vectorized_elements_raw<vector_t>(
|
||||
smem, bottom_tensor_thread_coord, page_offset, 0, pre_nop_);
|
||||
@@ -1046,6 +1131,13 @@ struct tile_scatter_gather
|
||||
|
||||
CK_TILE_DEVICE void update_page_idx(const PageIdxArray& new_idx) { page_idx_ = new_idx; }
|
||||
|
||||
CK_TILE_DEVICE void update_physical_pages(const PageIdxArray& pages)
|
||||
{
|
||||
static_assert(kUseGlobalLoad_,
|
||||
"global-load mode only; physical_pages_ is unused in SRD mode.");
|
||||
physical_pages_ = pages;
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE void update_valids(const ValidArray& new_valids)
|
||||
{
|
||||
if constexpr(std::is_same_v<ValidArray, std::nullptr_t> == false)
|
||||
@@ -1139,7 +1231,29 @@ struct tile_scatter_gather
|
||||
// 2. thread descriptor for thread tensor in register: [y0, y1, ...] ==> [d]
|
||||
TileDstr tile_dstr_;
|
||||
|
||||
// Scatter/gather offsets for each element, set by update_page_idx().
|
||||
// SRD mode (kUseGlobalLoad=false): buffer_load(SRD, page_idx_[i] + coord).
|
||||
// page_idx_[i] = within-page offset when kPageBlockSize >= kN0 (SRD rebased to page base)
|
||||
// page_idx_[i] = page_base + within-page offset when kPageBlockSize < kN0 (full voffset)
|
||||
// Global load mode (kUseGlobalLoad=true): page_idx_[i] = within-page offset only.
|
||||
// Full address = base + physical_pages_[i] * page_stride_elements_ + page_idx_[i] + coord
|
||||
PageIdxArray page_idx_;
|
||||
|
||||
// Physical page indices for global load mode (kUseGlobalLoad=true only).
|
||||
// Maps each gather element to its physical page in a paged memory pool.
|
||||
// Updated via update_physical_pages() before each load call.
|
||||
// SRD mode: collapsed to gl_field_empty_t so the storage disappears.
|
||||
[[no_unique_address]] std::conditional_t<kUseGlobalLoad_, PageIdxArray, gl_field_empty_t>
|
||||
physical_pages_;
|
||||
|
||||
// Page stride in elements for global load mode (kUseGlobalLoad=true only).
|
||||
// physical_pages_[i] * page_stride_elements_ gives the page base offset in elements.
|
||||
// Set at construction time via the make_tile_scatter_gather overload that
|
||||
// takes bool_constant<kUseGlobalLoad>; immutable thereafter.
|
||||
// SRD mode: collapsed to gl_field_empty_t so the storage disappears.
|
||||
[[no_unique_address]] std::conditional_t<kUseGlobalLoad_, index_t, gl_field_empty_t>
|
||||
page_stride_elements_;
|
||||
|
||||
ValidArray valids_;
|
||||
|
||||
// this contains:
|
||||
@@ -1178,7 +1292,8 @@ template <typename TensorView_,
|
||||
typename StaticPageIndexArray_,
|
||||
index_t HsGatherDim,
|
||||
index_t NumCoord,
|
||||
index_t... YsGatherDims>
|
||||
index_t... YsGatherDims,
|
||||
bool UseGlobalLoad = false>
|
||||
CK_TILE_DEVICE constexpr auto
|
||||
make_tile_scatter_gather(const TensorView_& tensor_view,
|
||||
const WindowLengths_& window_lengths,
|
||||
@@ -1187,7 +1302,9 @@ make_tile_scatter_gather(const TensorView_& tensor_view,
|
||||
const StaticPageIndexArray_& page_idx,
|
||||
number<HsGatherDim>,
|
||||
number<NumCoord>,
|
||||
sequence<YsGatherDims...>)
|
||||
sequence<YsGatherDims...>,
|
||||
bool_constant<UseGlobalLoad> = {},
|
||||
index_t page_stride_elements = 0)
|
||||
{
|
||||
return tile_scatter_gather<remove_cvref_t<TensorView_>,
|
||||
remove_cvref_t<WindowLengths_>,
|
||||
@@ -1196,11 +1313,17 @@ make_tile_scatter_gather(const TensorView_& tensor_view,
|
||||
std::nullptr_t,
|
||||
HsGatherDim,
|
||||
NumCoord,
|
||||
sequence<YsGatherDims...>>{
|
||||
tensor_view, window_lengths, origin, tile_distribution, page_idx, nullptr};
|
||||
sequence<YsGatherDims...>,
|
||||
UseGlobalLoad>{tensor_view,
|
||||
window_lengths,
|
||||
origin,
|
||||
tile_distribution,
|
||||
page_idx,
|
||||
nullptr,
|
||||
page_stride_elements};
|
||||
}
|
||||
|
||||
// Legacy overload (compatible with original API)
|
||||
// Legacy overload (compatible with original API, kUseGlobalLoad=false)
|
||||
template <typename TensorView_,
|
||||
typename WindowLengths_,
|
||||
typename StaticTileDistribution_,
|
||||
@@ -1227,6 +1350,42 @@ make_tile_scatter_gather(const TensorView_& tensor_view,
|
||||
tensor_view, window_lengths, origin, tile_distribution, page_idx, nullptr};
|
||||
}
|
||||
|
||||
// Overload with kUseGlobalLoad (simple, used by K cache).
|
||||
// page_stride_elements is forwarded to the constructor; required (non-zero)
|
||||
// when UseGlobalLoad=true so that physical_pages_[i] * page_stride_elements_
|
||||
// produces a valid address. Defaulting to 0 keeps SRD-mode call sites unchanged
|
||||
// (page_stride_elements_ is unread in SRD mode).
|
||||
template <typename TensorView_,
|
||||
typename WindowLengths_,
|
||||
typename StaticTileDistribution_,
|
||||
typename StaticPageIndexArray_,
|
||||
bool UseGlobalLoad>
|
||||
CK_TILE_DEVICE constexpr auto
|
||||
make_tile_scatter_gather(const TensorView_& tensor_view,
|
||||
const WindowLengths_& window_lengths,
|
||||
const multi_index<TensorView_::get_num_of_dimension()>& origin,
|
||||
const StaticTileDistribution_& tile_distribution,
|
||||
const StaticPageIndexArray_& page_idx,
|
||||
bool_constant<UseGlobalLoad>,
|
||||
index_t page_stride_elements = 0)
|
||||
{
|
||||
return tile_scatter_gather<remove_cvref_t<TensorView_>,
|
||||
remove_cvref_t<WindowLengths_>,
|
||||
remove_cvref_t<StaticTileDistribution_>,
|
||||
remove_cvref_t<StaticPageIndexArray_>,
|
||||
std::nullptr_t,
|
||||
0,
|
||||
1,
|
||||
sequence<0>,
|
||||
UseGlobalLoad>{tensor_view,
|
||||
window_lengths,
|
||||
origin,
|
||||
tile_distribution,
|
||||
page_idx,
|
||||
nullptr,
|
||||
page_stride_elements};
|
||||
}
|
||||
|
||||
template <typename TensorView,
|
||||
typename WindowLengths,
|
||||
typename StaticTileDistribution,
|
||||
|
||||
@@ -12,6 +12,20 @@
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// `always_false_v<T...>` — a value-template that is always `false` but whose
|
||||
// evaluation is deferred until template instantiation. The canonical use is
|
||||
// inside the `else` arm of an `if constexpr` chain or under an arch-gated
|
||||
// `#if` to fire a `static_assert` ONLY when the offending instantiation is
|
||||
// actually requested, e.g.:
|
||||
//
|
||||
// if constexpr (...) { ... }
|
||||
// else { static_assert(always_false_v<T>, "unsupported T"); }
|
||||
//
|
||||
// A bare `static_assert(false, ...)` would fire at template-definition
|
||||
// parse time on conforming compilers, breaking the whole TU.
|
||||
template <typename...>
|
||||
inline constexpr bool always_false_v = false;
|
||||
|
||||
// remove_cvref_t
|
||||
template <typename T>
|
||||
using remove_reference_t = typename std::remove_reference<T>::type;
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp"
|
||||
#include "ck_tile/ops/fmha/block/block_attention_kv_load_mode_enum.hpp"
|
||||
#include "ck_tile/ops/fmha/block/block_attention_kvcache_layout_enum.hpp"
|
||||
#include "ck_tile/ops/fmha/block/block_attention_quant_scale_enum.hpp"
|
||||
#include "ck_tile/ops/fmha/block/block_dropout.hpp"
|
||||
|
||||
@@ -0,0 +1,17 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// KV cache load addressing mode selector for batch_prefill / paged-attention pipelines.
|
||||
// - BUFFER_LOAD: SGPR-based SRD via buffer_load_* (default; 32-bit byte addressing, <2GB pool)
|
||||
// - GLOBAL_LOAD_LDS: direct global_load_lds_* (64-bit addressing, required for >2GB KV cache)
|
||||
enum class BlockAttentionKVCacheLoadModeEnum
|
||||
{
|
||||
BUFFER_LOAD = 0,
|
||||
GLOBAL_LOAD_LDS = 1,
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -6,6 +6,7 @@
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/common/tensor_layout.hpp"
|
||||
#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp"
|
||||
#include "ck_tile/ops/fmha/block/block_attention_kv_load_mode_enum.hpp"
|
||||
#include "ck_tile/ops/fmha/block/block_attention_kvcache_layout_enum.hpp"
|
||||
#include "ck_tile/ops/fmha/block/block_attention_quant_scale_enum.hpp"
|
||||
#include "ck_tile/ops/fmha/block/block_dropout.hpp"
|
||||
@@ -134,7 +135,8 @@ template <typename IndexArrayType,
|
||||
BlockAttentionKVCacheMemoryLayoutEnum kKVMemoryLayout,
|
||||
bool kIsKcache,
|
||||
index_t kN0,
|
||||
index_t kVectorSize>
|
||||
index_t kVectorSize,
|
||||
bool kUseGlobalLoad_ = false>
|
||||
CK_TILE_HOST_DEVICE void kv_offset_array_transform(const IndexArrayType& physical_pages,
|
||||
const index_t& stride_token,
|
||||
const index_t& stride_page_block,
|
||||
@@ -156,81 +158,65 @@ CK_TILE_HOST_DEVICE void kv_offset_array_transform(const IndexArrayType& physica
|
||||
const index_t& thread_coord_start = coord_vec[kCoordAxis];
|
||||
constexpr index_t kInPageOffsetMask = (1 << kLog2PageSize) - 1;
|
||||
|
||||
if constexpr(kIsKcache)
|
||||
{
|
||||
// K cache: per-token lookup
|
||||
// Each token may be on a different page, so we use physical_pages[k0] for each.
|
||||
static_for<0, kLoopCount, 1>{}([&](auto k0) {
|
||||
const index_t global_token_idx =
|
||||
global_seq_offset + thread_coord_start + kLoopStart + kLoopStride * k0.value;
|
||||
const index_t token_idx_in_page = global_token_idx & kInPageOffsetMask;
|
||||
// Addressing strategy — four cases controlled by (kPageBlockSize vs kN0, kUseGlobalLoad_):
|
||||
//
|
||||
// Case 1: kPageBlockSize >= kN0
|
||||
// SRD is rebased per-tile to the page base (rebase_{k,v}_window in caller).
|
||||
// Page base is absorbed into the SRD's 48-bit base pointer (SGPR-resident).
|
||||
// This function writes within-page offset only.
|
||||
//
|
||||
// Case 2: kPageBlockSize < kN0 && kUseGlobalLoad_
|
||||
// SRD cannot be rebased (multi-page wave). Loads use global_load_lds_*; the full
|
||||
// 64-bit address is computed by tile_scatter_gather::load() in
|
||||
// include/ck_tile/core/tensor/tile_scatter_gather.hpp from physical_pages_ +
|
||||
// page_stride_elements_. This function writes within-page offset only.
|
||||
//
|
||||
// Case 3: kPageBlockSize < kN0 && !kUseGlobalLoad_ (kNeedFullOffset == true)
|
||||
// SRD base is the entire KV buffer; the only place to encode page identity
|
||||
// is the voffset itself. This function writes the FULL offset:
|
||||
// page * stride_page_block + within_page
|
||||
// Limited to <2GB total KV bytes by 32-bit voffset hardware width.
|
||||
//
|
||||
// Case 4: kPageBlockSize >= kN0 && kUseGlobalLoad_
|
||||
// Not emitted by codegen. Backstop static_assert in
|
||||
// BlockFmhaBatchPrefillPipelineQRKSVSAsync.
|
||||
constexpr bool kNeedFullOffset = (kPageBlockSize < kN0) && !kUseGlobalLoad_;
|
||||
|
||||
if constexpr(kPageBlockSize >= kN0)
|
||||
static_for<0, kLoopCount, 1>{}([&](auto k0) {
|
||||
const index_t global_token_idx =
|
||||
global_seq_offset + thread_coord_start + kLoopStart + kLoopStride * k0.value;
|
||||
const index_t token_idx_in_page = global_token_idx & kInPageOffsetMask;
|
||||
|
||||
// Within-page offset (layout-dependent for V cache with VECTORIZED_LAYOUT)
|
||||
const index_t within_page = [&]() {
|
||||
if constexpr(!kIsKcache && kKVMemoryLayout ==
|
||||
BlockAttentionKVCacheMemoryLayoutEnum::VECTORIZED_LAYOUT)
|
||||
{
|
||||
// SRD rebasing mode: within-page offset only.
|
||||
// The full page base is handled by rebasing the SRD pointer.
|
||||
kv_offset_vec[k0] = token_idx_in_page * stride_token;
|
||||
return (token_idx_in_page / kVectorSize) * (stride_token * kVectorSize) +
|
||||
(token_idx_in_page % kVectorSize);
|
||||
}
|
||||
else
|
||||
{
|
||||
// Full global offset (original code path for ps1, ps16, etc.)
|
||||
const index_t physical_page = physical_pages[k0];
|
||||
kv_offset_vec[k0] =
|
||||
physical_page * stride_page_block + token_idx_in_page * stride_token;
|
||||
return token_idx_in_page * stride_token;
|
||||
}
|
||||
});
|
||||
}
|
||||
else // V cache
|
||||
{
|
||||
// V cache: use physical_pages[k0] for each token
|
||||
// physical_pages was already populated correctly by load_physical_pages(), handling:
|
||||
// - page_size=1: page_idx maps token_idx -> physical_page directly
|
||||
// - V tile crosses pages: per-token page lookup
|
||||
// - V tile in single page: lane0 lookup with broadcast to all lanes
|
||||
static_for<0, kLoopCount, 1>{}([&](auto k0) {
|
||||
const index_t global_token_idx =
|
||||
global_seq_offset + thread_coord_start + kLoopStart + kLoopStride * k0.value;
|
||||
const index_t token_idx_in_page = global_token_idx & kInPageOffsetMask;
|
||||
}();
|
||||
|
||||
if constexpr(kPageBlockSize >= kN0)
|
||||
{
|
||||
// SRD rebasing mode: within-page offset only.
|
||||
// The full page base is handled by rebasing the SRD pointer.
|
||||
if constexpr(kKVMemoryLayout ==
|
||||
BlockAttentionKVCacheMemoryLayoutEnum::VECTORIZED_LAYOUT)
|
||||
{
|
||||
const index_t token_offset =
|
||||
(token_idx_in_page / kVectorSize) * (stride_token * kVectorSize) +
|
||||
(token_idx_in_page % kVectorSize);
|
||||
kv_offset_vec[k0] = token_offset;
|
||||
}
|
||||
else
|
||||
{
|
||||
kv_offset_vec[k0] = token_idx_in_page * stride_token;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
// Full global offset (original code path for ps1, ps16, etc.)
|
||||
const index_t physical_page = physical_pages[k0];
|
||||
const long_index_t page_base_offset =
|
||||
static_cast<long_index_t>(physical_page) * stride_page_block;
|
||||
|
||||
if constexpr(kKVMemoryLayout ==
|
||||
BlockAttentionKVCacheMemoryLayoutEnum::VECTORIZED_LAYOUT)
|
||||
{
|
||||
const index_t token_offset =
|
||||
(token_idx_in_page / kVectorSize) * (stride_token * kVectorSize) +
|
||||
(token_idx_in_page % kVectorSize);
|
||||
kv_offset_vec[k0] = page_base_offset + token_offset;
|
||||
}
|
||||
else
|
||||
{
|
||||
kv_offset_vec[k0] = page_base_offset + token_idx_in_page * stride_token;
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
// SRD + page_size < kN0: add page base to form complete voffset for buffer_load.
|
||||
//
|
||||
// 32-bit by hardware: SRD buffer_load voffset is fundamentally 32-bit (CDNA3 MUBUF
|
||||
// microcode format), so this branch is only reachable when total KV bytes fit in
|
||||
// INT32_MAX. The kUseGlobalLoad_ template path handles the >2GB case via 64-bit
|
||||
// global_load_lds_*; widening kv_offset_vec here would not lift the 2GB ceiling
|
||||
// because the hardware truncates voffset regardless.
|
||||
if constexpr(kNeedFullOffset)
|
||||
{
|
||||
kv_offset_vec[k0] = physical_pages[k0] * stride_page_block + within_page;
|
||||
}
|
||||
else
|
||||
{
|
||||
kv_offset_vec[k0] = within_page;
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
// a variation of qr/ks/vs, where we use async copy to load k (potentially v in the future)
|
||||
@@ -270,10 +256,21 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
|
||||
static constexpr index_t kSubQKHeaddim = BlockFmhaShape::kSubQKHeaddim;
|
||||
static constexpr index_t kPageBlockSize = Problem::kPageBlockSize;
|
||||
static constexpr index_t kVectorSize = Problem::kVectorSize;
|
||||
static constexpr auto I0 = number<0>{};
|
||||
static constexpr auto I1 = number<1>{};
|
||||
static constexpr auto I2 = number<2>{};
|
||||
static constexpr auto I3 = number<3>{};
|
||||
// Single load-mode selector for the whole pipeline. GLOBAL_LOAD_LDS routes K/V
|
||||
// tiles through global_load_lds_* (handles >2GB KV cache); BUFFER_LOAD uses SRD
|
||||
// buffer_load_*. The enum is named at the trait/Problem level; internally we
|
||||
// derive a bool helper to keep `if constexpr` sites narrow. Codegen only emits
|
||||
// GLOBAL_LOAD_LDS arms when page_size < kN0; the static_assert is a backstop.
|
||||
static constexpr auto kKVLoadMode = Problem::kKVLoadMode;
|
||||
static constexpr bool kUseGlobalLoad =
|
||||
(kKVLoadMode == BlockAttentionKVCacheLoadModeEnum::GLOBAL_LOAD_LDS);
|
||||
static_assert(!kUseGlobalLoad || (kPageBlockSize < kN0),
|
||||
"GLOBAL_LOAD_LDS load mode is only valid when kPageBlockSize < kN0; "
|
||||
"codegen should not emit this instantiation otherwise.");
|
||||
static constexpr auto I0 = number<0>{};
|
||||
static constexpr auto I1 = number<1>{};
|
||||
static constexpr auto I2 = number<2>{};
|
||||
static constexpr auto I3 = number<3>{};
|
||||
|
||||
static_assert(kSubQKHeaddim <= 256, "hdim bigger than 256 is not suitable for this pipeline!");
|
||||
static constexpr bool kIsGroupMode = Problem::kIsGroupMode;
|
||||
@@ -626,19 +623,26 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
|
||||
kKVMemoryLayout,
|
||||
true,
|
||||
kN0,
|
||||
kVectorSize>(
|
||||
kVectorSize,
|
||||
kUseGlobalLoad>(
|
||||
k_physical_pages, stride_k, page_stride_k, k_coord, k_offsets, current_seq_k);
|
||||
|
||||
auto k_dram_window = make_tile_scatter_gather(k_dram_block_window.get_bottom_tensor_view(),
|
||||
k_dram_block_window.get_window_lengths(),
|
||||
k_dram_block_window.get_window_origin(),
|
||||
k_dist,
|
||||
k_offsets); // K DRAM tile window for
|
||||
k_offsets,
|
||||
bool_constant<kUseGlobalLoad>{},
|
||||
page_stride_k);
|
||||
if constexpr(kUseGlobalLoad)
|
||||
{
|
||||
k_dram_window.update_physical_pages(k_physical_pages);
|
||||
}
|
||||
k_dram_window.init_raw();
|
||||
|
||||
// SRD rebasing: move the buffer descriptor base pointer to each page's start address
|
||||
// using 48-bit pointer arithmetic, so voffset only needs the small within-page offset.
|
||||
// Only applies when kPageBlockSize >= kN0 (all threads in a wave access the same page).
|
||||
// SRD rebasing for K: only for page_size >= kN0 (all threads on same page).
|
||||
// For page_size < kN0: either flat loads (kUseGlobalLoad) or full offsets handle
|
||||
// addressing.
|
||||
auto rebase_k_window = [&](auto& window, index_t physical_page) {
|
||||
if constexpr(kPageBlockSize >= kN0)
|
||||
{
|
||||
@@ -649,24 +653,36 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
|
||||
const auto* page_ptr =
|
||||
base_ptr + static_cast<long_index_t>(physical_page) * page_stride_k;
|
||||
window.set_bottom_tensor_view_data_ptr(page_ptr);
|
||||
// Limit SRD num_records to one page worth of elements.
|
||||
// Without this, the SRD claims validity for [page_ptr, page_ptr +
|
||||
// full_buffer_size), which extends far beyond the allocated buffer when rebased to
|
||||
// high pages. On gfx950, the hardware may validate the full SRD range against page
|
||||
// table permissions, causing faults on freed/protected memory beyond the buffer.
|
||||
window.set_bottom_tensor_view_buffer_size(page_stride_k);
|
||||
window.init_raw();
|
||||
}
|
||||
};
|
||||
|
||||
// SRD rebasing for V: only for page_size >= kN0 (all threads on same page).
|
||||
// For page_size < kN0: either flat loads (kUseGlobalLoad) or full offsets handle
|
||||
// addressing.
|
||||
auto rebase_v_window = [&](auto& window, index_t physical_page) {
|
||||
if constexpr(kPageBlockSize >= kN0)
|
||||
{
|
||||
// readfirstlane: make physical_page provably wave-uniform so the
|
||||
// resulting SRD lands in SGPRs (required by buffer load instructions).
|
||||
physical_page = __builtin_amdgcn_readfirstlane(physical_page);
|
||||
const auto* base_ptr =
|
||||
v_dram_block_window_tmp.get_bottom_tensor_view().buf_.p_data_;
|
||||
const auto* page_ptr =
|
||||
base_ptr + static_cast<long_index_t>(physical_page) * page_stride_v;
|
||||
window.set_bottom_tensor_view_data_ptr(page_ptr);
|
||||
window.set_bottom_tensor_view_buffer_size(page_stride_v);
|
||||
window.init_raw();
|
||||
}
|
||||
};
|
||||
|
||||
// Initial K SRD rebase
|
||||
// Initial K SRD rebase (no-op for page_size < kN0, uses flat loads instead)
|
||||
rebase_k_window(k_dram_window, k_physical_pages[number<0>{}]);
|
||||
|
||||
constexpr auto k_oob_ck = bool_constant<true>{};
|
||||
@@ -874,12 +890,13 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
|
||||
kKVMemoryLayout,
|
||||
false,
|
||||
kN0,
|
||||
kVectorSize>(v_physical_pages_k2,
|
||||
stride_v,
|
||||
page_stride_v,
|
||||
v_coord,
|
||||
v_offsets_k2,
|
||||
current_seq_k);
|
||||
kVectorSize,
|
||||
kUseGlobalLoad>(v_physical_pages_k2,
|
||||
stride_v,
|
||||
page_stride_v,
|
||||
v_coord,
|
||||
v_offsets_k2,
|
||||
current_seq_k);
|
||||
|
||||
static_for<0, V_KIterInner, 1>{}([&](auto k1) {
|
||||
constexpr auto idx = number<k1.value + k2.value * V_KIterInner>{};
|
||||
@@ -899,9 +916,20 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
|
||||
kKVMemoryLayout,
|
||||
false,
|
||||
kN0,
|
||||
kVectorSize>(
|
||||
kVectorSize,
|
||||
kUseGlobalLoad>(
|
||||
v_physical_pages, stride_v, page_stride_v, v_coord, v_offsets, current_seq_k);
|
||||
}
|
||||
|
||||
// v_offsets semantics — see the four-case addressing-strategy block above
|
||||
// kNeedFullOffset in kv_offset_array_transform. Three cases reach this lambda:
|
||||
// Case 1 (kPageBlockSize >= kN0): within-page offset; page base in SRD.
|
||||
// Case 2 (page_size < kN0, kUseGlobalLoad): within-page offset; page base computed
|
||||
// by tile_scatter_gather::load() from
|
||||
// physical_pages_.
|
||||
// Case 3 (page_size < kN0, !kUseGlobalLoad, == kNeedFullOffset):
|
||||
// FULL offset (page * stride + within),
|
||||
// carried in the 32-bit voffset (<2GB cap).
|
||||
};
|
||||
|
||||
// Prefetch V physical pages early to hide buffer load latency
|
||||
@@ -915,11 +943,32 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
|
||||
v_offsets,
|
||||
number<1>{}, // HsGatherDim
|
||||
number<1>{}, // NumCoord
|
||||
VPageIndexYDims);
|
||||
VPageIndexYDims,
|
||||
bool_constant<kUseGlobalLoad>{},
|
||||
page_stride_v);
|
||||
if constexpr(kUseGlobalLoad)
|
||||
{
|
||||
v_dram_window.update_physical_pages(v_physical_pages);
|
||||
}
|
||||
|
||||
// Initial V SRD rebase
|
||||
// Initial V SRD rebase. Single source of truth: rebase_v_window's own
|
||||
// `if constexpr(kPageBlockSize >= kN0)` makes this a no-op for case 2/3.
|
||||
// Do not re-add an outer guard here — it would duplicate the inner check
|
||||
// and drift if the lambda's gating condition ever changes.
|
||||
rebase_v_window(v_dram_window, v_physical_pages[number<0>{}]);
|
||||
|
||||
// Save the *current* tile's V physical pages into v_dram_window before
|
||||
// prefetch_v_physical_pages overwrites the v_physical_pages buffer with the
|
||||
// *next* tile's pages. Case-2 only (kUseGlobalLoad); case-1/3 don't read
|
||||
// physical_pages_ from the window. Encapsulating the save+prefetch pair
|
||||
// here makes the ordering invariant unmissable when a fourth prefetch site
|
||||
// is added later.
|
||||
auto save_and_prefetch_v_pages = [&](auto k_loop_start) {
|
||||
if constexpr(kUseGlobalLoad)
|
||||
v_dram_window.update_physical_pages(v_physical_pages);
|
||||
prefetch_v_physical_pages(k_loop_start);
|
||||
};
|
||||
|
||||
// prefetch K tile
|
||||
async_load_tile_raw(
|
||||
k_lds_store(LdsSeq.at(number<0>{})), k_dram_window, number<-1>{}, k_oob_ck, k_pre_np);
|
||||
@@ -972,7 +1021,7 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
|
||||
}
|
||||
|
||||
// Prefetch V physical pages early - overlaps with GEMM0 computation
|
||||
prefetch_v_physical_pages(number<kK1>{});
|
||||
save_and_prefetch_v_pages(number<kK1>{});
|
||||
|
||||
// STAGE 1, QK gemm
|
||||
clear_tile(s_acc); // initialize C
|
||||
@@ -1166,7 +1215,7 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
|
||||
// Prefetch V physical pages early - overlaps with softmax computation
|
||||
if constexpr(k1_loops > 1)
|
||||
{
|
||||
prefetch_v_physical_pages(number<2 * kK1>{});
|
||||
save_and_prefetch_v_pages(number<2 * kK1>{});
|
||||
}
|
||||
|
||||
auto m_local = block_tile_reduce<SMPLComputeDataType>(
|
||||
@@ -1220,8 +1269,7 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
|
||||
v_dram_window,
|
||||
{0,
|
||||
kK1}); // will have scratch if move this right after load_tile(v_dram)...
|
||||
v_buf = load_tile(
|
||||
v_dram_window, number<-1>{}, bool_constant<false>{}); // load next v_buf
|
||||
v_buf = load_tile(v_dram_window, number<-1>{}, bool_constant<false>{});
|
||||
update_v_offsets(number<2 * kK1>{});
|
||||
v_dram_window.update_page_idx(v_offsets);
|
||||
rebase_v_window(v_dram_window, v_physical_pages[number<0>{}]);
|
||||
@@ -1390,8 +1438,7 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
|
||||
static_for<0, k1_loops - 1, 1>{}([&](auto i_k1) {
|
||||
if constexpr(i_k1 != 0 && i_k1 < k1_loops - 1)
|
||||
{
|
||||
v_buf = load_tile(
|
||||
v_dram_window, number<-1>{}, bool_constant<false>{}); // load next v_buf
|
||||
v_buf = load_tile(v_dram_window, number<-1>{}, bool_constant<false>{});
|
||||
// Update V offsets using previously prefetched physical pages
|
||||
update_v_offsets(number<(2 + i_k1.value) * kK1>{});
|
||||
v_dram_window.update_page_idx(v_offsets);
|
||||
@@ -1401,7 +1448,7 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
|
||||
// Prefetch V physical pages for NEXT iteration - overlaps with GEMM1
|
||||
if constexpr(i_k1 + 1 < k1_loops - 1)
|
||||
{
|
||||
prefetch_v_physical_pages(number<(2 + i_k1.value + 1) * kK1>{});
|
||||
save_and_prefetch_v_pages(number<(2 + i_k1.value + 1) * kK1>{});
|
||||
}
|
||||
|
||||
block_sync_lds();
|
||||
@@ -1481,9 +1528,12 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
|
||||
kKVMemoryLayout,
|
||||
true,
|
||||
kN0,
|
||||
kVectorSize>(
|
||||
kVectorSize,
|
||||
kUseGlobalLoad>(
|
||||
k_physical_pages, stride_k, page_stride_k, k_coord, k_offsets, current_seq_k);
|
||||
k_dram_window.update_page_idx(k_offsets);
|
||||
if constexpr(kUseGlobalLoad)
|
||||
k_dram_window.update_physical_pages(k_physical_pages);
|
||||
rebase_k_window(k_dram_window, k_physical_pages[number<0>{}]);
|
||||
|
||||
// After sink→window transition (i_total_loops == num_sink_loop), V window
|
||||
|
||||
@@ -117,6 +117,12 @@ struct BlockFmhaBatchPrefillPipelineProblem
|
||||
static_assert((kPageBlockSize & (kPageBlockSize - 1)) == 0,
|
||||
"kPageBlockSize must be power of two");
|
||||
|
||||
// KV cache load addressing mode. GLOBAL_LOAD_LDS handles >2GB pools via
|
||||
// 64-bit addressing; BUFFER_LOAD (default) uses SRD buffer_load for the
|
||||
// <2GB fast path. The 2GB bound = INT32_MAX byte offset, matching CK's
|
||||
// existing TwoGB convention.
|
||||
static constexpr auto kKVLoadMode = Traits_::kKVLoadMode;
|
||||
|
||||
static constexpr index_t kVectorSize = 16 / sizeof(KDataType_); // Dwordx4
|
||||
static constexpr auto kKVMemoryLayout = Traits_::kKVMemoryLayout;
|
||||
static constexpr auto kKVLookupTable = Traits_::kKVLookupTable;
|
||||
|
||||
@@ -5,6 +5,7 @@
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp"
|
||||
#include "ck_tile/ops/fmha/block/block_attention_kv_load_mode_enum.hpp"
|
||||
#include "ck_tile/ops/fmha/block/block_attention_kvcache_layout_enum.hpp"
|
||||
#include "ck_tile/ops/fmha/block/block_attention_quant_scale_enum.hpp"
|
||||
#include "ck_tile/ops/fmha/block/block_rotary_embedding.hpp"
|
||||
@@ -58,7 +59,9 @@ template <bool kPadSeqLenQ_ /* padding for seqlen_q */,
|
||||
BlockAttentionKVCacheMemoryLayoutEnum kKVMemoryLayout_ =
|
||||
BlockAttentionKVCacheMemoryLayoutEnum::VECTORIZED_LAYOUT,
|
||||
BlockAttentionKVCacheLookupTableEnum kKVLookupTable_ =
|
||||
BlockAttentionKVCacheLookupTableEnum::SGLANG_PAGE_TABLE_1D>
|
||||
BlockAttentionKVCacheLookupTableEnum::SGLANG_PAGE_TABLE_1D,
|
||||
BlockAttentionKVCacheLoadModeEnum kKVLoadMode_ =
|
||||
BlockAttentionKVCacheLoadModeEnum::BUFFER_LOAD>
|
||||
struct TileFmhaBatchPrefillTraits : public TileFmhaTraits<kPadSeqLenQ_,
|
||||
kPadSeqLenK_,
|
||||
kPadHeadDimQ_,
|
||||
@@ -76,6 +79,7 @@ struct TileFmhaBatchPrefillTraits : public TileFmhaTraits<kPadSeqLenQ_,
|
||||
static constexpr auto kKVMemoryLayout = kKVMemoryLayout_;
|
||||
static constexpr auto kKVLookupTable = kKVLookupTable_;
|
||||
static constexpr index_t kPageBlockSize = kPageBlockSize_;
|
||||
static constexpr auto kKVLoadMode = kKVLoadMode_;
|
||||
static_assert(kKVMemoryLayout == BlockAttentionKVCacheMemoryLayoutEnum::VECTORIZED_LAYOUT ||
|
||||
kKVMemoryLayout == BlockAttentionKVCacheMemoryLayoutEnum::LINEAR_LAYOUT,
|
||||
"Batch prefill only supports vectorized or linear KV cache layout.");
|
||||
|
||||
Reference in New Issue
Block a user