mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-29 19:28:33 +00:00
use statically_indexed_array instead of c-style array.
This commit is contained in:
@@ -21,7 +21,8 @@ namespace ck_tile {
|
||||
// - Crosses pages: per-token lookup
|
||||
// - Single page: lane0 lookup once, broadcast to all
|
||||
// Output: physical_pages array with kLoopCount elements
|
||||
template <typename CoordVecType,
|
||||
template <typename IndexArrayType,
|
||||
typename CoordVecType,
|
||||
index_t kCoordAxis,
|
||||
index_t kPageBlockSize,
|
||||
index_t kLoopStart,
|
||||
@@ -33,7 +34,7 @@ template <typename CoordVecType,
|
||||
CK_TILE_DEVICE void load_physical_pages(const index_t* page_idx,
|
||||
const CoordVecType& coord_vec,
|
||||
index_t global_seq_offset,
|
||||
index_t (&physical_pages)[kLoopCount])
|
||||
IndexArrayType& physical_pages)
|
||||
{
|
||||
static constexpr index_t kLog2PageSize = [] {
|
||||
index_t shift = 0;
|
||||
@@ -54,8 +55,8 @@ CK_TILE_DEVICE void load_physical_pages(const index_t* page_idx,
|
||||
static_for<0, kLoopCount, 1>{}([&](auto k0) {
|
||||
const index_t global_token_idx =
|
||||
global_seq_offset + thread_coord_start + kLoopStart + kLoopStride * k0.value;
|
||||
const index_t page_id = global_token_idx >> kLog2PageSize;
|
||||
physical_pages[k0.value] = page_idx[page_id];
|
||||
const index_t page_id = global_token_idx >> kLog2PageSize;
|
||||
physical_pages[k0] = page_idx[page_id];
|
||||
});
|
||||
}
|
||||
else
|
||||
@@ -73,7 +74,7 @@ CK_TILE_DEVICE void load_physical_pages(const index_t* page_idx,
|
||||
static_for<0, kLoopCount, 1>{}([&](auto k0) {
|
||||
const index_t global_token_idx =
|
||||
global_seq_offset + thread_coord_start + kLoopStart + kLoopStride * k0.value;
|
||||
physical_pages[k0.value] = page_idx[global_token_idx];
|
||||
physical_pages[k0] = page_idx[global_token_idx];
|
||||
});
|
||||
}
|
||||
else if constexpr(kVTileCrossesPages)
|
||||
@@ -83,8 +84,8 @@ CK_TILE_DEVICE void load_physical_pages(const index_t* page_idx,
|
||||
static_for<0, kLoopCount, 1>{}([&](auto k0) {
|
||||
const index_t global_token_idx =
|
||||
global_seq_offset + thread_coord_start + kLoopStart + kLoopStride * k0.value;
|
||||
const index_t page_id = global_token_idx >> kLog2PageSize;
|
||||
physical_pages[k0.value] = page_idx[page_id];
|
||||
const index_t page_id = global_token_idx >> kLog2PageSize;
|
||||
physical_pages[k0] = page_idx[page_id];
|
||||
});
|
||||
}
|
||||
else
|
||||
@@ -96,7 +97,7 @@ CK_TILE_DEVICE void load_physical_pages(const index_t* page_idx,
|
||||
const index_t shared_physical_page = page_idx[lane0_page_id];
|
||||
|
||||
static_for<0, kLoopCount, 1>{}(
|
||||
[&](auto k0) { physical_pages[k0.value] = shared_physical_page; });
|
||||
[&](auto k0) { physical_pages[k0] = shared_physical_page; });
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -123,7 +124,7 @@ CK_TILE_DEVICE void load_physical_pages(const index_t* page_idx,
|
||||
// LINEAR_LAYOUT: [page, token_in_page, head_dim]
|
||||
// VECTORIZED_LAYOUT: [page, token_in_page/kVectorSize, head_dim, kVectorSize]
|
||||
//
|
||||
template <typename OffsetVecType,
|
||||
template <typename IndexArrayType,
|
||||
typename CoordVecType,
|
||||
index_t kCoordAxis,
|
||||
index_t kPageBlockSize,
|
||||
@@ -134,11 +135,11 @@ template <typename OffsetVecType,
|
||||
bool kIsKcache,
|
||||
index_t kN0,
|
||||
index_t kVectorSize>
|
||||
CK_TILE_HOST_DEVICE void kv_offset_array_transform(const index_t (&physical_pages)[kLoopCount],
|
||||
CK_TILE_HOST_DEVICE void kv_offset_array_transform(const IndexArrayType& physical_pages,
|
||||
const index_t& stride_token,
|
||||
const index_t& stride_page_block,
|
||||
const CoordVecType& coord_vec,
|
||||
OffsetVecType& kv_offset_vec,
|
||||
IndexArrayType& kv_offset_vec,
|
||||
index_t global_seq_offset = 0)
|
||||
{
|
||||
static constexpr index_t kLog2PageSize = [] {
|
||||
@@ -164,7 +165,7 @@ CK_TILE_HOST_DEVICE void kv_offset_array_transform(const index_t (&physical_page
|
||||
const index_t global_token_idx =
|
||||
global_seq_offset + thread_coord_start + kLoopStart + kLoopStride * k0.value;
|
||||
const index_t token_idx_in_page = global_token_idx & kInPageOffsetMask;
|
||||
const index_t physical_page = physical_pages[k0.value];
|
||||
const index_t physical_page = physical_pages[k0];
|
||||
|
||||
kv_offset_vec[k0] = static_cast<long_index_t>(physical_page) * stride_page_block +
|
||||
static_cast<long_index_t>(token_idx_in_page) * stride_token;
|
||||
@@ -181,7 +182,7 @@ CK_TILE_HOST_DEVICE void kv_offset_array_transform(const index_t (&physical_page
|
||||
const index_t global_token_idx =
|
||||
global_seq_offset + thread_coord_start + kLoopStart + kLoopStride * k0.value;
|
||||
const index_t token_idx_in_page = global_token_idx & kInPageOffsetMask;
|
||||
const index_t physical_page = physical_pages[k0.value];
|
||||
const index_t physical_page = physical_pages[k0];
|
||||
|
||||
const long_index_t page_base_offset =
|
||||
static_cast<long_index_t>(physical_page) * stride_page_block;
|
||||
@@ -574,8 +575,9 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
|
||||
|
||||
// Load physical pages first, then compute offsets.
|
||||
// k_physical_pages can be reused for descale lookup later.
|
||||
index_t k_physical_pages[NRepeat] = {};
|
||||
load_physical_pages<decltype(k_coord),
|
||||
statically_indexed_array<index_t, NRepeat> k_physical_pages{};
|
||||
load_physical_pages<statically_indexed_array<index_t, NRepeat>,
|
||||
decltype(k_coord),
|
||||
0,
|
||||
kPageBlockSize,
|
||||
0,
|
||||
@@ -737,7 +739,7 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
|
||||
statically_indexed_array<index_t, V_PageIdxRepeat> v_offsets;
|
||||
// V physical pages array for use with kv_offset_array_transform
|
||||
// For V_KIterOuter > 1, we need V_PageIdxRepeat elements; otherwise V_KIterInner
|
||||
index_t v_physical_pages[V_PageIdxRepeat] = {};
|
||||
statically_indexed_array<index_t, V_PageIdxRepeat> v_physical_pages{};
|
||||
|
||||
// Prefetch V physical pages - can be called early to hide buffer load latency
|
||||
auto prefetch_v_physical_pages = [&](auto k_loop_start) {
|
||||
@@ -746,8 +748,9 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
|
||||
{
|
||||
static_for<0, V_KIterOuter, 1>{}([&](auto k2) {
|
||||
// Load physical pages for this k2 slice into the appropriate portion of array
|
||||
index_t v_physical_pages_k2[V_KIterInner] = {};
|
||||
load_physical_pages<decltype(v_coord),
|
||||
statically_indexed_array<index_t, V_KIterInner> v_physical_pages_k2{};
|
||||
load_physical_pages<statically_indexed_array<index_t, V_KIterInner>,
|
||||
decltype(v_coord),
|
||||
I1,
|
||||
kPageBlockSize,
|
||||
kLoopStart + k2.value * V_KLanes * V_KIterInner,
|
||||
@@ -759,14 +762,15 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
|
||||
|
||||
// Copy to merged array
|
||||
static_for<0, V_KIterInner, 1>{}([&](auto k1) {
|
||||
constexpr auto idx = k1.value + k2.value * V_KIterInner;
|
||||
v_physical_pages[idx] = v_physical_pages_k2[k1.value];
|
||||
constexpr auto idx = number<k1.value + k2.value * V_KIterInner>{};
|
||||
v_physical_pages[idx] = v_physical_pages_k2[k1];
|
||||
});
|
||||
});
|
||||
}
|
||||
else
|
||||
{
|
||||
load_physical_pages<decltype(v_coord),
|
||||
load_physical_pages<statically_indexed_array<index_t, V_KIterInner>,
|
||||
decltype(v_coord),
|
||||
I1,
|
||||
kPageBlockSize,
|
||||
kLoopStart,
|
||||
@@ -789,10 +793,10 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
|
||||
static_for<0, V_KIterOuter, 1>{}([&](auto k2) {
|
||||
statically_indexed_array<index_t, V_KIterInner> v_offsets_k2;
|
||||
// Extract physical pages for this k2 slice
|
||||
index_t v_physical_pages_k2[V_KIterInner];
|
||||
statically_indexed_array<index_t, V_KIterInner> v_physical_pages_k2;
|
||||
static_for<0, V_KIterInner, 1>{}([&](auto k1) {
|
||||
constexpr auto idx = k1.value + k2.value * V_KIterInner;
|
||||
v_physical_pages_k2[k1.value] = v_physical_pages[idx];
|
||||
constexpr auto idx = number<k1.value + k2.value * V_KIterInner>{};
|
||||
v_physical_pages_k2[k1] = v_physical_pages[idx];
|
||||
});
|
||||
|
||||
kv_offset_array_transform<statically_indexed_array<index_t, V_KIterInner>,
|
||||
@@ -893,7 +897,7 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
|
||||
if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::KV_BLOCKSCALE)
|
||||
{
|
||||
const index_t scale_offset =
|
||||
k_physical_pages[0] * kv_block_descale_stride_block +
|
||||
k_physical_pages[number<0>{}] * kv_block_descale_stride_block +
|
||||
block_indices.kv_head_idx * kv_block_descale_stride_head;
|
||||
k_descale = kv_block_descale_ptr[scale_offset + 0 * kv_block_descale_stride_kv];
|
||||
v_descale = kv_block_descale_ptr[scale_offset + 1 * kv_block_descale_stride_kv];
|
||||
@@ -1342,7 +1346,8 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
|
||||
k_dram_window.set_window_origin(k_dram_block_window.get_window_origin());
|
||||
|
||||
// KV_BLOCKSCALE: reload physical pages for the new tile
|
||||
load_physical_pages<decltype(k_coord),
|
||||
load_physical_pages<statically_indexed_array<index_t, NRepeat>,
|
||||
decltype(k_coord),
|
||||
0,
|
||||
kPageBlockSize,
|
||||
0,
|
||||
|
||||
Reference in New Issue
Block a user