mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 22:39:03 +00:00
[FMHA] Enable page size 16 for batch prefill kernel (#3568)
* [FMHA] Enable page size 16 for batch prefill kernel * Refactor batch prefill KV offset logic to simplify template arguments - Remove redundant `kLog2PageSize` and `kIsVTileFitsInPage` from template args. - Add static assert to forbid `page_size=1` with vectorized layout.
This commit is contained in:
@@ -17,12 +17,12 @@ template <typename OffsetVecType,
|
||||
typename CoordVecType,
|
||||
index_t kCoordAxis,
|
||||
index_t kPageBlockSize,
|
||||
index_t kLog2PageSize,
|
||||
index_t kLoopStart,
|
||||
index_t kLoopCount,
|
||||
index_t kLoopStride,
|
||||
BlockAttentionKVCacheMemoryLayoutEnum kKVMemoryLayout,
|
||||
bool kIsKcache,
|
||||
index_t kN0,
|
||||
index_t kVectorSize>
|
||||
CK_TILE_HOST_DEVICE void kv_offset_array_transform(const index_t* page_idx,
|
||||
const index_t& stride_token,
|
||||
@@ -31,6 +31,17 @@ CK_TILE_HOST_DEVICE void kv_offset_array_transform(const index_t* page_idx,
|
||||
OffsetVecType& kv_offset_vec,
|
||||
index_t global_seq_offset = 0)
|
||||
{
|
||||
static constexpr index_t kLog2PageSize = [] {
|
||||
index_t shift = 0;
|
||||
index_t val = kPageBlockSize;
|
||||
while(val > 1)
|
||||
{
|
||||
val >>= 1;
|
||||
shift++;
|
||||
}
|
||||
return shift;
|
||||
}();
|
||||
|
||||
const index_t& thread_coord_start = coord_vec[kCoordAxis];
|
||||
constexpr index_t kInPageOffsetMask = (1 << kLog2PageSize) - 1;
|
||||
if constexpr(kIsKcache)
|
||||
@@ -48,7 +59,10 @@ CK_TILE_HOST_DEVICE void kv_offset_array_transform(const index_t* page_idx,
|
||||
else
|
||||
{
|
||||
// for v offsets
|
||||
if constexpr(kLog2PageSize == 0 &&
|
||||
// for page_size > 1, the V tile crosses pages when page_size is not a multiple of kN0.
|
||||
static constexpr bool kVTileCrossesPages =
|
||||
(kPageBlockSize > 1) && (kPageBlockSize % kN0 != 0);
|
||||
if constexpr(kPageBlockSize == 1 &&
|
||||
kKVMemoryLayout == BlockAttentionKVCacheMemoryLayoutEnum::LINEAR_LAYOUT)
|
||||
{
|
||||
// page size = 1, per-token page lookup.
|
||||
@@ -64,11 +78,42 @@ CK_TILE_HOST_DEVICE void kv_offset_array_transform(const index_t* page_idx,
|
||||
kv_offset_vec[k0] = page_base_offset;
|
||||
});
|
||||
}
|
||||
else
|
||||
else if constexpr(kVTileCrossesPages)
|
||||
{
|
||||
// This path handles page_size > 1 and/or non-linear KV layout, where page_idx is
|
||||
// indexed by page_id (token_idx >> log2_page_size) with an in-page offset.
|
||||
// Assumes the V tile stays within a single page so lane0 can broadcast the page id.
|
||||
// V tile crosses multiple pages (e.g., page_size < kN0), so page_id must be computed
|
||||
// per token.
|
||||
static_for<0, kLoopCount, 1>{}([&](auto k0) {
|
||||
const index_t global_token_idx =
|
||||
global_seq_offset + thread_coord_start + kLoopStart + kLoopStride * k0.value;
|
||||
const index_t page_id = global_token_idx >> kLog2PageSize;
|
||||
const index_t token_idx_in_page = global_token_idx & kInPageOffsetMask;
|
||||
|
||||
const long_index_t page_base_offset =
|
||||
static_cast<long_index_t>(page_idx[page_id]) * stride_page_block;
|
||||
|
||||
if constexpr(kKVMemoryLayout ==
|
||||
BlockAttentionKVCacheMemoryLayoutEnum::VECTORIZED_LAYOUT)
|
||||
{
|
||||
// Vectorized layout uses a packed [token/kVectorSize, head_dim, kVectorSize]
|
||||
// address pattern.
|
||||
const long_index_t token_offset =
|
||||
static_cast<long_index_t>((token_idx_in_page / kVectorSize) *
|
||||
(stride_token * kVectorSize)) +
|
||||
(token_idx_in_page % kVectorSize);
|
||||
|
||||
kv_offset_vec[k0] = page_base_offset + token_offset;
|
||||
}
|
||||
else // BlockAttentionKVCacheMemoryLayoutEnum::LINEAR_LAYOUT
|
||||
{
|
||||
kv_offset_vec[k0] = page_base_offset +
|
||||
static_cast<long_index_t>(token_idx_in_page) * stride_token;
|
||||
}
|
||||
});
|
||||
}
|
||||
else // !kVTileCrossesPages
|
||||
{
|
||||
// V tile is fully contained in one page, so page_id is shared.
|
||||
// Use lane0 to compute page_id once and broadcast page_base_offset.
|
||||
const index_t lane0_start = __builtin_amdgcn_readfirstlane(thread_coord_start);
|
||||
const index_t lane0_page_id =
|
||||
(global_seq_offset + lane0_start + kLoopStart) >> kLog2PageSize;
|
||||
@@ -77,8 +122,9 @@ CK_TILE_HOST_DEVICE void kv_offset_array_transform(const index_t* page_idx,
|
||||
static_cast<long_index_t>(page_idx[lane0_page_id]) * stride_page_block;
|
||||
|
||||
static_for<0, kLoopCount, 1>{}([&](auto k0) {
|
||||
// kLoopStride allows non-unit token spacing in the tile distribution.
|
||||
const index_t token_idx_in_page =
|
||||
(global_seq_offset + thread_coord_start + kLoopStart + k0.value) &
|
||||
(global_seq_offset + thread_coord_start + kLoopStart + kLoopStride * k0.value) &
|
||||
kInPageOffsetMask;
|
||||
|
||||
if constexpr(kKVMemoryLayout ==
|
||||
@@ -142,7 +188,6 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
|
||||
static constexpr index_t kQKHeaddim = BlockFmhaShape::kQKHeaddim;
|
||||
static constexpr index_t kSubQKHeaddim = BlockFmhaShape::kSubQKHeaddim;
|
||||
static constexpr index_t kPageBlockSize = Problem::kPageBlockSize;
|
||||
static constexpr index_t kLog2PageSize = Problem::kLog2PageSize;
|
||||
static constexpr index_t kVectorSize = Problem::kVectorSize;
|
||||
static constexpr auto I0 = number<0>{};
|
||||
static constexpr auto I1 = number<1>{};
|
||||
@@ -150,9 +195,6 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
|
||||
static constexpr auto I3 = number<3>{};
|
||||
|
||||
static_assert(kSubQKHeaddim <= 256, "hdim bigger than 256 is not suitable for this pipeline!");
|
||||
static_assert(kPageBlockSize % kN0 == 0 || kLog2PageSize == 0,
|
||||
"Page size must be 1, or a multiple of the tile size (kN0).");
|
||||
|
||||
static constexpr bool kIsGroupMode = Problem::kIsGroupMode;
|
||||
// TODO: seq_q always support padding, hdim_q/v support multiple of vector(like 8x)
|
||||
// only need special care about seq_k padding (oob need set -INF of p instead of zero)
|
||||
@@ -456,12 +498,12 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
|
||||
decltype(k_coord),
|
||||
0,
|
||||
kPageBlockSize,
|
||||
kLog2PageSize,
|
||||
0,
|
||||
NRepeat,
|
||||
kN0 / NRepeat,
|
||||
kKVMemoryLayout,
|
||||
true,
|
||||
kN0,
|
||||
kVectorSize>(
|
||||
page_idx, stride_k, page_stride_k, k_coord, k_offsets, current_seq_k);
|
||||
|
||||
@@ -501,12 +543,12 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
|
||||
decltype(v_coord),
|
||||
VPageIndexDim,
|
||||
kPageBlockSize,
|
||||
kLog2PageSize,
|
||||
0,
|
||||
V_KRepeat,
|
||||
1,
|
||||
kKVMemoryLayout,
|
||||
false,
|
||||
kN0,
|
||||
kVectorSize>(
|
||||
page_idx, stride_v, page_stride_v, v_coord, v_offsets, current_seq_k);
|
||||
|
||||
@@ -587,12 +629,12 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
|
||||
decltype(v_coord),
|
||||
VPageIndexDim,
|
||||
kPageBlockSize,
|
||||
kLog2PageSize,
|
||||
kK1,
|
||||
V_KRepeat,
|
||||
1,
|
||||
kKVMemoryLayout,
|
||||
false,
|
||||
kN0,
|
||||
kVectorSize>(
|
||||
page_idx, stride_v, page_stride_v, v_coord, v_offsets, current_seq_k);
|
||||
v_dram_window.update_page_idx(v_offsets);
|
||||
@@ -761,12 +803,12 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
|
||||
decltype(v_coord),
|
||||
VPageIndexDim,
|
||||
kPageBlockSize,
|
||||
kLog2PageSize,
|
||||
2 * kK1,
|
||||
V_KRepeat,
|
||||
1,
|
||||
kKVMemoryLayout,
|
||||
false,
|
||||
kN0,
|
||||
kVectorSize>(
|
||||
page_idx, stride_v, page_stride_v, v_coord, v_offsets, current_seq_k);
|
||||
v_dram_window.update_page_idx(v_offsets);
|
||||
@@ -900,12 +942,12 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
|
||||
decltype(v_coord),
|
||||
VPageIndexDim,
|
||||
kPageBlockSize,
|
||||
kLog2PageSize,
|
||||
(2 + i_k1.value) * kK1,
|
||||
V_KRepeat,
|
||||
1,
|
||||
kKVMemoryLayout,
|
||||
false,
|
||||
kN0,
|
||||
kVectorSize>(
|
||||
page_idx, stride_v, page_stride_v, v_coord, v_offsets, current_seq_k);
|
||||
v_dram_window.update_page_idx(v_offsets);
|
||||
@@ -957,12 +999,12 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
|
||||
decltype(k_coord),
|
||||
0,
|
||||
kPageBlockSize,
|
||||
kLog2PageSize,
|
||||
0,
|
||||
NRepeat,
|
||||
kN0 / NRepeat,
|
||||
kKVMemoryLayout,
|
||||
true,
|
||||
kN0,
|
||||
kVectorSize>(
|
||||
page_idx, stride_k, page_stride_k, k_coord, k_offsets, current_seq_k);
|
||||
k_dram_window.update_page_idx(k_offsets);
|
||||
|
||||
@@ -107,16 +107,6 @@ struct BlockFmhaBatchPrefillPipelineProblem
|
||||
static_assert(kPageBlockSize > 0, "kPageBlockSize must be positive");
|
||||
static_assert((kPageBlockSize & (kPageBlockSize - 1)) == 0,
|
||||
"kPageBlockSize must be power of two");
|
||||
static constexpr index_t kLog2PageSize = []() constexpr {
|
||||
index_t shift = 0;
|
||||
index_t val = kPageBlockSize_;
|
||||
while(val > 1)
|
||||
{
|
||||
val >>= 1;
|
||||
shift++;
|
||||
}
|
||||
return shift;
|
||||
}();
|
||||
|
||||
static constexpr index_t kVectorSize = 16 / sizeof(KDataType_); // Dwordx4
|
||||
static constexpr auto kKVMemoryLayout = Traits_::kKVMemoryLayout;
|
||||
@@ -126,6 +116,8 @@ struct BlockFmhaBatchPrefillPipelineProblem
|
||||
|
||||
static_assert(BlockFmhaShape_::kQKHeaddim % kVectorSize == 0,
|
||||
"kQKHeaddim must be divisible by kVectorSize");
|
||||
static_assert(!(kPageBlockSize == 1 && kIsVectorizedLayout),
|
||||
"page_size=1 only supports linear KV cache layout");
|
||||
static_assert(!kIsVectorizedLayout || kPageBlockSize % kVectorSize == 0,
|
||||
"kPageBlockSize must be divisible by kVectorSize for vectorized layout");
|
||||
static_assert(kIsGroupMode_, "Batch prefill requires group mode");
|
||||
|
||||
Reference in New Issue
Block a user