mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 22:39:03 +00:00
[FMHA] Support page_size=1 (linear layout) in batch prefill pipeline (#3545)
- Enable page_size=1 support in batch prefill codegen (linear layout only). - Implement per-token page lookup in `kv_offset_array_transform` for page_size=1 to handle 3D input tensors correctly. - Relax `kPageBlockSize` alignment assertion for the page_size=1 case.
This commit is contained in:
@@ -36,7 +36,7 @@ DTYPE_BITS = {
|
||||
|
||||
K0_MAX_SUBMAX_MAP = {32: 32, 64: 64, 96: 128, 128: 128, 256: 256}
|
||||
|
||||
SUPPORTED_PAGE_SIZE = [128, 256, 1024]
|
||||
SUPPORTED_PAGE_SIZE = [1, 128, 256, 1024]
|
||||
SUPPORTED_KV_MEMORY_LAYOUT = ["vectorized", "linear"]
|
||||
SUPPORTED_KV_LOOKUP_TABLE = ["vllm", "sglang"]
|
||||
KV_MEMORY_LAYOUT_ENUM_MAP = {
|
||||
@@ -737,6 +737,8 @@ def get_fwd_blobs(
|
||||
|
||||
# Generate kernels for both page_size=16 and page_size=1024
|
||||
for page_size in SUPPORTED_PAGE_SIZE:
|
||||
if page_size == 1 and pipeline.F_kv_memory_layout != "linear":
|
||||
continue
|
||||
k = FmhaFwdKernel(
|
||||
F_idx=0,
|
||||
F_hdim=hdim,
|
||||
|
||||
@@ -24,9 +24,9 @@ template <typename OffsetVecType,
|
||||
BlockAttentionKVCacheMemoryLayoutEnum kKVMemoryLayout,
|
||||
bool kIsKcache,
|
||||
index_t kVectorSize>
|
||||
CK_TILE_HOST_DEVICE void kv_offset_array_transform(const index_t* page_vec,
|
||||
const index_t& stride_kv,
|
||||
const index_t& page_stride_kv,
|
||||
CK_TILE_HOST_DEVICE void kv_offset_array_transform(const index_t* page_idx,
|
||||
const index_t& stride_token,
|
||||
const index_t& stride_page_block,
|
||||
const CoordVecType& coord_vec,
|
||||
OffsetVecType& kv_offset_vec,
|
||||
index_t global_seq_offset = 0)
|
||||
@@ -39,47 +39,70 @@ CK_TILE_HOST_DEVICE void kv_offset_array_transform(const index_t* page_vec,
|
||||
static_for<0, kLoopCount, 1>{}([&](auto k0) {
|
||||
const index_t global_token_idx =
|
||||
global_seq_offset + thread_coord_start + kLoopStart + kLoopStride * k0.value;
|
||||
const index_t page_id = global_token_idx >> kLog2PageSize;
|
||||
const index_t page_offset = global_token_idx & kInPageOffsetMask;
|
||||
kv_offset_vec[k0] = static_cast<long_index_t>(page_vec[page_id]) * page_stride_kv +
|
||||
static_cast<long_index_t>(page_offset) * stride_kv;
|
||||
const index_t page_id = global_token_idx >> kLog2PageSize;
|
||||
const index_t token_idx_in_page = global_token_idx & kInPageOffsetMask;
|
||||
kv_offset_vec[k0] = static_cast<long_index_t>(page_idx[page_id]) * stride_page_block +
|
||||
static_cast<long_index_t>(token_idx_in_page) * stride_token;
|
||||
});
|
||||
}
|
||||
else
|
||||
{
|
||||
// for v offsets
|
||||
const index_t lane0_start = __builtin_amdgcn_readfirstlane(thread_coord_start);
|
||||
const index_t lane0_page_id =
|
||||
(global_seq_offset + lane0_start + kLoopStart) >> kLog2PageSize;
|
||||
if constexpr(kLog2PageSize == 0 &&
|
||||
kKVMemoryLayout == BlockAttentionKVCacheMemoryLayoutEnum::LINEAR_LAYOUT)
|
||||
{
|
||||
// page size = 1, per-token page lookup.
|
||||
// Here page_idx maps token_idx -> physical_page_id, so global_seq_offset must be
|
||||
// the absolute token index within the batch's kv_page_indices slice.
|
||||
static_for<0, kLoopCount, 1>{}([&](auto k0) {
|
||||
const index_t global_token_idx =
|
||||
global_seq_offset + thread_coord_start + kLoopStart + kLoopStride * k0.value;
|
||||
|
||||
const long_index_t page_loc =
|
||||
static_cast<long_index_t>(page_vec[lane0_page_id]) * page_stride_kv;
|
||||
const long_index_t page_base_offset =
|
||||
static_cast<long_index_t>(page_idx[global_token_idx]) * stride_page_block;
|
||||
|
||||
static_for<0, kLoopCount, 1>{}([&](auto k0) {
|
||||
const index_t page_offset =
|
||||
(global_seq_offset + thread_coord_start + kLoopStart + k0.value) &
|
||||
kInPageOffsetMask;
|
||||
kv_offset_vec[k0] = page_base_offset;
|
||||
});
|
||||
}
|
||||
else
|
||||
{
|
||||
// This path handles page_size > 1 and/or non-linear KV layout, where page_idx is
|
||||
// indexed by page_id (token_idx >> log2_page_size) with an in-page offset.
|
||||
// Assumes the V tile stays within a single page so lane0 can broadcast the page id.
|
||||
const index_t lane0_start = __builtin_amdgcn_readfirstlane(thread_coord_start);
|
||||
const index_t lane0_page_id =
|
||||
(global_seq_offset + lane0_start + kLoopStart) >> kLog2PageSize;
|
||||
|
||||
if constexpr(kKVMemoryLayout ==
|
||||
BlockAttentionKVCacheMemoryLayoutEnum::VECTORIZED_LAYOUT)
|
||||
{
|
||||
// Vectorized layout offset
|
||||
// Layout: [BlockSize/kVectorSize, HeadDim, kVectorSize]
|
||||
// Offset(s) = (s / kVectorSize) * (HeadDim * kVectorSize) + (s % kVectorSize)
|
||||
const index_t s = page_offset;
|
||||
const index_t D = stride_kv;
|
||||
const long_index_t page_base_offset =
|
||||
static_cast<long_index_t>(page_idx[lane0_page_id]) * stride_page_block;
|
||||
|
||||
const long_index_t s_offset =
|
||||
static_cast<long_index_t>((s / kVectorSize) * (D * kVectorSize)) +
|
||||
(s % kVectorSize);
|
||||
static_for<0, kLoopCount, 1>{}([&](auto k0) {
|
||||
const index_t token_idx_in_page =
|
||||
(global_seq_offset + thread_coord_start + kLoopStart + k0.value) &
|
||||
kInPageOffsetMask;
|
||||
|
||||
kv_offset_vec[k0] = page_loc + s_offset;
|
||||
}
|
||||
else // BlockAttentionKVCacheMemoryLayoutEnum::LINEAR_LAYOUT
|
||||
{
|
||||
kv_offset_vec[k0] = page_loc + static_cast<long_index_t>(page_offset) * stride_kv;
|
||||
}
|
||||
});
|
||||
if constexpr(kKVMemoryLayout ==
|
||||
BlockAttentionKVCacheMemoryLayoutEnum::VECTORIZED_LAYOUT)
|
||||
{
|
||||
// Vectorized layout offset
|
||||
// Layout: [BlockSize/kVectorSize, HeadDim, kVectorSize]
|
||||
// Offset = (token_idx_in_page / kVectorSize) * (HeadDim * kVectorSize) +
|
||||
// (token_idx_in_page % kVectorSize)
|
||||
|
||||
const long_index_t token_offset =
|
||||
static_cast<long_index_t>((token_idx_in_page / kVectorSize) *
|
||||
(stride_token * kVectorSize)) +
|
||||
(token_idx_in_page % kVectorSize);
|
||||
|
||||
kv_offset_vec[k0] = page_base_offset + token_offset;
|
||||
}
|
||||
else // BlockAttentionKVCacheMemoryLayoutEnum::LINEAR_LAYOUT
|
||||
{
|
||||
kv_offset_vec[k0] = page_base_offset +
|
||||
static_cast<long_index_t>(token_idx_in_page) * stride_token;
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -127,9 +150,8 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
|
||||
static constexpr auto I3 = number<3>{};
|
||||
|
||||
static_assert(kSubQKHeaddim <= 256, "hdim bigger than 256 is not suitable for this pipeline!");
|
||||
static_assert(kPageBlockSize % kN0 == 0,
|
||||
"V offset assumes each tile stays within a page; kPageBlockSize must be "
|
||||
"divisible by kN0.");
|
||||
static_assert(kPageBlockSize % kN0 == 0 || kLog2PageSize == 0,
|
||||
"Page size must be 1, or a multiple of the tile size (kN0).");
|
||||
|
||||
static constexpr bool kIsGroupMode = Problem::kIsGroupMode;
|
||||
// TODO: seq_q always support padding, hdim_q/v support multiple of vector(like 8x)
|
||||
|
||||
Reference in New Issue
Block a user