[FMHA] Batch Prefill Support Improvements: Change KV Cache Layout & Large Page Size Support (#3442)

* add page_block_size parameter

* add is_sglang_layout to  parameters

* add kv_offset_array_transform to batch async for page size 16

* add kv_last_page_lens to kernel

* change kv layout to [num_total_pages, page_block_size, hdim]

* format

* - enable codegen of batch_prefill kernels
- create new problem struct BlockFmhaBatchPrefillPipelineProblem for
  batch prefill kernels
- generate different page sizes of batch prefill kernels (1, 16)

* 1. fix wrong calculation of page id in kv_offset_array_transform in gfx950
2. support page size 1024

* fix python format

* change kv cache layout to [num_blocks, num_kv_heads, head_size/x,
block_size, x] and [num_blocks, num_kv_heads, block_size/X, head_size, X]

* 1. Introduced `kVectorSize` in BlockFmhaBatchPrefillPipelineProblem instead of using hardcode values
2. Makes batch prefill kernel traits structures inherent from fmha fwd
   traits
3. Add some static check for Page size, vector size, hdim, ..., etc.

* [Refactor] Replace is_sglang_layout with Enums for KV cache configuration

Refactored `fmha_batch_prefill` to use `BlockAttentionKVCacheMemoryLayoutEnum` (VECTORIZED/LINEAR) and `BlockAttentionKVCacheLookupTableEnum` (SGLANG_1D/VLLM_2D) instead of a single
boolean.

**Changes:**
*   Added Enum definitions in `block_attention_kvcache_layout_enum.hpp`.
*   Updated Kernel, Pipeline, and Traits to template on these Enums.
*   Implemented `kv_offset_array_transform` logic based on `kKVMemoryLayout`.
*   Refactored `PageBlockTableKargs` to adapt to `kKVLookupTable`.
*   Updated CodeGen scripts to support new parameters.

This decouples memory layout from the paging mechanism, enabling flexible KV cache configurations.

* 1. remove batch prefill pipeline with sk_pad=false
2. correct some comments
3. add static assert to make sure v offsets is in same page within a tile.

* fix vgpr spill count

* remove unnecessary t2s functions

* add fp8 support for receipt 200 and 600 in fmha_bath_prefill.py

* support linear kv cache layout

* Remove block_table_ptr from fwd_batch_prefill_args. Instead, reuse
kv_page_indices as a pointer of the lookup table.

* 1. merge multiple transforms into single transform.
2. add static check to make sure vlayout is row-major.

* move FmhaFwdCommonKargs::seqlen_k_ptr to VllmPageTableKargs.

* update changelog

---------

Co-authored-by: ltqin <letaoqin@amd.com>
Co-authored-by: PoYen, Chen <PoYen.Chen@amd.com>
This commit is contained in:
Jeff Huang
2026-01-05 18:41:47 +08:00
committed by GitHub
parent e339101e9c
commit cc75a1dc5f
9 changed files with 983 additions and 430 deletions

View File

@@ -11,6 +11,7 @@ Documentation for Composable Kernel available at [https://rocm.docs.amd.com/proj
* Added support for microscaling (MX) FP8/FP4 mixed data types to Flatmm pipeline.
* Added support for fp8 dynamic tensor-wise quantization of fp8 fmha fwd kernel.
* Added FP8 KV cache support for FMHA batch prefill.
* Added FMHA batch prefill kernel support for several KV cache layouts, flexible page sizes, and different lookup table configurations.
### Changed

View File

@@ -36,6 +36,19 @@ DTYPE_BITS = {
K0_MAX_SUBMAX_MAP = {32: 32, 64: 64, 96: 128, 128: 128, 256: 256}
SUPPORTED_PAGE_SIZE = [128, 256, 1024]
SUPPORTED_KV_MEMORY_LAYOUT = ["vectorized", "linear"]
SUPPORTED_KV_LOOKUP_TABLE = ["vllm", "sglang"]
KV_MEMORY_LAYOUT_ENUM_MAP = {
"vectorized": "ck_tile::BlockAttentionKVCacheMemoryLayoutEnum::VECTORIZED_LAYOUT",
"linear": "ck_tile::BlockAttentionKVCacheMemoryLayoutEnum::LINEAR_LAYOUT",
}
KV_LOOKUP_TABLE_ENUM_MAP = {
"vllm": "ck_tile::BlockAttentionKVCacheLookupTableEnum::VLLM_BLOCK_TABLE_2D",
"sglang": "ck_tile::BlockAttentionKVCacheLookupTableEnum::SGLANG_PAGE_TABLE_1D",
}
FMHA_BATCH_PREFILL_PIPELINE_MAP = {
"qr_async": "ck_tile::BlockFmhaBatchPrefillPipelineQRKSVSAsync",
}
@@ -59,7 +72,7 @@ using fmha_shape_{F_idx} = ck_tile::TileFmhaShape<fmha_block_tile_{F_idx},
ck_tile::sequence<{F_wm1}, {F_wn1}, {F_wk1}>,
{F_vlayout}>;
using fmha_trait_{F_idx} = ck_tile::TileFmhaTraits<{F_spad},
using fmha_trait_{F_idx} = ck_tile::TileFmhaBatchPrefillTraits<{F_spad},
{F_skpad},
{F_dpad},
{F_dvpad},
@@ -69,13 +82,17 @@ using fmha_trait_{F_idx} = ck_tile::TileFmhaTraits<{F_spad},
{F_lse},
{F_dropout},
{F_qscale},
{F_occupancy}>;
{F_occupancy},
false,
{F_page_size},
{F_kv_memory_layout},
{F_kv_lookup_table}>;
using fmha_variant_{F_idx} = ck_tile::ComposedAttention<{F_logits} * ck_tile::LOGITS_SOFT_CAP, CK_TILE_FMHA_FWD_FAST_EXP2>;
using fmha_mask_{F_idx} = {F_mask};
using fmha_pipeline_problem_{F_idx} = ck_tile::BlockFmhaPipelineProblem<
using fmha_pipeline_problem_{F_idx} = ck_tile::BlockFmhaBatchPrefillPipelineProblem<
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::QDataType,
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::KDataType,
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::VDataType,
@@ -92,6 +109,7 @@ using fmha_pipeline_problem_{F_idx} = ck_tile::BlockFmhaPipelineProblem<
fmha_variant_{F_idx},
fmha_mask_{F_idx},
false,
{F_page_size},
fmha_trait_{F_idx}>;
using fmha_pipeline_{F_idx} = {F_pipeline}<
@@ -105,8 +123,8 @@ using fmha_epilogue_{F_idx} =
using fmha_kernel_{F_idx} =
ck_tile::FmhaBatchPrefillWithPagedKVCacheKernel<fmha_pipeline_{F_idx}, fmha_epilogue_{F_idx}>;
using trait_{F_idx} = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode},{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout},
{F_pipeline_enum}, {F_logits}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_dropout}, {F_qscale}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, false>;
using trait_{F_idx} = fmha_fwd_batch_prefill_traits_<{F_hdim}, {F_dtype}, {F_mode},{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout},
{F_pipeline_enum}, {F_logits}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_dropout}, {F_qscale}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, false, false, {F_page_size}, {F_kv_memory_layout}, {F_kv_lookup_table}>;
#include <iostream>
@@ -184,8 +202,8 @@ FMHA_FWD_API_PER_HDIM_CASE = """ {F_if} (t.hdim_q <= {F_hdim} && t.hdim_v
"""
FMHA_FWD_API_INNER_DISPATCH = """ {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && (t.has_logits_soft_cap == {F_logits}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.has_dropout == {F_dropout}) && (t.qscale_type == {F_qscale_check}) &&
({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck}) && ({F_constraint})) {{
using trait_ = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, {F_lse}, {F_dropout}, {F_qscale}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, false>;
({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck}) && ({F_constraint}) && (t.kv_memory_layout == {F_kv_memory_layout}) && (t.kv_lookup_table == {F_kv_lookup_table}) && (t.page_size == {F_page_size})) {{
using trait_ = fmha_fwd_batch_prefill_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, {F_lse}, {F_dropout}, {F_qscale}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, false, false, {F_page_size}, {F_kv_memory_layout}, {F_kv_lookup_table}>;
return fmha_batch_prefill_<trait_>(s, a);
}}
"""
@@ -230,12 +248,15 @@ class FmhaFwdApiTrait:
dpad: str
dvpad: str
constraint: CppConstraint
kv_memory_layout: str
kv_lookup_table: str
page_size: int = 1 # page block size
@property
def name(self) -> str:
return (
f"{self.hdim}-{self.dtype}-{self.mode}-{self.bm0}-{self.bn0}-{self.bk0}-{self.bn0}-{self.bk1}-{self.bk0max}-"
+ f"{self.vlayout}-{self.logits}-{self.mask}-{self.bias}-{self.lse}-{self.dropout}-{self.qscale}-{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}"
+ f"{self.vlayout}-{self.logits}-{self.mask}-{self.bias}-{self.lse}-{self.dropout}-{self.qscale}-{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}-{self.kv_memory_layout}-{self.kv_lookup_table}-ps{self.page_size}"
)
@property
@@ -322,6 +343,8 @@ class FmhaFwdPipeline:
F_dropout: str #
F_qscale: str # no/pertensor
F_mask: str # value from MASK_MAP
F_kv_memory_layout: str #
F_kv_lookup_table: str #
F_constraint: CppConstraint = field(default_factory=lambda: CppConstraint())
@property
@@ -382,6 +405,8 @@ class FmhaFwdPipeline:
n += f"_{self.F_qscale}"
else:
n += "_nqscale"
n += "_" + self.F_kv_memory_layout + "_" + self.F_kv_lookup_table
return n
@@ -440,6 +465,13 @@ class FmhaFwdApiPool:
F_bk0max=trait.bk0max,
F_hdim=hdim,
F_dtype=FWD_DTYPE_MAP[dtype],
F_kv_memory_layout=KV_MEMORY_LAYOUT_ENUM_MAP[
trait.kv_memory_layout
],
F_kv_lookup_table=KV_LOOKUP_TABLE_ENUM_MAP[
trait.kv_lookup_table
],
F_page_size=trait.page_size,
)
if_j = "if" if j == 0 else "else if"
per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format(
@@ -497,6 +529,7 @@ class FmhaFwdKernel:
F_tile: FmhaFwdTileSize
F_pipeline: FmhaFwdPipeline
mask_impl: str
F_page_size: int = 1 # page block size
@property
def template(self) -> str:
@@ -534,17 +567,24 @@ class FmhaFwdKernel:
F_dropout=BOOL_MAP[self.F_pipeline.F_dropout],
F_qscale=QSCALE_MAP[self.F_pipeline.F_qscale],
F_occupancy=self.F_tile.F_occupancy,
F_kv_memory_layout=KV_MEMORY_LAYOUT_ENUM_MAP[
self.F_pipeline.F_kv_memory_layout
],
F_kv_lookup_table=KV_LOOKUP_TABLE_ENUM_MAP[
self.F_pipeline.F_kv_lookup_table
],
F_pipeline_enum=PIPELINE_ENUM_MAP[self.F_pipeline.tag],
F_mask=get_mask_map(self.mask_impl)[self.F_pipeline.F_mask],
F_mode=MODE_MAP[self.F_mode],
F_pipeline=FMHA_BATCH_PREFILL_PIPELINE_MAP[self.F_pipeline.tag],
F_page_size=self.F_page_size,
)
@property
def name(self) -> str:
# TODO: we don't encode idx here
return (
f"fmha_batch_prefill_d{self.F_hdim}_{self.F_dtype}_{self.F_mode}_"
f"fmha_batch_prefill_d{self.F_hdim}_{self.F_dtype}_{self.F_mode}_ps{self.F_page_size}_"
+ self.F_tile.name
+ "_"
+ self.F_pipeline.name
@@ -578,6 +618,9 @@ class FmhaFwdKernel:
dpad=self.F_pipeline.F_dpad,
dvpad=self.F_pipeline.F_dvpad,
constraint=self.F_tile.F_constraint & self.F_pipeline.F_constraint,
kv_memory_layout=self.F_pipeline.F_kv_memory_layout,
kv_lookup_table=self.F_pipeline.F_kv_lookup_table,
page_size=self.F_page_size,
)
@@ -604,23 +647,42 @@ class KernelComponentFactory:
pipelines = []
if dtype in ["fp16", "bf16"]:
qscale = "no"
for logits, mask, bias, lse, dropout in itertools.product(
for (
logits,
mask,
bias,
lse,
dropout,
kv_memory_layout,
kv_lookup_table,
) in itertools.product(
["t", "f"],
get_mask_map(mask_impl).keys(),
BIAS_MAP.keys(),
["t", "f"],
["t", "f"],
SUPPORTED_KV_MEMORY_LAYOUT,
SUPPORTED_KV_LOOKUP_TABLE,
):
pipelines.append(FmhaFwdPipeline("qr_async", "row", "t", "t", "t", "t", logits, bias, lse, dropout, qscale, mask)) # fmt: skip
pipelines.append(FmhaFwdPipeline("qr_async", "row", "t", "t", "t", "t", logits, bias, lse, dropout, qscale, mask, kv_memory_layout, kv_lookup_table)) # fmt: skip
elif dtype in ["fp8bf16"]:
# no need lse/dropout kernels
for logits, qscale, mask, bias in itertools.product(
for (
logits,
qscale,
mask,
bias,
kv_memory_layout,
kv_lookup_table,
) in itertools.product(
["t", "f"],
["pertensor"],
get_mask_map(mask_impl).keys(),
["no"],
SUPPORTED_KV_MEMORY_LAYOUT,
SUPPORTED_KV_LOOKUP_TABLE,
):
pipelines.append(FmhaFwdPipeline("qr_async", "row", "t", "t", "t", "t", logits, bias, "f", "f", qscale, mask)) # fmt: skip
pipelines.append(FmhaFwdPipeline("qr_async", "row", "t", "t", "t", "t", logits, bias, "f", "f", qscale, mask, kv_memory_layout, kv_lookup_table)) # fmt: skip
else:
assert False
return pipelines
@@ -672,69 +734,73 @@ def get_fwd_blobs(
or pipeline.F_logits == "f"
):
continue
k = FmhaFwdKernel(
F_idx=0,
F_hdim=hdim,
F_dtype=dtype,
F_mode=mode,
F_tile=tile,
F_pipeline=pipeline,
mask_impl=mask_impl,
)
if kernel_filter != "":
if not fnmatch.fnmatch(k.name, kernel_filter):
continue
if optdim_list != [-1]:
if hdim not in optdim_list:
continue
# 2 - Flash attention integration
if receipt in (2, 3):
cond = dtype in ["fp16", "bf16"]
cond &= pipeline.F_vlayout == "row"
cond &= pipeline.F_bias in ["no", "alibi"]
cond &= pipeline.F_qscale == "no"
if not cond:
continue
# PyTorch integration
elif receipt == 4:
cond = dtype in ["fp16", "bf16"]
cond &= pipeline.F_vlayout == "row"
cond &= pipeline.F_bias in ["no", "bias"]
cond &= pipeline.F_qscale == "no"
if not cond:
continue
# Aiter(mha_fwd) integration
elif receipt == 100:
cond = dtype in ["fp16", "bf16"]
cond &= mode == "batch"
cond &= pipeline.F_vlayout == "row"
cond &= pipeline.F_qscale == "no"
if not cond:
continue
# Aiter(mha_batch_prefill) integration
elif receipt == 200:
cond = dtype in ["fp16", "bf16", "fp8bf16"]
cond &= mode == "group"
cond &= pipeline.F_vlayout == "row"
if not cond:
continue
# aiter::mha_batch_prefill C++ api integration
elif receipt == 600:
cond = dtype in ["fp16", "bf16", "fp8bf16"]
cond &= mode == "group"
cond &= pipeline.F_vlayout == "row"
cond &= pipeline.F_qscale == "no"
if not cond:
continue
# fp32 only
if receipt == 800 or receipt == 801:
cond = dtype == "fp32"
if not cond:
continue
# Generate kernels for both page_size=16 and page_size=1024
for page_size in SUPPORTED_PAGE_SIZE:
k = FmhaFwdKernel(
F_idx=0,
F_hdim=hdim,
F_dtype=dtype,
F_mode=mode,
F_tile=tile,
F_pipeline=pipeline,
mask_impl=mask_impl,
F_page_size=page_size,
)
if kernel_filter != "":
if not fnmatch.fnmatch(k.name, kernel_filter):
continue
if optdim_list != [-1]:
if hdim not in optdim_list:
continue
# 2 - Flash attention integration
if receipt in (2, 3):
cond = dtype in ["fp16", "bf16"]
cond &= pipeline.F_vlayout == "row"
cond &= pipeline.F_bias in ["no", "alibi"]
cond &= pipeline.F_qscale == "no"
if not cond:
continue
# PyTorch integration
elif receipt == 4:
cond = dtype in ["fp16", "bf16"]
cond &= pipeline.F_vlayout == "row"
cond &= pipeline.F_bias in ["no", "bias"]
cond &= pipeline.F_qscale == "no"
if not cond:
continue
# Aiter(mha_fwd) integration
elif receipt == 100:
cond = dtype in ["fp16", "bf16"]
cond &= mode == "batch"
cond &= pipeline.F_vlayout == "row"
cond &= pipeline.F_qscale == "no"
if not cond:
continue
# Aiter(mha_batch_prefill) integration
elif receipt == 200:
cond = dtype in ["fp16", "bf16", "fp8bf16"]
cond &= mode == "group"
cond &= pipeline.F_vlayout == "row"
if not cond:
continue
# aiter::mha_batch_prefill C++ api integration
elif receipt == 600:
cond = dtype in ["fp16", "bf16", "fp8bf16"]
cond &= mode == "group"
cond &= pipeline.F_vlayout == "row"
cond &= pipeline.F_qscale == "no"
if not cond:
continue
api_pool.register_traits(k.api_trait())
gen.append(k)
# fp32 only
if receipt == 800 or receipt == 801:
cond = dtype == "fp32"
if not cond:
continue
api_pool.register_traits(k.api_trait())
gen.append(k)
return (api_pool, gen)

View File

@@ -529,14 +529,25 @@ struct fmha_batch_prefill_args
ck_tile::index_t nhead_q;
ck_tile::index_t nhead_k;
// SGLang-style page table
int32_t num_total_pages;
void* kv_indptr;
void* kv_page_indices;
#if 0 // we assume page_block_size=1 for now
void* kv_last_page_lens;
ck_tile::index_t page_block_size;
#endif
// KV cache page table fields (kv_lookup_table selects interpretation):
// - SGLANG_PAGE_TABLE_1D:
// kv_indptr: prefix-sum [batch+1] into kv_page_indices
// kv_page_indices: 1D list of physical page ids, length = num_total_pages
// kv_last_page_lens: per-batch last page lengths [batch]
// - VLLM_BLOCK_TABLE_2D:
// kv_page_indices: block_table [batch, max_blocks_per_seq] (2D)
// batch_stride_block_table: row stride for block_table
// seqlen_k_ptr: per-batch seqlen_k [batch]
int32_t num_total_pages; // total physical pages in KV cache (SGLang/vLLM)
ck_tile::index_t page_block_size; // tokens per page (SGLang/vLLM)
ck_tile::BlockAttentionKVCacheMemoryLayoutEnum
kv_memory_layout; // KV memory layout (SGLang/vLLM)
ck_tile::BlockAttentionKVCacheLookupTableEnum kv_lookup_table; // lookup table layout selector
void* kv_indptr; // SGLang: prefix-sum; vLLM: unused
void* kv_page_indices; // SGLang: 1D page list; vLLM: block_table 2D
void* kv_last_page_lens; // SGLang: last page lengths; vLLM: unused
void* seqlen_k_ptr; // vLLM: per-batch seqlen_k; SGLang: unused
ck_tile::index_t batch_stride_block_table; // vLLM: row stride; SGLang: unused
float scale_s;
float scale_p;
@@ -1113,6 +1124,22 @@ template <typename FmhaKernel>
auto fmha_batch_prefill_create_kargs_and_grids(fmha_batch_prefill_args args)
{
assert(args.nhead_q % args.nhead_k == 0);
using PageTableKargs = typename FmhaKernel::PageBlockTableKargs;
const PageTableKargs page_table = [&]() {
if constexpr(FmhaKernel::kKVLookupTable ==
ck_tile::BlockAttentionKVCacheLookupTableEnum::SGLANG_PAGE_TABLE_1D)
{
return PageTableKargs{reinterpret_cast<const int32_t*>(args.kv_indptr),
reinterpret_cast<const int32_t*>(args.kv_page_indices),
reinterpret_cast<const int32_t*>(args.kv_last_page_lens)};
}
else
{
return PageTableKargs{reinterpret_cast<const int32_t*>(args.kv_page_indices),
args.batch_stride_block_table,
reinterpret_cast<const int32_t*>(args.seqlen_k_ptr)};
}
}();
auto kargs = [&] {
// create group mode kernel arguments
if constexpr(FmhaKernel::kIsGroupMode)
@@ -1133,12 +1160,8 @@ auto fmha_batch_prefill_create_kargs_and_grids(fmha_batch_prefill_args args)
args.nhead_q,
args.nhead_q / args.nhead_k,
args.num_total_pages,
args.kv_indptr,
args.kv_page_indices,
#if 0 // we assume page_block_size=1 for now
args.kv_last_page_lens,
args.page_block_size,
#endif
page_table,
args.scale_s,
args.scale_p,
args.scale_o,
@@ -1184,12 +1207,8 @@ auto fmha_batch_prefill_create_kargs_and_grids(fmha_batch_prefill_args args)
args.nhead_q,
args.nhead_q / args.nhead_k,
args.num_total_pages,
args.kv_indptr,
args.kv_page_indices,
#if 0 // we assume page_block_size=1 for now
args.kv_last_page_lens,
args.page_block_size,
#endif
page_table,
args.scale_s,
args.scale_p,
args.scale_o,
@@ -1281,6 +1300,65 @@ struct fmha_fwd_traits_
static constexpr bool kHasSink = kHasSink_;
};
template <ck_tile::index_t HDim_,
typename DataType_,
bool kIsGroupMode_,
ck_tile::index_t kM0_,
ck_tile::index_t kN0_,
ck_tile::index_t kK0_,
ck_tile::index_t kN1_,
ck_tile::index_t kK1_,
ck_tile::index_t kK0BlockLength_,
bool kIsVLayoutRowMajor_,
ck_tile::BlockFmhaPipelineEnum FmhaPipelineEnum_,
bool kHasLogitsSoftCap_,
typename FmhaMask_,
ck_tile::BlockAttentionBiasEnum BiasEnum_,
bool kStoreLse_,
bool kHasDropout_,
ck_tile::BlockAttentionQuantScaleEnum QScaleEnum_,
bool kPadS_,
bool kPadSK_,
bool kPadD_,
bool kPadDv_,
bool kUseTrLoad_,
bool kSkipMinSeqlenQ_ = false,
ck_tile::index_t kPageBlockSize_ = 1,
ck_tile::BlockAttentionKVCacheMemoryLayoutEnum kKVMemoryLayout_ =
ck_tile::BlockAttentionKVCacheMemoryLayoutEnum::VECTORIZED_LAYOUT,
ck_tile::BlockAttentionKVCacheLookupTableEnum kKVLookupTable_ =
ck_tile::BlockAttentionKVCacheLookupTableEnum::SGLANG_PAGE_TABLE_1D>
struct fmha_fwd_batch_prefill_traits_ : public fmha_fwd_traits_<HDim_,
DataType_,
kIsGroupMode_,
kM0_,
kN0_,
kK0_,
kN1_,
kK1_,
kK0BlockLength_,
kIsVLayoutRowMajor_,
FmhaPipelineEnum_,
kHasLogitsSoftCap_,
FmhaMask_,
BiasEnum_,
kStoreLse_,
kHasDropout_,
QScaleEnum_,
kPadS_,
kPadSK_,
kPadD_,
kPadDv_,
kUseTrLoad_,
kSkipMinSeqlenQ_,
false>
{
static constexpr auto kKVMemoryLayout = kKVMemoryLayout_;
static constexpr auto kKVLookupTable = kKVLookupTable_;
static constexpr ck_tile::index_t kPageBlockSize = kPageBlockSize_;
static_assert(kIsVLayoutRowMajor_, "Batch prefill only supports row-major V layout");
};
template <typename Traits_, typename Arch = void>
float fmha_fwd_(const ck_tile::stream_config&, fmha_fwd_args);
@@ -1527,7 +1605,15 @@ float fmha_fwd_appendkv(fmha_fwd_appendkv_traits,
fmha_fwd_appendkv_args,
const ck_tile::stream_config&);
using fmha_batch_prefill_traits = fmha_fwd_traits;
struct fmha_batch_prefill_traits : public fmha_fwd_traits
{
ck_tile::BlockAttentionKVCacheMemoryLayoutEnum kv_memory_layout =
ck_tile::BlockAttentionKVCacheMemoryLayoutEnum::VECTORIZED_LAYOUT;
ck_tile::BlockAttentionKVCacheLookupTableEnum kv_lookup_table =
ck_tile::BlockAttentionKVCacheLookupTableEnum::SGLANG_PAGE_TABLE_1D;
int page_size = 1;
};
float fmha_batch_prefill(fmha_batch_prefill_traits,
fmha_batch_prefill_args,
const ck_tile::stream_config&);

View File

@@ -3,6 +3,7 @@
#pragma once
#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp"
#include "ck_tile/ops/fmha/block/block_attention_kvcache_layout_enum.hpp"
#include "ck_tile/ops/fmha/block/block_attention_quant_scale_enum.hpp"
#include "ck_tile/ops/fmha/block/block_dropout.hpp"
#include "ck_tile/ops/fmha/block/block_masking.hpp"

View File

@@ -0,0 +1,32 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
namespace ck_tile {
// KV cache memory layout selector.
//
// Layout summary (kVectorSize = 16 / sizeof(KDataType)):
// - VECTORIZED_LAYOUT (swizzled):
// K: [NumBlocks, NumHeads, HeadDim/kVectorSize, PageSize, kVectorSize]
// V: [NumBlocks, NumHeads, PageSize/kVectorSize, HeadDim, kVectorSize]
// - LINEAR_LAYOUT:
// K: [NumBlocks, PageSize, NumHeads, HeadDim]
// V: [NumBlocks, PageSize, NumHeads, HeadDim]
enum class BlockAttentionKVCacheMemoryLayoutEnum
{
VECTORIZED_LAYOUT = 0,
LINEAR_LAYOUT = 1,
};
// KV cache lookup table layout selector.
// - VLLM_BLOCK_TABLE_2D: block_table[batch, max_blocks_per_seq]
// - SGLANG_PAGE_TABLE_1D: kv_page_indices[kv_indptr[b] ... kv_indptr[b+1])
enum class BlockAttentionKVCacheLookupTableEnum
{
VLLM_BLOCK_TABLE_2D = 0,
SGLANG_PAGE_TABLE_1D = 1,
};
} // namespace ck_tile

View File

@@ -6,6 +6,7 @@
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common.hpp"
#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp"
#include "ck_tile/ops/fmha/block/block_attention_kvcache_layout_enum.hpp"
#include "ck_tile/ops/fmha/block/block_attention_quant_scale_enum.hpp"
#include "ck_tile/ops/fmha/block/variants.hpp"
@@ -56,12 +57,15 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
static constexpr bool kStoreLSE = FmhaPipeline::kStoreLSE;
static constexpr bool kHasDropout = FmhaPipeline::kHasDropout;
static constexpr auto QScaleEnum = FmhaPipeline::Problem::QScaleEnum;
static constexpr auto kKVMemoryLayout = FmhaPipeline::Problem::kKVMemoryLayout;
static constexpr auto kKVLookupTable = FmhaPipeline::Problem::kKVLookupTable;
static constexpr index_t kPageBlockSize = FmhaPipeline::kPageBlockSize;
static constexpr index_t kVectorSize = FmhaPipeline::kVectorSize;
using AttentionVariant = ck_tile::remove_cvref_t<typename FmhaPipeline::AttentionVariant>;
using FmhaMask = ck_tile::remove_cvref_t<typename FmhaPipeline::FmhaMask>;
static constexpr bool kHasMask = FmhaMask::IsMasking;
static constexpr bool kUseAsyncCopy = FmhaPipeline::Policy::AsyncCopy;
template <ck_tile::index_t I> // to avoid duplicated base class prblem, introduce an template
// arg
struct FmhaFwdEmptyKargs
@@ -71,6 +75,26 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
// kargs use aggregate initializer, so no constructor will provided
// use inheritance to minimize karg size
// user need to use MakeKargs() function to create kargs.
struct SglangPageTableKargs
{
const int32_t* kv_indptr;
const int32_t* kv_page_indices;
const int32_t* kv_last_page_lens;
};
struct VllmPageTableKargs
{
const int32_t* block_table_ptr;
ck_tile::index_t batch_stride_block_table;
const int32_t* seqlen_k_ptr;
};
using PageBlockTableKargs =
std::conditional_t<kKVLookupTable ==
BlockAttentionKVCacheLookupTableEnum::SGLANG_PAGE_TABLE_1D,
SglangPageTableKargs,
VllmPageTableKargs>;
struct FmhaFwdCommonKargs
{
const void* q_ptr;
@@ -89,14 +113,8 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
ck_tile::index_t nhead_ratio_qk;
int32_t num_total_pages;
const int32_t* kv_indptr;
const int32_t* kv_page_indices;
#if 0 // we assume page_block_size=1 for now
const int32_t* kv_last_page_lens;
ck_tile::index_t page_block_size;
#else
static constexpr ck_tile::index_t page_block_size = 1;
#endif
PageBlockTableKargs page_table;
float scale_s;
@@ -295,12 +313,8 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
ck_tile::index_t num_head_q,
ck_tile::index_t nhead_ratio_qk,
int32_t num_total_pages,
const void* kv_indptr,
const void* kv_page_indices,
#if 0 // we assume page_block_size=1 for now
const void* kv_last_page_lens,
ck_tile::index_t page_block_size,
#endif
const PageBlockTableKargs& page_table,
float scale_s,
[[maybe_unused]] float scale_p,
[[maybe_unused]] float scale_o,
@@ -345,12 +359,8 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
num_head_q,
nhead_ratio_qk,
num_total_pages,
reinterpret_cast<const int32_t*>(kv_indptr),
reinterpret_cast<const int32_t*>(kv_page_indices),
#if 0 // we assume page_block_size=1 for now
reinterpret_cast<const int32_t*>(kv_last_page_lens),
page_block_size,
#endif
page_table,
#if CK_TILE_FMHA_FWD_FAST_EXP2
static_cast<float>(scale_s * ck_tile::log2e_v<>),
#else
@@ -453,12 +463,8 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
ck_tile::index_t num_head_q,
ck_tile::index_t nhead_ratio_qk,
int32_t num_total_pages,
const void* kv_indptr,
const void* kv_page_indices,
#if 0 // we assume page_block_size=1 for now
const void* kv_last_page_lens,
ck_tile::index_t page_block_size,
#endif
const PageBlockTableKargs& page_table,
float scale_s,
[[maybe_unused]] float scale_p,
[[maybe_unused]] float scale_o,
@@ -498,12 +504,8 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
num_head_q,
nhead_ratio_qk,
num_total_pages,
reinterpret_cast<const int32_t*>(kv_indptr),
reinterpret_cast<const int32_t*>(kv_page_indices),
#if 0 // we assume page_block_size=1 for now
reinterpret_cast<const int32_t*>(kv_last_page_lens),
page_block_size,
#endif
page_table,
#if CK_TILE_FMHA_FWD_FAST_EXP2
static_cast<float>(scale_s * ck_tile::log2e_v<>),
#else
@@ -700,10 +702,46 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
long_index_t batch_offset_lse = 0;
long_index_t batch_offset_o = 0;
const int32_t num_page_blocks = kargs.kv_indptr[i_batch + 1] - kargs.kv_indptr[i_batch];
#if 0 // we assume page_block_size=1 for now
const int32_t last_page_len = kargs.kv_last_page_lens[i_batch];
#endif
const index_t seqlen_k = [&]() {
if constexpr(kKVLookupTable ==
BlockAttentionKVCacheLookupTableEnum::SGLANG_PAGE_TABLE_1D)
{
const int32_t page_start = kargs.page_table.kv_indptr[i_batch];
const int32_t page_end = kargs.page_table.kv_indptr[i_batch + 1];
const int32_t num_page_blocks = page_end - page_start;
const int32_t last_page_len = [&]() {
if constexpr(kPageBlockSize == 1)
return static_cast<int32_t>(kPageBlockSize);
else
return kargs.page_table.kv_last_page_lens[i_batch];
}();
return num_page_blocks > 0
? static_cast<index_t>((num_page_blocks - 1) * kargs.page_block_size +
last_page_len)
: 0;
}
else // BlockAttentionKVCacheLookupTableEnum::VLLM_BLOCK_TABLE_2D
{
if(kargs.page_table.seqlen_k_ptr != nullptr)
return static_cast<index_t>(kargs.page_table.seqlen_k_ptr[i_batch]);
else
return kargs.seqlen_k;
}
}();
const int32_t* page_idx = [&]() {
if constexpr(kKVLookupTable ==
BlockAttentionKVCacheLookupTableEnum::SGLANG_PAGE_TABLE_1D)
{
return kargs.page_table.kv_page_indices + kargs.page_table.kv_indptr[i_batch];
}
else // BlockAttentionKVCacheLookupTableEnum::VLLM_BLOCK_TABLE_2D
{
return kargs.page_table.block_table_ptr +
static_cast<long_index_t>(i_batch) *
kargs.page_table.batch_stride_block_table;
}
}();
if constexpr(kIsGroupMode)
{
// get starting offset for each batch
@@ -711,8 +749,6 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
batch_offset_q = query_start * kargs.stride_q;
kargs.kv_page_indices += kargs.kv_indptr[i_batch];
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
{
batch_offset_bias = query_start * kargs.stride_bias;
@@ -737,18 +773,12 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
return;
}
#if 0 // we assume page_block_size=1 for now
kargs.seqlen_k = (num_page_blocks - 1) * kargs.page_block_size + last_page_len;
#else
kargs.seqlen_k = num_page_blocks;
#endif
kargs.seqlen_k = seqlen_k;
}
else
{
batch_offset_q = static_cast<long_index_t>(i_batch) * kargs.batch_stride_q;
kargs.kv_page_indices += kargs.kv_indptr[i_batch];
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
{
batch_offset_bias = static_cast<long_index_t>(i_batch) * kargs.batch_stride_bias;
@@ -764,11 +794,7 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
}
batch_offset_o = static_cast<long_index_t>(i_batch) * kargs.batch_stride_o;
#if 0 // we assume page_block_size=1 for now
kargs.seqlen_k = (num_page_blocks - 1) * kargs.page_block_size + last_page_len;
#else
kargs.seqlen_k = num_page_blocks;
#endif
kargs.seqlen_k = seqlen_k;
}
// for simplicity, batch stride we just modify the pointer
@@ -809,60 +835,137 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
}
}();
const auto k_dram = [&]() {
const auto k_dram_naive = make_naive_tensor_view<address_space_enum::global>(
k_ptr,
make_tuple(kargs.num_total_pages * kargs.page_block_size, kargs.hdim_q),
make_tuple(kargs.stride_k, 1),
number<FmhaPipeline::kAlignmentK>{},
number<1>{});
constexpr bool kPadSeqLenK_ = kUseAsyncCopy ? kPadSeqLenK : true;
return pad_tensor_view(
k_dram_naive,
make_tuple(number<FmhaPipeline::kN0>{}, number<FmhaPipeline::kK0>{}),
sequence<kPadSeqLenK_, kPadHeadDimQ>{});
}();
const auto v_dram = [&]() {
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
if constexpr(kKVMemoryLayout ==
BlockAttentionKVCacheMemoryLayoutEnum::VECTORIZED_LAYOUT)
{
const auto v_dram_naive = make_naive_tensor_view<address_space_enum::global>(
v_ptr,
make_tuple(kargs.num_total_pages * kargs.page_block_size, kargs.hdim_v),
make_tuple(kargs.stride_v, 1),
number<FmhaPipeline::kAlignmentV>{},
// Vectorized K Layout: [NumPages, D/kVectorSize, S, kVectorSize]
// Logical View for Pipeline: (TotalSeqK, D)
// Define the naive physical view with 4D shape: (NumPages, HeadDim/kVectorSize,
// PageBlockSize, kVectorSize)
// Strides: (BatchStride, PageBlockSize*kVectorSize, kVectorSize, 1)
const auto k_dram_naive = make_naive_tensor_view<address_space_enum::global>(
k_ptr,
make_tuple(kargs.num_total_pages,
kargs.hdim_q / kVectorSize,
kargs.page_block_size,
kVectorSize),
make_tuple(
kargs.batch_stride_k, kargs.page_block_size * kVectorSize, kVectorSize, 1),
number<FmhaPipeline::kAlignmentK>{},
number<1>{});
const auto v_dram_transposed = transform_tensor_view(
v_dram_naive,
make_tuple(
make_pass_through_transform(kargs.hdim_v),
make_pass_through_transform(kargs.num_total_pages * kargs.page_block_size)),
make_tuple(sequence<1>{}, sequence<0>{}),
// Merge to (TotalSeqK, D) in a single transform:
// physical (Page, D/vec, S, vec) -> logical (TotalSeqK, D)
auto k_dram_2d = transform_tensor_view(
k_dram_naive,
make_tuple(make_merge_transform(make_tuple(kargs.num_total_pages,
kargs.page_block_size)), // TotalSeqK
make_merge_transform(
make_tuple(static_cast<int32_t>(kargs.hdim_q / kVectorSize),
static_cast<int32_t>(kVectorSize)))), // D
make_tuple(sequence<0, 2>{}, sequence<1, 3>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
constexpr bool kPadSeqLenK_ = kUseAsyncCopy ? kPadSeqLenK : true;
return pad_tensor_view(
v_dram_transposed,
k_dram_2d,
make_tuple(number<FmhaPipeline::kN0>{}, number<FmhaPipeline::kK0>{}),
sequence<kPadSeqLenK_, kPadHeadDimQ>{});
}
else
{
// Linear K Layout: [NumPages, PageSize, NumHeads, HeadDim]
// Logical View for Pipeline: (TotalSeqK, D)
const auto k_dram_naive = make_naive_tensor_view<address_space_enum::global>(
k_ptr,
make_tuple(kargs.num_total_pages, kargs.page_block_size, kargs.hdim_q),
make_tuple(kargs.batch_stride_k, kargs.stride_k, 1),
number<FmhaPipeline::kAlignmentK>{},
number<1>{});
// Merge to (TotalSeqK, D) in a single transform:
// physical (Page, S, D) -> logical (TotalSeqK, D)
auto k_dram_2d = transform_tensor_view(
k_dram_naive,
make_tuple(make_merge_transform(
make_tuple(kargs.num_total_pages, kargs.page_block_size)),
make_pass_through_transform(kargs.hdim_q)),
make_tuple(sequence<0, 1>{}, sequence<2>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
constexpr bool kPadSeqLenK_ = kUseAsyncCopy ? kPadSeqLenK : true;
return pad_tensor_view(
k_dram_2d,
make_tuple(number<FmhaPipeline::kN0>{}, number<FmhaPipeline::kK0>{}),
sequence<kPadSeqLenK_, kPadHeadDimQ>{});
}
}();
const auto v_dram = [&]() {
if constexpr(kKVMemoryLayout ==
BlockAttentionKVCacheMemoryLayoutEnum::VECTORIZED_LAYOUT)
{
// Vectorized V Layout: [NumPages, S/kVectorSize, D, kVectorSize]
// Logical View for Pipeline: (D, TotalSeqK) - Transposed for GEMM
// Define the naive physical view with 4D shape: (NumPages,
// PageBlockSize/kVectorSize, HeadDim, kVectorSize)
// Strides: (BatchStride, HeadDim*kVectorSize, kVectorSize, 1)
const auto v_dram_naive = make_naive_tensor_view<address_space_enum::global>(
v_ptr,
make_tuple(kargs.num_total_pages,
kargs.page_block_size / kVectorSize,
kargs.hdim_v,
kVectorSize),
make_tuple(kargs.batch_stride_v, kargs.hdim_v * kVectorSize, kVectorSize, 1),
number<FmhaPipeline::kAlignmentV>{},
number<1>{});
// Merge to (D, TotalSeqK) in a single transform:
// physical (Page, S/vec, D, vec) -> logical (D, TotalSeqK)
auto v_dram_final = transform_tensor_view(
v_dram_naive,
make_tuple(make_pass_through_transform(kargs.hdim_v), // D
make_merge_transform(make_tuple(kargs.num_total_pages,
kargs.page_block_size / kVectorSize,
kVectorSize))), // TotalSeqK
make_tuple(sequence<2>{}, sequence<0, 1, 3>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
constexpr bool kPadSeqLenK_ = kUseAsyncCopy ? kPadSeqLenK : true;
return pad_tensor_view(
v_dram_final,
make_tuple(number<FmhaPipeline::kN1>{}, number<FmhaPipeline::kK1>{}),
sequence<kPadHeadDimV, kPadSeqLenK_>{});
}
else
{
// Linear V Layout: [NumPages, PageSize, NumHeads, HeadDim]
// Logical View for Pipeline: (D, TotalSeqK)
const auto v_dram_naive = make_naive_tensor_view<address_space_enum::global>(
v_ptr,
make_tuple(kargs.hdim_v, kargs.num_total_pages * kargs.page_block_size),
make_tuple(kargs.stride_v, 1),
make_tuple(kargs.num_total_pages, kargs.page_block_size, kargs.hdim_v),
make_tuple(kargs.batch_stride_v, kargs.stride_v, 1),
number<FmhaPipeline::kAlignmentV>{},
number<1>{});
constexpr bool kPadHeadDimV_ = kUseAsyncCopy ? kPadHeadDimV : false;
return pad_tensor_view(
// Merge to (D, TotalSeqK) in a single transform:
// physical (Page, S, D) -> logical (D, TotalSeqK)
auto v_dram_final = transform_tensor_view(
v_dram_naive,
make_tuple(make_pass_through_transform(kargs.hdim_v),
make_merge_transform(
make_tuple(kargs.num_total_pages, kargs.page_block_size))),
make_tuple(sequence<2>{}, sequence<0, 1>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
constexpr bool kPadSeqLenK_ = kUseAsyncCopy ? kPadSeqLenK : true;
return pad_tensor_view(
v_dram_final,
make_tuple(number<FmhaPipeline::kN1>{}, number<FmhaPipeline::kK1>{}),
sequence<kPadHeadDimV_, kPadSeqLenK>{});
sequence<kPadHeadDimV, kPadSeqLenK_>{});
}
}();
auto q_dram_window = make_tile_window(
q_dram,
[&]() {
@@ -1070,6 +1173,15 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
BlockIndices block_indices{i_batch, i_nhead, i_nhead / kargs.nhead_ratio_qk};
const index_t stride_k_for_pipeline =
kKVMemoryLayout == BlockAttentionKVCacheMemoryLayoutEnum::VECTORIZED_LAYOUT
? kVectorSize
: kargs.stride_k;
const index_t stride_v_for_pipeline =
kKVMemoryLayout == BlockAttentionKVCacheMemoryLayoutEnum::VECTORIZED_LAYOUT
? kargs.hdim_v
: kargs.stride_v;
auto o_acc_tile = [&] {
if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::PERTENSOR)
{
@@ -1108,9 +1220,11 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
variant_params,
block_indices,
smem_ptr,
kargs.kv_page_indices,
kargs.stride_k,
kargs.stride_v,
page_idx,
stride_k_for_pipeline,
stride_v_for_pipeline,
kargs.batch_stride_k,
kargs.batch_stride_v,
dropout);
}
else
@@ -1128,9 +1242,11 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
variant_params,
block_indices,
smem_ptr,
kargs.kv_page_indices,
kargs.stride_k,
kargs.stride_v,
page_idx,
stride_k_for_pipeline,
stride_v_for_pipeline,
kargs.batch_stride_k,
kargs.batch_stride_v,
dropout);
}
}();

View File

@@ -6,12 +6,82 @@
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp"
#include "ck_tile/ops/fmha/block/block_attention_kvcache_layout_enum.hpp"
#include "ck_tile/ops/fmha/block/block_dropout.hpp"
#include "ck_tile/ops/fmha/block/variants.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async_default_policy.hpp"
#include "ck_tile/ops/reduce/block/block_reduce.hpp"
namespace ck_tile {
template <typename OffsetVecType,
typename CoordVecType,
index_t kCoordAxis,
index_t kPageBlockSize,
index_t kLog2PageSize,
index_t kLoopStart,
index_t kLoopCount,
index_t kLoopStride,
BlockAttentionKVCacheMemoryLayoutEnum kKVMemoryLayout,
bool kIsKcache,
index_t kVectorSize>
CK_TILE_HOST_DEVICE void kv_offset_array_transform(const index_t* page_vec,
const index_t& stride_kv,
const index_t& page_stride_kv,
const CoordVecType& coord_vec,
OffsetVecType& kv_offset_vec,
index_t global_seq_offset = 0)
{
const index_t& thread_coord_start = coord_vec[kCoordAxis];
constexpr index_t kInPageOffsetMask = (1 << kLog2PageSize) - 1;
if constexpr(kIsKcache)
{
// for k offsets
static_for<0, kLoopCount, 1>{}([&](auto k0) {
const index_t global_token_idx =
global_seq_offset + thread_coord_start + kLoopStart + kLoopStride * k0.value;
const index_t page_id = global_token_idx >> kLog2PageSize;
const index_t page_offset = global_token_idx & kInPageOffsetMask;
kv_offset_vec[k0] = static_cast<long_index_t>(page_vec[page_id]) * page_stride_kv +
static_cast<long_index_t>(page_offset) * stride_kv;
});
}
else
{
// for v offsets
const index_t lane0_start = __builtin_amdgcn_readfirstlane(thread_coord_start);
const index_t lane0_page_id =
(global_seq_offset + lane0_start + kLoopStart) >> kLog2PageSize;
const long_index_t page_loc =
static_cast<long_index_t>(page_vec[lane0_page_id]) * page_stride_kv;
static_for<0, kLoopCount, 1>{}([&](auto k0) {
const index_t page_offset =
(global_seq_offset + thread_coord_start + kLoopStart + k0.value) &
kInPageOffsetMask;
if constexpr(kKVMemoryLayout ==
BlockAttentionKVCacheMemoryLayoutEnum::VECTORIZED_LAYOUT)
{
// Vectorized layout offset
// Layout: [BlockSize/kVectorSize, HeadDim, kVectorSize]
// Offset(s) = (s / kVectorSize) * (HeadDim * kVectorSize) + (s % kVectorSize)
const index_t s = page_offset;
const index_t D = stride_kv;
const long_index_t s_offset =
static_cast<long_index_t>((s / kVectorSize) * (D * kVectorSize)) +
(s % kVectorSize);
kv_offset_vec[k0] = page_loc + s_offset;
}
else // BlockAttentionKVCacheMemoryLayoutEnum::LINEAR_LAYOUT
{
kv_offset_vec[k0] = page_loc + static_cast<long_index_t>(page_offset) * stride_kv;
}
});
}
}
// a variation of qr/ks/vs, where we use async copy to load k (potentially v in the future)
template <typename Problem_,
@@ -41,19 +111,25 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
static constexpr index_t kBlockSize = Problem::kBlockSize;
static constexpr index_t kM0 = BlockFmhaShape::kM0;
static constexpr index_t kN0 = BlockFmhaShape::kN0;
static constexpr index_t kK0 = BlockFmhaShape::kK0;
static constexpr index_t kN1 = BlockFmhaShape::kN1;
static constexpr index_t kK1 = BlockFmhaShape::kK1;
static constexpr index_t kQKHeaddim = BlockFmhaShape::kQKHeaddim;
static constexpr index_t kSubQKHeaddim = BlockFmhaShape::kSubQKHeaddim;
static constexpr auto I0 = number<0>{};
static constexpr auto I1 = number<1>{};
static constexpr auto I2 = number<2>{};
static constexpr auto I3 = number<3>{};
static constexpr index_t kM0 = BlockFmhaShape::kM0;
static constexpr index_t kN0 = BlockFmhaShape::kN0;
static constexpr index_t kK0 = BlockFmhaShape::kK0;
static constexpr index_t kN1 = BlockFmhaShape::kN1;
static constexpr index_t kK1 = BlockFmhaShape::kK1;
static constexpr index_t kQKHeaddim = BlockFmhaShape::kQKHeaddim;
static constexpr index_t kSubQKHeaddim = BlockFmhaShape::kSubQKHeaddim;
static constexpr index_t kPageBlockSize = Problem::kPageBlockSize;
static constexpr index_t kLog2PageSize = Problem::kLog2PageSize;
static constexpr index_t kVectorSize = Problem::kVectorSize;
static constexpr auto I0 = number<0>{};
static constexpr auto I1 = number<1>{};
static constexpr auto I2 = number<2>{};
static constexpr auto I3 = number<3>{};
static_assert(kSubQKHeaddim <= 256, "hdim bigger than 256 is not suitable for this pipeline!");
static_assert(kPageBlockSize % kN0 == 0,
"V offset assumes each tile stays within a page; kPageBlockSize must be "
"divisible by kN0.");
static constexpr bool kIsGroupMode = Problem::kIsGroupMode;
// TODO: seq_q always support padding, hdim_q/v support multiple of vector(like 8x)
@@ -68,6 +144,7 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
static constexpr auto BiasEnum = Problem::BiasEnum;
static constexpr bool kStoreLSE = Problem::kStoreLSE;
static constexpr bool kHasDropout = Problem::kHasDropout;
static constexpr auto kKVMemoryLayout = Problem::kKVMemoryLayout;
static_assert((CK_TILE_FMHA_FWD_FAST_EXP2 &&
(kHasLogitsSoftCap && Problem::BiasEnum == BlockAttentionBiasEnum::NO_BIAS ||
@@ -196,6 +273,8 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
const index_t* page_idx,
const index_t stride_k,
const index_t stride_v,
const index_t page_stride_k,
const index_t page_stride_v,
DropoutType& dropout) const
{
static_assert(
@@ -325,9 +404,20 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
using KDstrEncode = typename decltype(k_dist)::DstrEncode;
constexpr index_t NRepeat = KDstrEncode::hs_lengthss_[I0][I0];
statically_indexed_array<index_t, NRepeat> k_offsets;
static_for<0, NRepeat, 1>{}([&](auto n0) {
k_offsets[n0] = page_idx[k_coord[0] + kN0 / NRepeat * n0.value] * stride_k;
});
index_t current_seq_k = seqlen_k_start;
kv_offset_array_transform<statically_indexed_array<index_t, NRepeat>,
decltype(k_coord),
0,
kPageBlockSize,
kLog2PageSize,
0,
NRepeat,
kN0 / NRepeat,
kKVMemoryLayout,
true,
kVectorSize>(
page_idx, stride_k, page_stride_k, k_coord, k_offsets, current_seq_k);
auto k_dram_window = make_tile_scatter_gather(k_dram_block_window.get_bottom_tensor_view(),
k_dram_block_window.get_window_lengths(),
k_dram_block_window.get_window_origin(),
@@ -360,10 +450,18 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
using VDstrEncode = typename decltype(v_dist)::DstrEncode;
constexpr index_t V_KRepeat = VDstrEncode::hs_lengthss_[I1][I3];
statically_indexed_array<index_t, V_KRepeat> v_offsets;
(void)stride_k;
static_for<0, V_KRepeat, 1>{}([&](auto k0) {
v_offsets[k0] = page_idx[v_coord[VPageIndexDim] + k0.value] * stride_v;
});
kv_offset_array_transform<statically_indexed_array<index_t, V_KRepeat>,
decltype(v_coord),
VPageIndexDim,
kPageBlockSize,
kLog2PageSize,
0,
V_KRepeat,
1,
kKVMemoryLayout,
false,
kVectorSize>(
page_idx, stride_v, page_stride_v, v_coord, v_offsets, current_seq_k);
auto v_dram_window =
make_tile_scatter_gather(v_dram_block_window_tmp.get_bottom_tensor_view(),
@@ -425,13 +523,6 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
async_load_fence();
__builtin_amdgcn_s_barrier();
const auto bias_tile = load_tile(bias_dram_window); // load bias tile
auto v_buf = load_tile(v_dram_window, number<-1>{}, bool_constant<false>{});
static_for<0, V_KRepeat, 1>{}([&](auto k0) {
v_offsets[k0] = page_idx[kK1 + v_coord[VPageIndexDim] + k0.value] * stride_v;
});
v_dram_window.update_page_idx(v_offsets);
__builtin_amdgcn_sched_barrier(0);
{ // tail
gemm_0(
@@ -444,49 +535,67 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
}
__builtin_amdgcn_sched_barrier(1);
// STAGE 2, scale_s, add bias, mask, softmax
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
{
s_acc = tile_elementwise_in(s_acc_element_func, s_acc);
tile_elementwise_inout([&scale_s](auto& x) { x = x * scale_s; }, s_acc);
tile_elementwise_inout(
[&](auto& x, const auto& y) {
#if !CK_TILE_FMHA_FWD_FAST_EXP2
x += type_convert<SaccDataType>(bias_element_func(y));
#else
x += log2e_v<SaccDataType> *
type_convert<SaccDataType>(bias_element_func(y));
#endif
},
s_acc,
bias_tile);
}
else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI)
{
const auto k_origin = k_dram_block_window.get_window_origin();
constexpr auto s_spans = decltype(s_acc)::get_distributed_spans();
s_acc = tile_elementwise_in(s_acc_element_func, s_acc);
sweep_tile_span(s_spans[number<0>{}], [&](auto idx0) {
sweep_tile_span(s_spans[number<1>{}], [&](auto idx1) {
const auto tile_idx = get_x_indices_from_distributed_indices(
s_acc.get_tile_distribution(), make_tuple(idx0, idx1));
auto v_buf = load_tile(v_dram_window, number<-1>{}, bool_constant<false>{});
kv_offset_array_transform<statically_indexed_array<index_t, V_KRepeat>,
decltype(v_coord),
VPageIndexDim,
kPageBlockSize,
kLog2PageSize,
kK1,
V_KRepeat,
1,
kKVMemoryLayout,
false,
kVectorSize>(
page_idx, stride_v, page_stride_v, v_coord, v_offsets, current_seq_k);
v_dram_window.update_page_idx(v_offsets);
const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
constexpr auto i_j_idx = make_tuple(idx0, idx1);
const auto p = [&]() {
const auto bias_tile = load_tile(bias_dram_window); // load bias tile
s_acc(i_j_idx) *= scale_s;
position_encoding.update(s_acc(i_j_idx), row, col);
});
});
}
else
{
s_acc = tile_elementwise_in(s_acc_element_func, s_acc);
if constexpr(kHasLogitsSoftCap)
// STAGE 2, scale_s, add bias, mask, softmax
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
{
auto apply_logits_transform =
[&variant, &variant_params, &block_indices](auto& x) {
s_acc = tile_elementwise_in(s_acc_element_func, s_acc);
tile_elementwise_inout([&scale_s](auto& x) { x = x * scale_s; }, s_acc);
tile_elementwise_inout(
[&](auto& x, const auto& y) {
#if !CK_TILE_FMHA_FWD_FAST_EXP2
x += type_convert<SaccDataType>(bias_element_func(y));
#else
x += log2e_v<SaccDataType> *
type_convert<SaccDataType>(bias_element_func(y));
#endif
},
s_acc,
bias_tile);
}
else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI)
{
const auto k_origin = k_dram_block_window.get_window_origin();
constexpr auto s_spans = decltype(s_acc)::get_distributed_spans();
s_acc = tile_elementwise_in(s_acc_element_func, s_acc);
sweep_tile_span(s_spans[number<0>{}], [&](auto idx0) {
sweep_tile_span(s_spans[number<1>{}], [&](auto idx1) {
const auto tile_idx = get_x_indices_from_distributed_indices(
s_acc.get_tile_distribution(), make_tuple(idx0, idx1));
const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
constexpr auto i_j_idx = make_tuple(idx0, idx1);
s_acc(i_j_idx) *= scale_s;
position_encoding.update(s_acc(i_j_idx), row, col);
});
});
}
else
{
s_acc = tile_elementwise_in(s_acc_element_func, s_acc);
if constexpr(kHasLogitsSoftCap)
{
auto apply_logits_transform = [&variant, &variant_params, &block_indices](
auto& x) {
x = variant.LogitsTransform(variant_params,
variant.QueryTransform(variant_params, x),
block_indices.batch_idx,
@@ -494,216 +603,229 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
block_indices.kv_head_idx);
};
#if !CK_TILE_FMHA_FWD_FAST_EXP2
for(index_t i = 0; i < s_acc.thread_buf_.size(); ++i)
{
apply_logits_transform(s_acc.thread_buf_[i]);
}
for(index_t i = 0; i < s_acc.thread_buf_.size(); ++i)
{
apply_logits_transform(s_acc.thread_buf_[i]);
}
#else
for(index_t i = 0; i < s_acc.thread_buf_.size(); ++i)
{
for(index_t i = 0; i < s_acc.thread_buf_.size(); ++i)
{
#if(defined(__gfx90a__) || defined(__gfx94__)) && \
(CK_TILE_ATTENTION_LOGITS_SOFT_CAP_DEFAULT == CK_TILE_ATTENTION_LOGITS_SOFT_CAP_SOFTSIGN && \
CK_TILE_ATTENTION_USE_SOFTSIGN_ASM)
// Avoid data hazard if v_mfma is followed by inline asm consumer
// instructions. In this case, compiler won't add s_nop for us
if(i == s_acc.thread_buf_.size() / 2)
{
__builtin_amdgcn_sched_barrier(0);
// Avoid data hazard if v_mfma is followed by inline asm consumer
// instructions. In this case, compiler won't add s_nop for us
if(i == s_acc.thread_buf_.size() / 2)
{
__builtin_amdgcn_sched_barrier(0);
}
#endif
apply_logits_transform(s_acc.thread_buf_[i]);
}
#endif
apply_logits_transform(s_acc.thread_buf_[i]);
}
#endif
}
else
{
#if !CK_TILE_FMHA_FWD_FAST_EXP2
tile_elementwise_inout([&scale_s](auto& x) { x = x * scale_s; }, s_acc);
#endif
}
}
move_tile_window(bias_dram_window, {0, kN0});
if constexpr(kPadSeqLenK || FmhaMask::IsMasking)
{
const auto k_origin = k_dram_block_window.get_window_origin();
bool need_perpixel_check = mask.IsEdgeTile(q_origin.at(number<0>{}),
k_origin.at(number<0>{}),
number<kM0>{},
number<kN0>{});
if(need_perpixel_check)
{
set_tile_if(
s_acc, -numeric<SMPLComputeDataType>::infinity(), [&](auto tile_idx) {
const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
return !variant.LogitsMask(variant_params,
block_indices.batch_idx,
row,
col,
block_indices.qo_head_idx,
block_indices.kv_head_idx);
});
}
}
const auto s = cast_tile<SMPLComputeDataType>(s_acc); // S{j}
auto m_local = block_tile_reduce<SMPLComputeDataType>(
s,
sequence<1>{},
f_max,
-numeric<SMPLComputeDataType>::infinity()); // m_local = rowmax(S{j})
block_tile_reduce_sync(m_local, f_max, bool_constant<false>{});
const auto m_old = m; // m{j-1}
tile_elementwise_inout(
[](auto& e0, auto e1, auto e2) { e0 = max(e1, e2); }, m, m_old, m_local); // m{j}
auto p_compute = make_static_distributed_tensor<SMPLComputeDataType>(
s.get_tile_distribution()); // Pcompute{j}
__builtin_amdgcn_sched_barrier(0x7F);
// store & prefetch next v, after the max reduction
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
{
auto v_shuffle_tmp = make_static_distributed_tensor<VDataType>(
Policy::template MakeShuffledVRegBlockDescriptor<Problem>());
shuffle_tile(v_shuffle_tmp, v_buf);
auto v_lds_window_tmp =
get_slice_tile(v_lds_window,
sequence<(LdsSeq.at(number<k0_loops>{})) * kN1, 0>{},
sequence<(LdsSeq.at(number<k0_loops>{}) + 1) * kN1, kK1>{});
store_tile(
v_lds_window_tmp,
tile_elementwise_in(v_element_func, v_shuffle_tmp)); // store the prefetch
}
else
{
auto v_lds_window_tmp =
get_slice_tile(v_lds_window,
sequence<(LdsSeq.at(number<k0_loops>{})) * kN1, 0>{},
sequence<(LdsSeq.at(number<k0_loops>{}) + 1) * kN1, kK1>{});
store_tile(v_lds_window_tmp,
tile_elementwise_in(v_element_func, v_buf)); // store the prefetch
}
if constexpr(k1_loops > 1)
{
move_tile_window(
v_dram_window,
{0, kK1}); // will have scratch if move this right after load_tile(v_dram)...
v_buf = load_tile(
v_dram_window, number<-1>{}, bool_constant<false>{}); // load next v_buf
static_for<0, V_KRepeat, 1>{}([&](auto k0) {
v_offsets[k0] =
page_idx[kK1 * 2 + v_coord[VPageIndexDim] + k0.value] * stride_v;
});
v_dram_window.update_page_idx(v_offsets);
}
__builtin_amdgcn_sched_barrier(0);
static const auto get_validated_m = [](SMPLComputeDataType raw_m) {
/// NOTICE: bias might be materialized mask including -inf values, need
/// consideration. alibi does not have this problem
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS ||
FmhaMask::IsMasking)
{
return raw_m == -numeric<SMPLComputeDataType>::infinity()
? type_convert<SMPLComputeDataType>(0.f)
: raw_m;
}
else
{
return raw_m;
}
};
constexpr auto p_spans = decltype(p_compute)::get_distributed_spans();
sweep_tile_span(p_spans[number<0>{}], [&](auto idx0) {
constexpr auto i_idx = make_tuple(idx0);
#if CK_TILE_FMHA_FWD_FAST_EXP2
auto row_max = scale_s * get_validated_m(m[i_idx]);
#endif
sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) {
constexpr auto i_j_idx = make_tuple(idx0, idx1);
#if CK_TILE_FMHA_FWD_FAST_EXP2
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS ||
BiasEnum == BlockAttentionBiasEnum::ALIBI)
{
p_compute(i_j_idx) = exp2(s[i_j_idx] - get_validated_m(m[i_idx]));
}
else
{
if constexpr(kHasLogitsSoftCap)
#if !CK_TILE_FMHA_FWD_FAST_EXP2
tile_elementwise_inout([&scale_s](auto& x) { x = x * scale_s; }, s_acc);
#endif
}
}
move_tile_window(bias_dram_window, {0, kN0});
if constexpr(kPadSeqLenK || FmhaMask::IsMasking)
{
const auto k_origin = k_dram_block_window.get_window_origin();
bool need_perpixel_check = mask.IsEdgeTile(q_origin.at(number<0>{}),
k_origin.at(number<0>{}),
number<kM0>{},
number<kN0>{});
if(need_perpixel_check)
{
set_tile_if(
s_acc, -numeric<SMPLComputeDataType>::infinity(), [&](auto tile_idx) {
const auto row =
q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
const auto col =
k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
return !variant.LogitsMask(variant_params,
block_indices.batch_idx,
row,
col,
block_indices.qo_head_idx,
block_indices.kv_head_idx);
});
}
}
const auto s = cast_tile<SMPLComputeDataType>(s_acc); // S{j}
auto m_local = block_tile_reduce<SMPLComputeDataType>(
s,
sequence<1>{},
f_max,
-numeric<SMPLComputeDataType>::infinity()); // m_local = rowmax(S{j})
block_tile_reduce_sync(m_local, f_max, bool_constant<false>{});
const auto m_old = m; // m{j-1}
tile_elementwise_inout([](auto& e0, auto e1, auto e2) { e0 = max(e1, e2); },
m,
m_old,
m_local); // m{j}
auto p_compute = make_static_distributed_tensor<SMPLComputeDataType>(
s.get_tile_distribution()); // Pcompute{j}
__builtin_amdgcn_sched_barrier(0x7F);
// store & prefetch next v, after the max reduction
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
{
auto v_shuffle_tmp = make_static_distributed_tensor<VDataType>(
Policy::template MakeShuffledVRegBlockDescriptor<Problem>());
shuffle_tile(v_shuffle_tmp, v_buf);
auto v_lds_window_tmp =
get_slice_tile(v_lds_window,
sequence<(LdsSeq.at(number<k0_loops>{})) * kN1, 0>{},
sequence<(LdsSeq.at(number<k0_loops>{}) + 1) * kN1, kK1>{});
store_tile(
v_lds_window_tmp,
tile_elementwise_in(v_element_func, v_shuffle_tmp)); // store the prefetch
}
else
{
auto v_lds_window_tmp =
get_slice_tile(v_lds_window,
sequence<(LdsSeq.at(number<k0_loops>{})) * kN1, 0>{},
sequence<(LdsSeq.at(number<k0_loops>{}) + 1) * kN1, kK1>{});
store_tile(v_lds_window_tmp,
tile_elementwise_in(v_element_func, v_buf)); // store the prefetch
}
if constexpr(k1_loops > 1)
{
move_tile_window(
v_dram_window,
{0,
kK1}); // will have scratch if move this right after load_tile(v_dram)...
v_buf = load_tile(
v_dram_window, number<-1>{}, bool_constant<false>{}); // load next v_buf
kv_offset_array_transform<statically_indexed_array<index_t, V_KRepeat>,
decltype(v_coord),
VPageIndexDim,
kPageBlockSize,
kLog2PageSize,
2 * kK1,
V_KRepeat,
1,
kKVMemoryLayout,
false,
kVectorSize>(
page_idx, stride_v, page_stride_v, v_coord, v_offsets, current_seq_k);
v_dram_window.update_page_idx(v_offsets);
}
__builtin_amdgcn_sched_barrier(0);
static const auto get_validated_m = [](SMPLComputeDataType raw_m) {
/// NOTICE: bias might be materialized mask including -inf values, need
/// consideration. alibi does not have this problem
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS ||
FmhaMask::IsMasking)
{
return raw_m == -numeric<SMPLComputeDataType>::infinity()
? type_convert<SMPLComputeDataType>(0.f)
: raw_m;
}
else
{
return raw_m;
}
};
constexpr auto p_spans = decltype(p_compute)::get_distributed_spans();
sweep_tile_span(p_spans[number<0>{}], [&](auto idx0) {
constexpr auto i_idx = make_tuple(idx0);
#if CK_TILE_FMHA_FWD_FAST_EXP2
auto row_max = scale_s * get_validated_m(m[i_idx]);
#endif
sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) {
constexpr auto i_j_idx = make_tuple(idx0, idx1);
#if CK_TILE_FMHA_FWD_FAST_EXP2
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS ||
BiasEnum == BlockAttentionBiasEnum::ALIBI)
{
p_compute(i_j_idx) = exp2(s[i_j_idx] - get_validated_m(m[i_idx]));
}
else
{
p_compute(i_j_idx) = exp2(scale_s * s[i_j_idx] - row_max);
if constexpr(kHasLogitsSoftCap)
{
p_compute(i_j_idx) = exp2(s[i_j_idx] - get_validated_m(m[i_idx]));
}
else
{
p_compute(i_j_idx) = exp2(scale_s * s[i_j_idx] - row_max);
}
}
}
#else
p_compute(i_j_idx) = exp(s[i_j_idx] - get_validated_m(m[i_idx]));
p_compute(i_j_idx) = exp(s[i_j_idx] - get_validated_m(m[i_idx]));
#endif
});
});
});
auto rowsum_p = block_tile_reduce<SMPLComputeDataType>(
p_compute, sequence<1>{}, f_sum, SMPLComputeDataType{0}); // rowsum(Pcompute{j})
auto rowsum_p = block_tile_reduce<SMPLComputeDataType>(
p_compute, sequence<1>{}, f_sum, SMPLComputeDataType{0}); // rowsum(Pcompute{j})
block_tile_reduce_sync(rowsum_p, f_sum, bool_constant<false>{});
// l{j}, Oacc{j}
constexpr auto o_spans = decltype(o_acc)::get_distributed_spans();
sweep_tile_span(o_spans[number<0>{}], [&](auto idx0) {
constexpr auto i_idx = make_tuple(idx0);
block_tile_reduce_sync(rowsum_p, f_sum, bool_constant<false>{});
// l{j}, Oacc{j}
constexpr auto o_spans = decltype(o_acc)::get_distributed_spans();
sweep_tile_span(o_spans[number<0>{}], [&](auto idx0) {
constexpr auto i_idx = make_tuple(idx0);
#if CK_TILE_FMHA_FWD_FAST_EXP2
const auto tmp = [&]() {
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS ||
BiasEnum == BlockAttentionBiasEnum::ALIBI)
{
return exp2(m_old[i_idx] - get_validated_m(m[i_idx]));
}
else
{
if constexpr(kHasLogitsSoftCap)
const auto tmp = [&]() {
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS ||
BiasEnum == BlockAttentionBiasEnum::ALIBI)
{
return exp2(m_old[i_idx] - get_validated_m(m[i_idx]));
}
else
{
auto row_max = scale_s * get_validated_m(m[i_idx]);
return exp2(scale_s * m_old[i_idx] - row_max);
if constexpr(kHasLogitsSoftCap)
{
return exp2(m_old[i_idx] - get_validated_m(m[i_idx]));
}
else
{
auto row_max = scale_s * get_validated_m(m[i_idx]);
return exp2(scale_s * m_old[i_idx] - row_max);
}
}
}
}();
}();
#else
const auto tmp = exp(m_old[i_idx] - get_validated_m(m[i_idx]));
const auto tmp = exp(m_old[i_idx] - get_validated_m(m[i_idx]));
#endif
l(i_idx) = tmp * l[i_idx] + rowsum_p[i_idx];
sweep_tile_span(o_spans[number<1>{}], [&](auto idx1) {
constexpr auto i_j_idx = make_tuple(idx0, idx1);
// FIXME: this use different equation from FA v2 paper,
// but produce correc result.
// Is the equation wrong?
o_acc(i_j_idx) *= tmp;
l(i_idx) = tmp * l[i_idx] + rowsum_p[i_idx];
sweep_tile_span(o_spans[number<1>{}], [&](auto idx1) {
constexpr auto i_j_idx = make_tuple(idx0, idx1);
// FIXME: this use different equation from FA v2 paper,
// but produce correc result.
// Is the equation wrong?
o_acc(i_j_idx) *= tmp;
});
});
});
if constexpr(kHasDropout)
{
auto randval_ptr =
reinterpret_cast<char*>(smem_ptr) + Policy::template GetSmemSizeKV<Problem>();
dropout.template Run<decltype(gemm_0), SMPLComputeDataType, RandValOutputDataType>(
randval_ptr,
seqlen_k_start + i_total_loops * kN0,
p_compute,
randval_dram_window);
}
if constexpr(kHasDropout)
{
auto randval_ptr = reinterpret_cast<char*>(smem_ptr) +
Policy::template GetSmemSizeKV<Problem>();
dropout
.template Run<decltype(gemm_0), SMPLComputeDataType, RandValOutputDataType>(
randval_ptr,
seqlen_k_start + i_total_loops * kN0,
p_compute,
randval_dram_window);
}
const auto p = [&]() {
#if CK_TILE_FMHA_FLOAT_TO_FLOAT16_RTN
// For fp32 to fp16,
// impl::cast_tile_pkrtz_fp16_fp32 would cause precision issue,
@@ -727,11 +849,18 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
{
v_buf = load_tile(
v_dram_window, number<-1>{}, bool_constant<false>{}); // load next v_buf
static_for<0, V_KRepeat, 1>{}([&](auto k0) {
v_offsets[k0] = page_idx[kK1 * 2 + i_k1.value * kK1 +
v_coord[VPageIndexDim] + k0.value] *
stride_v;
});
kv_offset_array_transform<statically_indexed_array<index_t, V_KRepeat>,
decltype(v_coord),
VPageIndexDim,
kPageBlockSize,
kLog2PageSize,
(2 + i_k1.value) * kK1,
V_KRepeat,
1,
kKVMemoryLayout,
false,
kVectorSize>(
page_idx, stride_v, page_stride_v, v_coord, v_offsets, current_seq_k);
v_dram_window.update_page_idx(v_offsets);
}
block_sync_lds();
@@ -772,14 +901,23 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
i_total_loops++;
if(i_total_loops < num_total_loop)
{
page_idx += kN0;
current_seq_k += kN0;
// move K tile windows
move_tile_window(k_dram_block_window, {kN0, 0});
k_dram_window.set_window_origin(k_dram_block_window.get_window_origin());
static_for<0, NRepeat, 1>{}([&](auto n0) {
k_offsets[n0] = page_idx[k_coord[0] + kN0 / NRepeat * n0.value] * stride_k;
});
kv_offset_array_transform<statically_indexed_array<index_t, NRepeat>,
decltype(k_coord),
0,
kPageBlockSize,
kLog2PageSize,
0,
NRepeat,
kN0 / NRepeat,
kKVMemoryLayout,
true,
kVectorSize>(
page_idx, stride_k, page_stride_k, k_coord, k_offsets, current_seq_k);
k_dram_window.update_page_idx(k_offsets);
if constexpr(k1_loops >= 2 &&
LdsSeq.at(number<0>{}) == LdsSeq.at(number<k0_loops + k1_loops - 2>{}))
@@ -887,6 +1025,8 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
const index_t* page_idx,
const index_t stride_k,
const index_t stride_v,
const index_t page_stride_k,
const index_t page_stride_v,
DropoutType& dropout) const
{
return operator()(q_dram_block_window_tmp,
@@ -913,6 +1053,8 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
page_idx,
stride_k,
stride_v,
page_stride_k,
page_stride_v,
dropout);
}
};

View File

@@ -4,6 +4,7 @@
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/fmha/block/block_attention_kvcache_layout_enum.hpp"
#include "ck_tile/ops/fmha/block/block_rotary_embedding.hpp"
namespace ck_tile {
@@ -65,6 +66,71 @@ struct BlockFmhaPipelineProblem
static constexpr bool kHasSink = Traits::kHasSink;
};
template <typename QDataType_,
typename KDataType_,
typename VDataType_,
typename SaccDataType_,
typename SMPLComputeDataType_,
typename BiasDataType_,
typename RandValOutputDataType_,
typename LSEDataType_,
typename PDataType_,
typename OaccDataType_,
typename ODataType_,
typename BlockFmhaShape_,
bool kIsGroupMode_,
typename AttentionVariant_,
typename FmhaMask_,
bool kUseTrLoad_,
int kPageBlockSize_,
typename Traits_>
struct BlockFmhaBatchPrefillPipelineProblem
: public BlockFmhaPipelineProblem<QDataType_,
KDataType_,
VDataType_,
SaccDataType_,
SMPLComputeDataType_,
BiasDataType_,
RandValOutputDataType_,
LSEDataType_,
PDataType_,
OaccDataType_,
ODataType_,
BlockFmhaShape_,
kIsGroupMode_,
AttentionVariant_,
FmhaMask_,
kUseTrLoad_,
Traits_>
{
static constexpr index_t kPageBlockSize = kPageBlockSize_;
static_assert(kPageBlockSize > 0, "kPageBlockSize must be positive");
static_assert((kPageBlockSize & (kPageBlockSize - 1)) == 0,
"kPageBlockSize must be power of two");
static constexpr index_t kLog2PageSize = []() constexpr {
index_t shift = 0;
index_t val = kPageBlockSize_;
while(val > 1)
{
val >>= 1;
shift++;
}
return shift;
}();
static constexpr index_t kVectorSize = 16 / sizeof(KDataType_); // Dwordx4
static constexpr auto kKVMemoryLayout = Traits_::kKVMemoryLayout;
static constexpr auto kKVLookupTable = Traits_::kKVLookupTable;
static constexpr bool kIsVectorizedLayout =
kKVMemoryLayout == BlockAttentionKVCacheMemoryLayoutEnum::VECTORIZED_LAYOUT;
static_assert(BlockFmhaShape_::kQKHeaddim % kVectorSize == 0,
"kQKHeaddim must be divisible by kVectorSize");
static_assert(!kIsVectorizedLayout || kPageBlockSize % kVectorSize == 0,
"kPageBlockSize must be divisible by kVectorSize for vectorized layout");
static_assert(kIsGroupMode_, "Batch prefill requires group mode");
};
template <typename QDataType_,
typename KDataType_,
typename VDataType_,

View File

@@ -5,6 +5,7 @@
#include "ck_tile/core.hpp"
#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp"
#include "ck_tile/ops/fmha/block/block_attention_kvcache_layout_enum.hpp"
#include "ck_tile/ops/fmha/block/block_attention_quant_scale_enum.hpp"
#include "ck_tile/ops/fmha/block/block_rotary_embedding.hpp"
@@ -40,6 +41,48 @@ struct TileFmhaTraits
static constexpr bool kHasSink = kHasSink_;
};
template <bool kPadSeqLenQ_ /* padding for seqlen_q */,
bool kPadSeqLenK_ /* padding for seqlen_k */,
bool kPadHeadDimQ_ /* padding for hdim_q */,
bool kPadHeadDimV_ /* padding for hdim_v */,
bool kHasLogitsSoftCap_,
BlockAttentionBiasEnum BiasEnum_,
bool kHasBiasGrad_,
bool kStoreLSE_,
bool kHasDropout_,
BlockAttentionQuantScaleEnum QScaleEnum_,
index_t kBlockPerCu_ = -1, /* overwrite occupancy if not -1 */
bool kSkipMinSeqlenQ_ = false, /* skip min seqlen q while chunked prefill */
index_t kPageBlockSize_ = 1,
BlockAttentionKVCacheMemoryLayoutEnum kKVMemoryLayout_ =
BlockAttentionKVCacheMemoryLayoutEnum::VECTORIZED_LAYOUT,
BlockAttentionKVCacheLookupTableEnum kKVLookupTable_ =
BlockAttentionKVCacheLookupTableEnum::SGLANG_PAGE_TABLE_1D>
struct TileFmhaBatchPrefillTraits : public TileFmhaTraits<kPadSeqLenQ_,
kPadSeqLenK_,
kPadHeadDimQ_,
kPadHeadDimV_,
kHasLogitsSoftCap_,
BiasEnum_,
kHasBiasGrad_,
kStoreLSE_,
kHasDropout_,
QScaleEnum_,
kBlockPerCu_,
kSkipMinSeqlenQ_,
false>
{
static constexpr auto kKVMemoryLayout = kKVMemoryLayout_;
static constexpr auto kKVLookupTable = kKVLookupTable_;
static constexpr index_t kPageBlockSize = kPageBlockSize_;
static_assert(kKVMemoryLayout == BlockAttentionKVCacheMemoryLayoutEnum::VECTORIZED_LAYOUT ||
kKVMemoryLayout == BlockAttentionKVCacheMemoryLayoutEnum::LINEAR_LAYOUT,
"Batch prefill only supports vectorized or linear KV cache layout.");
static_assert(kPageBlockSize > 0 && ((kPageBlockSize & (kPageBlockSize - 1)) == 0),
"kPageBlockSize should be a power of 2 to support efficient page-based KV cache "
"addressing.");
};
template <index_t kPadHeadDimQ_ /* paddding for hdim_q */,
index_t kPadHeadDimV_ /* paddding for hdim_v */,
BlockAttentionBiasEnum BiasEnum_,