mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-06 15:54:31 +00:00
Add template argument 'kIsPagedKV' for splitkv kernels
This commit is contained in:
@@ -53,6 +53,7 @@ using fmha_trait = ck_tile::TileFmhaFwdSplitKVTraits<{F_spad},
|
||||
{F_lse},
|
||||
{F_dropout},
|
||||
{F_squant},
|
||||
{F_pagedkv},
|
||||
kHasUnevenSplits,
|
||||
{F_occupancy}>;
|
||||
|
||||
@@ -97,8 +98,9 @@ static void run(const ck_tile::stream_config& s, fmha_fwd_args a)
|
||||
}};
|
||||
}}
|
||||
|
||||
using trait_{F_idx} = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0blen}, {F_vlayout},
|
||||
{F_pipeline_enum}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>;
|
||||
using trait_{F_idx} = fmha_fwd_splitkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0blen}, {F_vlayout},
|
||||
{F_pipeline_enum}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_pagedkv}, {F_spad}, {F_skpad}, {F_dpad},
|
||||
{F_dvpad}>;
|
||||
|
||||
#include <iostream>
|
||||
|
||||
@@ -225,14 +227,22 @@ float fmha_fwd_splitkv(fmha_fwd_traits t, fmha_fwd_args a, const ck_tile::stream
|
||||
"""
|
||||
|
||||
FMHA_FWD_SPLITKV_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.has_dropout == {F_dropout}) && (t.do_fp8_static_quant == {F_squant}) &&
|
||||
({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck})) {{
|
||||
using traits_ = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0blen}, {F_vlayout}, {F_pipeline_enum}, {F_mask}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>;
|
||||
((a.block_table_ptr != nullptr) == {F_pagedkv}) && ({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck})) {{
|
||||
using traits_ = fmha_fwd_splitkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0blen}, {F_vlayout}, {F_pipeline_enum}, {F_mask}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_pagedkv}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>;
|
||||
using traits2_ = fmha_fwd_splitkv_combine_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}/2, {F_bn1}, {F_lse}, {F_squant}, {F_spad}, {F_dvpad}>;
|
||||
|
||||
return fmha_fwd_splitkv_<traits_, traits2_>(s, a);
|
||||
}}
|
||||
"""
|
||||
|
||||
@dataclass
|
||||
class FmhaFwdSplitKVApiTrait(FmhaFwdApiTrait):
|
||||
pagedkv : str
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return FmhaFwdApiTrait.name + f'-{self.pagedkv}'
|
||||
|
||||
@dataclass
|
||||
class FmhaFwdSplitKVPipeline:
|
||||
tag : str
|
||||
@@ -246,6 +256,7 @@ class FmhaFwdSplitKVPipeline:
|
||||
F_lse : str #
|
||||
F_dropout : str #
|
||||
F_squant : str #
|
||||
F_pagedkv : str # t/f
|
||||
F_mask : str # value from MASK_MAP
|
||||
|
||||
@property
|
||||
@@ -269,6 +280,7 @@ class FmhaFwdSplitKVPipeline:
|
||||
if self.F_lse == 't' : n += '_lse'
|
||||
if self.F_dropout == 't' : n += '_dropout'
|
||||
if self.F_squant == 't' : n += '_squant'
|
||||
if self.F_pagedkv == 't' : n += '_pagedkv'
|
||||
return n
|
||||
|
||||
@dataclass
|
||||
@@ -300,7 +312,7 @@ class FmhaFwdSplitKVApiPool:
|
||||
self.pool = dict()
|
||||
self.mask_impl = mask_impl
|
||||
|
||||
def register_traits(self, trait : FmhaFwdApiTrait) -> None:
|
||||
def register_traits(self, trait : FmhaFwdSplitKVApiTrait) -> None:
|
||||
# TODO: do we need to check duplication?
|
||||
if trait.dtype not in self.pool.keys():
|
||||
self.pool[trait.dtype] = dict()
|
||||
@@ -322,8 +334,8 @@ class FmhaFwdSplitKVApiPool:
|
||||
inners = inners + FMHA_FWD_SPLITKV_API_INNER_DISPATCH.format(F_if=if_k, F_mode=MODE_MAP[trait.mode], F_vlayout=LAYOUT_MAP[trait.vlayout],
|
||||
F_pipeline_enum=PIPELINE_ENUM_MAP[trait.pipeline_tag], F_mask=get_mask_map(self.mask_impl)[trait.mask],
|
||||
F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask], F_bias_check=BIAS_CHECK_MAP[trait.bias], F_bias=BIAS_MAP[trait.bias],
|
||||
F_lse=BOOL_MAP[trait.lse], F_dropout=BOOL_MAP[trait.dropout] ,
|
||||
F_squant=BOOL_MAP[trait.squant], F_scheck=trait.scheck, F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck,
|
||||
F_lse=BOOL_MAP[trait.lse], F_dropout=BOOL_MAP[trait.dropout], F_squant=BOOL_MAP[trait.squant], F_pagedkv=BOOL_MAP[trait.pagedkv],
|
||||
F_scheck=trait.scheck, F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck,
|
||||
F_spad=BOOL_MAP[trait.spad], F_skpad=BOOL_MAP[trait.skpad], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad],
|
||||
F_bm0=trait.bm0, F_bn0=trait.bn0, F_bk0=trait.bk0, F_bn1=trait.bn1, F_bk1=trait.bk1, F_bk0blen=trait.bk0blen,
|
||||
F_hdim=hdim, F_dtype=DTYPE_MAP[dtype])
|
||||
@@ -385,6 +397,7 @@ class FmhaFwdSplitKVKernel:
|
||||
F_lse = BOOL_MAP[self.F_pipeline.F_lse],
|
||||
F_dropout = BOOL_MAP[self.F_pipeline.F_dropout],
|
||||
F_squant = BOOL_MAP[self.F_pipeline.F_squant],
|
||||
F_pagedkv = BOOL_MAP[self.F_pipeline.F_pagedkv],
|
||||
F_occupancy = self.F_tile.F_occupancy,
|
||||
F_pipeline_enum = PIPELINE_ENUM_MAP[self.F_pipeline.tag],
|
||||
F_mask = get_mask_map(self.mask_impl)[self.F_pipeline.F_mask],
|
||||
@@ -401,8 +414,8 @@ class FmhaFwdSplitKVKernel:
|
||||
def filename(self) -> str:
|
||||
return self.name + ".cpp"
|
||||
|
||||
def api_trait(self) -> FmhaFwdApiTrait:
|
||||
return FmhaFwdApiTrait(
|
||||
def api_trait(self) -> FmhaFwdSplitKVApiTrait:
|
||||
return FmhaFwdSplitKVApiTrait(
|
||||
pipeline_tag=self.F_pipeline.tag,
|
||||
hdim=str(self.F_hdim),
|
||||
dtype=self.F_dtype,
|
||||
@@ -419,6 +432,7 @@ class FmhaFwdSplitKVKernel:
|
||||
lse=self.F_pipeline.F_lse,
|
||||
dropout=self.F_pipeline.F_dropout,
|
||||
squant=self.F_pipeline.F_squant,
|
||||
pagedkv=self.F_pipeline.F_pagedkv,
|
||||
spad=self.F_pipeline.F_spad,
|
||||
skpad=self.F_pipeline.F_skpad,
|
||||
dpad=self.F_pipeline.F_dpad,
|
||||
@@ -460,29 +474,6 @@ class FmhaFwdSplitKVCombineKernel:
|
||||
def filename(self) -> str:
|
||||
return self.name + ".cpp"
|
||||
|
||||
def api_trait(self) -> FmhaFwdApiTrait:
|
||||
return FmhaFwdApiTrait(
|
||||
pipeline_tag=self.F_pipeline.tag,
|
||||
hdim=str(self.F_hdim),
|
||||
dtype=self.F_dtype,
|
||||
mode=self.F_mode,
|
||||
bm0=self.F_tile.F_bm0,
|
||||
bn0=self.F_tile.F_bn0,
|
||||
bk0=self.F_tile.F_bk0,
|
||||
bn1=self.F_tile.F_bn1,
|
||||
bk1=self.F_tile.F_bk1,
|
||||
bk0blen=self.F_tile.F_bk0blen,
|
||||
vlayout=self.F_pipeline.F_vlayout,
|
||||
mask=self.F_pipeline.F_mask,
|
||||
bias=self.F_pipeline.F_bias,
|
||||
lse=self.F_pipeline.F_lse,
|
||||
dropout=self.F_pipeline.F_dropout,
|
||||
squant=self.F_pipeline.F_squant,
|
||||
spad=self.F_pipeline.F_spad,
|
||||
skpad=self.F_pipeline.F_skpad,
|
||||
dpad=self.F_pipeline.F_dpad,
|
||||
dvpad=self.F_pipeline.F_dvpad)
|
||||
|
||||
# TODO: design a more practical way to do it
|
||||
# this is current supported tile size per hdim
|
||||
def get_fmha_fwd_tile_dict_from_dtype(dtype : str) -> Optional[dict]:
|
||||
@@ -534,26 +525,26 @@ def get_fwd_splitkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) ->
|
||||
pipelines = []
|
||||
if dtype in ['fp16', 'bf16']:
|
||||
# splitkv kernel donot support dropout
|
||||
for mask, bias, lse, dropout in itertools.product(get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"], ["f"]):
|
||||
for mask, bias, lse, dropout, pagedkv in itertools.product(get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"], ["f"], ["t", "f"]):
|
||||
if hdim == 256:
|
||||
# if True:
|
||||
pipelines.append(Pipeline('qr', 'row', 'f', 'f', 'f', 'f', bias, lse, dropout, squant, mask))
|
||||
pipelines.append(Pipeline('qr', 'col', 'f', 'f', 'f', 'f', bias, lse, dropout, squant, mask))
|
||||
pipelines.append(Pipeline('qr', 'row', 'f', 'f', 'f', 'f', bias, lse, dropout, squant, pagedkv, mask))
|
||||
pipelines.append(Pipeline('qr', 'col', 'f', 'f', 'f', 'f', bias, lse, dropout, squant, pagedkv, mask))
|
||||
|
||||
pipelines.append(Pipeline('qr', 'row', 't', 't', 't', 't', bias, lse, dropout, squant, mask))
|
||||
pipelines.append(Pipeline('qr', 'col', 't', 't', 't', 't', bias, lse, dropout, squant, mask))
|
||||
pipelines.append(Pipeline('qr', 'row', 't', 't', 't', 't', bias, lse, dropout, squant, pagedkv, mask))
|
||||
pipelines.append(Pipeline('qr', 'col', 't', 't', 't', 't', bias, lse, dropout, squant, pagedkv, mask))
|
||||
else:
|
||||
pipelines.append(Pipeline('qr_async', 'row', 't', 'f', 't', 't', bias, lse, dropout, squant, mask))
|
||||
pipelines.append(Pipeline('qr_async', 'row', 't', 't', 't', 't', bias, lse, dropout, squant, mask))
|
||||
pipelines.append(Pipeline('qr_async', 'col', 't', 'f', 't', 't', bias, lse, dropout, squant, mask))
|
||||
pipelines.append(Pipeline('qr_async', 'col', 't', 't', 't', 't', bias, lse, dropout, squant, mask))
|
||||
pipelines.append(Pipeline('qr_async', 'row', 't', 'f', 't', 't', bias, lse, dropout, squant, pagedkv, mask))
|
||||
pipelines.append(Pipeline('qr_async', 'row', 't', 't', 't', 't', bias, lse, dropout, squant, pagedkv, mask))
|
||||
pipelines.append(Pipeline('qr_async', 'col', 't', 'f', 't', 't', bias, lse, dropout, squant, pagedkv, mask))
|
||||
pipelines.append(Pipeline('qr_async', 'col', 't', 't', 't', 't', bias, lse, dropout, squant, pagedkv, mask))
|
||||
if receipt == 1:
|
||||
pipelines.append(Pipeline('qr', 'row', 't', 't', 't', 't', bias, lse, dropout, squant, mask)) # TODO: cover arbitraty hdim
|
||||
pipelines.append(Pipeline('qr', 'col', 't', 'f', 't', 't', bias, lse, dropout, squant, mask)) # TODO: cover arbitraty hdim
|
||||
pipelines.append(Pipeline('qr', 'row', 't', 't', 't', 't', bias, lse, dropout, squant, pagedkv, mask)) # TODO: cover arbitraty hdim
|
||||
pipelines.append(Pipeline('qr', 'col', 't', 'f', 't', 't', bias, lse, dropout, squant, pagedkv, mask)) # TODO: cover arbitraty hdim
|
||||
elif dtype in ['fp8', 'bf8']:
|
||||
# no need lse/dropout kernels
|
||||
# no need lse/dropout/paged-kv kernels
|
||||
for mask, bias in itertools.product(get_mask_map(mask_impl).keys(), BIAS_MAP.keys()):
|
||||
pipelines.append(Pipeline('qr', 'col', 'f', 'f', 'f', 'f', bias, 'f', 'f', squant, mask))
|
||||
pipelines.append(Pipeline('qr', 'col', 'f', 'f', 'f', 'f', bias, 'f', 'f', squant, 'f', mask))
|
||||
else:
|
||||
assert False
|
||||
return pipelines
|
||||
|
||||
@@ -125,6 +125,7 @@ auto create_args(int argc, char* argv[])
|
||||
.insert(
|
||||
"rotary_dim", "0", "RoPE rotary dimension. rotary_dim <= 0 means not apply RoPE at all")
|
||||
.insert("rotary_interleaved", "1", "whether to apply interleaved RoPE")
|
||||
.insert("page_block_size", "0", "paged-kvcache block size. 0 means not use paged-kvcahe.")
|
||||
.insert("warmup", "5", "number of iterations before benchmark the kernel")
|
||||
.insert("repeat", "20", "number of iterations to benchmark the kernel");
|
||||
|
||||
@@ -257,7 +258,7 @@ float fmha_fwd_dispatch(fmha_fwd_traits traits,
|
||||
const ck_tile::stream_config& config)
|
||||
{
|
||||
#if CK_TILE_FMHA_FWD_SPLITKV_API
|
||||
if(1 < args.num_splits)
|
||||
if(1 < args.num_splits || args.block_table_ptr != nullptr)
|
||||
{
|
||||
return fmha_fwd_splitkv(traits, args, config);
|
||||
}
|
||||
@@ -415,9 +416,11 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
return false;
|
||||
}
|
||||
|
||||
#if CK_TILE_FMHA_FWD_APPENDKV_API
|
||||
const bool is_rotary_interleaved = arg_parser.get_bool("rotary_interleaved");
|
||||
#endif
|
||||
|
||||
int num_splits = arg_parser.get_int("num_splits");
|
||||
ck_tile::index_t num_splits = arg_parser.get_int("num_splits");
|
||||
#if !CK_TILE_FMHA_FWD_SPLITKV_API
|
||||
if(num_splits != 1)
|
||||
{
|
||||
@@ -425,6 +428,23 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
num_splits = 1;
|
||||
}
|
||||
#endif
|
||||
ck_tile::index_t page_block_size = arg_parser.get_int("page_block_size");
|
||||
#if !CK_TILE_FMHA_FWD_SPLITKV_API
|
||||
if(0 < page_block_size)
|
||||
{
|
||||
std::cerr << "paged-kvcache is not supported. ignoring the 'page_block_size' option"
|
||||
<< std::endl;
|
||||
page_block_size = 0;
|
||||
}
|
||||
#endif
|
||||
if(!(page_block_size % 256 == 0))
|
||||
{
|
||||
std::cerr << "only paged-kvcache block size divisible by 256 are currently supported"
|
||||
<< std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
#define ENABLE_PAGED_KVCACHE 0
|
||||
|
||||
int stream_warmup = arg_parser.get_int("warmup");
|
||||
int stream_repeat = arg_parser.get_int("repeat");
|
||||
@@ -486,6 +506,11 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
}
|
||||
}
|
||||
|
||||
const ck_tile::index_t max_num_blocks =
|
||||
(ENABLE_PAGED_KVCACHE && 0 < page_block_size
|
||||
? batch * std::min(1, ck_tile::integer_divide_ceil(max_seqlen_k, page_block_size))
|
||||
: 0);
|
||||
|
||||
// legalize num_splits according to other options
|
||||
if(num_splits < 1)
|
||||
{
|
||||
@@ -520,18 +545,26 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
: (seqlen_kpads[0] < 0 ? seqstart_k_host.back()
|
||||
: seqstart_k_with_padding_host.back()));
|
||||
|
||||
std::cerr << "[POYENC] num_blocks: " << max_num_blocks << std::endl;
|
||||
|
||||
ck_tile::HostTensor<QDataType> q_host(
|
||||
get_lengths(i_perm, shape_batch, nhead, shape_seqlen_q, hdim_q));
|
||||
ck_tile::HostTensor<KDataType> k_host(
|
||||
get_lengths(i_perm, shape_batch, nhead_k, shape_seqlen_k, hdim_q));
|
||||
ENABLE_PAGED_KVCACHE && 0 < page_block_size
|
||||
? get_lengths(i_perm, max_num_blocks, nhead_k, page_block_size, hdim_q)
|
||||
: get_lengths(i_perm, shape_batch, nhead_k, shape_seqlen_k, hdim_q));
|
||||
/// NOTICE: always use same shape for knew_host & vnew_host in batch/group mode
|
||||
ck_tile::HostTensor<KDataType> knew_host(
|
||||
0 < seqlen_knew
|
||||
? get_lengths(i_perm, batch, nhead_k, seqlen_knew, hdim_q)
|
||||
: std::array<ck_tile::index_t, 4>{1, 1, 1, 1} /* dummy shape for simplifying code */);
|
||||
ck_tile::HostTensor<VDataType> v_host(
|
||||
is_v_rowmajor ? get_lengths(i_perm, shape_batch, nhead_k, shape_seqlen_k, hdim_v)
|
||||
: get_lengths(i_perm, shape_batch, nhead_k, hdim_v, shape_seqlen_k));
|
||||
ENABLE_PAGED_KVCACHE && 0 < page_block_size
|
||||
? (is_v_rowmajor
|
||||
? get_lengths(i_perm, max_num_blocks, nhead_k, page_block_size, hdim_v)
|
||||
: get_lengths(i_perm, max_num_blocks, nhead_k, hdim_v, page_block_size))
|
||||
: (is_v_rowmajor ? get_lengths(i_perm, shape_batch, nhead_k, shape_seqlen_k, hdim_v)
|
||||
: get_lengths(i_perm, shape_batch, nhead_k, hdim_v, shape_seqlen_k)));
|
||||
ck_tile::HostTensor<VDataType> vnew_host(
|
||||
0 < seqlen_knew
|
||||
? (is_v_rowmajor ? get_lengths(i_perm, batch, nhead_k, seqlen_knew, hdim_v)
|
||||
@@ -571,6 +604,11 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
p_drop > 0 ? get_lengths(true, shape_batch, nhead, shape_seqlen_q, max_seqlen_k)
|
||||
: std::array<ck_tile::index_t, 4>{1, 1, 1, 1});
|
||||
|
||||
ck_tile::HostTensor<int32_t> block_table_host(
|
||||
ENABLE_PAGED_KVCACHE && 0 < page_block_size
|
||||
? std::array<ck_tile::index_t, 2>{batch, max_num_blocks / batch}
|
||||
: std::array<ck_tile::index_t, 2>{1, 1});
|
||||
|
||||
if(init_method == "ui" || init_method == "0")
|
||||
{
|
||||
ck_tile::FillUniformDistributionIntegerValue<QDataType>{-3.f, 3.f, seed}(q_host);
|
||||
@@ -648,6 +686,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
}
|
||||
}
|
||||
}
|
||||
iota_shuffle(block_table_host.begin(), block_table_host.end(), 0);
|
||||
|
||||
ck_tile::DeviceMem q_buf(q_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem k_buf(k_host.get_element_space_size_in_bytes());
|
||||
@@ -667,6 +706,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
ck_tile::DeviceMem rotary_sin_buf(rotary_sin_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem randval_buf(randval_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem alibi_slope_buf(alibi_slope_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem block_table_buf(block_table_host.get_element_space_size_in_bytes());
|
||||
|
||||
q_buf.ToDevice(q_host.data());
|
||||
k_buf.ToDevice(k_host.data());
|
||||
@@ -682,6 +722,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
rotary_cos_buf.ToDevice(rotary_cos_host.data());
|
||||
rotary_sin_buf.ToDevice(rotary_sin_host.data());
|
||||
alibi_slope_buf.ToDevice(alibi_slope_host.data());
|
||||
block_table_buf.ToDevice(block_table_host.data());
|
||||
|
||||
// clang-format off
|
||||
auto layout_str = [&](bool permute){
|
||||
@@ -842,7 +883,9 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
if(is_v_rowmajor)
|
||||
return i_perm ? hdim_v : nhead_k * hdim_v;
|
||||
else
|
||||
return i_perm ? shape_seqlen_k : nhead_k * shape_seqlen_k;
|
||||
return ENABLE_PAGED_KVCACHE && 0 < page_block_size
|
||||
? (i_perm ? page_block_size : nhead_k * page_block_size)
|
||||
: (i_perm ? shape_seqlen_k : nhead_k * shape_seqlen_k);
|
||||
}();
|
||||
const ck_tile::index_t stride_bias = (i_perm ? shape_seqlen_k : 1 * shape_seqlen_k);
|
||||
const ck_tile::index_t stride_randval = (max_seqlen_k);
|
||||
@@ -850,12 +893,18 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
const ck_tile::index_t stride_o = (o_perm ? hdim_v : nhead * hdim_v);
|
||||
// setup nhead_stride_* arguments
|
||||
const ck_tile::index_t nhead_stride_q = (i_perm ? shape_seqlen_q * hdim_q : hdim_q);
|
||||
const ck_tile::index_t nhead_stride_k = (i_perm ? shape_seqlen_k * hdim_q : hdim_q);
|
||||
const ck_tile::index_t nhead_stride_k = (ENABLE_PAGED_KVCACHE && 0 < page_block_size
|
||||
? (i_perm ? page_block_size * hdim_q : hdim_q)
|
||||
: (i_perm ? shape_seqlen_k * hdim_q : hdim_q));
|
||||
const ck_tile::index_t nhead_stride_v = [&]() {
|
||||
if(is_v_rowmajor)
|
||||
return i_perm ? shape_seqlen_k * hdim_v : hdim_v;
|
||||
return ENABLE_PAGED_KVCACHE && 0 < page_block_size
|
||||
? (i_perm ? page_block_size * hdim_v : hdim_v)
|
||||
: (i_perm ? shape_seqlen_k * hdim_v : hdim_v);
|
||||
else
|
||||
return i_perm ? hdim_v * shape_seqlen_k : shape_seqlen_k;
|
||||
return ENABLE_PAGED_KVCACHE && 0 < page_block_size
|
||||
? (i_perm ? hdim_v * page_block_size : page_block_size)
|
||||
: (i_perm ? hdim_v * shape_seqlen_k : shape_seqlen_k);
|
||||
}();
|
||||
const ck_tile::index_t nhead_stride_bias =
|
||||
(i_perm ? 0 * shape_seqlen_q * shape_seqlen_k : 0 * shape_seqlen_k);
|
||||
@@ -865,78 +914,86 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
const ck_tile::index_t nhead_stride_o_acc = (max_seqlen_q * hdim_v);
|
||||
const ck_tile::index_t nhead_stride_o = (o_perm ? shape_seqlen_q * hdim_v : hdim_v);
|
||||
// setup batch_stride_* arguments
|
||||
const ck_tile::index_t batch_stride_q = (nhead * shape_seqlen_q * hdim_q);
|
||||
const ck_tile::index_t batch_stride_k = (nhead_k * shape_seqlen_k * hdim_q);
|
||||
const ck_tile::index_t batch_stride_v = (nhead_k * hdim_v * shape_seqlen_k);
|
||||
const ck_tile::index_t batch_stride_q = (nhead * shape_seqlen_q * hdim_q);
|
||||
const ck_tile::index_t batch_stride_k =
|
||||
(ENABLE_PAGED_KVCACHE && 0 < page_block_size ? (nhead_k * page_block_size * hdim_q)
|
||||
: (nhead_k * shape_seqlen_k * hdim_q));
|
||||
const ck_tile::index_t batch_stride_v =
|
||||
(ENABLE_PAGED_KVCACHE && 0 < page_block_size ? (nhead_k * hdim_v * page_block_size)
|
||||
: (nhead_k * hdim_v * shape_seqlen_k));
|
||||
const ck_tile::index_t batch_stride_bias = (0 * nhead * shape_seqlen_q * shape_seqlen_k);
|
||||
const ck_tile::index_t batch_stride_randval = (nhead * shape_seqlen_q * max_seqlen_k);
|
||||
const ck_tile::index_t batch_stride_lse = (nhead * max_seqlen_q);
|
||||
const ck_tile::index_t batch_stride_lse_acc = (nhead * max_seqlen_q);
|
||||
const ck_tile::index_t batch_stride_o_acc = (nhead * max_seqlen_q * hdim_v);
|
||||
const ck_tile::index_t batch_stride_o = (nhead * shape_seqlen_q * hdim_v);
|
||||
const ck_tile::index_t batch_stride_block_table = (max_num_blocks / batch);
|
||||
// setup split_stride_* arguments (only used in split-kv kernel)
|
||||
const ck_tile::index_t split_stride_lse_acc = (batch * nhead * max_seqlen_q);
|
||||
const ck_tile::index_t split_stride_o_acc = (batch * nhead * max_seqlen_q * hdim_v);
|
||||
|
||||
return fmha_fwd_args{q_buf.GetDeviceBuffer(),
|
||||
k_buf.GetDeviceBuffer(),
|
||||
v_buf.GetDeviceBuffer(),
|
||||
bias.type == bias_enum::alibi ? alibi_slope_buf.GetDeviceBuffer()
|
||||
: bias_buf.GetDeviceBuffer(),
|
||||
randval_buf.GetDeviceBuffer(),
|
||||
lse_acc_buf.GetDeviceBuffer(),
|
||||
o_acc_buf.GetDeviceBuffer(),
|
||||
lse_buf.GetDeviceBuffer(),
|
||||
o_buf.GetDeviceBuffer(),
|
||||
seqstart_q.GetDeviceBuffer(),
|
||||
seqstart_k.GetDeviceBuffer(),
|
||||
k_paddings_[0] < 0 ? nullptr : seqlen_k_buf.GetDeviceBuffer(),
|
||||
shape_seqlen_q,
|
||||
shape_seqlen_k,
|
||||
batch,
|
||||
max_seqlen_q,
|
||||
hdim_q,
|
||||
hdim_v,
|
||||
nhead,
|
||||
nhead_k,
|
||||
num_splits,
|
||||
scale_s,
|
||||
scale_p,
|
||||
scale_o,
|
||||
stride_q,
|
||||
stride_k,
|
||||
stride_v,
|
||||
bias.type == bias_enum::alibi ? (bias.rank_info == 0 ? 0 : nhead)
|
||||
: stride_bias,
|
||||
stride_randval,
|
||||
stride_o_acc,
|
||||
stride_o,
|
||||
nhead_stride_q,
|
||||
nhead_stride_k,
|
||||
nhead_stride_v,
|
||||
nhead_stride_bias,
|
||||
nhead_stride_randval,
|
||||
nhead_stride_lse,
|
||||
nhead_stride_lse_acc,
|
||||
nhead_stride_o_acc,
|
||||
nhead_stride_o,
|
||||
batch_stride_q,
|
||||
batch_stride_k,
|
||||
batch_stride_v,
|
||||
batch_stride_bias,
|
||||
batch_stride_randval,
|
||||
batch_stride_lse,
|
||||
batch_stride_lse_acc,
|
||||
batch_stride_o_acc,
|
||||
batch_stride_o,
|
||||
split_stride_lse_acc,
|
||||
split_stride_o_acc,
|
||||
mask.left,
|
||||
mask.right,
|
||||
static_cast<ck_tile::index_t>(mask.type),
|
||||
p_drop,
|
||||
s_randval,
|
||||
{drop_seed, drop_offset}};
|
||||
return fmha_fwd_args{
|
||||
q_buf.GetDeviceBuffer(),
|
||||
k_buf.GetDeviceBuffer(),
|
||||
v_buf.GetDeviceBuffer(),
|
||||
bias.type == bias_enum::alibi ? alibi_slope_buf.GetDeviceBuffer()
|
||||
: bias_buf.GetDeviceBuffer(),
|
||||
randval_buf.GetDeviceBuffer(),
|
||||
1 < num_splits ? lse_acc_buf.GetDeviceBuffer() : nullptr,
|
||||
1 < num_splits ? o_acc_buf.GetDeviceBuffer() : nullptr,
|
||||
lse_buf.GetDeviceBuffer(),
|
||||
o_buf.GetDeviceBuffer(),
|
||||
seqstart_q.GetDeviceBuffer(),
|
||||
seqstart_k.GetDeviceBuffer(),
|
||||
k_paddings_[0] < 0 ? nullptr : seqlen_k_buf.GetDeviceBuffer(),
|
||||
0 < page_block_size ? block_table_buf.GetDeviceBuffer() : nullptr,
|
||||
batch_stride_block_table, // only used if 'block_table_ptr' is not nullptr
|
||||
page_block_size, // only used if 'block_table_ptr' is not nullptr
|
||||
shape_seqlen_q,
|
||||
shape_seqlen_k,
|
||||
batch,
|
||||
max_seqlen_q,
|
||||
hdim_q,
|
||||
hdim_v,
|
||||
nhead,
|
||||
nhead_k,
|
||||
num_splits, // only used in splitkv kernel
|
||||
scale_s,
|
||||
scale_p,
|
||||
scale_o,
|
||||
stride_q,
|
||||
stride_k,
|
||||
stride_v,
|
||||
bias.type == bias_enum::alibi ? (bias.rank_info == 0 ? 0 : nhead) : stride_bias,
|
||||
stride_randval,
|
||||
stride_o_acc,
|
||||
stride_o,
|
||||
nhead_stride_q,
|
||||
nhead_stride_k,
|
||||
nhead_stride_v,
|
||||
nhead_stride_bias,
|
||||
nhead_stride_randval,
|
||||
nhead_stride_lse,
|
||||
nhead_stride_lse_acc,
|
||||
nhead_stride_o_acc,
|
||||
nhead_stride_o,
|
||||
batch_stride_q,
|
||||
batch_stride_k,
|
||||
batch_stride_v,
|
||||
batch_stride_bias,
|
||||
batch_stride_randval,
|
||||
batch_stride_lse,
|
||||
batch_stride_lse_acc,
|
||||
batch_stride_o_acc,
|
||||
batch_stride_o,
|
||||
split_stride_lse_acc,
|
||||
split_stride_o_acc,
|
||||
mask.left,
|
||||
mask.right,
|
||||
static_cast<ck_tile::index_t>(mask.type),
|
||||
p_drop,
|
||||
s_randval,
|
||||
{drop_seed, drop_offset}};
|
||||
}();
|
||||
|
||||
const float fwd_ave_time = fmha_fwd_dispatch(fmha_traits, fmha_args, stream_config);
|
||||
@@ -1018,8 +1075,20 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
}
|
||||
#endif
|
||||
|
||||
if(i_perm) k_host_ref.ForEach([&](auto& self, auto i) { self(i) = k_host(b, i[0] / nr, i[1] + key_offset, i[2]); });
|
||||
else k_host_ref.ForEach([&](auto& self, auto i) { self(i) = k_host(b, i[1] + key_offset, i[0] / nr, i[2]); });
|
||||
if (ENABLE_PAGED_KVCACHE && 0 < page_block_size) {
|
||||
if(i_perm) {
|
||||
k_host_ref.ForEach([&](auto& self, auto i) {
|
||||
self(i) = k_host(block_table_host(wb, i[1] / page_block_size), i[0] / nr, i[1] % page_block_size, i[2]);
|
||||
});
|
||||
} else {
|
||||
k_host_ref.ForEach([&](auto& self, auto i) {
|
||||
self(i) = k_host(block_table_host(wb, i[1] / page_block_size), i[1] % page_block_size, i[0] / nr, i[2]);
|
||||
});
|
||||
}
|
||||
} else {
|
||||
if(i_perm) k_host_ref.ForEach([&](auto& self, auto i) { self(i) = k_host(b, i[0] / nr, i[1] + key_offset, i[2]); });
|
||||
else k_host_ref.ForEach([&](auto& self, auto i) { self(i) = k_host(b, i[1] + key_offset, i[0] / nr, i[2]); });
|
||||
}
|
||||
|
||||
#if CK_TILE_FMHA_FWD_APPENDKV_API
|
||||
// copy Knew to the end of K
|
||||
@@ -1058,16 +1127,40 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
});
|
||||
}
|
||||
#endif
|
||||
|
||||
if (is_v_rowmajor) {
|
||||
// v_host_ref: [nhead, hdim, seq], v_host: [b, h_k, s, d]
|
||||
if(i_perm) v_host_ref.ForEach([&](auto& self, auto i) { self(i) = v_host(b, i[0] / nr, i[2] + key_offset, i[1]); });
|
||||
// v_host_ref: [nhead, hdim, seq], v_host: [b, s, h_k, d]
|
||||
else v_host_ref.ForEach([&](auto& self, auto i) { self(i) = v_host(b, i[2] + key_offset, i[0] / nr, i[1]); });
|
||||
}
|
||||
else {
|
||||
if(i_perm) v_host_ref.ForEach([&](auto& self, auto i) { self(i) = v_host(b, i[0] / nr, i[1], i[2] + key_offset); });
|
||||
else v_host_ref.ForEach([&](auto& self, auto i) { self(i) = v_host(b, i[1], i[0] / nr, i[2] + key_offset); });
|
||||
if (ENABLE_PAGED_KVCACHE && 0 < page_block_size) {
|
||||
if (is_v_rowmajor) {
|
||||
if(i_perm) {
|
||||
v_host_ref.ForEach([&](auto& self, auto i) {
|
||||
self(i) = v_host(block_table_host(wb, i[2] / page_block_size), i[0] / nr, i[2] % page_block_size, i[1]);
|
||||
});
|
||||
} else {
|
||||
v_host_ref.ForEach([&](auto& self, auto i) {
|
||||
self(i) = v_host(block_table_host(wb, i[2] / page_block_size), i[2] % page_block_size, i[0] / nr, i[1]);
|
||||
});
|
||||
}
|
||||
}
|
||||
else {
|
||||
if(i_perm) {
|
||||
v_host_ref.ForEach([&](auto& self, auto i) {
|
||||
self(i) = v_host(block_table_host(wb, i[2] / page_block_size), i[0] / nr, i[1], i[2] % page_block_size);
|
||||
});
|
||||
} else {
|
||||
v_host_ref.ForEach([&](auto& self, auto i) {
|
||||
self(i) = v_host(block_table_host(wb, i[2] / page_block_size), i[1], i[0] / nr, i[2] % page_block_size);
|
||||
});
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if (is_v_rowmajor) {
|
||||
// v_host_ref: [nhead, hdim, seq], v_host: [b, h_k, s, d]
|
||||
if(i_perm) v_host_ref.ForEach([&](auto& self, auto i) { self(i) = v_host(b, i[0] / nr, i[2] + key_offset, i[1]); });
|
||||
// v_host_ref: [nhead, hdim, seq], v_host: [b, s, h_k, d]
|
||||
else v_host_ref.ForEach([&](auto& self, auto i) { self(i) = v_host(b, i[2] + key_offset, i[0] / nr, i[1]); });
|
||||
}
|
||||
else {
|
||||
if(i_perm) v_host_ref.ForEach([&](auto& self, auto i) { self(i) = v_host(b, i[0] / nr, i[1], i[2] + key_offset); });
|
||||
else v_host_ref.ForEach([&](auto& self, auto i) { self(i) = v_host(b, i[1], i[0] / nr, i[2] + key_offset); });
|
||||
}
|
||||
}
|
||||
|
||||
#if CK_TILE_FMHA_FWD_APPENDKV_API
|
||||
|
||||
@@ -103,6 +103,9 @@ struct fmha_fwd_args
|
||||
const void* seqstart_q_ptr;
|
||||
const void* seqstart_k_ptr;
|
||||
const void* seqlen_k_ptr;
|
||||
void* block_table_ptr;
|
||||
ck_tile::index_t batch_stride_block_table;
|
||||
ck_tile::index_t page_block_size;
|
||||
ck_tile::index_t seqlen_q;
|
||||
ck_tile::index_t seqlen_k;
|
||||
ck_tile::index_t batch;
|
||||
@@ -315,6 +318,9 @@ auto fmha_fwd_splitkv_create_kargs_and_grids(fmha_fwd_args args)
|
||||
args.nhead_q,
|
||||
args.nhead_q / args.nhead_k,
|
||||
args.num_splits,
|
||||
args.block_table_ptr,
|
||||
args.batch_stride_block_table,
|
||||
args.page_block_size,
|
||||
args.scale_s,
|
||||
args.scale_p,
|
||||
args.stride_q,
|
||||
@@ -359,6 +365,9 @@ auto fmha_fwd_splitkv_create_kargs_and_grids(fmha_fwd_args args)
|
||||
args.nhead_q,
|
||||
args.nhead_q / args.nhead_k,
|
||||
args.num_splits,
|
||||
args.block_table_ptr,
|
||||
args.batch_stride_block_table,
|
||||
args.page_block_size,
|
||||
args.scale_s,
|
||||
args.scale_p,
|
||||
args.stride_q,
|
||||
@@ -585,6 +594,51 @@ struct fmha_fwd_traits_
|
||||
template <typename Traits_>
|
||||
float fmha_fwd_(const ck_tile::stream_config&, fmha_fwd_args);
|
||||
|
||||
template <ck_tile::index_t HDim,
|
||||
typename DataType,
|
||||
bool kIsGroupMode,
|
||||
ck_tile::index_t kM0,
|
||||
ck_tile::index_t kN0,
|
||||
ck_tile::index_t kK0,
|
||||
ck_tile::index_t kN1,
|
||||
ck_tile::index_t kK1,
|
||||
ck_tile::index_t kK0BlockLength,
|
||||
bool kIsVLayoutRowMajor,
|
||||
ck_tile::BlockFmhaPipelineEnum FmhaPipelineEnum,
|
||||
typename FmhaMask,
|
||||
ck_tile::BlockAttentionBiasEnum BiasEnum,
|
||||
bool kStoreLse,
|
||||
bool kHasDropout,
|
||||
bool kDoFp8StaticQuant,
|
||||
bool kIsPagedKV_,
|
||||
bool kPadS,
|
||||
bool kPadSK,
|
||||
bool kPadD,
|
||||
bool kPadDv>
|
||||
struct fmha_fwd_splitkv_traits_ : fmha_fwd_traits_<HDim,
|
||||
DataType,
|
||||
kIsGroupMode,
|
||||
kM0,
|
||||
kN0,
|
||||
kK0,
|
||||
kN1,
|
||||
kK1,
|
||||
kK0BlockLength,
|
||||
kIsVLayoutRowMajor,
|
||||
FmhaPipelineEnum,
|
||||
FmhaMask,
|
||||
BiasEnum,
|
||||
kStoreLse,
|
||||
kHasDropout,
|
||||
kDoFp8StaticQuant,
|
||||
kPadS,
|
||||
kPadSK,
|
||||
kPadD,
|
||||
kPadDv>
|
||||
{
|
||||
static constexpr bool kIsPagedKV = kIsPagedKV_;
|
||||
};
|
||||
|
||||
template <typename Traits_>
|
||||
void fmha_fwd_splitkv_oneshot_(const ck_tile::stream_config&, fmha_fwd_args);
|
||||
|
||||
|
||||
@@ -3,16 +3,17 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <algorithm>
|
||||
#include <cstdint>
|
||||
#include <cstdlib>
|
||||
#include <functional>
|
||||
#include <optional>
|
||||
#include <ostream>
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
#include <tuple>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
#include <functional>
|
||||
#include <string>
|
||||
|
||||
#include "ck_tile/core/container/span.hpp"
|
||||
|
||||
@@ -209,3 +210,15 @@ auto randint(Int low, Int high, std::optional<unsigned> seed = std::nullopt)
|
||||
std::uniform_int_distribution<Int> dist(low, high);
|
||||
return dist(engine);
|
||||
}
|
||||
|
||||
template <typename RandomAccessIterator, typename Int>
|
||||
std::enable_if_t<std::is_integral_v<Int>> iota_shuffle(RandomAccessIterator first,
|
||||
RandomAccessIterator last,
|
||||
Int value,
|
||||
std::optional<unsigned> seed = std::nullopt)
|
||||
{
|
||||
std::iota(first, last, value);
|
||||
|
||||
std::mt19937 engine(seed.has_value() ? *seed : std::random_device{}());
|
||||
std::shuffle(first, last, engine);
|
||||
}
|
||||
|
||||
@@ -48,6 +48,7 @@ struct FmhaFwdSplitKVKernel
|
||||
static constexpr auto BiasEnum = FmhaPipeline::BiasEnum;
|
||||
static constexpr bool kHasDropout = FmhaPipeline::kHasDropout;
|
||||
static constexpr bool kDoFp8StaticQuant = FmhaPipeline::Problem::kDoFp8StaticQuant;
|
||||
static constexpr bool kIsPagedKV = FmhaPipeline::Problem::kIsPagedKV;
|
||||
using FmhaMask = ck_tile::remove_cvref_t<typename FmhaPipeline::FmhaMask>;
|
||||
static constexpr bool kHasMask = FmhaMask::IsMasking;
|
||||
|
||||
@@ -122,6 +123,9 @@ struct FmhaFwdSplitKVKernel
|
||||
// if this param is larger than 1, indicate MQA/GQA case
|
||||
ck_tile::index_t nhead_ratio_qk;
|
||||
ck_tile::index_t num_splits;
|
||||
const void* block_table_ptr;
|
||||
ck_tile::index_t batch_stride_block_table;
|
||||
ck_tile::index_t page_block_size;
|
||||
|
||||
float scale_s;
|
||||
|
||||
@@ -254,6 +258,9 @@ struct FmhaFwdSplitKVKernel
|
||||
ck_tile::index_t num_head_q,
|
||||
ck_tile::index_t nhead_ratio_qk,
|
||||
ck_tile::index_t num_splits,
|
||||
const void* block_table_ptr,
|
||||
ck_tile::index_t batch_stride_block_table,
|
||||
ck_tile::index_t page_block_size,
|
||||
float scale_s,
|
||||
float scale_p,
|
||||
ck_tile::index_t stride_q,
|
||||
@@ -299,6 +306,9 @@ struct FmhaFwdSplitKVKernel
|
||||
num_head_q,
|
||||
nhead_ratio_qk,
|
||||
num_splits,
|
||||
block_table_ptr,
|
||||
batch_stride_block_table,
|
||||
page_block_size,
|
||||
#if CK_TILE_FMHA_FWD_FAST_EXP2
|
||||
static_cast<float>(scale_s * ck_tile::log2e_v<>),
|
||||
#else
|
||||
@@ -379,6 +389,9 @@ struct FmhaFwdSplitKVKernel
|
||||
ck_tile::index_t num_head_q,
|
||||
ck_tile::index_t nhead_ratio_qk,
|
||||
ck_tile::index_t num_splits,
|
||||
const void* block_table_ptr,
|
||||
ck_tile::index_t batch_stride_block_table,
|
||||
ck_tile::index_t page_block_size,
|
||||
float scale_s,
|
||||
float scale_p,
|
||||
ck_tile::index_t stride_q,
|
||||
@@ -419,6 +432,9 @@ struct FmhaFwdSplitKVKernel
|
||||
num_head_q,
|
||||
nhead_ratio_qk,
|
||||
num_splits,
|
||||
block_table_ptr,
|
||||
batch_stride_block_table,
|
||||
page_block_size,
|
||||
#if CK_TILE_FMHA_FWD_FAST_EXP2
|
||||
static_cast<float>(scale_s * ck_tile::log2e_v<>),
|
||||
#else
|
||||
@@ -565,8 +581,24 @@ struct FmhaFwdSplitKVKernel
|
||||
else
|
||||
{
|
||||
batch_offset_q = static_cast<long_index_t>(i_batch) * kargs.batch_stride_q;
|
||||
batch_offset_k = static_cast<long_index_t>(i_batch) * kargs.batch_stride_k;
|
||||
batch_offset_v = static_cast<long_index_t>(i_batch) * kargs.batch_stride_v;
|
||||
|
||||
if(true || kargs.block_table_ptr == nullptr)
|
||||
{
|
||||
batch_offset_k = static_cast<long_index_t>(i_batch) * kargs.batch_stride_k;
|
||||
batch_offset_v = static_cast<long_index_t>(i_batch) * kargs.batch_stride_v;
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto* block_table = reinterpret_cast<const int32_t*>(kargs.block_table_ptr) +
|
||||
i_batch * kargs.batch_stride_block_table;
|
||||
const auto i_block =
|
||||
static_cast<long_index_t>(block_table[i_n1 / kargs.page_block_size]);
|
||||
|
||||
batch_offset_k = static_cast<long_index_t>(i_batch) * kargs.batch_stride_k;
|
||||
// batch_offset_k = i_block * kargs.batch_stride_k;
|
||||
batch_offset_v = i_block * kargs.batch_stride_v;
|
||||
}
|
||||
|
||||
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
|
||||
{
|
||||
batch_offset_bias = static_cast<long_index_t>(i_batch) * kargs.batch_stride_bias;
|
||||
|
||||
@@ -51,6 +51,7 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
|
||||
static constexpr auto BiasEnum = Problem::BiasEnum;
|
||||
static constexpr bool kStoreLSE = true; // always store LSE (acc)
|
||||
static constexpr bool kHasDropout = false; // ignore this flag
|
||||
static constexpr bool kIsPagedKV = Problem::kIsPagedKV;
|
||||
static constexpr bool kHasUnevenSplits = Problem::kHasUnevenSplits;
|
||||
|
||||
// last dimension vector length used to create tensor view(and decide buffer_load vector length)
|
||||
|
||||
@@ -56,6 +56,7 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVSAsync
|
||||
static constexpr auto BiasEnum = Problem::BiasEnum;
|
||||
static constexpr bool kStoreLSE = true; // always store LSE (acc)
|
||||
static constexpr bool kHasDropout = false; // ignore this flag
|
||||
static constexpr bool kIsPagedKV = Problem::kIsPagedKV;
|
||||
static constexpr bool kHasUnevenSplits = Problem::kHasUnevenSplits;
|
||||
|
||||
// last dimension vector length used to create tensor view(and decide buffer_load vector length)
|
||||
|
||||
@@ -85,6 +85,7 @@ struct BlockFmhaFwdSplitKVPipelineProblem : BlockFmhaPipelineProblem<QDataType,
|
||||
FmhaMask,
|
||||
Traits>
|
||||
{
|
||||
static constexpr bool kIsPagedKV = Traits::kIsPagedKV;
|
||||
static constexpr bool kHasUnevenSplits = kIsGroupMode || Traits::kHasUnevenSplits;
|
||||
};
|
||||
|
||||
|
||||
@@ -42,8 +42,9 @@ template <bool kPadSeqLenQ /* padding for seqlen_q */,
|
||||
bool kStoreLSE,
|
||||
bool kHasDropout,
|
||||
bool kDoFp8StaticQuant,
|
||||
bool kHasUnevenSplits_ = true,
|
||||
index_t kBlockPerCu = -1 /* overwrite occupancy if not -1 */>
|
||||
bool kIsPagedKV_,
|
||||
bool kHasUnevenSplits_,
|
||||
index_t kBlockPerCu = -1 /* overwrite occupancy if not -1 */>
|
||||
struct TileFmhaFwdSplitKVTraits : TileFmhaTraits<kPadSeqLenQ,
|
||||
kPadSeqLenK,
|
||||
kPadHeadDimQ,
|
||||
@@ -55,6 +56,7 @@ struct TileFmhaFwdSplitKVTraits : TileFmhaTraits<kPadSeqLenQ,
|
||||
kDoFp8StaticQuant,
|
||||
kBlockPerCu>
|
||||
{
|
||||
static constexpr bool kIsPagedKV = kIsPagedKV_;
|
||||
// determine if some split (length) is not divisible by tile size
|
||||
static constexpr bool kHasUnevenSplits = kHasUnevenSplits_;
|
||||
};
|
||||
|
||||
Reference in New Issue
Block a user