[CK_TILE] fix(fmha): support >2GB KV cache in batch prefill via template dispatch (#6653)

## Motivation

The CK batch prefill kernel previously failed (silent overflow + page
faults) when the KV cache exceeded 2 GB, blocking long-context inference
workloads (e.g., 128K+ token contexts with paged KV).

Two distinct failure modes were addressed:

1. **>4GB SRD overflow (`page_size < kN0`):** The SRD
`buffer_load_dwordx4` path uses a 32-bit `voffset` register; for small
page sizes the rebased SRD spans the full KV pool and the offset wraps
past 2 GB, corrupting K/V loads.
2. **gfx950 page-table fault (`page_size >= kN0`):** On CDNA4 the
hardware validates the **full SRD `num_records` range** against
page-table permissions (CDNA3 only checks per-instruction `voffset`).
After per-tile SRD rebase, an un-trimmed `num_records` field extends
past the live page and faults on freed/protected memory.

## Technical Details

**Two-mode `tile_scatter_gather` selected by the `kUseGlobalLoad`
template parameter:**

| Case | `page_size` | KV cache size | Mode | Load path | Addressing |
|---|---|---|---|---|---|
| 1 | `>= kN0` (large pages) | any | SRD (`kUseGlobalLoad=false`) |
`buffer_load_dwordx4` | 32-bit `voffset`, bounded by per-page rebase |
| 2 | `< kN0` (small pages) | `<= 2 GB` | SRD (`kUseGlobalLoad=false`) |
`buffer_load_dwordx4` | 32-bit `voffset`, fits in INT32 byte range |
| 3 | `< kN0` (small pages) | `> 2 GB` | Global-load
(`kUseGlobalLoad=true`) | `async_load_tile_raw_flat` (K) +
`load_tile_flat` (V) | 64-bit |

**Dispatch:** the auto-gen API layer (`fmha_batch_prefill.py`) selects
the kernel instantiation at launch from `(page_block_size,
num_total_pages * batch_stride_k * kElementBytes)`, so the small-page
penalty is paid only when correctness requires it.

**gfx950 SRD `num_records` trimming:** in the K and V rebase lambdas of
`block_fmha_batch_prefill_pipeline_qr_ks_vs_async`,
`set_bottom_tensor_view_buffer_size(page_stride_k/v)` is called after
each rebase to constrain `num_records` to the live page. Required for
CDNA4 page-table validation; harmless on CDNA3.

**Pipeline sync for the global-load path:**
- V uses synchronous `load_tile_flat`; K uses
`async_load_tile_raw_flat`.
- `v_physical_pages_current` is double-buffered so the V flat load
doesn't race against the next iteration's K rebase computation.

**Arch guards:** `global_load_lds` intrinsics are gated to `__gfx94__` /
`__gfx950__` (CDNA3+). Other architectures hit a `dependent_false`
static_assert with a descriptive message.

**Device-side assertion convention:** SRD setters use
`__builtin_assume(cond)` (hint-only) rather than `<cassert>`'s
`assert()`. The latter introduces an `__assert_fail` call whose register
pressure scatters the K-SRD scalar register window across conditional
branches, corrupting `buffer_load_dwordx4` on gfx950.


## Test Plan

Tested on both MI308 (gfx942) and MI355 (gfx950) via the aiter wrapper
test suite. All coverage lives in **`op_tests/test_batch_prefill.py`**:

- **Functional matrix (96 cases)** — `test_batch_prefill`: `page_size ∈
{1, 16, 1024}` × `kv_layout ∈ {linear, vectorized}` × `dtype ∈ {bf16,
fp8 quant variants}` × `causal` × `soft_cap` × `LSE` × `batch_size ∈ {1,
4}` (parametrized to exercise per-sequence SRD rebase across batch
boundaries).
- **>2 GB coverage** — `test_batch_prefill_large_kvcache`: extended to
allocate a 5 GB+ KV cache pool and exercise both `kUseGlobalLoad=true`
(small-page) and `kUseGlobalLoad=false` (large-page rebase) paths.
Includes both single-batch and multi-batch (`batch_size=4`) cases to
exercise per-sequence SRD rebase across the >2 GB pool.
- Numerical reference: PyTorch SDPA, per-batch loop with `atol` / `rtol`
from the existing batch prefill test harness.

## Test Result

| Arch | `test_batch_prefill` | `test_batch_prefill_large_kvcache` (>2
GB) |
|------|----------------------|---------------------|
| MI308 (gfx942) | All passed | Passed |
| MI355 (gfx950) | All passed | Passed |

**Performance impact (gfx950, hot SRD path):**
- +2.67% kernel-time on `seqlen=1024 / page_sz=1024 / bf16 / sglang /
causal / soft_cap=30`, attributable in full to the two
`set_bottom_tensor_view_buffer_size` calls in the K/V rebase lambdas
(5-run median, signal/noise ≈ 9×).
- This cost is **mandatory for gfx950 correctness** on >2 GB workloads —
removing the setters re-introduces page-faults.
- gfx942: 0 regressions in the same range (all configs ≤ +0.97%).

## Submission Checklist

- [ ] Look over the contributing guidelines at
https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
This commit is contained in:
Jeff Huang
2026-04-24 07:08:41 +08:00
committed by GitHub
parent d7475e8125
commit b3d45b6fdb
10 changed files with 540 additions and 116 deletions

View File

@@ -3,6 +3,7 @@
#pragma once
#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp"
#include "ck_tile/ops/fmha/block/block_attention_kv_load_mode_enum.hpp"
#include "ck_tile/ops/fmha/block/block_attention_kvcache_layout_enum.hpp"
#include "ck_tile/ops/fmha/block/block_attention_quant_scale_enum.hpp"
#include "ck_tile/ops/fmha/block/block_dropout.hpp"

View File

@@ -0,0 +1,17 @@
// SPDX-License-Identifier: MIT
// Copyright (c) Advanced Micro Devices, Inc. All rights reserved.
#pragma once
namespace ck_tile {
// KV cache load addressing mode selector for batch_prefill / paged-attention pipelines.
// - BUFFER_LOAD: SGPR-based SRD via buffer_load_* (default; 32-bit byte addressing, <2GB pool)
// - GLOBAL_LOAD_LDS: direct global_load_lds_* (64-bit addressing, required for >2GB KV cache)
enum class BlockAttentionKVCacheLoadModeEnum
{
BUFFER_LOAD = 0,
GLOBAL_LOAD_LDS = 1,
};
} // namespace ck_tile

View File

@@ -6,6 +6,7 @@
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp"
#include "ck_tile/ops/fmha/block/block_attention_kv_load_mode_enum.hpp"
#include "ck_tile/ops/fmha/block/block_attention_kvcache_layout_enum.hpp"
#include "ck_tile/ops/fmha/block/block_attention_quant_scale_enum.hpp"
#include "ck_tile/ops/fmha/block/block_dropout.hpp"
@@ -134,7 +135,8 @@ template <typename IndexArrayType,
BlockAttentionKVCacheMemoryLayoutEnum kKVMemoryLayout,
bool kIsKcache,
index_t kN0,
index_t kVectorSize>
index_t kVectorSize,
bool kUseGlobalLoad_ = false>
CK_TILE_HOST_DEVICE void kv_offset_array_transform(const IndexArrayType& physical_pages,
const index_t& stride_token,
const index_t& stride_page_block,
@@ -156,81 +158,65 @@ CK_TILE_HOST_DEVICE void kv_offset_array_transform(const IndexArrayType& physica
const index_t& thread_coord_start = coord_vec[kCoordAxis];
constexpr index_t kInPageOffsetMask = (1 << kLog2PageSize) - 1;
if constexpr(kIsKcache)
{
// K cache: per-token lookup
// Each token may be on a different page, so we use physical_pages[k0] for each.
static_for<0, kLoopCount, 1>{}([&](auto k0) {
const index_t global_token_idx =
global_seq_offset + thread_coord_start + kLoopStart + kLoopStride * k0.value;
const index_t token_idx_in_page = global_token_idx & kInPageOffsetMask;
// Addressing strategy — four cases controlled by (kPageBlockSize vs kN0, kUseGlobalLoad_):
//
// Case 1: kPageBlockSize >= kN0
// SRD is rebased per-tile to the page base (rebase_{k,v}_window in caller).
// Page base is absorbed into the SRD's 48-bit base pointer (SGPR-resident).
// This function writes within-page offset only.
//
// Case 2: kPageBlockSize < kN0 && kUseGlobalLoad_
// SRD cannot be rebased (multi-page wave). Loads use global_load_lds_*; the full
// 64-bit address is computed by tile_scatter_gather::load() in
// include/ck_tile/core/tensor/tile_scatter_gather.hpp from physical_pages_ +
// page_stride_elements_. This function writes within-page offset only.
//
// Case 3: kPageBlockSize < kN0 && !kUseGlobalLoad_ (kNeedFullOffset == true)
// SRD base is the entire KV buffer; the only place to encode page identity
// is the voffset itself. This function writes the FULL offset:
// page * stride_page_block + within_page
// Limited to <2GB total KV bytes by 32-bit voffset hardware width.
//
// Case 4: kPageBlockSize >= kN0 && kUseGlobalLoad_
// Not emitted by codegen. Backstop static_assert in
// BlockFmhaBatchPrefillPipelineQRKSVSAsync.
constexpr bool kNeedFullOffset = (kPageBlockSize < kN0) && !kUseGlobalLoad_;
if constexpr(kPageBlockSize >= kN0)
static_for<0, kLoopCount, 1>{}([&](auto k0) {
const index_t global_token_idx =
global_seq_offset + thread_coord_start + kLoopStart + kLoopStride * k0.value;
const index_t token_idx_in_page = global_token_idx & kInPageOffsetMask;
// Within-page offset (layout-dependent for V cache with VECTORIZED_LAYOUT)
const index_t within_page = [&]() {
if constexpr(!kIsKcache && kKVMemoryLayout ==
BlockAttentionKVCacheMemoryLayoutEnum::VECTORIZED_LAYOUT)
{
// SRD rebasing mode: within-page offset only.
// The full page base is handled by rebasing the SRD pointer.
kv_offset_vec[k0] = token_idx_in_page * stride_token;
return (token_idx_in_page / kVectorSize) * (stride_token * kVectorSize) +
(token_idx_in_page % kVectorSize);
}
else
{
// Full global offset (original code path for ps1, ps16, etc.)
const index_t physical_page = physical_pages[k0];
kv_offset_vec[k0] =
physical_page * stride_page_block + token_idx_in_page * stride_token;
return token_idx_in_page * stride_token;
}
});
}
else // V cache
{
// V cache: use physical_pages[k0] for each token
// physical_pages was already populated correctly by load_physical_pages(), handling:
// - page_size=1: page_idx maps token_idx -> physical_page directly
// - V tile crosses pages: per-token page lookup
// - V tile in single page: lane0 lookup with broadcast to all lanes
static_for<0, kLoopCount, 1>{}([&](auto k0) {
const index_t global_token_idx =
global_seq_offset + thread_coord_start + kLoopStart + kLoopStride * k0.value;
const index_t token_idx_in_page = global_token_idx & kInPageOffsetMask;
}();
if constexpr(kPageBlockSize >= kN0)
{
// SRD rebasing mode: within-page offset only.
// The full page base is handled by rebasing the SRD pointer.
if constexpr(kKVMemoryLayout ==
BlockAttentionKVCacheMemoryLayoutEnum::VECTORIZED_LAYOUT)
{
const index_t token_offset =
(token_idx_in_page / kVectorSize) * (stride_token * kVectorSize) +
(token_idx_in_page % kVectorSize);
kv_offset_vec[k0] = token_offset;
}
else
{
kv_offset_vec[k0] = token_idx_in_page * stride_token;
}
}
else
{
// Full global offset (original code path for ps1, ps16, etc.)
const index_t physical_page = physical_pages[k0];
const long_index_t page_base_offset =
static_cast<long_index_t>(physical_page) * stride_page_block;
if constexpr(kKVMemoryLayout ==
BlockAttentionKVCacheMemoryLayoutEnum::VECTORIZED_LAYOUT)
{
const index_t token_offset =
(token_idx_in_page / kVectorSize) * (stride_token * kVectorSize) +
(token_idx_in_page % kVectorSize);
kv_offset_vec[k0] = page_base_offset + token_offset;
}
else
{
kv_offset_vec[k0] = page_base_offset + token_idx_in_page * stride_token;
}
}
});
}
// SRD + page_size < kN0: add page base to form complete voffset for buffer_load.
//
// 32-bit by hardware: SRD buffer_load voffset is fundamentally 32-bit (CDNA3 MUBUF
// microcode format), so this branch is only reachable when total KV bytes fit in
// INT32_MAX. The kUseGlobalLoad_ template path handles the >2GB case via 64-bit
// global_load_lds_*; widening kv_offset_vec here would not lift the 2GB ceiling
// because the hardware truncates voffset regardless.
if constexpr(kNeedFullOffset)
{
kv_offset_vec[k0] = physical_pages[k0] * stride_page_block + within_page;
}
else
{
kv_offset_vec[k0] = within_page;
}
});
}
// a variation of qr/ks/vs, where we use async copy to load k (potentially v in the future)
@@ -270,10 +256,21 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
static constexpr index_t kSubQKHeaddim = BlockFmhaShape::kSubQKHeaddim;
static constexpr index_t kPageBlockSize = Problem::kPageBlockSize;
static constexpr index_t kVectorSize = Problem::kVectorSize;
static constexpr auto I0 = number<0>{};
static constexpr auto I1 = number<1>{};
static constexpr auto I2 = number<2>{};
static constexpr auto I3 = number<3>{};
// Single load-mode selector for the whole pipeline. GLOBAL_LOAD_LDS routes K/V
// tiles through global_load_lds_* (handles >2GB KV cache); BUFFER_LOAD uses SRD
// buffer_load_*. The enum is named at the trait/Problem level; internally we
// derive a bool helper to keep `if constexpr` sites narrow. Codegen only emits
// GLOBAL_LOAD_LDS arms when page_size < kN0; the static_assert is a backstop.
static constexpr auto kKVLoadMode = Problem::kKVLoadMode;
static constexpr bool kUseGlobalLoad =
(kKVLoadMode == BlockAttentionKVCacheLoadModeEnum::GLOBAL_LOAD_LDS);
static_assert(!kUseGlobalLoad || (kPageBlockSize < kN0),
"GLOBAL_LOAD_LDS load mode is only valid when kPageBlockSize < kN0; "
"codegen should not emit this instantiation otherwise.");
static constexpr auto I0 = number<0>{};
static constexpr auto I1 = number<1>{};
static constexpr auto I2 = number<2>{};
static constexpr auto I3 = number<3>{};
static_assert(kSubQKHeaddim <= 256, "hdim bigger than 256 is not suitable for this pipeline!");
static constexpr bool kIsGroupMode = Problem::kIsGroupMode;
@@ -626,19 +623,26 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
kKVMemoryLayout,
true,
kN0,
kVectorSize>(
kVectorSize,
kUseGlobalLoad>(
k_physical_pages, stride_k, page_stride_k, k_coord, k_offsets, current_seq_k);
auto k_dram_window = make_tile_scatter_gather(k_dram_block_window.get_bottom_tensor_view(),
k_dram_block_window.get_window_lengths(),
k_dram_block_window.get_window_origin(),
k_dist,
k_offsets); // K DRAM tile window for
k_offsets,
bool_constant<kUseGlobalLoad>{},
page_stride_k);
if constexpr(kUseGlobalLoad)
{
k_dram_window.update_physical_pages(k_physical_pages);
}
k_dram_window.init_raw();
// SRD rebasing: move the buffer descriptor base pointer to each page's start address
// using 48-bit pointer arithmetic, so voffset only needs the small within-page offset.
// Only applies when kPageBlockSize >= kN0 (all threads in a wave access the same page).
// SRD rebasing for K: only for page_size >= kN0 (all threads on same page).
// For page_size < kN0: either flat loads (kUseGlobalLoad) or full offsets handle
// addressing.
auto rebase_k_window = [&](auto& window, index_t physical_page) {
if constexpr(kPageBlockSize >= kN0)
{
@@ -649,24 +653,36 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
const auto* page_ptr =
base_ptr + static_cast<long_index_t>(physical_page) * page_stride_k;
window.set_bottom_tensor_view_data_ptr(page_ptr);
// Limit SRD num_records to one page worth of elements.
// Without this, the SRD claims validity for [page_ptr, page_ptr +
// full_buffer_size), which extends far beyond the allocated buffer when rebased to
// high pages. On gfx950, the hardware may validate the full SRD range against page
// table permissions, causing faults on freed/protected memory beyond the buffer.
window.set_bottom_tensor_view_buffer_size(page_stride_k);
window.init_raw();
}
};
// SRD rebasing for V: only for page_size >= kN0 (all threads on same page).
// For page_size < kN0: either flat loads (kUseGlobalLoad) or full offsets handle
// addressing.
auto rebase_v_window = [&](auto& window, index_t physical_page) {
if constexpr(kPageBlockSize >= kN0)
{
// readfirstlane: make physical_page provably wave-uniform so the
// resulting SRD lands in SGPRs (required by buffer load instructions).
physical_page = __builtin_amdgcn_readfirstlane(physical_page);
const auto* base_ptr =
v_dram_block_window_tmp.get_bottom_tensor_view().buf_.p_data_;
const auto* page_ptr =
base_ptr + static_cast<long_index_t>(physical_page) * page_stride_v;
window.set_bottom_tensor_view_data_ptr(page_ptr);
window.set_bottom_tensor_view_buffer_size(page_stride_v);
window.init_raw();
}
};
// Initial K SRD rebase
// Initial K SRD rebase (no-op for page_size < kN0, uses flat loads instead)
rebase_k_window(k_dram_window, k_physical_pages[number<0>{}]);
constexpr auto k_oob_ck = bool_constant<true>{};
@@ -874,12 +890,13 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
kKVMemoryLayout,
false,
kN0,
kVectorSize>(v_physical_pages_k2,
stride_v,
page_stride_v,
v_coord,
v_offsets_k2,
current_seq_k);
kVectorSize,
kUseGlobalLoad>(v_physical_pages_k2,
stride_v,
page_stride_v,
v_coord,
v_offsets_k2,
current_seq_k);
static_for<0, V_KIterInner, 1>{}([&](auto k1) {
constexpr auto idx = number<k1.value + k2.value * V_KIterInner>{};
@@ -899,9 +916,20 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
kKVMemoryLayout,
false,
kN0,
kVectorSize>(
kVectorSize,
kUseGlobalLoad>(
v_physical_pages, stride_v, page_stride_v, v_coord, v_offsets, current_seq_k);
}
// v_offsets semantics — see the four-case addressing-strategy block above
// kNeedFullOffset in kv_offset_array_transform. Three cases reach this lambda:
// Case 1 (kPageBlockSize >= kN0): within-page offset; page base in SRD.
// Case 2 (page_size < kN0, kUseGlobalLoad): within-page offset; page base computed
// by tile_scatter_gather::load() from
// physical_pages_.
// Case 3 (page_size < kN0, !kUseGlobalLoad, == kNeedFullOffset):
// FULL offset (page * stride + within),
// carried in the 32-bit voffset (<2GB cap).
};
// Prefetch V physical pages early to hide buffer load latency
@@ -915,11 +943,32 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
v_offsets,
number<1>{}, // HsGatherDim
number<1>{}, // NumCoord
VPageIndexYDims);
VPageIndexYDims,
bool_constant<kUseGlobalLoad>{},
page_stride_v);
if constexpr(kUseGlobalLoad)
{
v_dram_window.update_physical_pages(v_physical_pages);
}
// Initial V SRD rebase
// Initial V SRD rebase. Single source of truth: rebase_v_window's own
// `if constexpr(kPageBlockSize >= kN0)` makes this a no-op for case 2/3.
// Do not re-add an outer guard here — it would duplicate the inner check
// and drift if the lambda's gating condition ever changes.
rebase_v_window(v_dram_window, v_physical_pages[number<0>{}]);
// Save the *current* tile's V physical pages into v_dram_window before
// prefetch_v_physical_pages overwrites the v_physical_pages buffer with the
// *next* tile's pages. Case-2 only (kUseGlobalLoad); case-1/3 don't read
// physical_pages_ from the window. Encapsulating the save+prefetch pair
// here makes the ordering invariant unmissable when a fourth prefetch site
// is added later.
auto save_and_prefetch_v_pages = [&](auto k_loop_start) {
if constexpr(kUseGlobalLoad)
v_dram_window.update_physical_pages(v_physical_pages);
prefetch_v_physical_pages(k_loop_start);
};
// prefetch K tile
async_load_tile_raw(
k_lds_store(LdsSeq.at(number<0>{})), k_dram_window, number<-1>{}, k_oob_ck, k_pre_np);
@@ -972,7 +1021,7 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
}
// Prefetch V physical pages early - overlaps with GEMM0 computation
prefetch_v_physical_pages(number<kK1>{});
save_and_prefetch_v_pages(number<kK1>{});
// STAGE 1, QK gemm
clear_tile(s_acc); // initialize C
@@ -1166,7 +1215,7 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
// Prefetch V physical pages early - overlaps with softmax computation
if constexpr(k1_loops > 1)
{
prefetch_v_physical_pages(number<2 * kK1>{});
save_and_prefetch_v_pages(number<2 * kK1>{});
}
auto m_local = block_tile_reduce<SMPLComputeDataType>(
@@ -1220,8 +1269,7 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
v_dram_window,
{0,
kK1}); // will have scratch if move this right after load_tile(v_dram)...
v_buf = load_tile(
v_dram_window, number<-1>{}, bool_constant<false>{}); // load next v_buf
v_buf = load_tile(v_dram_window, number<-1>{}, bool_constant<false>{});
update_v_offsets(number<2 * kK1>{});
v_dram_window.update_page_idx(v_offsets);
rebase_v_window(v_dram_window, v_physical_pages[number<0>{}]);
@@ -1390,8 +1438,7 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
static_for<0, k1_loops - 1, 1>{}([&](auto i_k1) {
if constexpr(i_k1 != 0 && i_k1 < k1_loops - 1)
{
v_buf = load_tile(
v_dram_window, number<-1>{}, bool_constant<false>{}); // load next v_buf
v_buf = load_tile(v_dram_window, number<-1>{}, bool_constant<false>{});
// Update V offsets using previously prefetched physical pages
update_v_offsets(number<(2 + i_k1.value) * kK1>{});
v_dram_window.update_page_idx(v_offsets);
@@ -1401,7 +1448,7 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
// Prefetch V physical pages for NEXT iteration - overlaps with GEMM1
if constexpr(i_k1 + 1 < k1_loops - 1)
{
prefetch_v_physical_pages(number<(2 + i_k1.value + 1) * kK1>{});
save_and_prefetch_v_pages(number<(2 + i_k1.value + 1) * kK1>{});
}
block_sync_lds();
@@ -1481,9 +1528,12 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
kKVMemoryLayout,
true,
kN0,
kVectorSize>(
kVectorSize,
kUseGlobalLoad>(
k_physical_pages, stride_k, page_stride_k, k_coord, k_offsets, current_seq_k);
k_dram_window.update_page_idx(k_offsets);
if constexpr(kUseGlobalLoad)
k_dram_window.update_physical_pages(k_physical_pages);
rebase_k_window(k_dram_window, k_physical_pages[number<0>{}]);
// After sink→window transition (i_total_loops == num_sink_loop), V window

View File

@@ -117,6 +117,12 @@ struct BlockFmhaBatchPrefillPipelineProblem
static_assert((kPageBlockSize & (kPageBlockSize - 1)) == 0,
"kPageBlockSize must be power of two");
// KV cache load addressing mode. GLOBAL_LOAD_LDS handles >2GB pools via
// 64-bit addressing; BUFFER_LOAD (default) uses SRD buffer_load for the
// <2GB fast path. The 2GB bound = INT32_MAX byte offset, matching CK's
// existing TwoGB convention.
static constexpr auto kKVLoadMode = Traits_::kKVLoadMode;
static constexpr index_t kVectorSize = 16 / sizeof(KDataType_); // Dwordx4
static constexpr auto kKVMemoryLayout = Traits_::kKVMemoryLayout;
static constexpr auto kKVLookupTable = Traits_::kKVLookupTable;

View File

@@ -5,6 +5,7 @@
#include "ck_tile/core.hpp"
#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp"
#include "ck_tile/ops/fmha/block/block_attention_kv_load_mode_enum.hpp"
#include "ck_tile/ops/fmha/block/block_attention_kvcache_layout_enum.hpp"
#include "ck_tile/ops/fmha/block/block_attention_quant_scale_enum.hpp"
#include "ck_tile/ops/fmha/block/block_rotary_embedding.hpp"
@@ -58,7 +59,9 @@ template <bool kPadSeqLenQ_ /* padding for seqlen_q */,
BlockAttentionKVCacheMemoryLayoutEnum kKVMemoryLayout_ =
BlockAttentionKVCacheMemoryLayoutEnum::VECTORIZED_LAYOUT,
BlockAttentionKVCacheLookupTableEnum kKVLookupTable_ =
BlockAttentionKVCacheLookupTableEnum::SGLANG_PAGE_TABLE_1D>
BlockAttentionKVCacheLookupTableEnum::SGLANG_PAGE_TABLE_1D,
BlockAttentionKVCacheLoadModeEnum kKVLoadMode_ =
BlockAttentionKVCacheLoadModeEnum::BUFFER_LOAD>
struct TileFmhaBatchPrefillTraits : public TileFmhaTraits<kPadSeqLenQ_,
kPadSeqLenK_,
kPadHeadDimQ_,
@@ -76,6 +79,7 @@ struct TileFmhaBatchPrefillTraits : public TileFmhaTraits<kPadSeqLenQ_,
static constexpr auto kKVMemoryLayout = kKVMemoryLayout_;
static constexpr auto kKVLookupTable = kKVLookupTable_;
static constexpr index_t kPageBlockSize = kPageBlockSize_;
static constexpr auto kKVLoadMode = kKVLoadMode_;
static_assert(kKVMemoryLayout == BlockAttentionKVCacheMemoryLayoutEnum::VECTORIZED_LAYOUT ||
kKVMemoryLayout == BlockAttentionKVCacheMemoryLayoutEnum::LINEAR_LAYOUT,
"Batch prefill only supports vectorized or linear KV cache layout.");