diff --git a/CHANGELOG.md b/CHANGELOG.md index 6a9b25b062..ce8e5197a8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,6 +11,7 @@ Documentation for Composable Kernel available at [https://rocm.docs.amd.com/proj * Added support for microscaling (MX) FP8/FP4 mixed data types to Flatmm pipeline. * Added support for fp8 dynamic tensor-wise quantization of fp8 fmha fwd kernel. * Added FP8 KV cache support for FMHA batch prefill. +* Added FMHA batch prefill kernel support for several KV cache layouts, flexible page sizes, and different lookup table configurations. ### Changed diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_batch_prefill.py b/example/ck_tile/01_fmha/codegen/ops/fmha_batch_prefill.py index 95e8379769..c4c70009d5 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_batch_prefill.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_batch_prefill.py @@ -36,6 +36,19 @@ DTYPE_BITS = { K0_MAX_SUBMAX_MAP = {32: 32, 64: 64, 96: 128, 128: 128, 256: 256} +SUPPORTED_PAGE_SIZE = [128, 256, 1024] +SUPPORTED_KV_MEMORY_LAYOUT = ["vectorized", "linear"] +SUPPORTED_KV_LOOKUP_TABLE = ["vllm", "sglang"] +KV_MEMORY_LAYOUT_ENUM_MAP = { + "vectorized": "ck_tile::BlockAttentionKVCacheMemoryLayoutEnum::VECTORIZED_LAYOUT", + "linear": "ck_tile::BlockAttentionKVCacheMemoryLayoutEnum::LINEAR_LAYOUT", +} +KV_LOOKUP_TABLE_ENUM_MAP = { + "vllm": "ck_tile::BlockAttentionKVCacheLookupTableEnum::VLLM_BLOCK_TABLE_2D", + "sglang": "ck_tile::BlockAttentionKVCacheLookupTableEnum::SGLANG_PAGE_TABLE_1D", +} + + FMHA_BATCH_PREFILL_PIPELINE_MAP = { "qr_async": "ck_tile::BlockFmhaBatchPrefillPipelineQRKSVSAsync", } @@ -59,7 +72,7 @@ using fmha_shape_{F_idx} = ck_tile::TileFmhaShape, {F_vlayout}>; -using fmha_trait_{F_idx} = ck_tile::TileFmhaTraits<{F_spad}, +using fmha_trait_{F_idx} = ck_tile::TileFmhaBatchPrefillTraits<{F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, @@ -69,13 +82,17 @@ using fmha_trait_{F_idx} = ck_tile::TileFmhaTraits<{F_spad}, {F_lse}, {F_dropout}, {F_qscale}, - {F_occupancy}>; + {F_occupancy}, + false, + {F_page_size}, + {F_kv_memory_layout}, + {F_kv_lookup_table}>; using fmha_variant_{F_idx} = ck_tile::ComposedAttention<{F_logits} * ck_tile::LOGITS_SOFT_CAP, CK_TILE_FMHA_FWD_FAST_EXP2>; using fmha_mask_{F_idx} = {F_mask}; -using fmha_pipeline_problem_{F_idx} = ck_tile::BlockFmhaPipelineProblem< +using fmha_pipeline_problem_{F_idx} = ck_tile::BlockFmhaBatchPrefillPipelineProblem< typename FmhaFwdTypeConfig::QDataType, typename FmhaFwdTypeConfig::KDataType, typename FmhaFwdTypeConfig::VDataType, @@ -92,6 +109,7 @@ using fmha_pipeline_problem_{F_idx} = ck_tile::BlockFmhaPipelineProblem< fmha_variant_{F_idx}, fmha_mask_{F_idx}, false, + {F_page_size}, fmha_trait_{F_idx}>; using fmha_pipeline_{F_idx} = {F_pipeline}< @@ -105,8 +123,8 @@ using fmha_epilogue_{F_idx} = using fmha_kernel_{F_idx} = ck_tile::FmhaBatchPrefillWithPagedKVCacheKernel; -using trait_{F_idx} = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode},{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, - {F_pipeline_enum}, {F_logits}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_dropout}, {F_qscale}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, false>; +using trait_{F_idx} = fmha_fwd_batch_prefill_traits_<{F_hdim}, {F_dtype}, {F_mode},{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, + {F_pipeline_enum}, {F_logits}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_dropout}, {F_qscale}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, false, false, {F_page_size}, {F_kv_memory_layout}, {F_kv_lookup_table}>; #include @@ -184,8 +202,8 @@ FMHA_FWD_API_PER_HDIM_CASE = """ {F_if} (t.hdim_q <= {F_hdim} && t.hdim_v """ FMHA_FWD_API_INNER_DISPATCH = """ {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && (t.has_logits_soft_cap == {F_logits}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.has_dropout == {F_dropout}) && (t.qscale_type == {F_qscale_check}) && - ({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck}) && ({F_constraint})) {{ - using trait_ = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, {F_lse}, {F_dropout}, {F_qscale}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, false>; + ({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck}) && ({F_constraint}) && (t.kv_memory_layout == {F_kv_memory_layout}) && (t.kv_lookup_table == {F_kv_lookup_table}) && (t.page_size == {F_page_size})) {{ + using trait_ = fmha_fwd_batch_prefill_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, {F_lse}, {F_dropout}, {F_qscale}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, false, false, {F_page_size}, {F_kv_memory_layout}, {F_kv_lookup_table}>; return fmha_batch_prefill_(s, a); }} """ @@ -230,12 +248,15 @@ class FmhaFwdApiTrait: dpad: str dvpad: str constraint: CppConstraint + kv_memory_layout: str + kv_lookup_table: str + page_size: int = 1 # page block size @property def name(self) -> str: return ( f"{self.hdim}-{self.dtype}-{self.mode}-{self.bm0}-{self.bn0}-{self.bk0}-{self.bn0}-{self.bk1}-{self.bk0max}-" - + f"{self.vlayout}-{self.logits}-{self.mask}-{self.bias}-{self.lse}-{self.dropout}-{self.qscale}-{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}" + + f"{self.vlayout}-{self.logits}-{self.mask}-{self.bias}-{self.lse}-{self.dropout}-{self.qscale}-{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}-{self.kv_memory_layout}-{self.kv_lookup_table}-ps{self.page_size}" ) @property @@ -322,6 +343,8 @@ class FmhaFwdPipeline: F_dropout: str # F_qscale: str # no/pertensor F_mask: str # value from MASK_MAP + F_kv_memory_layout: str # + F_kv_lookup_table: str # F_constraint: CppConstraint = field(default_factory=lambda: CppConstraint()) @property @@ -382,6 +405,8 @@ class FmhaFwdPipeline: n += f"_{self.F_qscale}" else: n += "_nqscale" + + n += "_" + self.F_kv_memory_layout + "_" + self.F_kv_lookup_table return n @@ -440,6 +465,13 @@ class FmhaFwdApiPool: F_bk0max=trait.bk0max, F_hdim=hdim, F_dtype=FWD_DTYPE_MAP[dtype], + F_kv_memory_layout=KV_MEMORY_LAYOUT_ENUM_MAP[ + trait.kv_memory_layout + ], + F_kv_lookup_table=KV_LOOKUP_TABLE_ENUM_MAP[ + trait.kv_lookup_table + ], + F_page_size=trait.page_size, ) if_j = "if" if j == 0 else "else if" per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format( @@ -497,6 +529,7 @@ class FmhaFwdKernel: F_tile: FmhaFwdTileSize F_pipeline: FmhaFwdPipeline mask_impl: str + F_page_size: int = 1 # page block size @property def template(self) -> str: @@ -534,17 +567,24 @@ class FmhaFwdKernel: F_dropout=BOOL_MAP[self.F_pipeline.F_dropout], F_qscale=QSCALE_MAP[self.F_pipeline.F_qscale], F_occupancy=self.F_tile.F_occupancy, + F_kv_memory_layout=KV_MEMORY_LAYOUT_ENUM_MAP[ + self.F_pipeline.F_kv_memory_layout + ], + F_kv_lookup_table=KV_LOOKUP_TABLE_ENUM_MAP[ + self.F_pipeline.F_kv_lookup_table + ], F_pipeline_enum=PIPELINE_ENUM_MAP[self.F_pipeline.tag], F_mask=get_mask_map(self.mask_impl)[self.F_pipeline.F_mask], F_mode=MODE_MAP[self.F_mode], F_pipeline=FMHA_BATCH_PREFILL_PIPELINE_MAP[self.F_pipeline.tag], + F_page_size=self.F_page_size, ) @property def name(self) -> str: # TODO: we don't encode idx here return ( - f"fmha_batch_prefill_d{self.F_hdim}_{self.F_dtype}_{self.F_mode}_" + f"fmha_batch_prefill_d{self.F_hdim}_{self.F_dtype}_{self.F_mode}_ps{self.F_page_size}_" + self.F_tile.name + "_" + self.F_pipeline.name @@ -578,6 +618,9 @@ class FmhaFwdKernel: dpad=self.F_pipeline.F_dpad, dvpad=self.F_pipeline.F_dvpad, constraint=self.F_tile.F_constraint & self.F_pipeline.F_constraint, + kv_memory_layout=self.F_pipeline.F_kv_memory_layout, + kv_lookup_table=self.F_pipeline.F_kv_lookup_table, + page_size=self.F_page_size, ) @@ -604,23 +647,42 @@ class KernelComponentFactory: pipelines = [] if dtype in ["fp16", "bf16"]: qscale = "no" - for logits, mask, bias, lse, dropout in itertools.product( + for ( + logits, + mask, + bias, + lse, + dropout, + kv_memory_layout, + kv_lookup_table, + ) in itertools.product( ["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"], ["t", "f"], + SUPPORTED_KV_MEMORY_LAYOUT, + SUPPORTED_KV_LOOKUP_TABLE, ): - pipelines.append(FmhaFwdPipeline("qr_async", "row", "t", "t", "t", "t", logits, bias, lse, dropout, qscale, mask)) # fmt: skip + pipelines.append(FmhaFwdPipeline("qr_async", "row", "t", "t", "t", "t", logits, bias, lse, dropout, qscale, mask, kv_memory_layout, kv_lookup_table)) # fmt: skip elif dtype in ["fp8bf16"]: # no need lse/dropout kernels - for logits, qscale, mask, bias in itertools.product( + for ( + logits, + qscale, + mask, + bias, + kv_memory_layout, + kv_lookup_table, + ) in itertools.product( ["t", "f"], ["pertensor"], get_mask_map(mask_impl).keys(), ["no"], + SUPPORTED_KV_MEMORY_LAYOUT, + SUPPORTED_KV_LOOKUP_TABLE, ): - pipelines.append(FmhaFwdPipeline("qr_async", "row", "t", "t", "t", "t", logits, bias, "f", "f", qscale, mask)) # fmt: skip + pipelines.append(FmhaFwdPipeline("qr_async", "row", "t", "t", "t", "t", logits, bias, "f", "f", qscale, mask, kv_memory_layout, kv_lookup_table)) # fmt: skip else: assert False return pipelines @@ -672,69 +734,73 @@ def get_fwd_blobs( or pipeline.F_logits == "f" ): continue - k = FmhaFwdKernel( - F_idx=0, - F_hdim=hdim, - F_dtype=dtype, - F_mode=mode, - F_tile=tile, - F_pipeline=pipeline, - mask_impl=mask_impl, - ) - if kernel_filter != "": - if not fnmatch.fnmatch(k.name, kernel_filter): - continue - if optdim_list != [-1]: - if hdim not in optdim_list: - continue - # 2 - Flash attention integration - if receipt in (2, 3): - cond = dtype in ["fp16", "bf16"] - cond &= pipeline.F_vlayout == "row" - cond &= pipeline.F_bias in ["no", "alibi"] - cond &= pipeline.F_qscale == "no" - if not cond: - continue - # PyTorch integration - elif receipt == 4: - cond = dtype in ["fp16", "bf16"] - cond &= pipeline.F_vlayout == "row" - cond &= pipeline.F_bias in ["no", "bias"] - cond &= pipeline.F_qscale == "no" - if not cond: - continue - # Aiter(mha_fwd) integration - elif receipt == 100: - cond = dtype in ["fp16", "bf16"] - cond &= mode == "batch" - cond &= pipeline.F_vlayout == "row" - cond &= pipeline.F_qscale == "no" - if not cond: - continue - # Aiter(mha_batch_prefill) integration - elif receipt == 200: - cond = dtype in ["fp16", "bf16", "fp8bf16"] - cond &= mode == "group" - cond &= pipeline.F_vlayout == "row" - if not cond: - continue - # aiter::mha_batch_prefill C++ api integration - elif receipt == 600: - cond = dtype in ["fp16", "bf16", "fp8bf16"] - cond &= mode == "group" - cond &= pipeline.F_vlayout == "row" - cond &= pipeline.F_qscale == "no" - if not cond: - continue - # fp32 only - if receipt == 800 or receipt == 801: - cond = dtype == "fp32" - if not cond: - continue + # Generate kernels for both page_size=16 and page_size=1024 + for page_size in SUPPORTED_PAGE_SIZE: + k = FmhaFwdKernel( + F_idx=0, + F_hdim=hdim, + F_dtype=dtype, + F_mode=mode, + F_tile=tile, + F_pipeline=pipeline, + mask_impl=mask_impl, + F_page_size=page_size, + ) + if kernel_filter != "": + if not fnmatch.fnmatch(k.name, kernel_filter): + continue + if optdim_list != [-1]: + if hdim not in optdim_list: + continue + # 2 - Flash attention integration + if receipt in (2, 3): + cond = dtype in ["fp16", "bf16"] + cond &= pipeline.F_vlayout == "row" + cond &= pipeline.F_bias in ["no", "alibi"] + cond &= pipeline.F_qscale == "no" + if not cond: + continue + # PyTorch integration + elif receipt == 4: + cond = dtype in ["fp16", "bf16"] + cond &= pipeline.F_vlayout == "row" + cond &= pipeline.F_bias in ["no", "bias"] + cond &= pipeline.F_qscale == "no" + if not cond: + continue + # Aiter(mha_fwd) integration + elif receipt == 100: + cond = dtype in ["fp16", "bf16"] + cond &= mode == "batch" + cond &= pipeline.F_vlayout == "row" + cond &= pipeline.F_qscale == "no" + if not cond: + continue + # Aiter(mha_batch_prefill) integration + elif receipt == 200: + cond = dtype in ["fp16", "bf16", "fp8bf16"] + cond &= mode == "group" + cond &= pipeline.F_vlayout == "row" + if not cond: + continue + # aiter::mha_batch_prefill C++ api integration + elif receipt == 600: + cond = dtype in ["fp16", "bf16", "fp8bf16"] + cond &= mode == "group" + cond &= pipeline.F_vlayout == "row" + cond &= pipeline.F_qscale == "no" + if not cond: + continue - api_pool.register_traits(k.api_trait()) - gen.append(k) + # fp32 only + if receipt == 800 or receipt == 801: + cond = dtype == "fp32" + if not cond: + continue + + api_pool.register_traits(k.api_trait()) + gen.append(k) return (api_pool, gen) diff --git a/example/ck_tile/01_fmha/fmha_fwd.hpp b/example/ck_tile/01_fmha/fmha_fwd.hpp index ba55d6d722..3ff4acfc15 100644 --- a/example/ck_tile/01_fmha/fmha_fwd.hpp +++ b/example/ck_tile/01_fmha/fmha_fwd.hpp @@ -529,14 +529,25 @@ struct fmha_batch_prefill_args ck_tile::index_t nhead_q; ck_tile::index_t nhead_k; - // SGLang-style page table - int32_t num_total_pages; - void* kv_indptr; - void* kv_page_indices; -#if 0 // we assume page_block_size=1 for now - void* kv_last_page_lens; - ck_tile::index_t page_block_size; -#endif + // KV cache page table fields (kv_lookup_table selects interpretation): + // - SGLANG_PAGE_TABLE_1D: + // kv_indptr: prefix-sum [batch+1] into kv_page_indices + // kv_page_indices: 1D list of physical page ids, length = num_total_pages + // kv_last_page_lens: per-batch last page lengths [batch] + // - VLLM_BLOCK_TABLE_2D: + // kv_page_indices: block_table [batch, max_blocks_per_seq] (2D) + // batch_stride_block_table: row stride for block_table + // seqlen_k_ptr: per-batch seqlen_k [batch] + int32_t num_total_pages; // total physical pages in KV cache (SGLang/vLLM) + ck_tile::index_t page_block_size; // tokens per page (SGLang/vLLM) + ck_tile::BlockAttentionKVCacheMemoryLayoutEnum + kv_memory_layout; // KV memory layout (SGLang/vLLM) + ck_tile::BlockAttentionKVCacheLookupTableEnum kv_lookup_table; // lookup table layout selector + void* kv_indptr; // SGLang: prefix-sum; vLLM: unused + void* kv_page_indices; // SGLang: 1D page list; vLLM: block_table 2D + void* kv_last_page_lens; // SGLang: last page lengths; vLLM: unused + void* seqlen_k_ptr; // vLLM: per-batch seqlen_k; SGLang: unused + ck_tile::index_t batch_stride_block_table; // vLLM: row stride; SGLang: unused float scale_s; float scale_p; @@ -1113,6 +1124,22 @@ template auto fmha_batch_prefill_create_kargs_and_grids(fmha_batch_prefill_args args) { assert(args.nhead_q % args.nhead_k == 0); + using PageTableKargs = typename FmhaKernel::PageBlockTableKargs; + const PageTableKargs page_table = [&]() { + if constexpr(FmhaKernel::kKVLookupTable == + ck_tile::BlockAttentionKVCacheLookupTableEnum::SGLANG_PAGE_TABLE_1D) + { + return PageTableKargs{reinterpret_cast(args.kv_indptr), + reinterpret_cast(args.kv_page_indices), + reinterpret_cast(args.kv_last_page_lens)}; + } + else + { + return PageTableKargs{reinterpret_cast(args.kv_page_indices), + args.batch_stride_block_table, + reinterpret_cast(args.seqlen_k_ptr)}; + } + }(); auto kargs = [&] { // create group mode kernel arguments if constexpr(FmhaKernel::kIsGroupMode) @@ -1133,12 +1160,8 @@ auto fmha_batch_prefill_create_kargs_and_grids(fmha_batch_prefill_args args) args.nhead_q, args.nhead_q / args.nhead_k, args.num_total_pages, - args.kv_indptr, - args.kv_page_indices, -#if 0 // we assume page_block_size=1 for now - args.kv_last_page_lens, args.page_block_size, -#endif + page_table, args.scale_s, args.scale_p, args.scale_o, @@ -1184,12 +1207,8 @@ auto fmha_batch_prefill_create_kargs_and_grids(fmha_batch_prefill_args args) args.nhead_q, args.nhead_q / args.nhead_k, args.num_total_pages, - args.kv_indptr, - args.kv_page_indices, -#if 0 // we assume page_block_size=1 for now - args.kv_last_page_lens, args.page_block_size, -#endif + page_table, args.scale_s, args.scale_p, args.scale_o, @@ -1281,6 +1300,65 @@ struct fmha_fwd_traits_ static constexpr bool kHasSink = kHasSink_; }; +template +struct fmha_fwd_batch_prefill_traits_ : public fmha_fwd_traits_ +{ + static constexpr auto kKVMemoryLayout = kKVMemoryLayout_; + static constexpr auto kKVLookupTable = kKVLookupTable_; + static constexpr ck_tile::index_t kPageBlockSize = kPageBlockSize_; + static_assert(kIsVLayoutRowMajor_, "Batch prefill only supports row-major V layout"); +}; + template float fmha_fwd_(const ck_tile::stream_config&, fmha_fwd_args); @@ -1527,7 +1605,15 @@ float fmha_fwd_appendkv(fmha_fwd_appendkv_traits, fmha_fwd_appendkv_args, const ck_tile::stream_config&); -using fmha_batch_prefill_traits = fmha_fwd_traits; +struct fmha_batch_prefill_traits : public fmha_fwd_traits +{ + ck_tile::BlockAttentionKVCacheMemoryLayoutEnum kv_memory_layout = + ck_tile::BlockAttentionKVCacheMemoryLayoutEnum::VECTORIZED_LAYOUT; + ck_tile::BlockAttentionKVCacheLookupTableEnum kv_lookup_table = + ck_tile::BlockAttentionKVCacheLookupTableEnum::SGLANG_PAGE_TABLE_1D; + int page_size = 1; +}; + float fmha_batch_prefill(fmha_batch_prefill_traits, fmha_batch_prefill_args, const ck_tile::stream_config&); diff --git a/include/ck_tile/ops/fmha.hpp b/include/ck_tile/ops/fmha.hpp index 20714397c9..eb4aa16d05 100644 --- a/include/ck_tile/ops/fmha.hpp +++ b/include/ck_tile/ops/fmha.hpp @@ -3,6 +3,7 @@ #pragma once #include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp" +#include "ck_tile/ops/fmha/block/block_attention_kvcache_layout_enum.hpp" #include "ck_tile/ops/fmha/block/block_attention_quant_scale_enum.hpp" #include "ck_tile/ops/fmha/block/block_dropout.hpp" #include "ck_tile/ops/fmha/block/block_masking.hpp" diff --git a/include/ck_tile/ops/fmha/block/block_attention_kvcache_layout_enum.hpp b/include/ck_tile/ops/fmha/block/block_attention_kvcache_layout_enum.hpp new file mode 100644 index 0000000000..c79e639469 --- /dev/null +++ b/include/ck_tile/ops/fmha/block/block_attention_kvcache_layout_enum.hpp @@ -0,0 +1,32 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +namespace ck_tile { + +// KV cache memory layout selector. +// +// Layout summary (kVectorSize = 16 / sizeof(KDataType)): +// - VECTORIZED_LAYOUT (swizzled): +// K: [NumBlocks, NumHeads, HeadDim/kVectorSize, PageSize, kVectorSize] +// V: [NumBlocks, NumHeads, PageSize/kVectorSize, HeadDim, kVectorSize] +// - LINEAR_LAYOUT: +// K: [NumBlocks, PageSize, NumHeads, HeadDim] +// V: [NumBlocks, PageSize, NumHeads, HeadDim] +enum class BlockAttentionKVCacheMemoryLayoutEnum +{ + VECTORIZED_LAYOUT = 0, + LINEAR_LAYOUT = 1, +}; + +// KV cache lookup table layout selector. +// - VLLM_BLOCK_TABLE_2D: block_table[batch, max_blocks_per_seq] +// - SGLANG_PAGE_TABLE_1D: kv_page_indices[kv_indptr[b] ... kv_indptr[b+1]) +enum class BlockAttentionKVCacheLookupTableEnum +{ + VLLM_BLOCK_TABLE_2D = 0, + SGLANG_PAGE_TABLE_1D = 1, +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/fmha/kernel/fmha_batch_prefill_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_batch_prefill_kernel.hpp index 73b6a329d1..9afd097eed 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_batch_prefill_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_batch_prefill_kernel.hpp @@ -6,6 +6,7 @@ #include "ck_tile/core.hpp" #include "ck_tile/ops/common.hpp" #include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp" +#include "ck_tile/ops/fmha/block/block_attention_kvcache_layout_enum.hpp" #include "ck_tile/ops/fmha/block/block_attention_quant_scale_enum.hpp" #include "ck_tile/ops/fmha/block/variants.hpp" @@ -56,12 +57,15 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel static constexpr bool kStoreLSE = FmhaPipeline::kStoreLSE; static constexpr bool kHasDropout = FmhaPipeline::kHasDropout; static constexpr auto QScaleEnum = FmhaPipeline::Problem::QScaleEnum; + static constexpr auto kKVMemoryLayout = FmhaPipeline::Problem::kKVMemoryLayout; + static constexpr auto kKVLookupTable = FmhaPipeline::Problem::kKVLookupTable; + static constexpr index_t kPageBlockSize = FmhaPipeline::kPageBlockSize; + static constexpr index_t kVectorSize = FmhaPipeline::kVectorSize; using AttentionVariant = ck_tile::remove_cvref_t; using FmhaMask = ck_tile::remove_cvref_t; static constexpr bool kHasMask = FmhaMask::IsMasking; static constexpr bool kUseAsyncCopy = FmhaPipeline::Policy::AsyncCopy; - template // to avoid duplicated base class prblem, introduce an template // arg struct FmhaFwdEmptyKargs @@ -71,6 +75,26 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel // kargs use aggregate initializer, so no constructor will provided // use inheritance to minimize karg size // user need to use MakeKargs() function to create kargs. + struct SglangPageTableKargs + { + const int32_t* kv_indptr; + const int32_t* kv_page_indices; + const int32_t* kv_last_page_lens; + }; + + struct VllmPageTableKargs + { + const int32_t* block_table_ptr; + ck_tile::index_t batch_stride_block_table; + const int32_t* seqlen_k_ptr; + }; + + using PageBlockTableKargs = + std::conditional_t; + struct FmhaFwdCommonKargs { const void* q_ptr; @@ -89,14 +113,8 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel ck_tile::index_t nhead_ratio_qk; int32_t num_total_pages; - const int32_t* kv_indptr; - const int32_t* kv_page_indices; -#if 0 // we assume page_block_size=1 for now - const int32_t* kv_last_page_lens; ck_tile::index_t page_block_size; -#else - static constexpr ck_tile::index_t page_block_size = 1; -#endif + PageBlockTableKargs page_table; float scale_s; @@ -295,12 +313,8 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel ck_tile::index_t num_head_q, ck_tile::index_t nhead_ratio_qk, int32_t num_total_pages, - const void* kv_indptr, - const void* kv_page_indices, -#if 0 // we assume page_block_size=1 for now - const void* kv_last_page_lens, ck_tile::index_t page_block_size, -#endif + const PageBlockTableKargs& page_table, float scale_s, [[maybe_unused]] float scale_p, [[maybe_unused]] float scale_o, @@ -345,12 +359,8 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel num_head_q, nhead_ratio_qk, num_total_pages, - reinterpret_cast(kv_indptr), - reinterpret_cast(kv_page_indices), -#if 0 // we assume page_block_size=1 for now - reinterpret_cast(kv_last_page_lens), page_block_size, -#endif + page_table, #if CK_TILE_FMHA_FWD_FAST_EXP2 static_cast(scale_s * ck_tile::log2e_v<>), #else @@ -453,12 +463,8 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel ck_tile::index_t num_head_q, ck_tile::index_t nhead_ratio_qk, int32_t num_total_pages, - const void* kv_indptr, - const void* kv_page_indices, -#if 0 // we assume page_block_size=1 for now - const void* kv_last_page_lens, ck_tile::index_t page_block_size, -#endif + const PageBlockTableKargs& page_table, float scale_s, [[maybe_unused]] float scale_p, [[maybe_unused]] float scale_o, @@ -498,12 +504,8 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel num_head_q, nhead_ratio_qk, num_total_pages, - reinterpret_cast(kv_indptr), - reinterpret_cast(kv_page_indices), -#if 0 // we assume page_block_size=1 for now - reinterpret_cast(kv_last_page_lens), page_block_size, -#endif + page_table, #if CK_TILE_FMHA_FWD_FAST_EXP2 static_cast(scale_s * ck_tile::log2e_v<>), #else @@ -700,10 +702,46 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel long_index_t batch_offset_lse = 0; long_index_t batch_offset_o = 0; - const int32_t num_page_blocks = kargs.kv_indptr[i_batch + 1] - kargs.kv_indptr[i_batch]; -#if 0 // we assume page_block_size=1 for now - const int32_t last_page_len = kargs.kv_last_page_lens[i_batch]; -#endif + const index_t seqlen_k = [&]() { + if constexpr(kKVLookupTable == + BlockAttentionKVCacheLookupTableEnum::SGLANG_PAGE_TABLE_1D) + { + const int32_t page_start = kargs.page_table.kv_indptr[i_batch]; + const int32_t page_end = kargs.page_table.kv_indptr[i_batch + 1]; + const int32_t num_page_blocks = page_end - page_start; + const int32_t last_page_len = [&]() { + if constexpr(kPageBlockSize == 1) + return static_cast(kPageBlockSize); + else + return kargs.page_table.kv_last_page_lens[i_batch]; + }(); + return num_page_blocks > 0 + ? static_cast((num_page_blocks - 1) * kargs.page_block_size + + last_page_len) + : 0; + } + else // BlockAttentionKVCacheLookupTableEnum::VLLM_BLOCK_TABLE_2D + { + if(kargs.page_table.seqlen_k_ptr != nullptr) + return static_cast(kargs.page_table.seqlen_k_ptr[i_batch]); + else + return kargs.seqlen_k; + } + }(); + const int32_t* page_idx = [&]() { + if constexpr(kKVLookupTable == + BlockAttentionKVCacheLookupTableEnum::SGLANG_PAGE_TABLE_1D) + { + return kargs.page_table.kv_page_indices + kargs.page_table.kv_indptr[i_batch]; + } + else // BlockAttentionKVCacheLookupTableEnum::VLLM_BLOCK_TABLE_2D + { + return kargs.page_table.block_table_ptr + + static_cast(i_batch) * + kargs.page_table.batch_stride_block_table; + } + }(); + if constexpr(kIsGroupMode) { // get starting offset for each batch @@ -711,8 +749,6 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel batch_offset_q = query_start * kargs.stride_q; - kargs.kv_page_indices += kargs.kv_indptr[i_batch]; - if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) { batch_offset_bias = query_start * kargs.stride_bias; @@ -737,18 +773,12 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel return; } -#if 0 // we assume page_block_size=1 for now - kargs.seqlen_k = (num_page_blocks - 1) * kargs.page_block_size + last_page_len; -#else - kargs.seqlen_k = num_page_blocks; -#endif + kargs.seqlen_k = seqlen_k; } else { batch_offset_q = static_cast(i_batch) * kargs.batch_stride_q; - kargs.kv_page_indices += kargs.kv_indptr[i_batch]; - if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) { batch_offset_bias = static_cast(i_batch) * kargs.batch_stride_bias; @@ -764,11 +794,7 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel } batch_offset_o = static_cast(i_batch) * kargs.batch_stride_o; -#if 0 // we assume page_block_size=1 for now - kargs.seqlen_k = (num_page_blocks - 1) * kargs.page_block_size + last_page_len; -#else - kargs.seqlen_k = num_page_blocks; -#endif + kargs.seqlen_k = seqlen_k; } // for simplicity, batch stride we just modify the pointer @@ -809,60 +835,137 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel } }(); const auto k_dram = [&]() { - const auto k_dram_naive = make_naive_tensor_view( - k_ptr, - make_tuple(kargs.num_total_pages * kargs.page_block_size, kargs.hdim_q), - make_tuple(kargs.stride_k, 1), - number{}, - number<1>{}); - - constexpr bool kPadSeqLenK_ = kUseAsyncCopy ? kPadSeqLenK : true; - return pad_tensor_view( - k_dram_naive, - make_tuple(number{}, number{}), - sequence{}); - }(); - const auto v_dram = [&]() { - if constexpr(std::is_same_v) + if constexpr(kKVMemoryLayout == + BlockAttentionKVCacheMemoryLayoutEnum::VECTORIZED_LAYOUT) { - const auto v_dram_naive = make_naive_tensor_view( - v_ptr, - make_tuple(kargs.num_total_pages * kargs.page_block_size, kargs.hdim_v), - make_tuple(kargs.stride_v, 1), - number{}, + // Vectorized K Layout: [NumPages, D/kVectorSize, S, kVectorSize] + // Logical View for Pipeline: (TotalSeqK, D) + + // Define the naive physical view with 4D shape: (NumPages, HeadDim/kVectorSize, + // PageBlockSize, kVectorSize) + // Strides: (BatchStride, PageBlockSize*kVectorSize, kVectorSize, 1) + const auto k_dram_naive = make_naive_tensor_view( + k_ptr, + make_tuple(kargs.num_total_pages, + kargs.hdim_q / kVectorSize, + kargs.page_block_size, + kVectorSize), + make_tuple( + kargs.batch_stride_k, kargs.page_block_size * kVectorSize, kVectorSize, 1), + number{}, number<1>{}); - const auto v_dram_transposed = transform_tensor_view( - v_dram_naive, - make_tuple( - make_pass_through_transform(kargs.hdim_v), - make_pass_through_transform(kargs.num_total_pages * kargs.page_block_size)), - make_tuple(sequence<1>{}, sequence<0>{}), + // Merge to (TotalSeqK, D) in a single transform: + // physical (Page, D/vec, S, vec) -> logical (TotalSeqK, D) + auto k_dram_2d = transform_tensor_view( + k_dram_naive, + make_tuple(make_merge_transform(make_tuple(kargs.num_total_pages, + kargs.page_block_size)), // TotalSeqK + make_merge_transform( + make_tuple(static_cast(kargs.hdim_q / kVectorSize), + static_cast(kVectorSize)))), // D + make_tuple(sequence<0, 2>{}, sequence<1, 3>{}), make_tuple(sequence<0>{}, sequence<1>{})); constexpr bool kPadSeqLenK_ = kUseAsyncCopy ? kPadSeqLenK : true; return pad_tensor_view( - v_dram_transposed, + k_dram_2d, + make_tuple(number{}, number{}), + sequence{}); + } + else + { + // Linear K Layout: [NumPages, PageSize, NumHeads, HeadDim] + // Logical View for Pipeline: (TotalSeqK, D) + const auto k_dram_naive = make_naive_tensor_view( + k_ptr, + make_tuple(kargs.num_total_pages, kargs.page_block_size, kargs.hdim_q), + make_tuple(kargs.batch_stride_k, kargs.stride_k, 1), + number{}, + number<1>{}); + + // Merge to (TotalSeqK, D) in a single transform: + // physical (Page, S, D) -> logical (TotalSeqK, D) + auto k_dram_2d = transform_tensor_view( + k_dram_naive, + make_tuple(make_merge_transform( + make_tuple(kargs.num_total_pages, kargs.page_block_size)), + make_pass_through_transform(kargs.hdim_q)), + make_tuple(sequence<0, 1>{}, sequence<2>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + + constexpr bool kPadSeqLenK_ = kUseAsyncCopy ? kPadSeqLenK : true; + return pad_tensor_view( + k_dram_2d, + make_tuple(number{}, number{}), + sequence{}); + } + }(); + const auto v_dram = [&]() { + if constexpr(kKVMemoryLayout == + BlockAttentionKVCacheMemoryLayoutEnum::VECTORIZED_LAYOUT) + { + // Vectorized V Layout: [NumPages, S/kVectorSize, D, kVectorSize] + // Logical View for Pipeline: (D, TotalSeqK) - Transposed for GEMM + + // Define the naive physical view with 4D shape: (NumPages, + // PageBlockSize/kVectorSize, HeadDim, kVectorSize) + // Strides: (BatchStride, HeadDim*kVectorSize, kVectorSize, 1) + const auto v_dram_naive = make_naive_tensor_view( + v_ptr, + make_tuple(kargs.num_total_pages, + kargs.page_block_size / kVectorSize, + kargs.hdim_v, + kVectorSize), + make_tuple(kargs.batch_stride_v, kargs.hdim_v * kVectorSize, kVectorSize, 1), + number{}, + number<1>{}); + + // Merge to (D, TotalSeqK) in a single transform: + // physical (Page, S/vec, D, vec) -> logical (D, TotalSeqK) + auto v_dram_final = transform_tensor_view( + v_dram_naive, + make_tuple(make_pass_through_transform(kargs.hdim_v), // D + make_merge_transform(make_tuple(kargs.num_total_pages, + kargs.page_block_size / kVectorSize, + kVectorSize))), // TotalSeqK + make_tuple(sequence<2>{}, sequence<0, 1, 3>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + + constexpr bool kPadSeqLenK_ = kUseAsyncCopy ? kPadSeqLenK : true; + return pad_tensor_view( + v_dram_final, make_tuple(number{}, number{}), sequence{}); } else { + // Linear V Layout: [NumPages, PageSize, NumHeads, HeadDim] + // Logical View for Pipeline: (D, TotalSeqK) const auto v_dram_naive = make_naive_tensor_view( v_ptr, - make_tuple(kargs.hdim_v, kargs.num_total_pages * kargs.page_block_size), - make_tuple(kargs.stride_v, 1), + make_tuple(kargs.num_total_pages, kargs.page_block_size, kargs.hdim_v), + make_tuple(kargs.batch_stride_v, kargs.stride_v, 1), number{}, number<1>{}); - constexpr bool kPadHeadDimV_ = kUseAsyncCopy ? kPadHeadDimV : false; - return pad_tensor_view( + // Merge to (D, TotalSeqK) in a single transform: + // physical (Page, S, D) -> logical (D, TotalSeqK) + auto v_dram_final = transform_tensor_view( v_dram_naive, + make_tuple(make_pass_through_transform(kargs.hdim_v), + make_merge_transform( + make_tuple(kargs.num_total_pages, kargs.page_block_size))), + make_tuple(sequence<2>{}, sequence<0, 1>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + + constexpr bool kPadSeqLenK_ = kUseAsyncCopy ? kPadSeqLenK : true; + return pad_tensor_view( + v_dram_final, make_tuple(number{}, number{}), - sequence{}); + sequence{}); } }(); - auto q_dram_window = make_tile_window( q_dram, [&]() { @@ -1070,6 +1173,15 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel BlockIndices block_indices{i_batch, i_nhead, i_nhead / kargs.nhead_ratio_qk}; + const index_t stride_k_for_pipeline = + kKVMemoryLayout == BlockAttentionKVCacheMemoryLayoutEnum::VECTORIZED_LAYOUT + ? kVectorSize + : kargs.stride_k; + const index_t stride_v_for_pipeline = + kKVMemoryLayout == BlockAttentionKVCacheMemoryLayoutEnum::VECTORIZED_LAYOUT + ? kargs.hdim_v + : kargs.stride_v; + auto o_acc_tile = [&] { if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::PERTENSOR) { @@ -1108,9 +1220,11 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel variant_params, block_indices, smem_ptr, - kargs.kv_page_indices, - kargs.stride_k, - kargs.stride_v, + page_idx, + stride_k_for_pipeline, + stride_v_for_pipeline, + kargs.batch_stride_k, + kargs.batch_stride_v, dropout); } else @@ -1128,9 +1242,11 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel variant_params, block_indices, smem_ptr, - kargs.kv_page_indices, - kargs.stride_k, - kargs.stride_v, + page_idx, + stride_k_for_pipeline, + stride_v_for_pipeline, + kargs.batch_stride_k, + kargs.batch_stride_v, dropout); } }(); diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp index 2102fe768f..0b47441995 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp @@ -6,12 +6,82 @@ #include "ck_tile/core.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" #include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp" +#include "ck_tile/ops/fmha/block/block_attention_kvcache_layout_enum.hpp" #include "ck_tile/ops/fmha/block/block_dropout.hpp" #include "ck_tile/ops/fmha/block/variants.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async_default_policy.hpp" #include "ck_tile/ops/reduce/block/block_reduce.hpp" namespace ck_tile { +template +CK_TILE_HOST_DEVICE void kv_offset_array_transform(const index_t* page_vec, + const index_t& stride_kv, + const index_t& page_stride_kv, + const CoordVecType& coord_vec, + OffsetVecType& kv_offset_vec, + index_t global_seq_offset = 0) +{ + const index_t& thread_coord_start = coord_vec[kCoordAxis]; + constexpr index_t kInPageOffsetMask = (1 << kLog2PageSize) - 1; + if constexpr(kIsKcache) + { + // for k offsets + static_for<0, kLoopCount, 1>{}([&](auto k0) { + const index_t global_token_idx = + global_seq_offset + thread_coord_start + kLoopStart + kLoopStride * k0.value; + const index_t page_id = global_token_idx >> kLog2PageSize; + const index_t page_offset = global_token_idx & kInPageOffsetMask; + kv_offset_vec[k0] = static_cast(page_vec[page_id]) * page_stride_kv + + static_cast(page_offset) * stride_kv; + }); + } + else + { + // for v offsets + const index_t lane0_start = __builtin_amdgcn_readfirstlane(thread_coord_start); + const index_t lane0_page_id = + (global_seq_offset + lane0_start + kLoopStart) >> kLog2PageSize; + + const long_index_t page_loc = + static_cast(page_vec[lane0_page_id]) * page_stride_kv; + + static_for<0, kLoopCount, 1>{}([&](auto k0) { + const index_t page_offset = + (global_seq_offset + thread_coord_start + kLoopStart + k0.value) & + kInPageOffsetMask; + + if constexpr(kKVMemoryLayout == + BlockAttentionKVCacheMemoryLayoutEnum::VECTORIZED_LAYOUT) + { + // Vectorized layout offset + // Layout: [BlockSize/kVectorSize, HeadDim, kVectorSize] + // Offset(s) = (s / kVectorSize) * (HeadDim * kVectorSize) + (s % kVectorSize) + const index_t s = page_offset; + const index_t D = stride_kv; + + const long_index_t s_offset = + static_cast((s / kVectorSize) * (D * kVectorSize)) + + (s % kVectorSize); + + kv_offset_vec[k0] = page_loc + s_offset; + } + else // BlockAttentionKVCacheMemoryLayoutEnum::LINEAR_LAYOUT + { + kv_offset_vec[k0] = page_loc + static_cast(page_offset) * stride_kv; + } + }); + } +} // a variation of qr/ks/vs, where we use async copy to load k (potentially v in the future) template {}; - static constexpr auto I1 = number<1>{}; - static constexpr auto I2 = number<2>{}; - static constexpr auto I3 = number<3>{}; + static constexpr index_t kM0 = BlockFmhaShape::kM0; + static constexpr index_t kN0 = BlockFmhaShape::kN0; + static constexpr index_t kK0 = BlockFmhaShape::kK0; + static constexpr index_t kN1 = BlockFmhaShape::kN1; + static constexpr index_t kK1 = BlockFmhaShape::kK1; + static constexpr index_t kQKHeaddim = BlockFmhaShape::kQKHeaddim; + static constexpr index_t kSubQKHeaddim = BlockFmhaShape::kSubQKHeaddim; + static constexpr index_t kPageBlockSize = Problem::kPageBlockSize; + static constexpr index_t kLog2PageSize = Problem::kLog2PageSize; + static constexpr index_t kVectorSize = Problem::kVectorSize; + static constexpr auto I0 = number<0>{}; + static constexpr auto I1 = number<1>{}; + static constexpr auto I2 = number<2>{}; + static constexpr auto I3 = number<3>{}; static_assert(kSubQKHeaddim <= 256, "hdim bigger than 256 is not suitable for this pipeline!"); + static_assert(kPageBlockSize % kN0 == 0, + "V offset assumes each tile stays within a page; kPageBlockSize must be " + "divisible by kN0."); static constexpr bool kIsGroupMode = Problem::kIsGroupMode; // TODO: seq_q always support padding, hdim_q/v support multiple of vector(like 8x) @@ -68,6 +144,7 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync static constexpr auto BiasEnum = Problem::BiasEnum; static constexpr bool kStoreLSE = Problem::kStoreLSE; static constexpr bool kHasDropout = Problem::kHasDropout; + static constexpr auto kKVMemoryLayout = Problem::kKVMemoryLayout; static_assert((CK_TILE_FMHA_FWD_FAST_EXP2 && (kHasLogitsSoftCap && Problem::BiasEnum == BlockAttentionBiasEnum::NO_BIAS || @@ -196,6 +273,8 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync const index_t* page_idx, const index_t stride_k, const index_t stride_v, + const index_t page_stride_k, + const index_t page_stride_v, DropoutType& dropout) const { static_assert( @@ -325,9 +404,20 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync using KDstrEncode = typename decltype(k_dist)::DstrEncode; constexpr index_t NRepeat = KDstrEncode::hs_lengthss_[I0][I0]; statically_indexed_array k_offsets; - static_for<0, NRepeat, 1>{}([&](auto n0) { - k_offsets[n0] = page_idx[k_coord[0] + kN0 / NRepeat * n0.value] * stride_k; - }); + index_t current_seq_k = seqlen_k_start; + kv_offset_array_transform, + decltype(k_coord), + 0, + kPageBlockSize, + kLog2PageSize, + 0, + NRepeat, + kN0 / NRepeat, + kKVMemoryLayout, + true, + kVectorSize>( + page_idx, stride_k, page_stride_k, k_coord, k_offsets, current_seq_k); + auto k_dram_window = make_tile_scatter_gather(k_dram_block_window.get_bottom_tensor_view(), k_dram_block_window.get_window_lengths(), k_dram_block_window.get_window_origin(), @@ -360,10 +450,18 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync using VDstrEncode = typename decltype(v_dist)::DstrEncode; constexpr index_t V_KRepeat = VDstrEncode::hs_lengthss_[I1][I3]; statically_indexed_array v_offsets; - (void)stride_k; - static_for<0, V_KRepeat, 1>{}([&](auto k0) { - v_offsets[k0] = page_idx[v_coord[VPageIndexDim] + k0.value] * stride_v; - }); + kv_offset_array_transform, + decltype(v_coord), + VPageIndexDim, + kPageBlockSize, + kLog2PageSize, + 0, + V_KRepeat, + 1, + kKVMemoryLayout, + false, + kVectorSize>( + page_idx, stride_v, page_stride_v, v_coord, v_offsets, current_seq_k); auto v_dram_window = make_tile_scatter_gather(v_dram_block_window_tmp.get_bottom_tensor_view(), @@ -425,13 +523,6 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync async_load_fence(); __builtin_amdgcn_s_barrier(); - const auto bias_tile = load_tile(bias_dram_window); // load bias tile - auto v_buf = load_tile(v_dram_window, number<-1>{}, bool_constant{}); - static_for<0, V_KRepeat, 1>{}([&](auto k0) { - v_offsets[k0] = page_idx[kK1 + v_coord[VPageIndexDim] + k0.value] * stride_v; - }); - v_dram_window.update_page_idx(v_offsets); - __builtin_amdgcn_sched_barrier(0); { // tail gemm_0( @@ -444,49 +535,67 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync } __builtin_amdgcn_sched_barrier(1); - // STAGE 2, scale_s, add bias, mask, softmax - if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) - { - s_acc = tile_elementwise_in(s_acc_element_func, s_acc); - tile_elementwise_inout([&scale_s](auto& x) { x = x * scale_s; }, s_acc); - tile_elementwise_inout( - [&](auto& x, const auto& y) { -#if !CK_TILE_FMHA_FWD_FAST_EXP2 - x += type_convert(bias_element_func(y)); -#else - x += log2e_v * - type_convert(bias_element_func(y)); -#endif - }, - s_acc, - bias_tile); - } - else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI) - { - const auto k_origin = k_dram_block_window.get_window_origin(); - constexpr auto s_spans = decltype(s_acc)::get_distributed_spans(); - s_acc = tile_elementwise_in(s_acc_element_func, s_acc); - sweep_tile_span(s_spans[number<0>{}], [&](auto idx0) { - sweep_tile_span(s_spans[number<1>{}], [&](auto idx1) { - const auto tile_idx = get_x_indices_from_distributed_indices( - s_acc.get_tile_distribution(), make_tuple(idx0, idx1)); + auto v_buf = load_tile(v_dram_window, number<-1>{}, bool_constant{}); + kv_offset_array_transform, + decltype(v_coord), + VPageIndexDim, + kPageBlockSize, + kLog2PageSize, + kK1, + V_KRepeat, + 1, + kKVMemoryLayout, + false, + kVectorSize>( + page_idx, stride_v, page_stride_v, v_coord, v_offsets, current_seq_k); + v_dram_window.update_page_idx(v_offsets); - const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{}); - const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{}); - constexpr auto i_j_idx = make_tuple(idx0, idx1); + const auto p = [&]() { + const auto bias_tile = load_tile(bias_dram_window); // load bias tile - s_acc(i_j_idx) *= scale_s; - position_encoding.update(s_acc(i_j_idx), row, col); - }); - }); - } - else - { - s_acc = tile_elementwise_in(s_acc_element_func, s_acc); - if constexpr(kHasLogitsSoftCap) + // STAGE 2, scale_s, add bias, mask, softmax + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) { - auto apply_logits_transform = - [&variant, &variant_params, &block_indices](auto& x) { + s_acc = tile_elementwise_in(s_acc_element_func, s_acc); + tile_elementwise_inout([&scale_s](auto& x) { x = x * scale_s; }, s_acc); + tile_elementwise_inout( + [&](auto& x, const auto& y) { +#if !CK_TILE_FMHA_FWD_FAST_EXP2 + x += type_convert(bias_element_func(y)); +#else + x += log2e_v * + type_convert(bias_element_func(y)); +#endif + }, + s_acc, + bias_tile); + } + else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI) + { + const auto k_origin = k_dram_block_window.get_window_origin(); + constexpr auto s_spans = decltype(s_acc)::get_distributed_spans(); + s_acc = tile_elementwise_in(s_acc_element_func, s_acc); + sweep_tile_span(s_spans[number<0>{}], [&](auto idx0) { + sweep_tile_span(s_spans[number<1>{}], [&](auto idx1) { + const auto tile_idx = get_x_indices_from_distributed_indices( + s_acc.get_tile_distribution(), make_tuple(idx0, idx1)); + + const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{}); + const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{}); + constexpr auto i_j_idx = make_tuple(idx0, idx1); + + s_acc(i_j_idx) *= scale_s; + position_encoding.update(s_acc(i_j_idx), row, col); + }); + }); + } + else + { + s_acc = tile_elementwise_in(s_acc_element_func, s_acc); + if constexpr(kHasLogitsSoftCap) + { + auto apply_logits_transform = [&variant, &variant_params, &block_indices]( + auto& x) { x = variant.LogitsTransform(variant_params, variant.QueryTransform(variant_params, x), block_indices.batch_idx, @@ -494,216 +603,229 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync block_indices.kv_head_idx); }; #if !CK_TILE_FMHA_FWD_FAST_EXP2 - for(index_t i = 0; i < s_acc.thread_buf_.size(); ++i) - { - apply_logits_transform(s_acc.thread_buf_[i]); - } + for(index_t i = 0; i < s_acc.thread_buf_.size(); ++i) + { + apply_logits_transform(s_acc.thread_buf_[i]); + } #else - for(index_t i = 0; i < s_acc.thread_buf_.size(); ++i) - { + for(index_t i = 0; i < s_acc.thread_buf_.size(); ++i) + { #if(defined(__gfx90a__) || defined(__gfx94__)) && \ (CK_TILE_ATTENTION_LOGITS_SOFT_CAP_DEFAULT == CK_TILE_ATTENTION_LOGITS_SOFT_CAP_SOFTSIGN && \ CK_TILE_ATTENTION_USE_SOFTSIGN_ASM) - // Avoid data hazard if v_mfma is followed by inline asm consumer - // instructions. In this case, compiler won't add s_nop for us - if(i == s_acc.thread_buf_.size() / 2) - { - __builtin_amdgcn_sched_barrier(0); + // Avoid data hazard if v_mfma is followed by inline asm consumer + // instructions. In this case, compiler won't add s_nop for us + if(i == s_acc.thread_buf_.size() / 2) + { + __builtin_amdgcn_sched_barrier(0); + } +#endif + apply_logits_transform(s_acc.thread_buf_[i]); } #endif - apply_logits_transform(s_acc.thread_buf_[i]); - } -#endif - } - else - { -#if !CK_TILE_FMHA_FWD_FAST_EXP2 - tile_elementwise_inout([&scale_s](auto& x) { x = x * scale_s; }, s_acc); -#endif - } - } - move_tile_window(bias_dram_window, {0, kN0}); - if constexpr(kPadSeqLenK || FmhaMask::IsMasking) - { - const auto k_origin = k_dram_block_window.get_window_origin(); - bool need_perpixel_check = mask.IsEdgeTile(q_origin.at(number<0>{}), - k_origin.at(number<0>{}), - number{}, - number{}); - - if(need_perpixel_check) - { - set_tile_if( - s_acc, -numeric::infinity(), [&](auto tile_idx) { - const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{}); - const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{}); - return !variant.LogitsMask(variant_params, - block_indices.batch_idx, - row, - col, - block_indices.qo_head_idx, - block_indices.kv_head_idx); - }); - } - } - - const auto s = cast_tile(s_acc); // S{j} - auto m_local = block_tile_reduce( - s, - sequence<1>{}, - f_max, - -numeric::infinity()); // m_local = rowmax(S{j}) - block_tile_reduce_sync(m_local, f_max, bool_constant{}); - - const auto m_old = m; // m{j-1} - tile_elementwise_inout( - [](auto& e0, auto e1, auto e2) { e0 = max(e1, e2); }, m, m_old, m_local); // m{j} - - auto p_compute = make_static_distributed_tensor( - s.get_tile_distribution()); // Pcompute{j} - - __builtin_amdgcn_sched_barrier(0x7F); - // store & prefetch next v, after the max reduction - if constexpr(std::is_same_v) - { - auto v_shuffle_tmp = make_static_distributed_tensor( - Policy::template MakeShuffledVRegBlockDescriptor()); - shuffle_tile(v_shuffle_tmp, v_buf); - - auto v_lds_window_tmp = - get_slice_tile(v_lds_window, - sequence<(LdsSeq.at(number{})) * kN1, 0>{}, - sequence<(LdsSeq.at(number{}) + 1) * kN1, kK1>{}); - - store_tile( - v_lds_window_tmp, - tile_elementwise_in(v_element_func, v_shuffle_tmp)); // store the prefetch - } - else - { - auto v_lds_window_tmp = - get_slice_tile(v_lds_window, - sequence<(LdsSeq.at(number{})) * kN1, 0>{}, - sequence<(LdsSeq.at(number{}) + 1) * kN1, kK1>{}); - store_tile(v_lds_window_tmp, - tile_elementwise_in(v_element_func, v_buf)); // store the prefetch - } - - if constexpr(k1_loops > 1) - { - move_tile_window( - v_dram_window, - {0, kK1}); // will have scratch if move this right after load_tile(v_dram)... - v_buf = load_tile( - v_dram_window, number<-1>{}, bool_constant{}); // load next v_buf - static_for<0, V_KRepeat, 1>{}([&](auto k0) { - v_offsets[k0] = - page_idx[kK1 * 2 + v_coord[VPageIndexDim] + k0.value] * stride_v; - }); - v_dram_window.update_page_idx(v_offsets); - } - __builtin_amdgcn_sched_barrier(0); - - static const auto get_validated_m = [](SMPLComputeDataType raw_m) { - /// NOTICE: bias might be materialized mask including -inf values, need - /// consideration. alibi does not have this problem - if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || - FmhaMask::IsMasking) - { - return raw_m == -numeric::infinity() - ? type_convert(0.f) - : raw_m; - } - else - { - return raw_m; - } - }; - - constexpr auto p_spans = decltype(p_compute)::get_distributed_spans(); - sweep_tile_span(p_spans[number<0>{}], [&](auto idx0) { - constexpr auto i_idx = make_tuple(idx0); -#if CK_TILE_FMHA_FWD_FAST_EXP2 - auto row_max = scale_s * get_validated_m(m[i_idx]); -#endif - sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) { - constexpr auto i_j_idx = make_tuple(idx0, idx1); -#if CK_TILE_FMHA_FWD_FAST_EXP2 - if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || - BiasEnum == BlockAttentionBiasEnum::ALIBI) - { - p_compute(i_j_idx) = exp2(s[i_j_idx] - get_validated_m(m[i_idx])); } else { - if constexpr(kHasLogitsSoftCap) +#if !CK_TILE_FMHA_FWD_FAST_EXP2 + tile_elementwise_inout([&scale_s](auto& x) { x = x * scale_s; }, s_acc); +#endif + } + } + move_tile_window(bias_dram_window, {0, kN0}); + if constexpr(kPadSeqLenK || FmhaMask::IsMasking) + { + const auto k_origin = k_dram_block_window.get_window_origin(); + bool need_perpixel_check = mask.IsEdgeTile(q_origin.at(number<0>{}), + k_origin.at(number<0>{}), + number{}, + number{}); + + if(need_perpixel_check) + { + set_tile_if( + s_acc, -numeric::infinity(), [&](auto tile_idx) { + const auto row = + q_origin.at(number<0>{}) + tile_idx.at(number<0>{}); + const auto col = + k_origin.at(number<0>{}) + tile_idx.at(number<1>{}); + return !variant.LogitsMask(variant_params, + block_indices.batch_idx, + row, + col, + block_indices.qo_head_idx, + block_indices.kv_head_idx); + }); + } + } + + const auto s = cast_tile(s_acc); // S{j} + auto m_local = block_tile_reduce( + s, + sequence<1>{}, + f_max, + -numeric::infinity()); // m_local = rowmax(S{j}) + block_tile_reduce_sync(m_local, f_max, bool_constant{}); + + const auto m_old = m; // m{j-1} + tile_elementwise_inout([](auto& e0, auto e1, auto e2) { e0 = max(e1, e2); }, + m, + m_old, + m_local); // m{j} + + auto p_compute = make_static_distributed_tensor( + s.get_tile_distribution()); // Pcompute{j} + + __builtin_amdgcn_sched_barrier(0x7F); + // store & prefetch next v, after the max reduction + if constexpr(std::is_same_v) + { + auto v_shuffle_tmp = make_static_distributed_tensor( + Policy::template MakeShuffledVRegBlockDescriptor()); + shuffle_tile(v_shuffle_tmp, v_buf); + + auto v_lds_window_tmp = + get_slice_tile(v_lds_window, + sequence<(LdsSeq.at(number{})) * kN1, 0>{}, + sequence<(LdsSeq.at(number{}) + 1) * kN1, kK1>{}); + + store_tile( + v_lds_window_tmp, + tile_elementwise_in(v_element_func, v_shuffle_tmp)); // store the prefetch + } + else + { + auto v_lds_window_tmp = + get_slice_tile(v_lds_window, + sequence<(LdsSeq.at(number{})) * kN1, 0>{}, + sequence<(LdsSeq.at(number{}) + 1) * kN1, kK1>{}); + store_tile(v_lds_window_tmp, + tile_elementwise_in(v_element_func, v_buf)); // store the prefetch + } + + if constexpr(k1_loops > 1) + { + move_tile_window( + v_dram_window, + {0, + kK1}); // will have scratch if move this right after load_tile(v_dram)... + v_buf = load_tile( + v_dram_window, number<-1>{}, bool_constant{}); // load next v_buf + kv_offset_array_transform, + decltype(v_coord), + VPageIndexDim, + kPageBlockSize, + kLog2PageSize, + 2 * kK1, + V_KRepeat, + 1, + kKVMemoryLayout, + false, + kVectorSize>( + page_idx, stride_v, page_stride_v, v_coord, v_offsets, current_seq_k); + v_dram_window.update_page_idx(v_offsets); + } + __builtin_amdgcn_sched_barrier(0); + + static const auto get_validated_m = [](SMPLComputeDataType raw_m) { + /// NOTICE: bias might be materialized mask including -inf values, need + /// consideration. alibi does not have this problem + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || + FmhaMask::IsMasking) + { + return raw_m == -numeric::infinity() + ? type_convert(0.f) + : raw_m; + } + else + { + return raw_m; + } + }; + + constexpr auto p_spans = decltype(p_compute)::get_distributed_spans(); + sweep_tile_span(p_spans[number<0>{}], [&](auto idx0) { + constexpr auto i_idx = make_tuple(idx0); +#if CK_TILE_FMHA_FWD_FAST_EXP2 + auto row_max = scale_s * get_validated_m(m[i_idx]); +#endif + sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) { + constexpr auto i_j_idx = make_tuple(idx0, idx1); +#if CK_TILE_FMHA_FWD_FAST_EXP2 + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || + BiasEnum == BlockAttentionBiasEnum::ALIBI) { p_compute(i_j_idx) = exp2(s[i_j_idx] - get_validated_m(m[i_idx])); } else { - p_compute(i_j_idx) = exp2(scale_s * s[i_j_idx] - row_max); + if constexpr(kHasLogitsSoftCap) + { + p_compute(i_j_idx) = exp2(s[i_j_idx] - get_validated_m(m[i_idx])); + } + else + { + p_compute(i_j_idx) = exp2(scale_s * s[i_j_idx] - row_max); + } } - } #else - p_compute(i_j_idx) = exp(s[i_j_idx] - get_validated_m(m[i_idx])); + p_compute(i_j_idx) = exp(s[i_j_idx] - get_validated_m(m[i_idx])); #endif + }); }); - }); - auto rowsum_p = block_tile_reduce( - p_compute, sequence<1>{}, f_sum, SMPLComputeDataType{0}); // rowsum(Pcompute{j}) + auto rowsum_p = block_tile_reduce( + p_compute, sequence<1>{}, f_sum, SMPLComputeDataType{0}); // rowsum(Pcompute{j}) - block_tile_reduce_sync(rowsum_p, f_sum, bool_constant{}); - // l{j}, Oacc{j} - constexpr auto o_spans = decltype(o_acc)::get_distributed_spans(); - sweep_tile_span(o_spans[number<0>{}], [&](auto idx0) { - constexpr auto i_idx = make_tuple(idx0); + block_tile_reduce_sync(rowsum_p, f_sum, bool_constant{}); + // l{j}, Oacc{j} + constexpr auto o_spans = decltype(o_acc)::get_distributed_spans(); + sweep_tile_span(o_spans[number<0>{}], [&](auto idx0) { + constexpr auto i_idx = make_tuple(idx0); #if CK_TILE_FMHA_FWD_FAST_EXP2 - const auto tmp = [&]() { - if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || - BiasEnum == BlockAttentionBiasEnum::ALIBI) - { - return exp2(m_old[i_idx] - get_validated_m(m[i_idx])); - } - else - { - if constexpr(kHasLogitsSoftCap) + const auto tmp = [&]() { + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || + BiasEnum == BlockAttentionBiasEnum::ALIBI) { return exp2(m_old[i_idx] - get_validated_m(m[i_idx])); } else { - auto row_max = scale_s * get_validated_m(m[i_idx]); - return exp2(scale_s * m_old[i_idx] - row_max); + if constexpr(kHasLogitsSoftCap) + { + return exp2(m_old[i_idx] - get_validated_m(m[i_idx])); + } + else + { + auto row_max = scale_s * get_validated_m(m[i_idx]); + return exp2(scale_s * m_old[i_idx] - row_max); + } } - } - }(); + }(); #else - const auto tmp = exp(m_old[i_idx] - get_validated_m(m[i_idx])); + const auto tmp = exp(m_old[i_idx] - get_validated_m(m[i_idx])); #endif - l(i_idx) = tmp * l[i_idx] + rowsum_p[i_idx]; - sweep_tile_span(o_spans[number<1>{}], [&](auto idx1) { - constexpr auto i_j_idx = make_tuple(idx0, idx1); - // FIXME: this use different equation from FA v2 paper, - // but produce correc result. - // Is the equation wrong? - o_acc(i_j_idx) *= tmp; + l(i_idx) = tmp * l[i_idx] + rowsum_p[i_idx]; + sweep_tile_span(o_spans[number<1>{}], [&](auto idx1) { + constexpr auto i_j_idx = make_tuple(idx0, idx1); + // FIXME: this use different equation from FA v2 paper, + // but produce correc result. + // Is the equation wrong? + o_acc(i_j_idx) *= tmp; + }); }); - }); - if constexpr(kHasDropout) - { - auto randval_ptr = - reinterpret_cast(smem_ptr) + Policy::template GetSmemSizeKV(); - dropout.template Run( - randval_ptr, - seqlen_k_start + i_total_loops * kN0, - p_compute, - randval_dram_window); - } + if constexpr(kHasDropout) + { + auto randval_ptr = reinterpret_cast(smem_ptr) + + Policy::template GetSmemSizeKV(); + dropout + .template Run( + randval_ptr, + seqlen_k_start + i_total_loops * kN0, + p_compute, + randval_dram_window); + } - const auto p = [&]() { #if CK_TILE_FMHA_FLOAT_TO_FLOAT16_RTN // For fp32 to fp16, // impl::cast_tile_pkrtz_fp16_fp32 would cause precision issue, @@ -727,11 +849,18 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync { v_buf = load_tile( v_dram_window, number<-1>{}, bool_constant{}); // load next v_buf - static_for<0, V_KRepeat, 1>{}([&](auto k0) { - v_offsets[k0] = page_idx[kK1 * 2 + i_k1.value * kK1 + - v_coord[VPageIndexDim] + k0.value] * - stride_v; - }); + kv_offset_array_transform, + decltype(v_coord), + VPageIndexDim, + kPageBlockSize, + kLog2PageSize, + (2 + i_k1.value) * kK1, + V_KRepeat, + 1, + kKVMemoryLayout, + false, + kVectorSize>( + page_idx, stride_v, page_stride_v, v_coord, v_offsets, current_seq_k); v_dram_window.update_page_idx(v_offsets); } block_sync_lds(); @@ -772,14 +901,23 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync i_total_loops++; if(i_total_loops < num_total_loop) { - page_idx += kN0; + current_seq_k += kN0; // move K tile windows move_tile_window(k_dram_block_window, {kN0, 0}); k_dram_window.set_window_origin(k_dram_block_window.get_window_origin()); - static_for<0, NRepeat, 1>{}([&](auto n0) { - k_offsets[n0] = page_idx[k_coord[0] + kN0 / NRepeat * n0.value] * stride_k; - }); + kv_offset_array_transform, + decltype(k_coord), + 0, + kPageBlockSize, + kLog2PageSize, + 0, + NRepeat, + kN0 / NRepeat, + kKVMemoryLayout, + true, + kVectorSize>( + page_idx, stride_k, page_stride_k, k_coord, k_offsets, current_seq_k); k_dram_window.update_page_idx(k_offsets); if constexpr(k1_loops >= 2 && LdsSeq.at(number<0>{}) == LdsSeq.at(number{})) @@ -887,6 +1025,8 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync const index_t* page_idx, const index_t stride_k, const index_t stride_v, + const index_t page_stride_k, + const index_t page_stride_v, DropoutType& dropout) const { return operator()(q_dram_block_window_tmp, @@ -913,6 +1053,8 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync page_idx, stride_k, stride_v, + page_stride_k, + page_stride_v, dropout); } }; diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp index a192e3f7b0..f9dc94bc65 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp @@ -4,6 +4,7 @@ #pragma once #include "ck_tile/core.hpp" +#include "ck_tile/ops/fmha/block/block_attention_kvcache_layout_enum.hpp" #include "ck_tile/ops/fmha/block/block_rotary_embedding.hpp" namespace ck_tile { @@ -65,6 +66,71 @@ struct BlockFmhaPipelineProblem static constexpr bool kHasSink = Traits::kHasSink; }; +template +struct BlockFmhaBatchPrefillPipelineProblem + : public BlockFmhaPipelineProblem +{ + static constexpr index_t kPageBlockSize = kPageBlockSize_; + static_assert(kPageBlockSize > 0, "kPageBlockSize must be positive"); + static_assert((kPageBlockSize & (kPageBlockSize - 1)) == 0, + "kPageBlockSize must be power of two"); + static constexpr index_t kLog2PageSize = []() constexpr { + index_t shift = 0; + index_t val = kPageBlockSize_; + while(val > 1) + { + val >>= 1; + shift++; + } + return shift; + }(); + + static constexpr index_t kVectorSize = 16 / sizeof(KDataType_); // Dwordx4 + static constexpr auto kKVMemoryLayout = Traits_::kKVMemoryLayout; + static constexpr auto kKVLookupTable = Traits_::kKVLookupTable; + static constexpr bool kIsVectorizedLayout = + kKVMemoryLayout == BlockAttentionKVCacheMemoryLayoutEnum::VECTORIZED_LAYOUT; + + static_assert(BlockFmhaShape_::kQKHeaddim % kVectorSize == 0, + "kQKHeaddim must be divisible by kVectorSize"); + static_assert(!kIsVectorizedLayout || kPageBlockSize % kVectorSize == 0, + "kPageBlockSize must be divisible by kVectorSize for vectorized layout"); + static_assert(kIsGroupMode_, "Batch prefill requires group mode"); +}; + template +struct TileFmhaBatchPrefillTraits : public TileFmhaTraits +{ + static constexpr auto kKVMemoryLayout = kKVMemoryLayout_; + static constexpr auto kKVLookupTable = kKVLookupTable_; + static constexpr index_t kPageBlockSize = kPageBlockSize_; + static_assert(kKVMemoryLayout == BlockAttentionKVCacheMemoryLayoutEnum::VECTORIZED_LAYOUT || + kKVMemoryLayout == BlockAttentionKVCacheMemoryLayoutEnum::LINEAR_LAYOUT, + "Batch prefill only supports vectorized or linear KV cache layout."); + static_assert(kPageBlockSize > 0 && ((kPageBlockSize & (kPageBlockSize - 1)) == 0), + "kPageBlockSize should be a power of 2 to support efficient page-based KV cache " + "addressing."); +}; + template