mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-04 13:41:24 +00:00
[FMHA] Batch Prefill Support Improvements: Change KV Cache Layout & Large Page Size Support (#3442)
* add page_block_size parameter * add is_sglang_layout to parameters * add kv_offset_array_transform to batch async for page size 16 * add kv_last_page_lens to kernel * change kv layout to [num_total_pages, page_block_size, hdim] * format * - enable codegen of batch_prefill kernels - create new problem struct BlockFmhaBatchPrefillPipelineProblem for batch prefill kernels - generate different page sizes of batch prefill kernels (1, 16) * 1. fix wrong calculation of page id in kv_offset_array_transform in gfx950 2. support page size 1024 * fix python format * change kv cache layout to [num_blocks, num_kv_heads, head_size/x, block_size, x] and [num_blocks, num_kv_heads, block_size/X, head_size, X] * 1. Introduced `kVectorSize` in BlockFmhaBatchPrefillPipelineProblem instead of using hardcode values 2. Makes batch prefill kernel traits structures inherent from fmha fwd traits 3. Add some static check for Page size, vector size, hdim, ..., etc. * [Refactor] Replace is_sglang_layout with Enums for KV cache configuration Refactored `fmha_batch_prefill` to use `BlockAttentionKVCacheMemoryLayoutEnum` (VECTORIZED/LINEAR) and `BlockAttentionKVCacheLookupTableEnum` (SGLANG_1D/VLLM_2D) instead of a single boolean. **Changes:** * Added Enum definitions in `block_attention_kvcache_layout_enum.hpp`. * Updated Kernel, Pipeline, and Traits to template on these Enums. * Implemented `kv_offset_array_transform` logic based on `kKVMemoryLayout`. * Refactored `PageBlockTableKargs` to adapt to `kKVLookupTable`. * Updated CodeGen scripts to support new parameters. This decouples memory layout from the paging mechanism, enabling flexible KV cache configurations. * 1. remove batch prefill pipeline with sk_pad=false 2. correct some comments 3. add static assert to make sure v offsets is in same page within a tile. * fix vgpr spill count * remove unnecessary t2s functions * add fp8 support for receipt 200 and 600 in fmha_bath_prefill.py * support linear kv cache layout * Remove block_table_ptr from fwd_batch_prefill_args. Instead, reuse kv_page_indices as a pointer of the lookup table. * 1. merge multiple transforms into single transform. 2. add static check to make sure vlayout is row-major. * move FmhaFwdCommonKargs::seqlen_k_ptr to VllmPageTableKargs. * update changelog --------- Co-authored-by: ltqin <letaoqin@amd.com> Co-authored-by: PoYen, Chen <PoYen.Chen@amd.com>
This commit is contained in:
@@ -0,0 +1,32 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// KV cache memory layout selector.
|
||||
//
|
||||
// Layout summary (kVectorSize = 16 / sizeof(KDataType)):
|
||||
// - VECTORIZED_LAYOUT (swizzled):
|
||||
// K: [NumBlocks, NumHeads, HeadDim/kVectorSize, PageSize, kVectorSize]
|
||||
// V: [NumBlocks, NumHeads, PageSize/kVectorSize, HeadDim, kVectorSize]
|
||||
// - LINEAR_LAYOUT:
|
||||
// K: [NumBlocks, PageSize, NumHeads, HeadDim]
|
||||
// V: [NumBlocks, PageSize, NumHeads, HeadDim]
|
||||
enum class BlockAttentionKVCacheMemoryLayoutEnum
|
||||
{
|
||||
VECTORIZED_LAYOUT = 0,
|
||||
LINEAR_LAYOUT = 1,
|
||||
};
|
||||
|
||||
// KV cache lookup table layout selector.
|
||||
// - VLLM_BLOCK_TABLE_2D: block_table[batch, max_blocks_per_seq]
|
||||
// - SGLANG_PAGE_TABLE_1D: kv_page_indices[kv_indptr[b] ... kv_indptr[b+1])
|
||||
enum class BlockAttentionKVCacheLookupTableEnum
|
||||
{
|
||||
VLLM_BLOCK_TABLE_2D = 0,
|
||||
SGLANG_PAGE_TABLE_1D = 1,
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
Reference in New Issue
Block a user