Merge commit 'cc75a1dc5f18613af29d8821375f79b0f3c6410b' into develop

This commit is contained in:
assistant-librarian[bot]
2026-01-05 11:13:10 +00:00
parent c5abac7854
commit 85642a59c2
9 changed files with 983 additions and 430 deletions

View File

@@ -11,6 +11,7 @@ Documentation for Composable Kernel available at [https://rocm.docs.amd.com/proj
* Added support for microscaling (MX) FP8/FP4 mixed data types to Flatmm pipeline.
* Added support for fp8 dynamic tensor-wise quantization of fp8 fmha fwd kernel.
* Added FP8 KV cache support for FMHA batch prefill.
* Added FMHA batch prefill kernel support for several KV cache layouts, flexible page sizes, and different lookup table configurations.
### Changed

View File

@@ -36,6 +36,19 @@ DTYPE_BITS = {
K0_MAX_SUBMAX_MAP = {32: 32, 64: 64, 96: 128, 128: 128, 256: 256}
SUPPORTED_PAGE_SIZE = [128, 256, 1024]
SUPPORTED_KV_MEMORY_LAYOUT = ["vectorized", "linear"]
SUPPORTED_KV_LOOKUP_TABLE = ["vllm", "sglang"]
KV_MEMORY_LAYOUT_ENUM_MAP = {
"vectorized": "ck_tile::BlockAttentionKVCacheMemoryLayoutEnum::VECTORIZED_LAYOUT",
"linear": "ck_tile::BlockAttentionKVCacheMemoryLayoutEnum::LINEAR_LAYOUT",
}
KV_LOOKUP_TABLE_ENUM_MAP = {
"vllm": "ck_tile::BlockAttentionKVCacheLookupTableEnum::VLLM_BLOCK_TABLE_2D",
"sglang": "ck_tile::BlockAttentionKVCacheLookupTableEnum::SGLANG_PAGE_TABLE_1D",
}
FMHA_BATCH_PREFILL_PIPELINE_MAP = {
"qr_async": "ck_tile::BlockFmhaBatchPrefillPipelineQRKSVSAsync",
}
@@ -59,7 +72,7 @@ using fmha_shape_{F_idx} = ck_tile::TileFmhaShape<fmha_block_tile_{F_idx},
ck_tile::sequence<{F_wm1}, {F_wn1}, {F_wk1}>,
{F_vlayout}>;
using fmha_trait_{F_idx} = ck_tile::TileFmhaTraits<{F_spad},
using fmha_trait_{F_idx} = ck_tile::TileFmhaBatchPrefillTraits<{F_spad},
{F_skpad},
{F_dpad},
{F_dvpad},
@@ -69,13 +82,17 @@ using fmha_trait_{F_idx} = ck_tile::TileFmhaTraits<{F_spad},
{F_lse},
{F_dropout},
{F_qscale},
{F_occupancy}>;
{F_occupancy},
false,
{F_page_size},
{F_kv_memory_layout},
{F_kv_lookup_table}>;
using fmha_variant_{F_idx} = ck_tile::ComposedAttention<{F_logits} * ck_tile::LOGITS_SOFT_CAP, CK_TILE_FMHA_FWD_FAST_EXP2>;
using fmha_mask_{F_idx} = {F_mask};
using fmha_pipeline_problem_{F_idx} = ck_tile::BlockFmhaPipelineProblem<
using fmha_pipeline_problem_{F_idx} = ck_tile::BlockFmhaBatchPrefillPipelineProblem<
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::QDataType,
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::KDataType,
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::VDataType,
@@ -92,6 +109,7 @@ using fmha_pipeline_problem_{F_idx} = ck_tile::BlockFmhaPipelineProblem<
fmha_variant_{F_idx},
fmha_mask_{F_idx},
false,
{F_page_size},
fmha_trait_{F_idx}>;
using fmha_pipeline_{F_idx} = {F_pipeline}<
@@ -105,8 +123,8 @@ using fmha_epilogue_{F_idx} =
using fmha_kernel_{F_idx} =
ck_tile::FmhaBatchPrefillWithPagedKVCacheKernel<fmha_pipeline_{F_idx}, fmha_epilogue_{F_idx}>;
using trait_{F_idx} = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode},{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout},
{F_pipeline_enum}, {F_logits}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_dropout}, {F_qscale}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, false>;
using trait_{F_idx} = fmha_fwd_batch_prefill_traits_<{F_hdim}, {F_dtype}, {F_mode},{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout},
{F_pipeline_enum}, {F_logits}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_dropout}, {F_qscale}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, false, false, {F_page_size}, {F_kv_memory_layout}, {F_kv_lookup_table}>;
#include <iostream>
@@ -184,8 +202,8 @@ FMHA_FWD_API_PER_HDIM_CASE = """ {F_if} (t.hdim_q <= {F_hdim} && t.hdim_v
"""
FMHA_FWD_API_INNER_DISPATCH = """ {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && (t.has_logits_soft_cap == {F_logits}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.has_dropout == {F_dropout}) && (t.qscale_type == {F_qscale_check}) &&
({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck}) && ({F_constraint})) {{
using trait_ = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, {F_lse}, {F_dropout}, {F_qscale}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, false>;
({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck}) && ({F_constraint}) && (t.kv_memory_layout == {F_kv_memory_layout}) && (t.kv_lookup_table == {F_kv_lookup_table}) && (t.page_size == {F_page_size})) {{
using trait_ = fmha_fwd_batch_prefill_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, {F_lse}, {F_dropout}, {F_qscale}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, false, false, {F_page_size}, {F_kv_memory_layout}, {F_kv_lookup_table}>;
return fmha_batch_prefill_<trait_>(s, a);
}}
"""
@@ -230,12 +248,15 @@ class FmhaFwdApiTrait:
dpad: str
dvpad: str
constraint: CppConstraint
kv_memory_layout: str
kv_lookup_table: str
page_size: int = 1 # page block size
@property
def name(self) -> str:
return (
f"{self.hdim}-{self.dtype}-{self.mode}-{self.bm0}-{self.bn0}-{self.bk0}-{self.bn0}-{self.bk1}-{self.bk0max}-"
+ f"{self.vlayout}-{self.logits}-{self.mask}-{self.bias}-{self.lse}-{self.dropout}-{self.qscale}-{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}"
+ f"{self.vlayout}-{self.logits}-{self.mask}-{self.bias}-{self.lse}-{self.dropout}-{self.qscale}-{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}-{self.kv_memory_layout}-{self.kv_lookup_table}-ps{self.page_size}"
)
@property
@@ -322,6 +343,8 @@ class FmhaFwdPipeline:
F_dropout: str #
F_qscale: str # no/pertensor
F_mask: str # value from MASK_MAP
F_kv_memory_layout: str #
F_kv_lookup_table: str #
F_constraint: CppConstraint = field(default_factory=lambda: CppConstraint())
@property
@@ -382,6 +405,8 @@ class FmhaFwdPipeline:
n += f"_{self.F_qscale}"
else:
n += "_nqscale"
n += "_" + self.F_kv_memory_layout + "_" + self.F_kv_lookup_table
return n
@@ -440,6 +465,13 @@ class FmhaFwdApiPool:
F_bk0max=trait.bk0max,
F_hdim=hdim,
F_dtype=FWD_DTYPE_MAP[dtype],
F_kv_memory_layout=KV_MEMORY_LAYOUT_ENUM_MAP[
trait.kv_memory_layout
],
F_kv_lookup_table=KV_LOOKUP_TABLE_ENUM_MAP[
trait.kv_lookup_table
],
F_page_size=trait.page_size,
)
if_j = "if" if j == 0 else "else if"
per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format(
@@ -497,6 +529,7 @@ class FmhaFwdKernel:
F_tile: FmhaFwdTileSize
F_pipeline: FmhaFwdPipeline
mask_impl: str
F_page_size: int = 1 # page block size
@property
def template(self) -> str:
@@ -534,17 +567,24 @@ class FmhaFwdKernel:
F_dropout=BOOL_MAP[self.F_pipeline.F_dropout],
F_qscale=QSCALE_MAP[self.F_pipeline.F_qscale],
F_occupancy=self.F_tile.F_occupancy,
F_kv_memory_layout=KV_MEMORY_LAYOUT_ENUM_MAP[
self.F_pipeline.F_kv_memory_layout
],
F_kv_lookup_table=KV_LOOKUP_TABLE_ENUM_MAP[
self.F_pipeline.F_kv_lookup_table
],
F_pipeline_enum=PIPELINE_ENUM_MAP[self.F_pipeline.tag],
F_mask=get_mask_map(self.mask_impl)[self.F_pipeline.F_mask],
F_mode=MODE_MAP[self.F_mode],
F_pipeline=FMHA_BATCH_PREFILL_PIPELINE_MAP[self.F_pipeline.tag],
F_page_size=self.F_page_size,
)
@property
def name(self) -> str:
# TODO: we don't encode idx here
return (
f"fmha_batch_prefill_d{self.F_hdim}_{self.F_dtype}_{self.F_mode}_"
f"fmha_batch_prefill_d{self.F_hdim}_{self.F_dtype}_{self.F_mode}_ps{self.F_page_size}_"
+ self.F_tile.name
+ "_"
+ self.F_pipeline.name
@@ -578,6 +618,9 @@ class FmhaFwdKernel:
dpad=self.F_pipeline.F_dpad,
dvpad=self.F_pipeline.F_dvpad,
constraint=self.F_tile.F_constraint & self.F_pipeline.F_constraint,
kv_memory_layout=self.F_pipeline.F_kv_memory_layout,
kv_lookup_table=self.F_pipeline.F_kv_lookup_table,
page_size=self.F_page_size,
)
@@ -604,23 +647,42 @@ class KernelComponentFactory:
pipelines = []
if dtype in ["fp16", "bf16"]:
qscale = "no"
for logits, mask, bias, lse, dropout in itertools.product(
for (
logits,
mask,
bias,
lse,
dropout,
kv_memory_layout,
kv_lookup_table,
) in itertools.product(
["t", "f"],
get_mask_map(mask_impl).keys(),
BIAS_MAP.keys(),
["t", "f"],
["t", "f"],
SUPPORTED_KV_MEMORY_LAYOUT,
SUPPORTED_KV_LOOKUP_TABLE,
):
pipelines.append(FmhaFwdPipeline("qr_async", "row", "t", "t", "t", "t", logits, bias, lse, dropout, qscale, mask)) # fmt: skip
pipelines.append(FmhaFwdPipeline("qr_async", "row", "t", "t", "t", "t", logits, bias, lse, dropout, qscale, mask, kv_memory_layout, kv_lookup_table)) # fmt: skip
elif dtype in ["fp8bf16"]:
# no need lse/dropout kernels
for logits, qscale, mask, bias in itertools.product(
for (
logits,
qscale,
mask,
bias,
kv_memory_layout,
kv_lookup_table,
) in itertools.product(
["t", "f"],
["pertensor"],
get_mask_map(mask_impl).keys(),
["no"],
SUPPORTED_KV_MEMORY_LAYOUT,
SUPPORTED_KV_LOOKUP_TABLE,
):
pipelines.append(FmhaFwdPipeline("qr_async", "row", "t", "t", "t", "t", logits, bias, "f", "f", qscale, mask)) # fmt: skip
pipelines.append(FmhaFwdPipeline("qr_async", "row", "t", "t", "t", "t", logits, bias, "f", "f", qscale, mask, kv_memory_layout, kv_lookup_table)) # fmt: skip
else:
assert False
return pipelines
@@ -672,69 +734,73 @@ def get_fwd_blobs(
or pipeline.F_logits == "f"
):
continue
k = FmhaFwdKernel(
F_idx=0,
F_hdim=hdim,
F_dtype=dtype,
F_mode=mode,
F_tile=tile,
F_pipeline=pipeline,
mask_impl=mask_impl,
)
if kernel_filter != "":
if not fnmatch.fnmatch(k.name, kernel_filter):
continue
if optdim_list != [-1]:
if hdim not in optdim_list:
continue
# 2 - Flash attention integration
if receipt in (2, 3):
cond = dtype in ["fp16", "bf16"]
cond &= pipeline.F_vlayout == "row"
cond &= pipeline.F_bias in ["no", "alibi"]
cond &= pipeline.F_qscale == "no"
if not cond:
continue
# PyTorch integration
elif receipt == 4:
cond = dtype in ["fp16", "bf16"]
cond &= pipeline.F_vlayout == "row"
cond &= pipeline.F_bias in ["no", "bias"]
cond &= pipeline.F_qscale == "no"
if not cond:
continue
# Aiter(mha_fwd) integration
elif receipt == 100:
cond = dtype in ["fp16", "bf16"]
cond &= mode == "batch"
cond &= pipeline.F_vlayout == "row"
cond &= pipeline.F_qscale == "no"
if not cond:
continue
# Aiter(mha_batch_prefill) integration
elif receipt == 200:
cond = dtype in ["fp16", "bf16", "fp8bf16"]
cond &= mode == "group"
cond &= pipeline.F_vlayout == "row"
if not cond:
continue
# aiter::mha_batch_prefill C++ api integration
elif receipt == 600:
cond = dtype in ["fp16", "bf16", "fp8bf16"]
cond &= mode == "group"
cond &= pipeline.F_vlayout == "row"
cond &= pipeline.F_qscale == "no"
if not cond:
continue
# fp32 only
if receipt == 800 or receipt == 801:
cond = dtype == "fp32"
if not cond:
continue
# Generate kernels for both page_size=16 and page_size=1024
for page_size in SUPPORTED_PAGE_SIZE:
k = FmhaFwdKernel(
F_idx=0,
F_hdim=hdim,
F_dtype=dtype,
F_mode=mode,
F_tile=tile,
F_pipeline=pipeline,
mask_impl=mask_impl,
F_page_size=page_size,
)
if kernel_filter != "":
if not fnmatch.fnmatch(k.name, kernel_filter):
continue
if optdim_list != [-1]:
if hdim not in optdim_list:
continue
# 2 - Flash attention integration
if receipt in (2, 3):
cond = dtype in ["fp16", "bf16"]
cond &= pipeline.F_vlayout == "row"
cond &= pipeline.F_bias in ["no", "alibi"]
cond &= pipeline.F_qscale == "no"
if not cond:
continue
# PyTorch integration
elif receipt == 4:
cond = dtype in ["fp16", "bf16"]
cond &= pipeline.F_vlayout == "row"
cond &= pipeline.F_bias in ["no", "bias"]
cond &= pipeline.F_qscale == "no"
if not cond:
continue
# Aiter(mha_fwd) integration
elif receipt == 100:
cond = dtype in ["fp16", "bf16"]
cond &= mode == "batch"
cond &= pipeline.F_vlayout == "row"
cond &= pipeline.F_qscale == "no"
if not cond:
continue
# Aiter(mha_batch_prefill) integration
elif receipt == 200:
cond = dtype in ["fp16", "bf16", "fp8bf16"]
cond &= mode == "group"
cond &= pipeline.F_vlayout == "row"
if not cond:
continue
# aiter::mha_batch_prefill C++ api integration
elif receipt == 600:
cond = dtype in ["fp16", "bf16", "fp8bf16"]
cond &= mode == "group"
cond &= pipeline.F_vlayout == "row"
cond &= pipeline.F_qscale == "no"
if not cond:
continue
api_pool.register_traits(k.api_trait())
gen.append(k)
# fp32 only
if receipt == 800 or receipt == 801:
cond = dtype == "fp32"
if not cond:
continue
api_pool.register_traits(k.api_trait())
gen.append(k)
return (api_pool, gen)

View File

@@ -529,14 +529,25 @@ struct fmha_batch_prefill_args
ck_tile::index_t nhead_q;
ck_tile::index_t nhead_k;
// SGLang-style page table
int32_t num_total_pages;
void* kv_indptr;
void* kv_page_indices;
#if 0 // we assume page_block_size=1 for now
void* kv_last_page_lens;
ck_tile::index_t page_block_size;
#endif
// KV cache page table fields (kv_lookup_table selects interpretation):
// - SGLANG_PAGE_TABLE_1D:
// kv_indptr: prefix-sum [batch+1] into kv_page_indices
// kv_page_indices: 1D list of physical page ids, length = num_total_pages
// kv_last_page_lens: per-batch last page lengths [batch]
// - VLLM_BLOCK_TABLE_2D:
// kv_page_indices: block_table [batch, max_blocks_per_seq] (2D)
// batch_stride_block_table: row stride for block_table
// seqlen_k_ptr: per-batch seqlen_k [batch]
int32_t num_total_pages; // total physical pages in KV cache (SGLang/vLLM)
ck_tile::index_t page_block_size; // tokens per page (SGLang/vLLM)
ck_tile::BlockAttentionKVCacheMemoryLayoutEnum
kv_memory_layout; // KV memory layout (SGLang/vLLM)
ck_tile::BlockAttentionKVCacheLookupTableEnum kv_lookup_table; // lookup table layout selector
void* kv_indptr; // SGLang: prefix-sum; vLLM: unused
void* kv_page_indices; // SGLang: 1D page list; vLLM: block_table 2D
void* kv_last_page_lens; // SGLang: last page lengths; vLLM: unused
void* seqlen_k_ptr; // vLLM: per-batch seqlen_k; SGLang: unused
ck_tile::index_t batch_stride_block_table; // vLLM: row stride; SGLang: unused
float scale_s;
float scale_p;
@@ -1113,6 +1124,22 @@ template <typename FmhaKernel>
auto fmha_batch_prefill_create_kargs_and_grids(fmha_batch_prefill_args args)
{
assert(args.nhead_q % args.nhead_k == 0);
using PageTableKargs = typename FmhaKernel::PageBlockTableKargs;
const PageTableKargs page_table = [&]() {
if constexpr(FmhaKernel::kKVLookupTable ==
ck_tile::BlockAttentionKVCacheLookupTableEnum::SGLANG_PAGE_TABLE_1D)
{
return PageTableKargs{reinterpret_cast<const int32_t*>(args.kv_indptr),
reinterpret_cast<const int32_t*>(args.kv_page_indices),
reinterpret_cast<const int32_t*>(args.kv_last_page_lens)};
}
else
{
return PageTableKargs{reinterpret_cast<const int32_t*>(args.kv_page_indices),
args.batch_stride_block_table,
reinterpret_cast<const int32_t*>(args.seqlen_k_ptr)};
}
}();
auto kargs = [&] {
// create group mode kernel arguments
if constexpr(FmhaKernel::kIsGroupMode)
@@ -1133,12 +1160,8 @@ auto fmha_batch_prefill_create_kargs_and_grids(fmha_batch_prefill_args args)
args.nhead_q,
args.nhead_q / args.nhead_k,
args.num_total_pages,
args.kv_indptr,
args.kv_page_indices,
#if 0 // we assume page_block_size=1 for now
args.kv_last_page_lens,
args.page_block_size,
#endif
page_table,
args.scale_s,
args.scale_p,
args.scale_o,
@@ -1184,12 +1207,8 @@ auto fmha_batch_prefill_create_kargs_and_grids(fmha_batch_prefill_args args)
args.nhead_q,
args.nhead_q / args.nhead_k,
args.num_total_pages,
args.kv_indptr,
args.kv_page_indices,
#if 0 // we assume page_block_size=1 for now
args.kv_last_page_lens,
args.page_block_size,
#endif
page_table,
args.scale_s,
args.scale_p,
args.scale_o,
@@ -1281,6 +1300,65 @@ struct fmha_fwd_traits_
static constexpr bool kHasSink = kHasSink_;
};
template <ck_tile::index_t HDim_,
typename DataType_,
bool kIsGroupMode_,
ck_tile::index_t kM0_,
ck_tile::index_t kN0_,
ck_tile::index_t kK0_,
ck_tile::index_t kN1_,
ck_tile::index_t kK1_,
ck_tile::index_t kK0BlockLength_,
bool kIsVLayoutRowMajor_,
ck_tile::BlockFmhaPipelineEnum FmhaPipelineEnum_,
bool kHasLogitsSoftCap_,
typename FmhaMask_,
ck_tile::BlockAttentionBiasEnum BiasEnum_,
bool kStoreLse_,
bool kHasDropout_,
ck_tile::BlockAttentionQuantScaleEnum QScaleEnum_,
bool kPadS_,
bool kPadSK_,
bool kPadD_,
bool kPadDv_,
bool kUseTrLoad_,
bool kSkipMinSeqlenQ_ = false,
ck_tile::index_t kPageBlockSize_ = 1,
ck_tile::BlockAttentionKVCacheMemoryLayoutEnum kKVMemoryLayout_ =
ck_tile::BlockAttentionKVCacheMemoryLayoutEnum::VECTORIZED_LAYOUT,
ck_tile::BlockAttentionKVCacheLookupTableEnum kKVLookupTable_ =
ck_tile::BlockAttentionKVCacheLookupTableEnum::SGLANG_PAGE_TABLE_1D>
struct fmha_fwd_batch_prefill_traits_ : public fmha_fwd_traits_<HDim_,
DataType_,
kIsGroupMode_,
kM0_,
kN0_,
kK0_,
kN1_,
kK1_,
kK0BlockLength_,
kIsVLayoutRowMajor_,
FmhaPipelineEnum_,
kHasLogitsSoftCap_,
FmhaMask_,
BiasEnum_,
kStoreLse_,
kHasDropout_,
QScaleEnum_,
kPadS_,
kPadSK_,
kPadD_,
kPadDv_,
kUseTrLoad_,
kSkipMinSeqlenQ_,
false>
{
static constexpr auto kKVMemoryLayout = kKVMemoryLayout_;
static constexpr auto kKVLookupTable = kKVLookupTable_;
static constexpr ck_tile::index_t kPageBlockSize = kPageBlockSize_;
static_assert(kIsVLayoutRowMajor_, "Batch prefill only supports row-major V layout");
};
template <typename Traits_, typename Arch = void>
float fmha_fwd_(const ck_tile::stream_config&, fmha_fwd_args);
@@ -1527,7 +1605,15 @@ float fmha_fwd_appendkv(fmha_fwd_appendkv_traits,
fmha_fwd_appendkv_args,
const ck_tile::stream_config&);
using fmha_batch_prefill_traits = fmha_fwd_traits;
struct fmha_batch_prefill_traits : public fmha_fwd_traits
{
ck_tile::BlockAttentionKVCacheMemoryLayoutEnum kv_memory_layout =
ck_tile::BlockAttentionKVCacheMemoryLayoutEnum::VECTORIZED_LAYOUT;
ck_tile::BlockAttentionKVCacheLookupTableEnum kv_lookup_table =
ck_tile::BlockAttentionKVCacheLookupTableEnum::SGLANG_PAGE_TABLE_1D;
int page_size = 1;
};
float fmha_batch_prefill(fmha_batch_prefill_traits,
fmha_batch_prefill_args,
const ck_tile::stream_config&);

View File

@@ -3,6 +3,7 @@
#pragma once
#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp"
#include "ck_tile/ops/fmha/block/block_attention_kvcache_layout_enum.hpp"
#include "ck_tile/ops/fmha/block/block_attention_quant_scale_enum.hpp"
#include "ck_tile/ops/fmha/block/block_dropout.hpp"
#include "ck_tile/ops/fmha/block/block_masking.hpp"

View File

@@ -0,0 +1,32 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
namespace ck_tile {
// KV cache memory layout selector.
//
// Layout summary (kVectorSize = 16 / sizeof(KDataType)):
// - VECTORIZED_LAYOUT (swizzled):
// K: [NumBlocks, NumHeads, HeadDim/kVectorSize, PageSize, kVectorSize]
// V: [NumBlocks, NumHeads, PageSize/kVectorSize, HeadDim, kVectorSize]
// - LINEAR_LAYOUT:
// K: [NumBlocks, PageSize, NumHeads, HeadDim]
// V: [NumBlocks, PageSize, NumHeads, HeadDim]
enum class BlockAttentionKVCacheMemoryLayoutEnum
{
VECTORIZED_LAYOUT = 0,
LINEAR_LAYOUT = 1,
};
// KV cache lookup table layout selector.
// - VLLM_BLOCK_TABLE_2D: block_table[batch, max_blocks_per_seq]
// - SGLANG_PAGE_TABLE_1D: kv_page_indices[kv_indptr[b] ... kv_indptr[b+1])
enum class BlockAttentionKVCacheLookupTableEnum
{
VLLM_BLOCK_TABLE_2D = 0,
SGLANG_PAGE_TABLE_1D = 1,
};
} // namespace ck_tile

View File

@@ -6,6 +6,7 @@
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common.hpp"
#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp"
#include "ck_tile/ops/fmha/block/block_attention_kvcache_layout_enum.hpp"
#include "ck_tile/ops/fmha/block/block_attention_quant_scale_enum.hpp"
#include "ck_tile/ops/fmha/block/variants.hpp"
@@ -56,12 +57,15 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
static constexpr bool kStoreLSE = FmhaPipeline::kStoreLSE;
static constexpr bool kHasDropout = FmhaPipeline::kHasDropout;
static constexpr auto QScaleEnum = FmhaPipeline::Problem::QScaleEnum;
static constexpr auto kKVMemoryLayout = FmhaPipeline::Problem::kKVMemoryLayout;
static constexpr auto kKVLookupTable = FmhaPipeline::Problem::kKVLookupTable;
static constexpr index_t kPageBlockSize = FmhaPipeline::kPageBlockSize;
static constexpr index_t kVectorSize = FmhaPipeline::kVectorSize;
using AttentionVariant = ck_tile::remove_cvref_t<typename FmhaPipeline::AttentionVariant>;
using FmhaMask = ck_tile::remove_cvref_t<typename FmhaPipeline::FmhaMask>;
static constexpr bool kHasMask = FmhaMask::IsMasking;
static constexpr bool kUseAsyncCopy = FmhaPipeline::Policy::AsyncCopy;
template <ck_tile::index_t I> // to avoid duplicated base class prblem, introduce an template
// arg
struct FmhaFwdEmptyKargs
@@ -71,6 +75,26 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
// kargs use aggregate initializer, so no constructor will provided
// use inheritance to minimize karg size
// user need to use MakeKargs() function to create kargs.
struct SglangPageTableKargs
{
const int32_t* kv_indptr;
const int32_t* kv_page_indices;
const int32_t* kv_last_page_lens;
};
struct VllmPageTableKargs
{
const int32_t* block_table_ptr;
ck_tile::index_t batch_stride_block_table;
const int32_t* seqlen_k_ptr;
};
using PageBlockTableKargs =
std::conditional_t<kKVLookupTable ==
BlockAttentionKVCacheLookupTableEnum::SGLANG_PAGE_TABLE_1D,
SglangPageTableKargs,
VllmPageTableKargs>;
struct FmhaFwdCommonKargs
{
const void* q_ptr;
@@ -89,14 +113,8 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
ck_tile::index_t nhead_ratio_qk;
int32_t num_total_pages;
const int32_t* kv_indptr;
const int32_t* kv_page_indices;
#if 0 // we assume page_block_size=1 for now
const int32_t* kv_last_page_lens;
ck_tile::index_t page_block_size;
#else
static constexpr ck_tile::index_t page_block_size = 1;
#endif
PageBlockTableKargs page_table;
float scale_s;
@@ -295,12 +313,8 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
ck_tile::index_t num_head_q,
ck_tile::index_t nhead_ratio_qk,
int32_t num_total_pages,
const void* kv_indptr,
const void* kv_page_indices,
#if 0 // we assume page_block_size=1 for now
const void* kv_last_page_lens,
ck_tile::index_t page_block_size,
#endif
const PageBlockTableKargs& page_table,
float scale_s,
[[maybe_unused]] float scale_p,
[[maybe_unused]] float scale_o,
@@ -345,12 +359,8 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
num_head_q,
nhead_ratio_qk,
num_total_pages,
reinterpret_cast<const int32_t*>(kv_indptr),
reinterpret_cast<const int32_t*>(kv_page_indices),
#if 0 // we assume page_block_size=1 for now
reinterpret_cast<const int32_t*>(kv_last_page_lens),
page_block_size,
#endif
page_table,
#if CK_TILE_FMHA_FWD_FAST_EXP2
static_cast<float>(scale_s * ck_tile::log2e_v<>),
#else
@@ -453,12 +463,8 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
ck_tile::index_t num_head_q,
ck_tile::index_t nhead_ratio_qk,
int32_t num_total_pages,
const void* kv_indptr,
const void* kv_page_indices,
#if 0 // we assume page_block_size=1 for now
const void* kv_last_page_lens,
ck_tile::index_t page_block_size,
#endif
const PageBlockTableKargs& page_table,
float scale_s,
[[maybe_unused]] float scale_p,
[[maybe_unused]] float scale_o,
@@ -498,12 +504,8 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
num_head_q,
nhead_ratio_qk,
num_total_pages,
reinterpret_cast<const int32_t*>(kv_indptr),
reinterpret_cast<const int32_t*>(kv_page_indices),
#if 0 // we assume page_block_size=1 for now
reinterpret_cast<const int32_t*>(kv_last_page_lens),
page_block_size,
#endif
page_table,
#if CK_TILE_FMHA_FWD_FAST_EXP2
static_cast<float>(scale_s * ck_tile::log2e_v<>),
#else
@@ -700,10 +702,46 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
long_index_t batch_offset_lse = 0;
long_index_t batch_offset_o = 0;
const int32_t num_page_blocks = kargs.kv_indptr[i_batch + 1] - kargs.kv_indptr[i_batch];
#if 0 // we assume page_block_size=1 for now
const int32_t last_page_len = kargs.kv_last_page_lens[i_batch];
#endif
const index_t seqlen_k = [&]() {
if constexpr(kKVLookupTable ==
BlockAttentionKVCacheLookupTableEnum::SGLANG_PAGE_TABLE_1D)
{
const int32_t page_start = kargs.page_table.kv_indptr[i_batch];
const int32_t page_end = kargs.page_table.kv_indptr[i_batch + 1];
const int32_t num_page_blocks = page_end - page_start;
const int32_t last_page_len = [&]() {
if constexpr(kPageBlockSize == 1)
return static_cast<int32_t>(kPageBlockSize);
else
return kargs.page_table.kv_last_page_lens[i_batch];
}();
return num_page_blocks > 0
? static_cast<index_t>((num_page_blocks - 1) * kargs.page_block_size +
last_page_len)
: 0;
}
else // BlockAttentionKVCacheLookupTableEnum::VLLM_BLOCK_TABLE_2D
{
if(kargs.page_table.seqlen_k_ptr != nullptr)
return static_cast<index_t>(kargs.page_table.seqlen_k_ptr[i_batch]);
else
return kargs.seqlen_k;
}
}();
const int32_t* page_idx = [&]() {
if constexpr(kKVLookupTable ==
BlockAttentionKVCacheLookupTableEnum::SGLANG_PAGE_TABLE_1D)
{
return kargs.page_table.kv_page_indices + kargs.page_table.kv_indptr[i_batch];
}
else // BlockAttentionKVCacheLookupTableEnum::VLLM_BLOCK_TABLE_2D
{
return kargs.page_table.block_table_ptr +
static_cast<long_index_t>(i_batch) *
kargs.page_table.batch_stride_block_table;
}
}();
if constexpr(kIsGroupMode)
{
// get starting offset for each batch
@@ -711,8 +749,6 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
batch_offset_q = query_start * kargs.stride_q;
kargs.kv_page_indices += kargs.kv_indptr[i_batch];
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
{
batch_offset_bias = query_start * kargs.stride_bias;
@@ -737,18 +773,12 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
return;
}
#if 0 // we assume page_block_size=1 for now
kargs.seqlen_k = (num_page_blocks - 1) * kargs.page_block_size + last_page_len;
#else
kargs.seqlen_k = num_page_blocks;
#endif
kargs.seqlen_k = seqlen_k;
}
else
{
batch_offset_q = static_cast<long_index_t>(i_batch) * kargs.batch_stride_q;
kargs.kv_page_indices += kargs.kv_indptr[i_batch];
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
{
batch_offset_bias = static_cast<long_index_t>(i_batch) * kargs.batch_stride_bias;
@@ -764,11 +794,7 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
}
batch_offset_o = static_cast<long_index_t>(i_batch) * kargs.batch_stride_o;
#if 0 // we assume page_block_size=1 for now
kargs.seqlen_k = (num_page_blocks - 1) * kargs.page_block_size + last_page_len;
#else
kargs.seqlen_k = num_page_blocks;
#endif
kargs.seqlen_k = seqlen_k;
}
// for simplicity, batch stride we just modify the pointer
@@ -809,60 +835,137 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
}
}();
const auto k_dram = [&]() {
const auto k_dram_naive = make_naive_tensor_view<address_space_enum::global>(
k_ptr,
make_tuple(kargs.num_total_pages * kargs.page_block_size, kargs.hdim_q),
make_tuple(kargs.stride_k, 1),
number<FmhaPipeline::kAlignmentK>{},
number<1>{});
constexpr bool kPadSeqLenK_ = kUseAsyncCopy ? kPadSeqLenK : true;
return pad_tensor_view(
k_dram_naive,
make_tuple(number<FmhaPipeline::kN0>{}, number<FmhaPipeline::kK0>{}),
sequence<kPadSeqLenK_, kPadHeadDimQ>{});
}();
const auto v_dram = [&]() {
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
if constexpr(kKVMemoryLayout ==
BlockAttentionKVCacheMemoryLayoutEnum::VECTORIZED_LAYOUT)
{
const auto v_dram_naive = make_naive_tensor_view<address_space_enum::global>(
v_ptr,
make_tuple(kargs.num_total_pages * kargs.page_block_size, kargs.hdim_v),
make_tuple(kargs.stride_v, 1),
number<FmhaPipeline::kAlignmentV>{},
// Vectorized K Layout: [NumPages, D/kVectorSize, S, kVectorSize]
// Logical View for Pipeline: (TotalSeqK, D)
// Define the naive physical view with 4D shape: (NumPages, HeadDim/kVectorSize,
// PageBlockSize, kVectorSize)
// Strides: (BatchStride, PageBlockSize*kVectorSize, kVectorSize, 1)
const auto k_dram_naive = make_naive_tensor_view<address_space_enum::global>(
k_ptr,
make_tuple(kargs.num_total_pages,
kargs.hdim_q / kVectorSize,
kargs.page_block_size,
kVectorSize),
make_tuple(
kargs.batch_stride_k, kargs.page_block_size * kVectorSize, kVectorSize, 1),
number<FmhaPipeline::kAlignmentK>{},
number<1>{});
const auto v_dram_transposed = transform_tensor_view(
v_dram_naive,
make_tuple(
make_pass_through_transform(kargs.hdim_v),
make_pass_through_transform(kargs.num_total_pages * kargs.page_block_size)),
make_tuple(sequence<1>{}, sequence<0>{}),
// Merge to (TotalSeqK, D) in a single transform:
// physical (Page, D/vec, S, vec) -> logical (TotalSeqK, D)
auto k_dram_2d = transform_tensor_view(
k_dram_naive,
make_tuple(make_merge_transform(make_tuple(kargs.num_total_pages,
kargs.page_block_size)), // TotalSeqK
make_merge_transform(
make_tuple(static_cast<int32_t>(kargs.hdim_q / kVectorSize),
static_cast<int32_t>(kVectorSize)))), // D
make_tuple(sequence<0, 2>{}, sequence<1, 3>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
constexpr bool kPadSeqLenK_ = kUseAsyncCopy ? kPadSeqLenK : true;
return pad_tensor_view(
v_dram_transposed,
k_dram_2d,
make_tuple(number<FmhaPipeline::kN0>{}, number<FmhaPipeline::kK0>{}),
sequence<kPadSeqLenK_, kPadHeadDimQ>{});
}
else
{
// Linear K Layout: [NumPages, PageSize, NumHeads, HeadDim]
// Logical View for Pipeline: (TotalSeqK, D)
const auto k_dram_naive = make_naive_tensor_view<address_space_enum::global>(
k_ptr,
make_tuple(kargs.num_total_pages, kargs.page_block_size, kargs.hdim_q),
make_tuple(kargs.batch_stride_k, kargs.stride_k, 1),
number<FmhaPipeline::kAlignmentK>{},
number<1>{});
// Merge to (TotalSeqK, D) in a single transform:
// physical (Page, S, D) -> logical (TotalSeqK, D)
auto k_dram_2d = transform_tensor_view(
k_dram_naive,
make_tuple(make_merge_transform(
make_tuple(kargs.num_total_pages, kargs.page_block_size)),
make_pass_through_transform(kargs.hdim_q)),
make_tuple(sequence<0, 1>{}, sequence<2>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
constexpr bool kPadSeqLenK_ = kUseAsyncCopy ? kPadSeqLenK : true;
return pad_tensor_view(
k_dram_2d,
make_tuple(number<FmhaPipeline::kN0>{}, number<FmhaPipeline::kK0>{}),
sequence<kPadSeqLenK_, kPadHeadDimQ>{});
}
}();
const auto v_dram = [&]() {
if constexpr(kKVMemoryLayout ==
BlockAttentionKVCacheMemoryLayoutEnum::VECTORIZED_LAYOUT)
{
// Vectorized V Layout: [NumPages, S/kVectorSize, D, kVectorSize]
// Logical View for Pipeline: (D, TotalSeqK) - Transposed for GEMM
// Define the naive physical view with 4D shape: (NumPages,
// PageBlockSize/kVectorSize, HeadDim, kVectorSize)
// Strides: (BatchStride, HeadDim*kVectorSize, kVectorSize, 1)
const auto v_dram_naive = make_naive_tensor_view<address_space_enum::global>(
v_ptr,
make_tuple(kargs.num_total_pages,
kargs.page_block_size / kVectorSize,
kargs.hdim_v,
kVectorSize),
make_tuple(kargs.batch_stride_v, kargs.hdim_v * kVectorSize, kVectorSize, 1),
number<FmhaPipeline::kAlignmentV>{},
number<1>{});
// Merge to (D, TotalSeqK) in a single transform:
// physical (Page, S/vec, D, vec) -> logical (D, TotalSeqK)
auto v_dram_final = transform_tensor_view(
v_dram_naive,
make_tuple(make_pass_through_transform(kargs.hdim_v), // D
make_merge_transform(make_tuple(kargs.num_total_pages,
kargs.page_block_size / kVectorSize,
kVectorSize))), // TotalSeqK
make_tuple(sequence<2>{}, sequence<0, 1, 3>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
constexpr bool kPadSeqLenK_ = kUseAsyncCopy ? kPadSeqLenK : true;
return pad_tensor_view(
v_dram_final,
make_tuple(number<FmhaPipeline::kN1>{}, number<FmhaPipeline::kK1>{}),
sequence<kPadHeadDimV, kPadSeqLenK_>{});
}
else
{
// Linear V Layout: [NumPages, PageSize, NumHeads, HeadDim]
// Logical View for Pipeline: (D, TotalSeqK)
const auto v_dram_naive = make_naive_tensor_view<address_space_enum::global>(
v_ptr,
make_tuple(kargs.hdim_v, kargs.num_total_pages * kargs.page_block_size),
make_tuple(kargs.stride_v, 1),
make_tuple(kargs.num_total_pages, kargs.page_block_size, kargs.hdim_v),
make_tuple(kargs.batch_stride_v, kargs.stride_v, 1),
number<FmhaPipeline::kAlignmentV>{},
number<1>{});
constexpr bool kPadHeadDimV_ = kUseAsyncCopy ? kPadHeadDimV : false;
return pad_tensor_view(
// Merge to (D, TotalSeqK) in a single transform:
// physical (Page, S, D) -> logical (D, TotalSeqK)
auto v_dram_final = transform_tensor_view(
v_dram_naive,
make_tuple(make_pass_through_transform(kargs.hdim_v),
make_merge_transform(
make_tuple(kargs.num_total_pages, kargs.page_block_size))),
make_tuple(sequence<2>{}, sequence<0, 1>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
constexpr bool kPadSeqLenK_ = kUseAsyncCopy ? kPadSeqLenK : true;
return pad_tensor_view(
v_dram_final,
make_tuple(number<FmhaPipeline::kN1>{}, number<FmhaPipeline::kK1>{}),
sequence<kPadHeadDimV_, kPadSeqLenK>{});
sequence<kPadHeadDimV, kPadSeqLenK_>{});
}
}();
auto q_dram_window = make_tile_window(
q_dram,
[&]() {
@@ -1070,6 +1173,15 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
BlockIndices block_indices{i_batch, i_nhead, i_nhead / kargs.nhead_ratio_qk};
const index_t stride_k_for_pipeline =
kKVMemoryLayout == BlockAttentionKVCacheMemoryLayoutEnum::VECTORIZED_LAYOUT
? kVectorSize
: kargs.stride_k;
const index_t stride_v_for_pipeline =
kKVMemoryLayout == BlockAttentionKVCacheMemoryLayoutEnum::VECTORIZED_LAYOUT
? kargs.hdim_v
: kargs.stride_v;
auto o_acc_tile = [&] {
if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::PERTENSOR)
{
@@ -1108,9 +1220,11 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
variant_params,
block_indices,
smem_ptr,
kargs.kv_page_indices,
kargs.stride_k,
kargs.stride_v,
page_idx,
stride_k_for_pipeline,
stride_v_for_pipeline,
kargs.batch_stride_k,
kargs.batch_stride_v,
dropout);
}
else
@@ -1128,9 +1242,11 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
variant_params,
block_indices,
smem_ptr,
kargs.kv_page_indices,
kargs.stride_k,
kargs.stride_v,
page_idx,
stride_k_for_pipeline,
stride_v_for_pipeline,
kargs.batch_stride_k,
kargs.batch_stride_v,
dropout);
}
}();

View File

@@ -6,12 +6,82 @@
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp"
#include "ck_tile/ops/fmha/block/block_attention_kvcache_layout_enum.hpp"
#include "ck_tile/ops/fmha/block/block_dropout.hpp"
#include "ck_tile/ops/fmha/block/variants.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async_default_policy.hpp"
#include "ck_tile/ops/reduce/block/block_reduce.hpp"
namespace ck_tile {
template <typename OffsetVecType,
typename CoordVecType,
index_t kCoordAxis,
index_t kPageBlockSize,
index_t kLog2PageSize,
index_t kLoopStart,
index_t kLoopCount,
index_t kLoopStride,
BlockAttentionKVCacheMemoryLayoutEnum kKVMemoryLayout,
bool kIsKcache,
index_t kVectorSize>
CK_TILE_HOST_DEVICE void kv_offset_array_transform(const index_t* page_vec,
const index_t& stride_kv,
const index_t& page_stride_kv,
const CoordVecType& coord_vec,
OffsetVecType& kv_offset_vec,
index_t global_seq_offset = 0)
{
const index_t& thread_coord_start = coord_vec[kCoordAxis];
constexpr index_t kInPageOffsetMask = (1 << kLog2PageSize) - 1;
if constexpr(kIsKcache)
{
// for k offsets
static_for<0, kLoopCount, 1>{}([&](auto k0) {
const index_t global_token_idx =
global_seq_offset + thread_coord_start + kLoopStart + kLoopStride * k0.value;
const index_t page_id = global_token_idx >> kLog2PageSize;
const index_t page_offset = global_token_idx & kInPageOffsetMask;
kv_offset_vec[k0] = static_cast<long_index_t>(page_vec[page_id]) * page_stride_kv +
static_cast<long_index_t>(page_offset) * stride_kv;
});
}
else
{
// for v offsets
const index_t lane0_start = __builtin_amdgcn_readfirstlane(thread_coord_start);
const index_t lane0_page_id =
(global_seq_offset + lane0_start + kLoopStart) >> kLog2PageSize;
const long_index_t page_loc =
static_cast<long_index_t>(page_vec[lane0_page_id]) * page_stride_kv;
static_for<0, kLoopCount, 1>{}([&](auto k0) {
const index_t page_offset =
(global_seq_offset + thread_coord_start + kLoopStart + k0.value) &
kInPageOffsetMask;
if constexpr(kKVMemoryLayout ==
BlockAttentionKVCacheMemoryLayoutEnum::VECTORIZED_LAYOUT)
{
// Vectorized layout offset
// Layout: [BlockSize/kVectorSize, HeadDim, kVectorSize]
// Offset(s) = (s / kVectorSize) * (HeadDim * kVectorSize) + (s % kVectorSize)
const index_t s = page_offset;
const index_t D = stride_kv;
const long_index_t s_offset =
static_cast<long_index_t>((s / kVectorSize) * (D * kVectorSize)) +
(s % kVectorSize);
kv_offset_vec[k0] = page_loc + s_offset;
}
else // BlockAttentionKVCacheMemoryLayoutEnum::LINEAR_LAYOUT
{
kv_offset_vec[k0] = page_loc + static_cast<long_index_t>(page_offset) * stride_kv;
}
});
}
}
// a variation of qr/ks/vs, where we use async copy to load k (potentially v in the future)
template <typename Problem_,
@@ -41,19 +111,25 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
static constexpr index_t kBlockSize = Problem::kBlockSize;
static constexpr index_t kM0 = BlockFmhaShape::kM0;
static constexpr index_t kN0 = BlockFmhaShape::kN0;
static constexpr index_t kK0 = BlockFmhaShape::kK0;
static constexpr index_t kN1 = BlockFmhaShape::kN1;
static constexpr index_t kK1 = BlockFmhaShape::kK1;
static constexpr index_t kQKHeaddim = BlockFmhaShape::kQKHeaddim;
static constexpr index_t kSubQKHeaddim = BlockFmhaShape::kSubQKHeaddim;
static constexpr auto I0 = number<0>{};
static constexpr auto I1 = number<1>{};
static constexpr auto I2 = number<2>{};
static constexpr auto I3 = number<3>{};
static constexpr index_t kM0 = BlockFmhaShape::kM0;
static constexpr index_t kN0 = BlockFmhaShape::kN0;
static constexpr index_t kK0 = BlockFmhaShape::kK0;
static constexpr index_t kN1 = BlockFmhaShape::kN1;
static constexpr index_t kK1 = BlockFmhaShape::kK1;
static constexpr index_t kQKHeaddim = BlockFmhaShape::kQKHeaddim;
static constexpr index_t kSubQKHeaddim = BlockFmhaShape::kSubQKHeaddim;
static constexpr index_t kPageBlockSize = Problem::kPageBlockSize;
static constexpr index_t kLog2PageSize = Problem::kLog2PageSize;
static constexpr index_t kVectorSize = Problem::kVectorSize;
static constexpr auto I0 = number<0>{};
static constexpr auto I1 = number<1>{};
static constexpr auto I2 = number<2>{};
static constexpr auto I3 = number<3>{};
static_assert(kSubQKHeaddim <= 256, "hdim bigger than 256 is not suitable for this pipeline!");
static_assert(kPageBlockSize % kN0 == 0,
"V offset assumes each tile stays within a page; kPageBlockSize must be "
"divisible by kN0.");
static constexpr bool kIsGroupMode = Problem::kIsGroupMode;
// TODO: seq_q always support padding, hdim_q/v support multiple of vector(like 8x)
@@ -68,6 +144,7 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
static constexpr auto BiasEnum = Problem::BiasEnum;
static constexpr bool kStoreLSE = Problem::kStoreLSE;
static constexpr bool kHasDropout = Problem::kHasDropout;
static constexpr auto kKVMemoryLayout = Problem::kKVMemoryLayout;
static_assert((CK_TILE_FMHA_FWD_FAST_EXP2 &&
(kHasLogitsSoftCap && Problem::BiasEnum == BlockAttentionBiasEnum::NO_BIAS ||
@@ -196,6 +273,8 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
const index_t* page_idx,
const index_t stride_k,
const index_t stride_v,
const index_t page_stride_k,
const index_t page_stride_v,
DropoutType& dropout) const
{
static_assert(
@@ -325,9 +404,20 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
using KDstrEncode = typename decltype(k_dist)::DstrEncode;
constexpr index_t NRepeat = KDstrEncode::hs_lengthss_[I0][I0];
statically_indexed_array<index_t, NRepeat> k_offsets;
static_for<0, NRepeat, 1>{}([&](auto n0) {
k_offsets[n0] = page_idx[k_coord[0] + kN0 / NRepeat * n0.value] * stride_k;
});
index_t current_seq_k = seqlen_k_start;
kv_offset_array_transform<statically_indexed_array<index_t, NRepeat>,
decltype(k_coord),
0,
kPageBlockSize,
kLog2PageSize,
0,
NRepeat,
kN0 / NRepeat,
kKVMemoryLayout,
true,
kVectorSize>(
page_idx, stride_k, page_stride_k, k_coord, k_offsets, current_seq_k);
auto k_dram_window = make_tile_scatter_gather(k_dram_block_window.get_bottom_tensor_view(),
k_dram_block_window.get_window_lengths(),
k_dram_block_window.get_window_origin(),
@@ -360,10 +450,18 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
using VDstrEncode = typename decltype(v_dist)::DstrEncode;
constexpr index_t V_KRepeat = VDstrEncode::hs_lengthss_[I1][I3];
statically_indexed_array<index_t, V_KRepeat> v_offsets;
(void)stride_k;
static_for<0, V_KRepeat, 1>{}([&](auto k0) {
v_offsets[k0] = page_idx[v_coord[VPageIndexDim] + k0.value] * stride_v;
});
kv_offset_array_transform<statically_indexed_array<index_t, V_KRepeat>,
decltype(v_coord),
VPageIndexDim,
kPageBlockSize,
kLog2PageSize,
0,
V_KRepeat,
1,
kKVMemoryLayout,
false,
kVectorSize>(
page_idx, stride_v, page_stride_v, v_coord, v_offsets, current_seq_k);
auto v_dram_window =
make_tile_scatter_gather(v_dram_block_window_tmp.get_bottom_tensor_view(),
@@ -425,13 +523,6 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
async_load_fence();
__builtin_amdgcn_s_barrier();
const auto bias_tile = load_tile(bias_dram_window); // load bias tile
auto v_buf = load_tile(v_dram_window, number<-1>{}, bool_constant<false>{});
static_for<0, V_KRepeat, 1>{}([&](auto k0) {
v_offsets[k0] = page_idx[kK1 + v_coord[VPageIndexDim] + k0.value] * stride_v;
});
v_dram_window.update_page_idx(v_offsets);
__builtin_amdgcn_sched_barrier(0);
{ // tail
gemm_0(
@@ -444,49 +535,67 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
}
__builtin_amdgcn_sched_barrier(1);
// STAGE 2, scale_s, add bias, mask, softmax
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
{
s_acc = tile_elementwise_in(s_acc_element_func, s_acc);
tile_elementwise_inout([&scale_s](auto& x) { x = x * scale_s; }, s_acc);
tile_elementwise_inout(
[&](auto& x, const auto& y) {
#if !CK_TILE_FMHA_FWD_FAST_EXP2
x += type_convert<SaccDataType>(bias_element_func(y));
#else
x += log2e_v<SaccDataType> *
type_convert<SaccDataType>(bias_element_func(y));
#endif
},
s_acc,
bias_tile);
}
else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI)
{
const auto k_origin = k_dram_block_window.get_window_origin();
constexpr auto s_spans = decltype(s_acc)::get_distributed_spans();
s_acc = tile_elementwise_in(s_acc_element_func, s_acc);
sweep_tile_span(s_spans[number<0>{}], [&](auto idx0) {
sweep_tile_span(s_spans[number<1>{}], [&](auto idx1) {
const auto tile_idx = get_x_indices_from_distributed_indices(
s_acc.get_tile_distribution(), make_tuple(idx0, idx1));
auto v_buf = load_tile(v_dram_window, number<-1>{}, bool_constant<false>{});
kv_offset_array_transform<statically_indexed_array<index_t, V_KRepeat>,
decltype(v_coord),
VPageIndexDim,
kPageBlockSize,
kLog2PageSize,
kK1,
V_KRepeat,
1,
kKVMemoryLayout,
false,
kVectorSize>(
page_idx, stride_v, page_stride_v, v_coord, v_offsets, current_seq_k);
v_dram_window.update_page_idx(v_offsets);
const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
constexpr auto i_j_idx = make_tuple(idx0, idx1);
const auto p = [&]() {
const auto bias_tile = load_tile(bias_dram_window); // load bias tile
s_acc(i_j_idx) *= scale_s;
position_encoding.update(s_acc(i_j_idx), row, col);
});
});
}
else
{
s_acc = tile_elementwise_in(s_acc_element_func, s_acc);
if constexpr(kHasLogitsSoftCap)
// STAGE 2, scale_s, add bias, mask, softmax
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
{
auto apply_logits_transform =
[&variant, &variant_params, &block_indices](auto& x) {
s_acc = tile_elementwise_in(s_acc_element_func, s_acc);
tile_elementwise_inout([&scale_s](auto& x) { x = x * scale_s; }, s_acc);
tile_elementwise_inout(
[&](auto& x, const auto& y) {
#if !CK_TILE_FMHA_FWD_FAST_EXP2
x += type_convert<SaccDataType>(bias_element_func(y));
#else
x += log2e_v<SaccDataType> *
type_convert<SaccDataType>(bias_element_func(y));
#endif
},
s_acc,
bias_tile);
}
else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI)
{
const auto k_origin = k_dram_block_window.get_window_origin();
constexpr auto s_spans = decltype(s_acc)::get_distributed_spans();
s_acc = tile_elementwise_in(s_acc_element_func, s_acc);
sweep_tile_span(s_spans[number<0>{}], [&](auto idx0) {
sweep_tile_span(s_spans[number<1>{}], [&](auto idx1) {
const auto tile_idx = get_x_indices_from_distributed_indices(
s_acc.get_tile_distribution(), make_tuple(idx0, idx1));
const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
constexpr auto i_j_idx = make_tuple(idx0, idx1);
s_acc(i_j_idx) *= scale_s;
position_encoding.update(s_acc(i_j_idx), row, col);
});
});
}
else
{
s_acc = tile_elementwise_in(s_acc_element_func, s_acc);
if constexpr(kHasLogitsSoftCap)
{
auto apply_logits_transform = [&variant, &variant_params, &block_indices](
auto& x) {
x = variant.LogitsTransform(variant_params,
variant.QueryTransform(variant_params, x),
block_indices.batch_idx,
@@ -494,216 +603,229 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
block_indices.kv_head_idx);
};
#if !CK_TILE_FMHA_FWD_FAST_EXP2
for(index_t i = 0; i < s_acc.thread_buf_.size(); ++i)
{
apply_logits_transform(s_acc.thread_buf_[i]);
}
for(index_t i = 0; i < s_acc.thread_buf_.size(); ++i)
{
apply_logits_transform(s_acc.thread_buf_[i]);
}
#else
for(index_t i = 0; i < s_acc.thread_buf_.size(); ++i)
{
for(index_t i = 0; i < s_acc.thread_buf_.size(); ++i)
{
#if(defined(__gfx90a__) || defined(__gfx94__)) && \
(CK_TILE_ATTENTION_LOGITS_SOFT_CAP_DEFAULT == CK_TILE_ATTENTION_LOGITS_SOFT_CAP_SOFTSIGN && \
CK_TILE_ATTENTION_USE_SOFTSIGN_ASM)
// Avoid data hazard if v_mfma is followed by inline asm consumer
// instructions. In this case, compiler won't add s_nop for us
if(i == s_acc.thread_buf_.size() / 2)
{
__builtin_amdgcn_sched_barrier(0);
// Avoid data hazard if v_mfma is followed by inline asm consumer
// instructions. In this case, compiler won't add s_nop for us
if(i == s_acc.thread_buf_.size() / 2)
{
__builtin_amdgcn_sched_barrier(0);
}
#endif
apply_logits_transform(s_acc.thread_buf_[i]);
}
#endif
apply_logits_transform(s_acc.thread_buf_[i]);
}
#endif
}
else
{
#if !CK_TILE_FMHA_FWD_FAST_EXP2
tile_elementwise_inout([&scale_s](auto& x) { x = x * scale_s; }, s_acc);
#endif
}
}
move_tile_window(bias_dram_window, {0, kN0});
if constexpr(kPadSeqLenK || FmhaMask::IsMasking)
{
const auto k_origin = k_dram_block_window.get_window_origin();
bool need_perpixel_check = mask.IsEdgeTile(q_origin.at(number<0>{}),
k_origin.at(number<0>{}),
number<kM0>{},
number<kN0>{});
if(need_perpixel_check)
{
set_tile_if(
s_acc, -numeric<SMPLComputeDataType>::infinity(), [&](auto tile_idx) {
const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
return !variant.LogitsMask(variant_params,
block_indices.batch_idx,
row,
col,
block_indices.qo_head_idx,
block_indices.kv_head_idx);
});
}
}
const auto s = cast_tile<SMPLComputeDataType>(s_acc); // S{j}
auto m_local = block_tile_reduce<SMPLComputeDataType>(
s,
sequence<1>{},
f_max,
-numeric<SMPLComputeDataType>::infinity()); // m_local = rowmax(S{j})
block_tile_reduce_sync(m_local, f_max, bool_constant<false>{});
const auto m_old = m; // m{j-1}
tile_elementwise_inout(
[](auto& e0, auto e1, auto e2) { e0 = max(e1, e2); }, m, m_old, m_local); // m{j}
auto p_compute = make_static_distributed_tensor<SMPLComputeDataType>(
s.get_tile_distribution()); // Pcompute{j}
__builtin_amdgcn_sched_barrier(0x7F);
// store & prefetch next v, after the max reduction
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
{
auto v_shuffle_tmp = make_static_distributed_tensor<VDataType>(
Policy::template MakeShuffledVRegBlockDescriptor<Problem>());
shuffle_tile(v_shuffle_tmp, v_buf);
auto v_lds_window_tmp =
get_slice_tile(v_lds_window,
sequence<(LdsSeq.at(number<k0_loops>{})) * kN1, 0>{},
sequence<(LdsSeq.at(number<k0_loops>{}) + 1) * kN1, kK1>{});
store_tile(
v_lds_window_tmp,
tile_elementwise_in(v_element_func, v_shuffle_tmp)); // store the prefetch
}
else
{
auto v_lds_window_tmp =
get_slice_tile(v_lds_window,
sequence<(LdsSeq.at(number<k0_loops>{})) * kN1, 0>{},
sequence<(LdsSeq.at(number<k0_loops>{}) + 1) * kN1, kK1>{});
store_tile(v_lds_window_tmp,
tile_elementwise_in(v_element_func, v_buf)); // store the prefetch
}
if constexpr(k1_loops > 1)
{
move_tile_window(
v_dram_window,
{0, kK1}); // will have scratch if move this right after load_tile(v_dram)...
v_buf = load_tile(
v_dram_window, number<-1>{}, bool_constant<false>{}); // load next v_buf
static_for<0, V_KRepeat, 1>{}([&](auto k0) {
v_offsets[k0] =
page_idx[kK1 * 2 + v_coord[VPageIndexDim] + k0.value] * stride_v;
});
v_dram_window.update_page_idx(v_offsets);
}
__builtin_amdgcn_sched_barrier(0);
static const auto get_validated_m = [](SMPLComputeDataType raw_m) {
/// NOTICE: bias might be materialized mask including -inf values, need
/// consideration. alibi does not have this problem
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS ||
FmhaMask::IsMasking)
{
return raw_m == -numeric<SMPLComputeDataType>::infinity()
? type_convert<SMPLComputeDataType>(0.f)
: raw_m;
}
else
{
return raw_m;
}
};
constexpr auto p_spans = decltype(p_compute)::get_distributed_spans();
sweep_tile_span(p_spans[number<0>{}], [&](auto idx0) {
constexpr auto i_idx = make_tuple(idx0);
#if CK_TILE_FMHA_FWD_FAST_EXP2
auto row_max = scale_s * get_validated_m(m[i_idx]);
#endif
sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) {
constexpr auto i_j_idx = make_tuple(idx0, idx1);
#if CK_TILE_FMHA_FWD_FAST_EXP2
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS ||
BiasEnum == BlockAttentionBiasEnum::ALIBI)
{
p_compute(i_j_idx) = exp2(s[i_j_idx] - get_validated_m(m[i_idx]));
}
else
{
if constexpr(kHasLogitsSoftCap)
#if !CK_TILE_FMHA_FWD_FAST_EXP2
tile_elementwise_inout([&scale_s](auto& x) { x = x * scale_s; }, s_acc);
#endif
}
}
move_tile_window(bias_dram_window, {0, kN0});
if constexpr(kPadSeqLenK || FmhaMask::IsMasking)
{
const auto k_origin = k_dram_block_window.get_window_origin();
bool need_perpixel_check = mask.IsEdgeTile(q_origin.at(number<0>{}),
k_origin.at(number<0>{}),
number<kM0>{},
number<kN0>{});
if(need_perpixel_check)
{
set_tile_if(
s_acc, -numeric<SMPLComputeDataType>::infinity(), [&](auto tile_idx) {
const auto row =
q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
const auto col =
k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
return !variant.LogitsMask(variant_params,
block_indices.batch_idx,
row,
col,
block_indices.qo_head_idx,
block_indices.kv_head_idx);
});
}
}
const auto s = cast_tile<SMPLComputeDataType>(s_acc); // S{j}
auto m_local = block_tile_reduce<SMPLComputeDataType>(
s,
sequence<1>{},
f_max,
-numeric<SMPLComputeDataType>::infinity()); // m_local = rowmax(S{j})
block_tile_reduce_sync(m_local, f_max, bool_constant<false>{});
const auto m_old = m; // m{j-1}
tile_elementwise_inout([](auto& e0, auto e1, auto e2) { e0 = max(e1, e2); },
m,
m_old,
m_local); // m{j}
auto p_compute = make_static_distributed_tensor<SMPLComputeDataType>(
s.get_tile_distribution()); // Pcompute{j}
__builtin_amdgcn_sched_barrier(0x7F);
// store & prefetch next v, after the max reduction
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
{
auto v_shuffle_tmp = make_static_distributed_tensor<VDataType>(
Policy::template MakeShuffledVRegBlockDescriptor<Problem>());
shuffle_tile(v_shuffle_tmp, v_buf);
auto v_lds_window_tmp =
get_slice_tile(v_lds_window,
sequence<(LdsSeq.at(number<k0_loops>{})) * kN1, 0>{},
sequence<(LdsSeq.at(number<k0_loops>{}) + 1) * kN1, kK1>{});
store_tile(
v_lds_window_tmp,
tile_elementwise_in(v_element_func, v_shuffle_tmp)); // store the prefetch
}
else
{
auto v_lds_window_tmp =
get_slice_tile(v_lds_window,
sequence<(LdsSeq.at(number<k0_loops>{})) * kN1, 0>{},
sequence<(LdsSeq.at(number<k0_loops>{}) + 1) * kN1, kK1>{});
store_tile(v_lds_window_tmp,
tile_elementwise_in(v_element_func, v_buf)); // store the prefetch
}
if constexpr(k1_loops > 1)
{
move_tile_window(
v_dram_window,
{0,
kK1}); // will have scratch if move this right after load_tile(v_dram)...
v_buf = load_tile(
v_dram_window, number<-1>{}, bool_constant<false>{}); // load next v_buf
kv_offset_array_transform<statically_indexed_array<index_t, V_KRepeat>,
decltype(v_coord),
VPageIndexDim,
kPageBlockSize,
kLog2PageSize,
2 * kK1,
V_KRepeat,
1,
kKVMemoryLayout,
false,
kVectorSize>(
page_idx, stride_v, page_stride_v, v_coord, v_offsets, current_seq_k);
v_dram_window.update_page_idx(v_offsets);
}
__builtin_amdgcn_sched_barrier(0);
static const auto get_validated_m = [](SMPLComputeDataType raw_m) {
/// NOTICE: bias might be materialized mask including -inf values, need
/// consideration. alibi does not have this problem
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS ||
FmhaMask::IsMasking)
{
return raw_m == -numeric<SMPLComputeDataType>::infinity()
? type_convert<SMPLComputeDataType>(0.f)
: raw_m;
}
else
{
return raw_m;
}
};
constexpr auto p_spans = decltype(p_compute)::get_distributed_spans();
sweep_tile_span(p_spans[number<0>{}], [&](auto idx0) {
constexpr auto i_idx = make_tuple(idx0);
#if CK_TILE_FMHA_FWD_FAST_EXP2
auto row_max = scale_s * get_validated_m(m[i_idx]);
#endif
sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) {
constexpr auto i_j_idx = make_tuple(idx0, idx1);
#if CK_TILE_FMHA_FWD_FAST_EXP2
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS ||
BiasEnum == BlockAttentionBiasEnum::ALIBI)
{
p_compute(i_j_idx) = exp2(s[i_j_idx] - get_validated_m(m[i_idx]));
}
else
{
p_compute(i_j_idx) = exp2(scale_s * s[i_j_idx] - row_max);
if constexpr(kHasLogitsSoftCap)
{
p_compute(i_j_idx) = exp2(s[i_j_idx] - get_validated_m(m[i_idx]));
}
else
{
p_compute(i_j_idx) = exp2(scale_s * s[i_j_idx] - row_max);
}
}
}
#else
p_compute(i_j_idx) = exp(s[i_j_idx] - get_validated_m(m[i_idx]));
p_compute(i_j_idx) = exp(s[i_j_idx] - get_validated_m(m[i_idx]));
#endif
});
});
});
auto rowsum_p = block_tile_reduce<SMPLComputeDataType>(
p_compute, sequence<1>{}, f_sum, SMPLComputeDataType{0}); // rowsum(Pcompute{j})
auto rowsum_p = block_tile_reduce<SMPLComputeDataType>(
p_compute, sequence<1>{}, f_sum, SMPLComputeDataType{0}); // rowsum(Pcompute{j})
block_tile_reduce_sync(rowsum_p, f_sum, bool_constant<false>{});
// l{j}, Oacc{j}
constexpr auto o_spans = decltype(o_acc)::get_distributed_spans();
sweep_tile_span(o_spans[number<0>{}], [&](auto idx0) {
constexpr auto i_idx = make_tuple(idx0);
block_tile_reduce_sync(rowsum_p, f_sum, bool_constant<false>{});
// l{j}, Oacc{j}
constexpr auto o_spans = decltype(o_acc)::get_distributed_spans();
sweep_tile_span(o_spans[number<0>{}], [&](auto idx0) {
constexpr auto i_idx = make_tuple(idx0);
#if CK_TILE_FMHA_FWD_FAST_EXP2
const auto tmp = [&]() {
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS ||
BiasEnum == BlockAttentionBiasEnum::ALIBI)
{
return exp2(m_old[i_idx] - get_validated_m(m[i_idx]));
}
else
{
if constexpr(kHasLogitsSoftCap)
const auto tmp = [&]() {
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS ||
BiasEnum == BlockAttentionBiasEnum::ALIBI)
{
return exp2(m_old[i_idx] - get_validated_m(m[i_idx]));
}
else
{
auto row_max = scale_s * get_validated_m(m[i_idx]);
return exp2(scale_s * m_old[i_idx] - row_max);
if constexpr(kHasLogitsSoftCap)
{
return exp2(m_old[i_idx] - get_validated_m(m[i_idx]));
}
else
{
auto row_max = scale_s * get_validated_m(m[i_idx]);
return exp2(scale_s * m_old[i_idx] - row_max);
}
}
}
}();
}();
#else
const auto tmp = exp(m_old[i_idx] - get_validated_m(m[i_idx]));
const auto tmp = exp(m_old[i_idx] - get_validated_m(m[i_idx]));
#endif
l(i_idx) = tmp * l[i_idx] + rowsum_p[i_idx];
sweep_tile_span(o_spans[number<1>{}], [&](auto idx1) {
constexpr auto i_j_idx = make_tuple(idx0, idx1);
// FIXME: this use different equation from FA v2 paper,
// but produce correc result.
// Is the equation wrong?
o_acc(i_j_idx) *= tmp;
l(i_idx) = tmp * l[i_idx] + rowsum_p[i_idx];
sweep_tile_span(o_spans[number<1>{}], [&](auto idx1) {
constexpr auto i_j_idx = make_tuple(idx0, idx1);
// FIXME: this use different equation from FA v2 paper,
// but produce correc result.
// Is the equation wrong?
o_acc(i_j_idx) *= tmp;
});
});
});
if constexpr(kHasDropout)
{
auto randval_ptr =
reinterpret_cast<char*>(smem_ptr) + Policy::template GetSmemSizeKV<Problem>();
dropout.template Run<decltype(gemm_0), SMPLComputeDataType, RandValOutputDataType>(
randval_ptr,
seqlen_k_start + i_total_loops * kN0,
p_compute,
randval_dram_window);
}
if constexpr(kHasDropout)
{
auto randval_ptr = reinterpret_cast<char*>(smem_ptr) +
Policy::template GetSmemSizeKV<Problem>();
dropout
.template Run<decltype(gemm_0), SMPLComputeDataType, RandValOutputDataType>(
randval_ptr,
seqlen_k_start + i_total_loops * kN0,
p_compute,
randval_dram_window);
}
const auto p = [&]() {
#if CK_TILE_FMHA_FLOAT_TO_FLOAT16_RTN
// For fp32 to fp16,
// impl::cast_tile_pkrtz_fp16_fp32 would cause precision issue,
@@ -727,11 +849,18 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
{
v_buf = load_tile(
v_dram_window, number<-1>{}, bool_constant<false>{}); // load next v_buf
static_for<0, V_KRepeat, 1>{}([&](auto k0) {
v_offsets[k0] = page_idx[kK1 * 2 + i_k1.value * kK1 +
v_coord[VPageIndexDim] + k0.value] *
stride_v;
});
kv_offset_array_transform<statically_indexed_array<index_t, V_KRepeat>,
decltype(v_coord),
VPageIndexDim,
kPageBlockSize,
kLog2PageSize,
(2 + i_k1.value) * kK1,
V_KRepeat,
1,
kKVMemoryLayout,
false,
kVectorSize>(
page_idx, stride_v, page_stride_v, v_coord, v_offsets, current_seq_k);
v_dram_window.update_page_idx(v_offsets);
}
block_sync_lds();
@@ -772,14 +901,23 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
i_total_loops++;
if(i_total_loops < num_total_loop)
{
page_idx += kN0;
current_seq_k += kN0;
// move K tile windows
move_tile_window(k_dram_block_window, {kN0, 0});
k_dram_window.set_window_origin(k_dram_block_window.get_window_origin());
static_for<0, NRepeat, 1>{}([&](auto n0) {
k_offsets[n0] = page_idx[k_coord[0] + kN0 / NRepeat * n0.value] * stride_k;
});
kv_offset_array_transform<statically_indexed_array<index_t, NRepeat>,
decltype(k_coord),
0,
kPageBlockSize,
kLog2PageSize,
0,
NRepeat,
kN0 / NRepeat,
kKVMemoryLayout,
true,
kVectorSize>(
page_idx, stride_k, page_stride_k, k_coord, k_offsets, current_seq_k);
k_dram_window.update_page_idx(k_offsets);
if constexpr(k1_loops >= 2 &&
LdsSeq.at(number<0>{}) == LdsSeq.at(number<k0_loops + k1_loops - 2>{}))
@@ -887,6 +1025,8 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
const index_t* page_idx,
const index_t stride_k,
const index_t stride_v,
const index_t page_stride_k,
const index_t page_stride_v,
DropoutType& dropout) const
{
return operator()(q_dram_block_window_tmp,
@@ -913,6 +1053,8 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
page_idx,
stride_k,
stride_v,
page_stride_k,
page_stride_v,
dropout);
}
};

View File

@@ -4,6 +4,7 @@
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/fmha/block/block_attention_kvcache_layout_enum.hpp"
#include "ck_tile/ops/fmha/block/block_rotary_embedding.hpp"
namespace ck_tile {
@@ -65,6 +66,71 @@ struct BlockFmhaPipelineProblem
static constexpr bool kHasSink = Traits::kHasSink;
};
template <typename QDataType_,
typename KDataType_,
typename VDataType_,
typename SaccDataType_,
typename SMPLComputeDataType_,
typename BiasDataType_,
typename RandValOutputDataType_,
typename LSEDataType_,
typename PDataType_,
typename OaccDataType_,
typename ODataType_,
typename BlockFmhaShape_,
bool kIsGroupMode_,
typename AttentionVariant_,
typename FmhaMask_,
bool kUseTrLoad_,
int kPageBlockSize_,
typename Traits_>
struct BlockFmhaBatchPrefillPipelineProblem
: public BlockFmhaPipelineProblem<QDataType_,
KDataType_,
VDataType_,
SaccDataType_,
SMPLComputeDataType_,
BiasDataType_,
RandValOutputDataType_,
LSEDataType_,
PDataType_,
OaccDataType_,
ODataType_,
BlockFmhaShape_,
kIsGroupMode_,
AttentionVariant_,
FmhaMask_,
kUseTrLoad_,
Traits_>
{
static constexpr index_t kPageBlockSize = kPageBlockSize_;
static_assert(kPageBlockSize > 0, "kPageBlockSize must be positive");
static_assert((kPageBlockSize & (kPageBlockSize - 1)) == 0,
"kPageBlockSize must be power of two");
static constexpr index_t kLog2PageSize = []() constexpr {
index_t shift = 0;
index_t val = kPageBlockSize_;
while(val > 1)
{
val >>= 1;
shift++;
}
return shift;
}();
static constexpr index_t kVectorSize = 16 / sizeof(KDataType_); // Dwordx4
static constexpr auto kKVMemoryLayout = Traits_::kKVMemoryLayout;
static constexpr auto kKVLookupTable = Traits_::kKVLookupTable;
static constexpr bool kIsVectorizedLayout =
kKVMemoryLayout == BlockAttentionKVCacheMemoryLayoutEnum::VECTORIZED_LAYOUT;
static_assert(BlockFmhaShape_::kQKHeaddim % kVectorSize == 0,
"kQKHeaddim must be divisible by kVectorSize");
static_assert(!kIsVectorizedLayout || kPageBlockSize % kVectorSize == 0,
"kPageBlockSize must be divisible by kVectorSize for vectorized layout");
static_assert(kIsGroupMode_, "Batch prefill requires group mode");
};
template <typename QDataType_,
typename KDataType_,
typename VDataType_,

View File

@@ -5,6 +5,7 @@
#include "ck_tile/core.hpp"
#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp"
#include "ck_tile/ops/fmha/block/block_attention_kvcache_layout_enum.hpp"
#include "ck_tile/ops/fmha/block/block_attention_quant_scale_enum.hpp"
#include "ck_tile/ops/fmha/block/block_rotary_embedding.hpp"
@@ -40,6 +41,48 @@ struct TileFmhaTraits
static constexpr bool kHasSink = kHasSink_;
};
template <bool kPadSeqLenQ_ /* padding for seqlen_q */,
bool kPadSeqLenK_ /* padding for seqlen_k */,
bool kPadHeadDimQ_ /* padding for hdim_q */,
bool kPadHeadDimV_ /* padding for hdim_v */,
bool kHasLogitsSoftCap_,
BlockAttentionBiasEnum BiasEnum_,
bool kHasBiasGrad_,
bool kStoreLSE_,
bool kHasDropout_,
BlockAttentionQuantScaleEnum QScaleEnum_,
index_t kBlockPerCu_ = -1, /* overwrite occupancy if not -1 */
bool kSkipMinSeqlenQ_ = false, /* skip min seqlen q while chunked prefill */
index_t kPageBlockSize_ = 1,
BlockAttentionKVCacheMemoryLayoutEnum kKVMemoryLayout_ =
BlockAttentionKVCacheMemoryLayoutEnum::VECTORIZED_LAYOUT,
BlockAttentionKVCacheLookupTableEnum kKVLookupTable_ =
BlockAttentionKVCacheLookupTableEnum::SGLANG_PAGE_TABLE_1D>
struct TileFmhaBatchPrefillTraits : public TileFmhaTraits<kPadSeqLenQ_,
kPadSeqLenK_,
kPadHeadDimQ_,
kPadHeadDimV_,
kHasLogitsSoftCap_,
BiasEnum_,
kHasBiasGrad_,
kStoreLSE_,
kHasDropout_,
QScaleEnum_,
kBlockPerCu_,
kSkipMinSeqlenQ_,
false>
{
static constexpr auto kKVMemoryLayout = kKVMemoryLayout_;
static constexpr auto kKVLookupTable = kKVLookupTable_;
static constexpr index_t kPageBlockSize = kPageBlockSize_;
static_assert(kKVMemoryLayout == BlockAttentionKVCacheMemoryLayoutEnum::VECTORIZED_LAYOUT ||
kKVMemoryLayout == BlockAttentionKVCacheMemoryLayoutEnum::LINEAR_LAYOUT,
"Batch prefill only supports vectorized or linear KV cache layout.");
static_assert(kPageBlockSize > 0 && ((kPageBlockSize & (kPageBlockSize - 1)) == 0),
"kPageBlockSize should be a power of 2 to support efficient page-based KV cache "
"addressing.");
};
template <index_t kPadHeadDimQ_ /* paddding for hdim_q */,
index_t kPadHeadDimV_ /* paddding for hdim_v */,
BlockAttentionBiasEnum BiasEnum_,