mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 06:49:15 +00:00
[FMHA] Batch Prefill Support Improvements: Change KV Cache Layout & Large Page Size Support (#3442)
* add page_block_size parameter * add is_sglang_layout to parameters * add kv_offset_array_transform to batch async for page size 16 * add kv_last_page_lens to kernel * change kv layout to [num_total_pages, page_block_size, hdim] * format * - enable codegen of batch_prefill kernels - create new problem struct BlockFmhaBatchPrefillPipelineProblem for batch prefill kernels - generate different page sizes of batch prefill kernels (1, 16) * 1. fix wrong calculation of page id in kv_offset_array_transform in gfx950 2. support page size 1024 * fix python format * change kv cache layout to [num_blocks, num_kv_heads, head_size/x, block_size, x] and [num_blocks, num_kv_heads, block_size/X, head_size, X] * 1. Introduced `kVectorSize` in BlockFmhaBatchPrefillPipelineProblem instead of using hardcode values 2. Makes batch prefill kernel traits structures inherent from fmha fwd traits 3. Add some static check for Page size, vector size, hdim, ..., etc. * [Refactor] Replace is_sglang_layout with Enums for KV cache configuration Refactored `fmha_batch_prefill` to use `BlockAttentionKVCacheMemoryLayoutEnum` (VECTORIZED/LINEAR) and `BlockAttentionKVCacheLookupTableEnum` (SGLANG_1D/VLLM_2D) instead of a single boolean. **Changes:** * Added Enum definitions in `block_attention_kvcache_layout_enum.hpp`. * Updated Kernel, Pipeline, and Traits to template on these Enums. * Implemented `kv_offset_array_transform` logic based on `kKVMemoryLayout`. * Refactored `PageBlockTableKargs` to adapt to `kKVLookupTable`. * Updated CodeGen scripts to support new parameters. This decouples memory layout from the paging mechanism, enabling flexible KV cache configurations. * 1. remove batch prefill pipeline with sk_pad=false 2. correct some comments 3. add static assert to make sure v offsets is in same page within a tile. * fix vgpr spill count * remove unnecessary t2s functions * add fp8 support for receipt 200 and 600 in fmha_bath_prefill.py * support linear kv cache layout * Remove block_table_ptr from fwd_batch_prefill_args. Instead, reuse kv_page_indices as a pointer of the lookup table. * 1. merge multiple transforms into single transform. 2. add static check to make sure vlayout is row-major. * move FmhaFwdCommonKargs::seqlen_k_ptr to VllmPageTableKargs. * update changelog --------- Co-authored-by: ltqin <letaoqin@amd.com> Co-authored-by: PoYen, Chen <PoYen.Chen@amd.com>
This commit is contained in:
@@ -36,6 +36,19 @@ DTYPE_BITS = {
|
||||
|
||||
K0_MAX_SUBMAX_MAP = {32: 32, 64: 64, 96: 128, 128: 128, 256: 256}
|
||||
|
||||
SUPPORTED_PAGE_SIZE = [128, 256, 1024]
|
||||
SUPPORTED_KV_MEMORY_LAYOUT = ["vectorized", "linear"]
|
||||
SUPPORTED_KV_LOOKUP_TABLE = ["vllm", "sglang"]
|
||||
KV_MEMORY_LAYOUT_ENUM_MAP = {
|
||||
"vectorized": "ck_tile::BlockAttentionKVCacheMemoryLayoutEnum::VECTORIZED_LAYOUT",
|
||||
"linear": "ck_tile::BlockAttentionKVCacheMemoryLayoutEnum::LINEAR_LAYOUT",
|
||||
}
|
||||
KV_LOOKUP_TABLE_ENUM_MAP = {
|
||||
"vllm": "ck_tile::BlockAttentionKVCacheLookupTableEnum::VLLM_BLOCK_TABLE_2D",
|
||||
"sglang": "ck_tile::BlockAttentionKVCacheLookupTableEnum::SGLANG_PAGE_TABLE_1D",
|
||||
}
|
||||
|
||||
|
||||
FMHA_BATCH_PREFILL_PIPELINE_MAP = {
|
||||
"qr_async": "ck_tile::BlockFmhaBatchPrefillPipelineQRKSVSAsync",
|
||||
}
|
||||
@@ -59,7 +72,7 @@ using fmha_shape_{F_idx} = ck_tile::TileFmhaShape<fmha_block_tile_{F_idx},
|
||||
ck_tile::sequence<{F_wm1}, {F_wn1}, {F_wk1}>,
|
||||
{F_vlayout}>;
|
||||
|
||||
using fmha_trait_{F_idx} = ck_tile::TileFmhaTraits<{F_spad},
|
||||
using fmha_trait_{F_idx} = ck_tile::TileFmhaBatchPrefillTraits<{F_spad},
|
||||
{F_skpad},
|
||||
{F_dpad},
|
||||
{F_dvpad},
|
||||
@@ -69,13 +82,17 @@ using fmha_trait_{F_idx} = ck_tile::TileFmhaTraits<{F_spad},
|
||||
{F_lse},
|
||||
{F_dropout},
|
||||
{F_qscale},
|
||||
{F_occupancy}>;
|
||||
{F_occupancy},
|
||||
false,
|
||||
{F_page_size},
|
||||
{F_kv_memory_layout},
|
||||
{F_kv_lookup_table}>;
|
||||
|
||||
using fmha_variant_{F_idx} = ck_tile::ComposedAttention<{F_logits} * ck_tile::LOGITS_SOFT_CAP, CK_TILE_FMHA_FWD_FAST_EXP2>;
|
||||
|
||||
using fmha_mask_{F_idx} = {F_mask};
|
||||
|
||||
using fmha_pipeline_problem_{F_idx} = ck_tile::BlockFmhaPipelineProblem<
|
||||
using fmha_pipeline_problem_{F_idx} = ck_tile::BlockFmhaBatchPrefillPipelineProblem<
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::QDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::KDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::VDataType,
|
||||
@@ -92,6 +109,7 @@ using fmha_pipeline_problem_{F_idx} = ck_tile::BlockFmhaPipelineProblem<
|
||||
fmha_variant_{F_idx},
|
||||
fmha_mask_{F_idx},
|
||||
false,
|
||||
{F_page_size},
|
||||
fmha_trait_{F_idx}>;
|
||||
|
||||
using fmha_pipeline_{F_idx} = {F_pipeline}<
|
||||
@@ -105,8 +123,8 @@ using fmha_epilogue_{F_idx} =
|
||||
using fmha_kernel_{F_idx} =
|
||||
ck_tile::FmhaBatchPrefillWithPagedKVCacheKernel<fmha_pipeline_{F_idx}, fmha_epilogue_{F_idx}>;
|
||||
|
||||
using trait_{F_idx} = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode},{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout},
|
||||
{F_pipeline_enum}, {F_logits}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_dropout}, {F_qscale}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, false>;
|
||||
using trait_{F_idx} = fmha_fwd_batch_prefill_traits_<{F_hdim}, {F_dtype}, {F_mode},{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout},
|
||||
{F_pipeline_enum}, {F_logits}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_dropout}, {F_qscale}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, false, false, {F_page_size}, {F_kv_memory_layout}, {F_kv_lookup_table}>;
|
||||
|
||||
#include <iostream>
|
||||
|
||||
@@ -184,8 +202,8 @@ FMHA_FWD_API_PER_HDIM_CASE = """ {F_if} (t.hdim_q <= {F_hdim} && t.hdim_v
|
||||
"""
|
||||
|
||||
FMHA_FWD_API_INNER_DISPATCH = """ {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && (t.has_logits_soft_cap == {F_logits}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.has_dropout == {F_dropout}) && (t.qscale_type == {F_qscale_check}) &&
|
||||
({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck}) && ({F_constraint})) {{
|
||||
using trait_ = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, {F_lse}, {F_dropout}, {F_qscale}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, false>;
|
||||
({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck}) && ({F_constraint}) && (t.kv_memory_layout == {F_kv_memory_layout}) && (t.kv_lookup_table == {F_kv_lookup_table}) && (t.page_size == {F_page_size})) {{
|
||||
using trait_ = fmha_fwd_batch_prefill_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, {F_lse}, {F_dropout}, {F_qscale}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, false, false, {F_page_size}, {F_kv_memory_layout}, {F_kv_lookup_table}>;
|
||||
return fmha_batch_prefill_<trait_>(s, a);
|
||||
}}
|
||||
"""
|
||||
@@ -230,12 +248,15 @@ class FmhaFwdApiTrait:
|
||||
dpad: str
|
||||
dvpad: str
|
||||
constraint: CppConstraint
|
||||
kv_memory_layout: str
|
||||
kv_lookup_table: str
|
||||
page_size: int = 1 # page block size
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return (
|
||||
f"{self.hdim}-{self.dtype}-{self.mode}-{self.bm0}-{self.bn0}-{self.bk0}-{self.bn0}-{self.bk1}-{self.bk0max}-"
|
||||
+ f"{self.vlayout}-{self.logits}-{self.mask}-{self.bias}-{self.lse}-{self.dropout}-{self.qscale}-{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}"
|
||||
+ f"{self.vlayout}-{self.logits}-{self.mask}-{self.bias}-{self.lse}-{self.dropout}-{self.qscale}-{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}-{self.kv_memory_layout}-{self.kv_lookup_table}-ps{self.page_size}"
|
||||
)
|
||||
|
||||
@property
|
||||
@@ -322,6 +343,8 @@ class FmhaFwdPipeline:
|
||||
F_dropout: str #
|
||||
F_qscale: str # no/pertensor
|
||||
F_mask: str # value from MASK_MAP
|
||||
F_kv_memory_layout: str #
|
||||
F_kv_lookup_table: str #
|
||||
F_constraint: CppConstraint = field(default_factory=lambda: CppConstraint())
|
||||
|
||||
@property
|
||||
@@ -382,6 +405,8 @@ class FmhaFwdPipeline:
|
||||
n += f"_{self.F_qscale}"
|
||||
else:
|
||||
n += "_nqscale"
|
||||
|
||||
n += "_" + self.F_kv_memory_layout + "_" + self.F_kv_lookup_table
|
||||
return n
|
||||
|
||||
|
||||
@@ -440,6 +465,13 @@ class FmhaFwdApiPool:
|
||||
F_bk0max=trait.bk0max,
|
||||
F_hdim=hdim,
|
||||
F_dtype=FWD_DTYPE_MAP[dtype],
|
||||
F_kv_memory_layout=KV_MEMORY_LAYOUT_ENUM_MAP[
|
||||
trait.kv_memory_layout
|
||||
],
|
||||
F_kv_lookup_table=KV_LOOKUP_TABLE_ENUM_MAP[
|
||||
trait.kv_lookup_table
|
||||
],
|
||||
F_page_size=trait.page_size,
|
||||
)
|
||||
if_j = "if" if j == 0 else "else if"
|
||||
per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format(
|
||||
@@ -497,6 +529,7 @@ class FmhaFwdKernel:
|
||||
F_tile: FmhaFwdTileSize
|
||||
F_pipeline: FmhaFwdPipeline
|
||||
mask_impl: str
|
||||
F_page_size: int = 1 # page block size
|
||||
|
||||
@property
|
||||
def template(self) -> str:
|
||||
@@ -534,17 +567,24 @@ class FmhaFwdKernel:
|
||||
F_dropout=BOOL_MAP[self.F_pipeline.F_dropout],
|
||||
F_qscale=QSCALE_MAP[self.F_pipeline.F_qscale],
|
||||
F_occupancy=self.F_tile.F_occupancy,
|
||||
F_kv_memory_layout=KV_MEMORY_LAYOUT_ENUM_MAP[
|
||||
self.F_pipeline.F_kv_memory_layout
|
||||
],
|
||||
F_kv_lookup_table=KV_LOOKUP_TABLE_ENUM_MAP[
|
||||
self.F_pipeline.F_kv_lookup_table
|
||||
],
|
||||
F_pipeline_enum=PIPELINE_ENUM_MAP[self.F_pipeline.tag],
|
||||
F_mask=get_mask_map(self.mask_impl)[self.F_pipeline.F_mask],
|
||||
F_mode=MODE_MAP[self.F_mode],
|
||||
F_pipeline=FMHA_BATCH_PREFILL_PIPELINE_MAP[self.F_pipeline.tag],
|
||||
F_page_size=self.F_page_size,
|
||||
)
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
# TODO: we don't encode idx here
|
||||
return (
|
||||
f"fmha_batch_prefill_d{self.F_hdim}_{self.F_dtype}_{self.F_mode}_"
|
||||
f"fmha_batch_prefill_d{self.F_hdim}_{self.F_dtype}_{self.F_mode}_ps{self.F_page_size}_"
|
||||
+ self.F_tile.name
|
||||
+ "_"
|
||||
+ self.F_pipeline.name
|
||||
@@ -578,6 +618,9 @@ class FmhaFwdKernel:
|
||||
dpad=self.F_pipeline.F_dpad,
|
||||
dvpad=self.F_pipeline.F_dvpad,
|
||||
constraint=self.F_tile.F_constraint & self.F_pipeline.F_constraint,
|
||||
kv_memory_layout=self.F_pipeline.F_kv_memory_layout,
|
||||
kv_lookup_table=self.F_pipeline.F_kv_lookup_table,
|
||||
page_size=self.F_page_size,
|
||||
)
|
||||
|
||||
|
||||
@@ -604,23 +647,42 @@ class KernelComponentFactory:
|
||||
pipelines = []
|
||||
if dtype in ["fp16", "bf16"]:
|
||||
qscale = "no"
|
||||
for logits, mask, bias, lse, dropout in itertools.product(
|
||||
for (
|
||||
logits,
|
||||
mask,
|
||||
bias,
|
||||
lse,
|
||||
dropout,
|
||||
kv_memory_layout,
|
||||
kv_lookup_table,
|
||||
) in itertools.product(
|
||||
["t", "f"],
|
||||
get_mask_map(mask_impl).keys(),
|
||||
BIAS_MAP.keys(),
|
||||
["t", "f"],
|
||||
["t", "f"],
|
||||
SUPPORTED_KV_MEMORY_LAYOUT,
|
||||
SUPPORTED_KV_LOOKUP_TABLE,
|
||||
):
|
||||
pipelines.append(FmhaFwdPipeline("qr_async", "row", "t", "t", "t", "t", logits, bias, lse, dropout, qscale, mask)) # fmt: skip
|
||||
pipelines.append(FmhaFwdPipeline("qr_async", "row", "t", "t", "t", "t", logits, bias, lse, dropout, qscale, mask, kv_memory_layout, kv_lookup_table)) # fmt: skip
|
||||
elif dtype in ["fp8bf16"]:
|
||||
# no need lse/dropout kernels
|
||||
for logits, qscale, mask, bias in itertools.product(
|
||||
for (
|
||||
logits,
|
||||
qscale,
|
||||
mask,
|
||||
bias,
|
||||
kv_memory_layout,
|
||||
kv_lookup_table,
|
||||
) in itertools.product(
|
||||
["t", "f"],
|
||||
["pertensor"],
|
||||
get_mask_map(mask_impl).keys(),
|
||||
["no"],
|
||||
SUPPORTED_KV_MEMORY_LAYOUT,
|
||||
SUPPORTED_KV_LOOKUP_TABLE,
|
||||
):
|
||||
pipelines.append(FmhaFwdPipeline("qr_async", "row", "t", "t", "t", "t", logits, bias, "f", "f", qscale, mask)) # fmt: skip
|
||||
pipelines.append(FmhaFwdPipeline("qr_async", "row", "t", "t", "t", "t", logits, bias, "f", "f", qscale, mask, kv_memory_layout, kv_lookup_table)) # fmt: skip
|
||||
else:
|
||||
assert False
|
||||
return pipelines
|
||||
@@ -672,69 +734,73 @@ def get_fwd_blobs(
|
||||
or pipeline.F_logits == "f"
|
||||
):
|
||||
continue
|
||||
k = FmhaFwdKernel(
|
||||
F_idx=0,
|
||||
F_hdim=hdim,
|
||||
F_dtype=dtype,
|
||||
F_mode=mode,
|
||||
F_tile=tile,
|
||||
F_pipeline=pipeline,
|
||||
mask_impl=mask_impl,
|
||||
)
|
||||
if kernel_filter != "":
|
||||
if not fnmatch.fnmatch(k.name, kernel_filter):
|
||||
continue
|
||||
if optdim_list != [-1]:
|
||||
if hdim not in optdim_list:
|
||||
continue
|
||||
# 2 - Flash attention integration
|
||||
if receipt in (2, 3):
|
||||
cond = dtype in ["fp16", "bf16"]
|
||||
cond &= pipeline.F_vlayout == "row"
|
||||
cond &= pipeline.F_bias in ["no", "alibi"]
|
||||
cond &= pipeline.F_qscale == "no"
|
||||
if not cond:
|
||||
continue
|
||||
# PyTorch integration
|
||||
elif receipt == 4:
|
||||
cond = dtype in ["fp16", "bf16"]
|
||||
cond &= pipeline.F_vlayout == "row"
|
||||
cond &= pipeline.F_bias in ["no", "bias"]
|
||||
cond &= pipeline.F_qscale == "no"
|
||||
if not cond:
|
||||
continue
|
||||
# Aiter(mha_fwd) integration
|
||||
elif receipt == 100:
|
||||
cond = dtype in ["fp16", "bf16"]
|
||||
cond &= mode == "batch"
|
||||
cond &= pipeline.F_vlayout == "row"
|
||||
cond &= pipeline.F_qscale == "no"
|
||||
if not cond:
|
||||
continue
|
||||
# Aiter(mha_batch_prefill) integration
|
||||
elif receipt == 200:
|
||||
cond = dtype in ["fp16", "bf16", "fp8bf16"]
|
||||
cond &= mode == "group"
|
||||
cond &= pipeline.F_vlayout == "row"
|
||||
if not cond:
|
||||
continue
|
||||
# aiter::mha_batch_prefill C++ api integration
|
||||
elif receipt == 600:
|
||||
cond = dtype in ["fp16", "bf16", "fp8bf16"]
|
||||
cond &= mode == "group"
|
||||
cond &= pipeline.F_vlayout == "row"
|
||||
cond &= pipeline.F_qscale == "no"
|
||||
if not cond:
|
||||
continue
|
||||
|
||||
# fp32 only
|
||||
if receipt == 800 or receipt == 801:
|
||||
cond = dtype == "fp32"
|
||||
if not cond:
|
||||
continue
|
||||
# Generate kernels for both page_size=16 and page_size=1024
|
||||
for page_size in SUPPORTED_PAGE_SIZE:
|
||||
k = FmhaFwdKernel(
|
||||
F_idx=0,
|
||||
F_hdim=hdim,
|
||||
F_dtype=dtype,
|
||||
F_mode=mode,
|
||||
F_tile=tile,
|
||||
F_pipeline=pipeline,
|
||||
mask_impl=mask_impl,
|
||||
F_page_size=page_size,
|
||||
)
|
||||
if kernel_filter != "":
|
||||
if not fnmatch.fnmatch(k.name, kernel_filter):
|
||||
continue
|
||||
if optdim_list != [-1]:
|
||||
if hdim not in optdim_list:
|
||||
continue
|
||||
# 2 - Flash attention integration
|
||||
if receipt in (2, 3):
|
||||
cond = dtype in ["fp16", "bf16"]
|
||||
cond &= pipeline.F_vlayout == "row"
|
||||
cond &= pipeline.F_bias in ["no", "alibi"]
|
||||
cond &= pipeline.F_qscale == "no"
|
||||
if not cond:
|
||||
continue
|
||||
# PyTorch integration
|
||||
elif receipt == 4:
|
||||
cond = dtype in ["fp16", "bf16"]
|
||||
cond &= pipeline.F_vlayout == "row"
|
||||
cond &= pipeline.F_bias in ["no", "bias"]
|
||||
cond &= pipeline.F_qscale == "no"
|
||||
if not cond:
|
||||
continue
|
||||
# Aiter(mha_fwd) integration
|
||||
elif receipt == 100:
|
||||
cond = dtype in ["fp16", "bf16"]
|
||||
cond &= mode == "batch"
|
||||
cond &= pipeline.F_vlayout == "row"
|
||||
cond &= pipeline.F_qscale == "no"
|
||||
if not cond:
|
||||
continue
|
||||
# Aiter(mha_batch_prefill) integration
|
||||
elif receipt == 200:
|
||||
cond = dtype in ["fp16", "bf16", "fp8bf16"]
|
||||
cond &= mode == "group"
|
||||
cond &= pipeline.F_vlayout == "row"
|
||||
if not cond:
|
||||
continue
|
||||
# aiter::mha_batch_prefill C++ api integration
|
||||
elif receipt == 600:
|
||||
cond = dtype in ["fp16", "bf16", "fp8bf16"]
|
||||
cond &= mode == "group"
|
||||
cond &= pipeline.F_vlayout == "row"
|
||||
cond &= pipeline.F_qscale == "no"
|
||||
if not cond:
|
||||
continue
|
||||
|
||||
api_pool.register_traits(k.api_trait())
|
||||
gen.append(k)
|
||||
# fp32 only
|
||||
if receipt == 800 or receipt == 801:
|
||||
cond = dtype == "fp32"
|
||||
if not cond:
|
||||
continue
|
||||
|
||||
api_pool.register_traits(k.api_trait())
|
||||
gen.append(k)
|
||||
|
||||
return (api_pool, gen)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user