use statically_indexed_array instead of c-style array.

This commit is contained in:
Jeff Huang
2026-02-03 09:00:42 +08:00
parent e1af9b7afb
commit 4933100b0f

View File

@@ -21,7 +21,8 @@ namespace ck_tile {
// - Crosses pages: per-token lookup
// - Single page: lane0 lookup once, broadcast to all
// Output: physical_pages array with kLoopCount elements
template <typename CoordVecType,
template <typename IndexArrayType,
typename CoordVecType,
index_t kCoordAxis,
index_t kPageBlockSize,
index_t kLoopStart,
@@ -33,7 +34,7 @@ template <typename CoordVecType,
CK_TILE_DEVICE void load_physical_pages(const index_t* page_idx,
const CoordVecType& coord_vec,
index_t global_seq_offset,
index_t (&physical_pages)[kLoopCount])
IndexArrayType& physical_pages)
{
static constexpr index_t kLog2PageSize = [] {
index_t shift = 0;
@@ -54,8 +55,8 @@ CK_TILE_DEVICE void load_physical_pages(const index_t* page_idx,
static_for<0, kLoopCount, 1>{}([&](auto k0) {
const index_t global_token_idx =
global_seq_offset + thread_coord_start + kLoopStart + kLoopStride * k0.value;
const index_t page_id = global_token_idx >> kLog2PageSize;
physical_pages[k0.value] = page_idx[page_id];
const index_t page_id = global_token_idx >> kLog2PageSize;
physical_pages[k0] = page_idx[page_id];
});
}
else
@@ -73,7 +74,7 @@ CK_TILE_DEVICE void load_physical_pages(const index_t* page_idx,
static_for<0, kLoopCount, 1>{}([&](auto k0) {
const index_t global_token_idx =
global_seq_offset + thread_coord_start + kLoopStart + kLoopStride * k0.value;
physical_pages[k0.value] = page_idx[global_token_idx];
physical_pages[k0] = page_idx[global_token_idx];
});
}
else if constexpr(kVTileCrossesPages)
@@ -83,8 +84,8 @@ CK_TILE_DEVICE void load_physical_pages(const index_t* page_idx,
static_for<0, kLoopCount, 1>{}([&](auto k0) {
const index_t global_token_idx =
global_seq_offset + thread_coord_start + kLoopStart + kLoopStride * k0.value;
const index_t page_id = global_token_idx >> kLog2PageSize;
physical_pages[k0.value] = page_idx[page_id];
const index_t page_id = global_token_idx >> kLog2PageSize;
physical_pages[k0] = page_idx[page_id];
});
}
else
@@ -96,7 +97,7 @@ CK_TILE_DEVICE void load_physical_pages(const index_t* page_idx,
const index_t shared_physical_page = page_idx[lane0_page_id];
static_for<0, kLoopCount, 1>{}(
[&](auto k0) { physical_pages[k0.value] = shared_physical_page; });
[&](auto k0) { physical_pages[k0] = shared_physical_page; });
}
}
}
@@ -123,7 +124,7 @@ CK_TILE_DEVICE void load_physical_pages(const index_t* page_idx,
// LINEAR_LAYOUT: [page, token_in_page, head_dim]
// VECTORIZED_LAYOUT: [page, token_in_page/kVectorSize, head_dim, kVectorSize]
//
template <typename OffsetVecType,
template <typename IndexArrayType,
typename CoordVecType,
index_t kCoordAxis,
index_t kPageBlockSize,
@@ -134,11 +135,11 @@ template <typename OffsetVecType,
bool kIsKcache,
index_t kN0,
index_t kVectorSize>
CK_TILE_HOST_DEVICE void kv_offset_array_transform(const index_t (&physical_pages)[kLoopCount],
CK_TILE_HOST_DEVICE void kv_offset_array_transform(const IndexArrayType& physical_pages,
const index_t& stride_token,
const index_t& stride_page_block,
const CoordVecType& coord_vec,
OffsetVecType& kv_offset_vec,
IndexArrayType& kv_offset_vec,
index_t global_seq_offset = 0)
{
static constexpr index_t kLog2PageSize = [] {
@@ -164,7 +165,7 @@ CK_TILE_HOST_DEVICE void kv_offset_array_transform(const index_t (&physical_page
const index_t global_token_idx =
global_seq_offset + thread_coord_start + kLoopStart + kLoopStride * k0.value;
const index_t token_idx_in_page = global_token_idx & kInPageOffsetMask;
const index_t physical_page = physical_pages[k0.value];
const index_t physical_page = physical_pages[k0];
kv_offset_vec[k0] = static_cast<long_index_t>(physical_page) * stride_page_block +
static_cast<long_index_t>(token_idx_in_page) * stride_token;
@@ -181,7 +182,7 @@ CK_TILE_HOST_DEVICE void kv_offset_array_transform(const index_t (&physical_page
const index_t global_token_idx =
global_seq_offset + thread_coord_start + kLoopStart + kLoopStride * k0.value;
const index_t token_idx_in_page = global_token_idx & kInPageOffsetMask;
const index_t physical_page = physical_pages[k0.value];
const index_t physical_page = physical_pages[k0];
const long_index_t page_base_offset =
static_cast<long_index_t>(physical_page) * stride_page_block;
@@ -574,8 +575,9 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
// Load physical pages first, then compute offsets.
// k_physical_pages can be reused for descale lookup later.
index_t k_physical_pages[NRepeat] = {};
load_physical_pages<decltype(k_coord),
statically_indexed_array<index_t, NRepeat> k_physical_pages{};
load_physical_pages<statically_indexed_array<index_t, NRepeat>,
decltype(k_coord),
0,
kPageBlockSize,
0,
@@ -737,7 +739,7 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
statically_indexed_array<index_t, V_PageIdxRepeat> v_offsets;
// V physical pages array for use with kv_offset_array_transform
// For V_KIterOuter > 1, we need V_PageIdxRepeat elements; otherwise V_KIterInner
index_t v_physical_pages[V_PageIdxRepeat] = {};
statically_indexed_array<index_t, V_PageIdxRepeat> v_physical_pages{};
// Prefetch V physical pages - can be called early to hide buffer load latency
auto prefetch_v_physical_pages = [&](auto k_loop_start) {
@@ -746,8 +748,9 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
{
static_for<0, V_KIterOuter, 1>{}([&](auto k2) {
// Load physical pages for this k2 slice into the appropriate portion of array
index_t v_physical_pages_k2[V_KIterInner] = {};
load_physical_pages<decltype(v_coord),
statically_indexed_array<index_t, V_KIterInner> v_physical_pages_k2{};
load_physical_pages<statically_indexed_array<index_t, V_KIterInner>,
decltype(v_coord),
I1,
kPageBlockSize,
kLoopStart + k2.value * V_KLanes * V_KIterInner,
@@ -759,14 +762,15 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
// Copy to merged array
static_for<0, V_KIterInner, 1>{}([&](auto k1) {
constexpr auto idx = k1.value + k2.value * V_KIterInner;
v_physical_pages[idx] = v_physical_pages_k2[k1.value];
constexpr auto idx = number<k1.value + k2.value * V_KIterInner>{};
v_physical_pages[idx] = v_physical_pages_k2[k1];
});
});
}
else
{
load_physical_pages<decltype(v_coord),
load_physical_pages<statically_indexed_array<index_t, V_KIterInner>,
decltype(v_coord),
I1,
kPageBlockSize,
kLoopStart,
@@ -789,10 +793,10 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
static_for<0, V_KIterOuter, 1>{}([&](auto k2) {
statically_indexed_array<index_t, V_KIterInner> v_offsets_k2;
// Extract physical pages for this k2 slice
index_t v_physical_pages_k2[V_KIterInner];
statically_indexed_array<index_t, V_KIterInner> v_physical_pages_k2;
static_for<0, V_KIterInner, 1>{}([&](auto k1) {
constexpr auto idx = k1.value + k2.value * V_KIterInner;
v_physical_pages_k2[k1.value] = v_physical_pages[idx];
constexpr auto idx = number<k1.value + k2.value * V_KIterInner>{};
v_physical_pages_k2[k1] = v_physical_pages[idx];
});
kv_offset_array_transform<statically_indexed_array<index_t, V_KIterInner>,
@@ -893,7 +897,7 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::KV_BLOCKSCALE)
{
const index_t scale_offset =
k_physical_pages[0] * kv_block_descale_stride_block +
k_physical_pages[number<0>{}] * kv_block_descale_stride_block +
block_indices.kv_head_idx * kv_block_descale_stride_head;
k_descale = kv_block_descale_ptr[scale_offset + 0 * kv_block_descale_stride_kv];
v_descale = kv_block_descale_ptr[scale_offset + 1 * kv_block_descale_stride_kv];
@@ -1342,7 +1346,8 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
k_dram_window.set_window_origin(k_dram_block_window.get_window_origin());
// KV_BLOCKSCALE: reload physical pages for the new tile
load_physical_pages<decltype(k_coord),
load_physical_pages<statically_indexed_array<index_t, NRepeat>,
decltype(k_coord),
0,
kPageBlockSize,
0,