[FMHA] Enable page size 16 for batch prefill kernel (#3568)

* [FMHA] Enable page size 16 for batch prefill kernel

* Refactor batch prefill KV offset logic to simplify template arguments
- Remove redundant `kLog2PageSize` and `kIsVTileFitsInPage` from template args.
- Add static assert to forbid `page_size=1` with vectorized layout.
This commit is contained in:
Jeff Huang
2026-01-15 22:11:44 +08:00
committed by GitHub
parent 5122637215
commit 993d3e2f0e
3 changed files with 62 additions and 28 deletions

View File

@@ -17,12 +17,12 @@ template <typename OffsetVecType,
typename CoordVecType,
index_t kCoordAxis,
index_t kPageBlockSize,
index_t kLog2PageSize,
index_t kLoopStart,
index_t kLoopCount,
index_t kLoopStride,
BlockAttentionKVCacheMemoryLayoutEnum kKVMemoryLayout,
bool kIsKcache,
index_t kN0,
index_t kVectorSize>
CK_TILE_HOST_DEVICE void kv_offset_array_transform(const index_t* page_idx,
const index_t& stride_token,
@@ -31,6 +31,17 @@ CK_TILE_HOST_DEVICE void kv_offset_array_transform(const index_t* page_idx,
OffsetVecType& kv_offset_vec,
index_t global_seq_offset = 0)
{
static constexpr index_t kLog2PageSize = [] {
index_t shift = 0;
index_t val = kPageBlockSize;
while(val > 1)
{
val >>= 1;
shift++;
}
return shift;
}();
const index_t& thread_coord_start = coord_vec[kCoordAxis];
constexpr index_t kInPageOffsetMask = (1 << kLog2PageSize) - 1;
if constexpr(kIsKcache)
@@ -48,7 +59,10 @@ CK_TILE_HOST_DEVICE void kv_offset_array_transform(const index_t* page_idx,
else
{
// for v offsets
if constexpr(kLog2PageSize == 0 &&
// for page_size > 1, the V tile crosses pages when page_size is not a multiple of kN0.
static constexpr bool kVTileCrossesPages =
(kPageBlockSize > 1) && (kPageBlockSize % kN0 != 0);
if constexpr(kPageBlockSize == 1 &&
kKVMemoryLayout == BlockAttentionKVCacheMemoryLayoutEnum::LINEAR_LAYOUT)
{
// page size = 1, per-token page lookup.
@@ -64,11 +78,42 @@ CK_TILE_HOST_DEVICE void kv_offset_array_transform(const index_t* page_idx,
kv_offset_vec[k0] = page_base_offset;
});
}
else
else if constexpr(kVTileCrossesPages)
{
// This path handles page_size > 1 and/or non-linear KV layout, where page_idx is
// indexed by page_id (token_idx >> log2_page_size) with an in-page offset.
// Assumes the V tile stays within a single page so lane0 can broadcast the page id.
// V tile crosses multiple pages (e.g., page_size < kN0), so page_id must be computed
// per token.
static_for<0, kLoopCount, 1>{}([&](auto k0) {
const index_t global_token_idx =
global_seq_offset + thread_coord_start + kLoopStart + kLoopStride * k0.value;
const index_t page_id = global_token_idx >> kLog2PageSize;
const index_t token_idx_in_page = global_token_idx & kInPageOffsetMask;
const long_index_t page_base_offset =
static_cast<long_index_t>(page_idx[page_id]) * stride_page_block;
if constexpr(kKVMemoryLayout ==
BlockAttentionKVCacheMemoryLayoutEnum::VECTORIZED_LAYOUT)
{
// Vectorized layout uses a packed [token/kVectorSize, head_dim, kVectorSize]
// address pattern.
const long_index_t token_offset =
static_cast<long_index_t>((token_idx_in_page / kVectorSize) *
(stride_token * kVectorSize)) +
(token_idx_in_page % kVectorSize);
kv_offset_vec[k0] = page_base_offset + token_offset;
}
else // BlockAttentionKVCacheMemoryLayoutEnum::LINEAR_LAYOUT
{
kv_offset_vec[k0] = page_base_offset +
static_cast<long_index_t>(token_idx_in_page) * stride_token;
}
});
}
else // !kVTileCrossesPages
{
// V tile is fully contained in one page, so page_id is shared.
// Use lane0 to compute page_id once and broadcast page_base_offset.
const index_t lane0_start = __builtin_amdgcn_readfirstlane(thread_coord_start);
const index_t lane0_page_id =
(global_seq_offset + lane0_start + kLoopStart) >> kLog2PageSize;
@@ -77,8 +122,9 @@ CK_TILE_HOST_DEVICE void kv_offset_array_transform(const index_t* page_idx,
static_cast<long_index_t>(page_idx[lane0_page_id]) * stride_page_block;
static_for<0, kLoopCount, 1>{}([&](auto k0) {
// kLoopStride allows non-unit token spacing in the tile distribution.
const index_t token_idx_in_page =
(global_seq_offset + thread_coord_start + kLoopStart + k0.value) &
(global_seq_offset + thread_coord_start + kLoopStart + kLoopStride * k0.value) &
kInPageOffsetMask;
if constexpr(kKVMemoryLayout ==
@@ -142,7 +188,6 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
static constexpr index_t kQKHeaddim = BlockFmhaShape::kQKHeaddim;
static constexpr index_t kSubQKHeaddim = BlockFmhaShape::kSubQKHeaddim;
static constexpr index_t kPageBlockSize = Problem::kPageBlockSize;
static constexpr index_t kLog2PageSize = Problem::kLog2PageSize;
static constexpr index_t kVectorSize = Problem::kVectorSize;
static constexpr auto I0 = number<0>{};
static constexpr auto I1 = number<1>{};
@@ -150,9 +195,6 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
static constexpr auto I3 = number<3>{};
static_assert(kSubQKHeaddim <= 256, "hdim bigger than 256 is not suitable for this pipeline!");
static_assert(kPageBlockSize % kN0 == 0 || kLog2PageSize == 0,
"Page size must be 1, or a multiple of the tile size (kN0).");
static constexpr bool kIsGroupMode = Problem::kIsGroupMode;
// TODO: seq_q always support padding, hdim_q/v support multiple of vector(like 8x)
// only need special care about seq_k padding (oob need set -INF of p instead of zero)
@@ -456,12 +498,12 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
decltype(k_coord),
0,
kPageBlockSize,
kLog2PageSize,
0,
NRepeat,
kN0 / NRepeat,
kKVMemoryLayout,
true,
kN0,
kVectorSize>(
page_idx, stride_k, page_stride_k, k_coord, k_offsets, current_seq_k);
@@ -501,12 +543,12 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
decltype(v_coord),
VPageIndexDim,
kPageBlockSize,
kLog2PageSize,
0,
V_KRepeat,
1,
kKVMemoryLayout,
false,
kN0,
kVectorSize>(
page_idx, stride_v, page_stride_v, v_coord, v_offsets, current_seq_k);
@@ -587,12 +629,12 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
decltype(v_coord),
VPageIndexDim,
kPageBlockSize,
kLog2PageSize,
kK1,
V_KRepeat,
1,
kKVMemoryLayout,
false,
kN0,
kVectorSize>(
page_idx, stride_v, page_stride_v, v_coord, v_offsets, current_seq_k);
v_dram_window.update_page_idx(v_offsets);
@@ -761,12 +803,12 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
decltype(v_coord),
VPageIndexDim,
kPageBlockSize,
kLog2PageSize,
2 * kK1,
V_KRepeat,
1,
kKVMemoryLayout,
false,
kN0,
kVectorSize>(
page_idx, stride_v, page_stride_v, v_coord, v_offsets, current_seq_k);
v_dram_window.update_page_idx(v_offsets);
@@ -900,12 +942,12 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
decltype(v_coord),
VPageIndexDim,
kPageBlockSize,
kLog2PageSize,
(2 + i_k1.value) * kK1,
V_KRepeat,
1,
kKVMemoryLayout,
false,
kN0,
kVectorSize>(
page_idx, stride_v, page_stride_v, v_coord, v_offsets, current_seq_k);
v_dram_window.update_page_idx(v_offsets);
@@ -957,12 +999,12 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
decltype(k_coord),
0,
kPageBlockSize,
kLog2PageSize,
0,
NRepeat,
kN0 / NRepeat,
kKVMemoryLayout,
true,
kN0,
kVectorSize>(
page_idx, stride_k, page_stride_k, k_coord, k_offsets, current_seq_k);
k_dram_window.update_page_idx(k_offsets);

View File

@@ -107,16 +107,6 @@ struct BlockFmhaBatchPrefillPipelineProblem
static_assert(kPageBlockSize > 0, "kPageBlockSize must be positive");
static_assert((kPageBlockSize & (kPageBlockSize - 1)) == 0,
"kPageBlockSize must be power of two");
static constexpr index_t kLog2PageSize = []() constexpr {
index_t shift = 0;
index_t val = kPageBlockSize_;
while(val > 1)
{
val >>= 1;
shift++;
}
return shift;
}();
static constexpr index_t kVectorSize = 16 / sizeof(KDataType_); // Dwordx4
static constexpr auto kKVMemoryLayout = Traits_::kKVMemoryLayout;
@@ -126,6 +116,8 @@ struct BlockFmhaBatchPrefillPipelineProblem
static_assert(BlockFmhaShape_::kQKHeaddim % kVectorSize == 0,
"kQKHeaddim must be divisible by kVectorSize");
static_assert(!(kPageBlockSize == 1 && kIsVectorizedLayout),
"page_size=1 only supports linear KV cache layout");
static_assert(!kIsVectorizedLayout || kPageBlockSize % kVectorSize == 0,
"kPageBlockSize must be divisible by kVectorSize for vectorized layout");
static_assert(kIsGroupMode_, "Batch prefill requires group mode");