mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-06 07:51:52 +00:00
Remove group mode from appendkv kernel
This commit is contained in:
@@ -41,7 +41,6 @@ using fmha_pipeline_problem_{F_idx} = ck_tile::BlockFmhaFwdAppendKVPipelineProbl
|
||||
{F_bd},
|
||||
{F_bdv},
|
||||
{F_vlayout},
|
||||
{F_mode},
|
||||
fmha_trait_{F_idx}>;
|
||||
|
||||
using fmha_pipeline_{F_idx} = ck_tile::BlockFmhaFwdAppendKVPipeline<
|
||||
@@ -51,7 +50,7 @@ using fmha_kernel_{F_idx} =
|
||||
ck_tile::FmhaFwdAppendKVKernel<ck_tile::FmhaFwdAppendKVTilePartitioner<{F_bs}, {F_bsk}, {F_bd}, {F_bdv}>,
|
||||
fmha_pipeline_{F_idx}>;
|
||||
|
||||
using trait_{F_idx} = fmha_fwd_appendkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bs}, {F_bsk}, {F_bd}, {F_bdv}, {F_vlayout},
|
||||
using trait_{F_idx} = fmha_fwd_appendkv_traits_<{F_hdim}, {F_dtype}, {F_bs}, {F_bsk}, {F_bd}, {F_bdv}, {F_vlayout},
|
||||
{F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_rope}, {F_pagedkv}>;
|
||||
|
||||
#include <iostream>
|
||||
@@ -78,10 +77,10 @@ float fmha_fwd_appendkv(fmha_fwd_appendkv_traits t, fmha_fwd_appendkv_args a, co
|
||||
}}
|
||||
"""
|
||||
|
||||
FMHA_FWD_APPENDKV_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) &&
|
||||
FMHA_FWD_APPENDKV_API_INNER_DISPATCH=""" {F_if}((t.is_v_rowmajor == {F_vlayout}) &&
|
||||
({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck}) && (t.rope_type == {F_rope_check}) &&
|
||||
((a.block_table_ptr != nullptr) == {F_pagedkv})) {{
|
||||
using trait_ = fmha_fwd_appendkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bs}, {F_bsk}, {F_bd}, {F_bdv}, {F_vlayout}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_rope}, {F_pagedkv}>;
|
||||
using trait_ = fmha_fwd_appendkv_traits_<{F_hdim}, {F_dtype}, {F_bs}, {F_bsk}, {F_bd}, {F_bdv}, {F_vlayout}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_rope}, {F_pagedkv}>;
|
||||
return fmha_fwd_appendkv_<trait_>(s, a);
|
||||
}}
|
||||
"""
|
||||
@@ -91,7 +90,6 @@ class FmhaFwdAppendKVApiTrait:
|
||||
# sync with fmha_fwd_traits<>, to generate fallback calls
|
||||
hdim : str
|
||||
dtype : str # data type
|
||||
mode : str # value from MODE_MAP
|
||||
bs : int # tile size along q seqlen
|
||||
bsk : int # tile size along k seqlen
|
||||
bd : int # tile size along qk gemm unroll
|
||||
@@ -106,13 +104,11 @@ class FmhaFwdAppendKVApiTrait:
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return f'{self.hdim}-{self.dtype}-{self.mode}-{self.bs}-{self.bsk}-{self.bd}-{self.bdv}-'+\
|
||||
f'{self.vlayout}-{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}-{self.rope}-'+\
|
||||
f'{self.pagedkv}'
|
||||
return f'{self.hdim}-{self.dtype}-{self.bs}-{self.bsk}-{self.bd}-{self.bdv}-{self.vlayout}-'+\
|
||||
f'{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}-{self.rope}-{self.pagedkv}'
|
||||
|
||||
@property
|
||||
def scheck(self) -> str:
|
||||
if self.mode == 'group': return 'true/*group mode spad always true*/' # group mode only generate spad/skpad == true
|
||||
if self.spad == 't' : return f'true /*a.seqlen_q % {self.bs} != 0*/'
|
||||
else : return f'a.seqlen_q % {self.bs} == 0'
|
||||
|
||||
@@ -183,7 +179,7 @@ class FmhaFwdAppendKVApiPool:
|
||||
inners=str()
|
||||
for k, trait in enumerate(traits):
|
||||
if_k = 'if' if k == 0 else 'else if'
|
||||
inners = inners + FMHA_FWD_APPENDKV_API_INNER_DISPATCH.format(F_if=if_k, F_mode=MODE_MAP[trait.mode], F_vlayout=LAYOUT_MAP[trait.vlayout],
|
||||
inners = inners + FMHA_FWD_APPENDKV_API_INNER_DISPATCH.format(F_if=if_k, F_vlayout=LAYOUT_MAP[trait.vlayout],
|
||||
F_scheck=trait.scheck, F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck, F_rope_check=ROPE_CHECK_MAP[trait.rope],
|
||||
F_pagedkv=BOOL_MAP[trait.pagedkv], F_spad=BOOL_MAP[trait.spad], F_skpad=BOOL_MAP[trait.skpad], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad],
|
||||
F_rope=ROPE_MAP[trait.rope], F_bs=trait.bs, F_bsk=trait.bsk, F_bd=trait.bd, F_bdv=trait.bdv, F_hdim=hdim, F_dtype=DTYPE_MAP[dtype])
|
||||
@@ -210,7 +206,6 @@ class FmhaFwdAppendKVKernel:
|
||||
F_idx : int # this is not a tunable, but a counter to differentiate symbol
|
||||
F_hdim : int # hdim
|
||||
F_dtype : str # data type
|
||||
F_mode : str # value from MODE_MAP
|
||||
F_tile : FmhaFwdAppendKVTileSize
|
||||
F_pipeline : FmhaFwdAppendKVPipeline
|
||||
mask_impl : str
|
||||
@@ -234,13 +229,12 @@ class FmhaFwdAppendKVKernel:
|
||||
F_dvpad = BOOL_MAP[self.F_pipeline.F_dvpad],
|
||||
F_rope = ROPE_MAP[self.F_pipeline.F_rope],
|
||||
F_pagedkv = BOOL_MAP[self.F_pipeline.F_pagedkv],
|
||||
F_occupancy = self.F_tile.F_occupancy,
|
||||
F_mode = MODE_MAP[self.F_mode])
|
||||
F_occupancy = self.F_tile.F_occupancy)
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
# TODO: we don't encode idx here
|
||||
return f"fmha_fwd_appendkv_d{self.F_hdim}_{self.F_dtype}_{self.F_mode}_" + \
|
||||
return f"fmha_fwd_appendkv_d{self.F_hdim}_{self.F_dtype}_" + \
|
||||
self.F_tile.name + '_' + self.F_pipeline.name
|
||||
|
||||
@property
|
||||
@@ -251,7 +245,6 @@ class FmhaFwdAppendKVKernel:
|
||||
return FmhaFwdAppendKVApiTrait(
|
||||
hdim=str(self.F_hdim),
|
||||
dtype=self.F_dtype,
|
||||
mode=self.F_mode,
|
||||
bs=self.F_tile.F_bs,
|
||||
bsk=self.F_tile.F_bsk,
|
||||
bd=self.F_tile.F_bd,
|
||||
@@ -320,19 +313,13 @@ def get_fwd_appendkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) ->
|
||||
d = get_fmha_fwd_appendkv_tile_dict_from_dtype(dtype)
|
||||
if d == None:
|
||||
continue
|
||||
#for hdim_str, mode, mask, bias, lse in itertools.product(d.keys(), MODE_MAP.keys(), MASK_MAP.keys(), ["t", "f"], ["t", "f"]):
|
||||
for hdim_str, mode in itertools.product(d.keys(), MODE_MAP.keys()):
|
||||
for hdim_str in d.keys():
|
||||
tile = d[hdim_str]
|
||||
hdim = int(hdim_str)
|
||||
for pipeline in get_pipelines(dtype, hdim):
|
||||
if mode == "group":
|
||||
if pipeline.F_spad != 't' or pipeline.F_skpad != 't':
|
||||
# in group mode, spad/skpad must be true, since we can't predict if seqlen of current batch need pad or not
|
||||
continue
|
||||
k = FmhaFwdAppendKVKernel(F_idx=0,
|
||||
F_hdim=hdim,
|
||||
F_dtype=dtype,
|
||||
F_mode=mode,
|
||||
F_tile=tile,
|
||||
F_pipeline=pipeline,
|
||||
mask_impl=mask_impl)
|
||||
|
||||
@@ -258,7 +258,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
{
|
||||
std::string data_type = arg_parser.get_str("prec");
|
||||
int do_validation = arg_parser.get_int("v");
|
||||
auto mode = static_cast<mode_enum>(arg_parser.get_uint32("mode"));
|
||||
|
||||
ck_tile::index_t batch = arg_parser.get_int("b");
|
||||
ck_tile::index_t nhead = arg_parser.get_int("h");
|
||||
ck_tile::index_t nhead_k = arg_parser.get_int("h_k");
|
||||
@@ -278,7 +278,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
}
|
||||
|
||||
ck_tile::index_t seqlen_knew = arg_parser.get_int("s_knew");
|
||||
#if !CK_TILE_FMHA_FWD_APPENDKV_API
|
||||
#if !CK_TILE_FMHA_FWD_APPENDKV_API || !CK_TILE_FMHA_FWD_SPLITKV_API
|
||||
if(seqlen_knew != 0)
|
||||
{
|
||||
std::cerr << "kvcache is not supported. ignoring the 's_knew' option" << std::endl;
|
||||
@@ -290,6 +290,29 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
seqlen_knew = randint<ck_tile::index_t>(1, arg_parser.get_int("s"), seed);
|
||||
}
|
||||
|
||||
ck_tile::index_t page_block_size = arg_parser.get_int("page_block_size");
|
||||
#if !CK_TILE_FMHA_FWD_SPLITKV_API
|
||||
if(0 < page_block_size)
|
||||
{
|
||||
std::cerr << "paged-kvcache is not supported. ignoring the 'page_block_size' option"
|
||||
<< std::endl;
|
||||
page_block_size = 0;
|
||||
}
|
||||
#endif
|
||||
if(!(page_block_size % 128 == 0))
|
||||
{
|
||||
std::cerr << "only paged-kvcache block size divisible by 128 are currently supported"
|
||||
<< std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
auto mode = static_cast<mode_enum>(arg_parser.get_uint32("mode"));
|
||||
if ((0 < seqlen_knew || 0 < page_block_size) && mode != mode_enum::batch) {
|
||||
std::cerr << "kvcache enabled. ignoring the 'mode' option"
|
||||
<< std::endl;
|
||||
mode = mode_enum::batch;
|
||||
}
|
||||
|
||||
auto [seqlen_qs, seqlen_ks, seqlen_kpads] = decode_seqlen(mode,
|
||||
batch,
|
||||
arg_parser.get_str("s"),
|
||||
@@ -420,21 +443,6 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
num_splits = 1;
|
||||
}
|
||||
#endif
|
||||
ck_tile::index_t page_block_size = arg_parser.get_int("page_block_size");
|
||||
#if !CK_TILE_FMHA_FWD_SPLITKV_API
|
||||
if(0 < page_block_size)
|
||||
{
|
||||
std::cerr << "paged-kvcache is not supported. ignoring the 'page_block_size' option"
|
||||
<< std::endl;
|
||||
page_block_size = 0;
|
||||
}
|
||||
#endif
|
||||
if(!(page_block_size % 128 == 0))
|
||||
{
|
||||
std::cerr << "only paged-kvcache block size divisible by 128 are currently supported"
|
||||
<< std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
int stream_warmup = arg_parser.get_int("warmup");
|
||||
int stream_repeat = arg_parser.get_int("repeat");
|
||||
@@ -581,11 +589,11 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
generate_rotary_cos_sin<KDataType>(shape_seqlen_k, rotary_dim, seed);
|
||||
|
||||
ck_tile::HostTensor<LSEDataType> lse_acc_host(
|
||||
1 < num_splits || 0 < page_block_size
|
||||
1 < num_splits || 0 < seqlen_knew || 0 < page_block_size
|
||||
? std::array<ck_tile::index_t, 4>{num_splits, batch, nhead, max_seqlen_q}
|
||||
: std::array<ck_tile::index_t, 4>{1, 1, 1, 1});
|
||||
ck_tile::HostTensor<OaccDataType> o_acc_host(
|
||||
1 < num_splits || 0 < page_block_size
|
||||
1 < num_splits || 0 < seqlen_knew || 0 < page_block_size
|
||||
? std::array<ck_tile::index_t, 5>{num_splits, batch, nhead, max_seqlen_q, hdim_v}
|
||||
: std::array<ck_tile::index_t, 5>{1, 1, 1, 1, 1});
|
||||
|
||||
@@ -762,7 +770,6 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
traits.hdim_q = hdim_q;
|
||||
traits.hdim_v = hdim_v;
|
||||
traits.data_type = data_type;
|
||||
traits.is_group_mode = (mode == mode_enum::group);
|
||||
traits.is_v_rowmajor = is_v_rowmajor;
|
||||
|
||||
if constexpr(std::is_same_v<fmha_fwd_appendkv_traits, std::decay_t<decltype(traits)>>)
|
||||
@@ -773,6 +780,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
}
|
||||
else // fmha_fwd_traits or fmha_splitkv_traits
|
||||
{
|
||||
traits.is_group_mode = (mode == mode_enum::group);
|
||||
traits.mask_type = mask.type;
|
||||
traits.bias_type = bias.type;
|
||||
traits.has_lse = lse;
|
||||
@@ -863,12 +871,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
args.k_ptr = k_buf.GetDeviceBuffer();
|
||||
args.v_ptr = v_buf.GetDeviceBuffer();
|
||||
|
||||
args.seqstart_q_ptr = seqstart_q.GetDeviceBuffer();
|
||||
args.seqstart_k_ptr = seqstart_k.GetDeviceBuffer();
|
||||
|
||||
args.seqlen_q = shape_seqlen_q;
|
||||
args.batch = batch;
|
||||
args.max_seqlen_q = max_seqlen_q;
|
||||
args.seqlen_q = shape_seqlen_q;
|
||||
args.hdim_q = hdim_q;
|
||||
args.hdim_v = hdim_v;
|
||||
args.nhead_q = nhead;
|
||||
@@ -891,7 +895,6 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
args.seqlen_knew = seqlen_knew;
|
||||
|
||||
args.seqlen_k_ptr = cache_seqlen_k_buf.GetDeviceBuffer();
|
||||
args.seqlen_k = shape_seqlen_k - seqlen_knew; // kvcache seqlen for batch mode
|
||||
|
||||
args.rotary_cos_ptr = rotary_cos_buf.GetDeviceBuffer();
|
||||
args.rotary_sin_ptr = rotary_sin_buf.GetDeviceBuffer();
|
||||
@@ -916,8 +919,12 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
args.lse_ptr = lse_buf.GetDeviceBuffer();
|
||||
args.o_ptr = o_buf.GetDeviceBuffer();
|
||||
|
||||
args.seqlen_k_ptr = k_paddings_[0] < 0 ? nullptr : seqlen_k_buf.GetDeviceBuffer();
|
||||
args.seqlen_k = shape_seqlen_k;
|
||||
args.seqstart_q_ptr = (mode == mode_enum::group ? seqstart_q.GetDeviceBuffer() : nullptr);
|
||||
args.seqstart_k_ptr = (mode == mode_enum::group ? seqstart_k.GetDeviceBuffer() : nullptr);
|
||||
args.seqlen_k_ptr = (0 < seqlen_knew || 0 < page_block_size || 0 <= k_paddings_[0] ? seqlen_k_buf.GetDeviceBuffer() : nullptr);
|
||||
|
||||
args.seqlen_k = (args.seqlen_k_ptr == nullptr ? shape_seqlen_k : -1);
|
||||
args.max_seqlen_q = max_seqlen_q;
|
||||
|
||||
args.scale_s = scale_s;
|
||||
args.scale_p = scale_p;
|
||||
@@ -990,7 +997,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
|
||||
const float fwd_ave_time = [&] {
|
||||
#if CK_TILE_FMHA_FWD_SPLITKV_API
|
||||
if(1 < num_splits || 0 < page_block_size)
|
||||
if(1 < num_splits || 0 < seqlen_knew || 0 < page_block_size)
|
||||
{
|
||||
fmha_fwd_splitkv_traits fmha_splitkv_traits;
|
||||
init_traits(fmha_splitkv_traits);
|
||||
|
||||
@@ -165,10 +165,11 @@ struct fmha_fwd_splitkv_args
|
||||
const void* seqstart_q_ptr;
|
||||
const void* seqstart_k_ptr;
|
||||
const void*
|
||||
seqlen_k_ptr; // only used if both 'seqstart_q_ptr' & 'seqstart_k_ptr' are not nullptr
|
||||
seqlen_k_ptr; // only used if both 'seqstart_q_ptr' & 'seqstart_k_ptr' are not nullptr, or
|
||||
// kvcache is used
|
||||
|
||||
ck_tile::index_t seqlen_q;
|
||||
ck_tile::index_t seqlen_k;
|
||||
ck_tile::index_t seqlen_k; // only used if 'seqlen_k_ptr' is nullptr
|
||||
ck_tile::index_t batch;
|
||||
ck_tile::index_t max_seqlen_q;
|
||||
ck_tile::index_t hdim_q;
|
||||
@@ -219,16 +220,11 @@ struct fmha_fwd_appendkv_args
|
||||
void* v_ptr;
|
||||
const void* vnew_ptr;
|
||||
|
||||
const void* seqstart_q_ptr;
|
||||
const void* seqstart_k_ptr;
|
||||
const void*
|
||||
seqlen_k_ptr; // only used if both 'seqstart_q_ptr' & 'seqstart_k_ptr' are not nullptr
|
||||
const void* seqlen_k_ptr;
|
||||
|
||||
ck_tile::index_t seqlen_q;
|
||||
ck_tile::index_t seqlen_k;
|
||||
ck_tile::index_t seqlen_knew;
|
||||
ck_tile::index_t batch;
|
||||
ck_tile::index_t max_seqlen_q;
|
||||
ck_tile::index_t hdim_q;
|
||||
ck_tile::index_t hdim_v;
|
||||
ck_tile::index_t nhead_q;
|
||||
@@ -371,7 +367,6 @@ auto fmha_fwd_splitkv_create_kargs_and_grids(fmha_fwd_splitkv_args args)
|
||||
args.lse_acc_ptr,
|
||||
args.o_acc_ptr,
|
||||
args.batch,
|
||||
args.max_seqlen_q,
|
||||
args.seqstart_q_ptr,
|
||||
args.seqstart_k_ptr,
|
||||
args.seqlen_k_ptr,
|
||||
@@ -415,9 +410,9 @@ auto fmha_fwd_splitkv_create_kargs_and_grids(fmha_fwd_splitkv_args args)
|
||||
args.lse_acc_ptr,
|
||||
args.o_acc_ptr,
|
||||
args.batch,
|
||||
args.max_seqlen_q,
|
||||
args.seqlen_q,
|
||||
args.seqlen_k,
|
||||
args.seqlen_k_ptr,
|
||||
args.hdim_q,
|
||||
args.hdim_v,
|
||||
args.nhead_q,
|
||||
@@ -525,53 +520,13 @@ template <typename Kernel>
|
||||
auto fmha_fwd_appendkv_create_kargs_and_grids(fmha_fwd_appendkv_args args)
|
||||
{
|
||||
assert(args.nhead_q % args.nhead_k == 0);
|
||||
auto kargs = [&] {
|
||||
// create group mode kernel arguments
|
||||
if constexpr(Kernel::kIsGroupMode)
|
||||
{
|
||||
return Kernel::MakeKargs(args.q_ptr,
|
||||
args.k_ptr,
|
||||
args.knew_ptr,
|
||||
args.v_ptr,
|
||||
args.vnew_ptr,
|
||||
args.seqstart_q_ptr,
|
||||
args.seqstart_k_ptr,
|
||||
args.seqlen_k_ptr,
|
||||
args.seqlen_knew,
|
||||
args.hdim_q,
|
||||
args.hdim_v,
|
||||
args.nhead_q,
|
||||
args.nhead_q / args.nhead_k,
|
||||
args.rotary_cos_ptr,
|
||||
args.rotary_sin_ptr,
|
||||
args.rotary_dim,
|
||||
args.block_table_ptr,
|
||||
args.batch_stride_block_table,
|
||||
args.page_block_size,
|
||||
args.stride_q,
|
||||
args.stride_k,
|
||||
args.stride_knew,
|
||||
args.stride_v,
|
||||
args.stride_vnew,
|
||||
args.nhead_stride_q,
|
||||
args.nhead_stride_k,
|
||||
args.nhead_stride_knew,
|
||||
args.nhead_stride_v,
|
||||
args.nhead_stride_vnew,
|
||||
args.batch_stride_k,
|
||||
args.batch_stride_knew,
|
||||
args.batch_stride_v,
|
||||
args.batch_stride_vnew);
|
||||
}
|
||||
else
|
||||
{ // create batch mode kernel arguments
|
||||
return Kernel::MakeKargs(args.q_ptr,
|
||||
auto kargs = Kernel::MakeKargs(args.q_ptr,
|
||||
args.k_ptr,
|
||||
args.knew_ptr,
|
||||
args.v_ptr,
|
||||
args.vnew_ptr,
|
||||
args.seqlen_q,
|
||||
args.seqlen_k,
|
||||
args.seqlen_k_ptr,
|
||||
args.seqlen_knew,
|
||||
args.hdim_q,
|
||||
args.hdim_v,
|
||||
@@ -598,10 +553,8 @@ auto fmha_fwd_appendkv_create_kargs_and_grids(fmha_fwd_appendkv_args args)
|
||||
args.batch_stride_knew,
|
||||
args.batch_stride_v,
|
||||
args.batch_stride_vnew);
|
||||
}
|
||||
}();
|
||||
|
||||
dim3 grids = Kernel::GridSize(args.batch, args.nhead_q, args.max_seqlen_q, args.seqlen_knew);
|
||||
dim3 grids = Kernel::GridSize(args.batch, args.nhead_q, args.seqlen_q, args.seqlen_knew);
|
||||
|
||||
return ck_tile::make_tuple(kargs, grids);
|
||||
}
|
||||
@@ -735,7 +688,6 @@ std::string fmha_fwd_splitkv_combine_get_name_();
|
||||
// this is used to pattern-match internl kernel implementation, not to instantiate kernel
|
||||
template <ck_tile::index_t HDim_,
|
||||
typename DataType_,
|
||||
bool kIsGroupMode_,
|
||||
ck_tile::index_t kTileSizeS_,
|
||||
ck_tile::index_t kTileSizeSk_,
|
||||
ck_tile::index_t kTileSizeD_,
|
||||
@@ -751,7 +703,6 @@ struct fmha_fwd_appendkv_traits_
|
||||
{
|
||||
static constexpr ck_tile::index_t HDim = HDim_;
|
||||
using DataType = ck_tile::remove_cvref_t<DataType_>;
|
||||
static constexpr bool kIsGroupMode = kIsGroupMode_;
|
||||
static constexpr ck_tile::index_t kTileSizeS = kTileSizeS_;
|
||||
static constexpr ck_tile::index_t kTileSizeSk = kTileSizeSk_;
|
||||
static constexpr ck_tile::index_t kTileSizeD = kTileSizeD_;
|
||||
@@ -807,7 +758,6 @@ struct fmha_fwd_appendkv_traits
|
||||
int hdim_q;
|
||||
int hdim_v;
|
||||
std::string data_type;
|
||||
bool is_group_mode;
|
||||
bool is_v_rowmajor;
|
||||
rope_enum rope_type;
|
||||
};
|
||||
|
||||
@@ -26,7 +26,6 @@ struct FmhaFwdAppendKVKernel
|
||||
|
||||
using VLayout = ck_tile::remove_cvref_t<typename FmhaPipeline::VLayout>;
|
||||
|
||||
static constexpr bool kIsGroupMode = FmhaPipeline::kIsGroupMode;
|
||||
static constexpr bool kPadSeqLenQ = FmhaPipeline::kPadSeqLenQ;
|
||||
static constexpr bool kPadSeqLenK = FmhaPipeline::kPadSeqLenK;
|
||||
static constexpr bool kPadHeadDimQ = FmhaPipeline::kPadHeadDimQ;
|
||||
@@ -58,8 +57,7 @@ struct FmhaFwdAppendKVKernel
|
||||
if (kPadHeadDimV) n += "dv";
|
||||
return n.empty() ? n : std::string("p") + n; }();
|
||||
return
|
||||
_SS_("fmha_fwd_appendkv_d") + _TS_(FmhaPipeline::kK0) + "_" + _SS_(t2s<QDataType>::name) +
|
||||
"_" + (kIsGroupMode ? "group" : "batch") + "_"
|
||||
_SS_("fmha_fwd_appendkv_d") + _TS_(FmhaPipeline::kK0) + "_" + _SS_(t2s<QDataType>::name) + "_"
|
||||
"b" + _TS_(FmhaPipeline::kM0) + "x" + _TS_(FmhaPipeline::kN0) + "x" + _TS_(FmhaPipeline::kK0) + "x" +
|
||||
_TS_(FmhaPipeline::kN1) + "_" + (kBlockPerCuInput == -1 ? "" : ("o" + _TS_(kBlockPerCu) + "_")) +
|
||||
"v" + (std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor> ? "r" : "c") + (pn.empty() ? "" : "_" + pn)
|
||||
@@ -79,7 +77,7 @@ struct FmhaFwdAppendKVKernel
|
||||
// kargs use aggregate initializer, so no constructor will provided
|
||||
// use inheritance to minimize karg size
|
||||
// user need to use MakeKargs() function to create kargs.
|
||||
struct CommonKargs
|
||||
struct BasicKargs
|
||||
{
|
||||
void* q_ptr;
|
||||
void* k_ptr;
|
||||
@@ -87,6 +85,8 @@ struct FmhaFwdAppendKVKernel
|
||||
void* v_ptr;
|
||||
const void* vnew_ptr;
|
||||
|
||||
const int32_t* seqlen_k_ptr;
|
||||
|
||||
ck_tile::index_t seqlen_q;
|
||||
ck_tile::index_t seqlen_k;
|
||||
ck_tile::index_t seqlen_knew;
|
||||
@@ -114,47 +114,32 @@ struct FmhaFwdAppendKVKernel
|
||||
ck_tile::index_t nhead_stride_v;
|
||||
ck_tile::index_t nhead_stride_vnew;
|
||||
|
||||
ck_tile::index_t batch_stride_q;
|
||||
ck_tile::index_t batch_stride_k;
|
||||
ck_tile::index_t batch_stride_knew;
|
||||
ck_tile::index_t batch_stride_v;
|
||||
ck_tile::index_t batch_stride_vnew;
|
||||
};
|
||||
|
||||
struct CommonRoPEKargs
|
||||
struct RoPEKargs
|
||||
{
|
||||
const void* rotary_cos_ptr;
|
||||
const void* rotary_sin_ptr;
|
||||
ck_tile::index_t rotary_dim;
|
||||
};
|
||||
|
||||
struct BatchModeKargs : CommonKargs,
|
||||
std::conditional_t<kApplyRoPE, CommonRoPEKargs, EmptyKargs<0>>
|
||||
{
|
||||
ck_tile::index_t batch_stride_q;
|
||||
ck_tile::index_t batch_stride_k;
|
||||
ck_tile::index_t batch_stride_v;
|
||||
};
|
||||
struct Kargs : BasicKargs,
|
||||
std::conditional_t<kApplyRoPE, RoPEKargs, EmptyKargs<0>>
|
||||
{};
|
||||
|
||||
struct GroupModeKargs : CommonKargs,
|
||||
std::conditional_t<kApplyRoPE, CommonRoPEKargs, EmptyKargs<0>>
|
||||
{
|
||||
const int32_t* seqstart_q_ptr;
|
||||
const int32_t* seqstart_k_ptr;
|
||||
const int32_t* seqlen_k_ptr;
|
||||
|
||||
ck_tile::index_t batch_stride_k;
|
||||
ck_tile::index_t batch_stride_v;
|
||||
};
|
||||
|
||||
using Kargs = std::conditional_t<kIsGroupMode, GroupModeKargs, BatchModeKargs>;
|
||||
|
||||
template <bool Cond = !kIsGroupMode>
|
||||
__host__ static constexpr std::enable_if_t<Cond, Kargs>
|
||||
__host__ static constexpr Kargs
|
||||
MakeKargs(void* q_ptr,
|
||||
void* k_ptr,
|
||||
const void* knew_ptr,
|
||||
void* v_ptr,
|
||||
const void* vnew_ptr,
|
||||
ck_tile::index_t seqlen_q,
|
||||
ck_tile::index_t seqlen_k,
|
||||
const void* seqlen_k_ptr,
|
||||
ck_tile::index_t seqlen_knew,
|
||||
ck_tile::index_t hdim_q,
|
||||
ck_tile::index_t hdim_v,
|
||||
@@ -187,8 +172,9 @@ struct FmhaFwdAppendKVKernel
|
||||
knew_ptr,
|
||||
v_ptr,
|
||||
vnew_ptr,
|
||||
reinterpret_cast<const int32_t*>(seqlen_k_ptr),
|
||||
seqlen_q,
|
||||
seqlen_k,
|
||||
-1, // seqlen_k will be updated by content of seqlen_k_ptr
|
||||
seqlen_knew,
|
||||
hdim_q,
|
||||
hdim_v,
|
||||
@@ -207,92 +193,13 @@ struct FmhaFwdAppendKVKernel
|
||||
nhead_stride_knew,
|
||||
nhead_stride_v,
|
||||
nhead_stride_vnew,
|
||||
batch_stride_q,
|
||||
batch_stride_k,
|
||||
batch_stride_knew,
|
||||
batch_stride_v,
|
||||
batch_stride_vnew}, // args for common karg
|
||||
{}, // placeholder for rope
|
||||
batch_stride_q,
|
||||
batch_stride_k,
|
||||
batch_stride_v};
|
||||
|
||||
if constexpr(kApplyRoPE)
|
||||
{
|
||||
kargs.rotary_cos_ptr = rotary_cos_ptr;
|
||||
kargs.rotary_sin_ptr = rotary_sin_ptr;
|
||||
kargs.rotary_dim = rotary_dim;
|
||||
}
|
||||
|
||||
return kargs;
|
||||
}
|
||||
|
||||
template <bool Cond = kIsGroupMode>
|
||||
__host__ static constexpr std::enable_if_t<Cond, Kargs>
|
||||
MakeKargs(void* q_ptr,
|
||||
void* k_ptr,
|
||||
const void* knew_ptr,
|
||||
void* v_ptr,
|
||||
const void* vnew_ptr,
|
||||
const void* seqstart_q_ptr,
|
||||
const void* seqstart_k_ptr,
|
||||
const void* seqlen_k_ptr,
|
||||
ck_tile::index_t seqlen_knew,
|
||||
ck_tile::index_t hdim_q,
|
||||
ck_tile::index_t hdim_v,
|
||||
ck_tile::index_t num_head_q,
|
||||
ck_tile::index_t nhead_ratio_qk,
|
||||
const void* rotary_cos_ptr,
|
||||
const void* rotary_sin_ptr,
|
||||
ck_tile::index_t rotary_dim,
|
||||
const void* block_table_ptr,
|
||||
ck_tile::index_t batch_stride_block_table,
|
||||
ck_tile::index_t page_block_size,
|
||||
ck_tile::index_t stride_q,
|
||||
ck_tile::index_t stride_k,
|
||||
ck_tile::index_t stride_knew,
|
||||
ck_tile::index_t stride_v,
|
||||
ck_tile::index_t stride_vnew,
|
||||
ck_tile::index_t nhead_stride_q,
|
||||
ck_tile::index_t nhead_stride_k,
|
||||
ck_tile::index_t nhead_stride_knew,
|
||||
ck_tile::index_t nhead_stride_v,
|
||||
ck_tile::index_t nhead_stride_vnew,
|
||||
ck_tile::index_t batch_stride_k,
|
||||
ck_tile::index_t batch_stride_knew,
|
||||
ck_tile::index_t batch_stride_v,
|
||||
ck_tile::index_t batch_stride_vnew)
|
||||
{
|
||||
Kargs kargs{{q_ptr,
|
||||
k_ptr,
|
||||
knew_ptr,
|
||||
v_ptr,
|
||||
vnew_ptr,
|
||||
-1, // seqlen will be updated by another pointer
|
||||
-1, //
|
||||
seqlen_knew,
|
||||
hdim_q,
|
||||
hdim_v,
|
||||
num_head_q,
|
||||
nhead_ratio_qk,
|
||||
block_table_ptr,
|
||||
batch_stride_block_table,
|
||||
page_block_size,
|
||||
stride_q,
|
||||
stride_k,
|
||||
stride_knew,
|
||||
stride_v,
|
||||
stride_vnew,
|
||||
nhead_stride_q,
|
||||
nhead_stride_k,
|
||||
nhead_stride_knew,
|
||||
nhead_stride_v,
|
||||
nhead_stride_vnew,
|
||||
batch_stride_knew,
|
||||
batch_stride_vnew}, // args for common karg
|
||||
{}, // placeholder for rope
|
||||
reinterpret_cast<const int32_t*>(seqstart_q_ptr),
|
||||
reinterpret_cast<const int32_t*>(seqstart_k_ptr),
|
||||
reinterpret_cast<const int32_t*>(seqlen_k_ptr),
|
||||
batch_stride_k,
|
||||
batch_stride_v};
|
||||
{} // placeholder for rope
|
||||
};
|
||||
|
||||
if constexpr(kApplyRoPE)
|
||||
{
|
||||
@@ -322,51 +229,15 @@ struct FmhaFwdAppendKVKernel
|
||||
const index_t i_m0 = __builtin_amdgcn_readfirstlane(i_tile * FmhaPipeline::kM0);
|
||||
const index_t i_n0 = __builtin_amdgcn_readfirstlane(i_tile * FmhaPipeline::kN0);
|
||||
|
||||
long_index_t batch_offset_q = 0;
|
||||
long_index_t batch_offset_k = 0;
|
||||
long_index_t batch_offset_knew =
|
||||
const long_index_t batch_offset_q = static_cast<long_index_t>(i_batch) * kargs.batch_stride_q;
|
||||
const long_index_t batch_offset_k = static_cast<long_index_t>(i_batch) * kargs.batch_stride_k;
|
||||
const long_index_t batch_offset_knew =
|
||||
static_cast<long_index_t>(i_batch) * kargs.batch_stride_knew;
|
||||
long_index_t batch_offset_v = 0;
|
||||
long_index_t batch_offset_vnew =
|
||||
const long_index_t batch_offset_v = static_cast<long_index_t>(i_batch) * kargs.batch_stride_v;
|
||||
const long_index_t batch_offset_vnew =
|
||||
static_cast<long_index_t>(i_batch) * kargs.batch_stride_vnew;
|
||||
|
||||
if constexpr(kIsGroupMode)
|
||||
{
|
||||
// get starting offset for each batch
|
||||
const long_index_t query_start = kargs.seqstart_q_ptr[i_batch];
|
||||
const long_index_t key_start = kargs.seqstart_k_ptr[i_batch];
|
||||
|
||||
batch_offset_q = query_start * kargs.stride_q;
|
||||
batch_offset_k = key_start * kargs.stride_k;
|
||||
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
batch_offset_v = key_start * kargs.stride_v;
|
||||
}
|
||||
else
|
||||
{
|
||||
batch_offset_v = key_start;
|
||||
}
|
||||
|
||||
// get real # queries & # keys under group mode
|
||||
const auto adjusted_seqstart_q_ptr = kargs.seqstart_q_ptr + i_batch;
|
||||
kargs.seqlen_q = adjusted_seqstart_q_ptr[1] - adjusted_seqstart_q_ptr[0];
|
||||
|
||||
if(kargs.seqlen_k_ptr != nullptr)
|
||||
{
|
||||
kargs.seqlen_k = kargs.seqlen_k_ptr[i_batch];
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto adjusted_seqstart_k_ptr = kargs.seqstart_k_ptr + i_batch;
|
||||
kargs.seqlen_k = adjusted_seqstart_k_ptr[1] - adjusted_seqstart_k_ptr[0];
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
batch_offset_q = static_cast<long_index_t>(i_batch) * kargs.batch_stride_q;
|
||||
batch_offset_k = static_cast<long_index_t>(i_batch) * kargs.batch_stride_k;
|
||||
batch_offset_v = static_cast<long_index_t>(i_batch) * kargs.batch_stride_v;
|
||||
}
|
||||
kargs.seqlen_k = kargs.seqlen_k_ptr[i_batch];
|
||||
|
||||
auto k_page_block_navigator = [&, i_batch_ = i_batch, i_nhead_ = i_nhead]() {
|
||||
if constexpr(kIsPagedKV)
|
||||
|
||||
@@ -19,11 +19,11 @@ struct FmhaFwdAppendKVTilePartitioner
|
||||
|
||||
CK_TILE_HOST static constexpr auto GridSize(ck_tile::index_t batch_size,
|
||||
ck_tile::index_t nhead,
|
||||
ck_tile::index_t max_seqlen_q,
|
||||
ck_tile::index_t seqlen_q,
|
||||
ck_tile::index_t seqlen_knew)
|
||||
{
|
||||
// TODO: this may need tuning
|
||||
return dim3(std::max(ck_tile::integer_divide_ceil(max_seqlen_q, kM0),
|
||||
return dim3(std::max(ck_tile::integer_divide_ceil(seqlen_q, kM0),
|
||||
ck_tile::integer_divide_ceil(seqlen_knew, kN0)),
|
||||
nhead,
|
||||
batch_size);
|
||||
|
||||
@@ -108,7 +108,6 @@ struct FmhaFwdSplitKVKernel
|
||||
void* o_acc_ptr;
|
||||
|
||||
ck_tile::index_t batch;
|
||||
ck_tile::index_t max_seqlen_q;
|
||||
|
||||
ck_tile::index_t seqlen_q;
|
||||
ck_tile::index_t seqlen_k;
|
||||
@@ -186,6 +185,8 @@ struct FmhaFwdSplitKVKernel
|
||||
std::conditional_t<kHasMask, MaskKargs, EmptyKargs<1>>,
|
||||
std::conditional_t<kDoFp8StaticQuant, Fp8StaticQuantKargs, EmptyKargs<2>>
|
||||
{
|
||||
const int32_t* seqlen_k_ptr;
|
||||
|
||||
ck_tile::index_t batch_stride_q;
|
||||
ck_tile::index_t batch_stride_k;
|
||||
ck_tile::index_t batch_stride_v;
|
||||
@@ -220,9 +221,9 @@ struct FmhaFwdSplitKVKernel
|
||||
void* lse_acc_ptr,
|
||||
void* o_acc_ptr,
|
||||
ck_tile::index_t batch,
|
||||
ck_tile::index_t max_seqlen_q,
|
||||
ck_tile::index_t seqlen_q,
|
||||
ck_tile::index_t seqlen_k,
|
||||
ck_tile::index_t seqlen_k, // only used if 'seqlen_k_ptr' is not specified
|
||||
const void* seqlen_k_ptr, // only used for (paged-) kvcache
|
||||
ck_tile::index_t hdim_q,
|
||||
ck_tile::index_t hdim_v,
|
||||
ck_tile::index_t num_head_q,
|
||||
@@ -262,7 +263,6 @@ struct FmhaFwdSplitKVKernel
|
||||
lse_acc_ptr,
|
||||
o_acc_ptr,
|
||||
batch,
|
||||
max_seqlen_q,
|
||||
seqlen_q,
|
||||
seqlen_k,
|
||||
hdim_q,
|
||||
@@ -294,6 +294,7 @@ struct FmhaFwdSplitKVKernel
|
||||
{}, // placeholder for bias
|
||||
{}, // placeholder for mask
|
||||
{}, // placeholder for fp8_static_quant args
|
||||
reinterpret_cast<const int32_t*>(seqlen_k_ptr),
|
||||
batch_stride_q,
|
||||
batch_stride_k,
|
||||
batch_stride_v};
|
||||
@@ -333,7 +334,6 @@ struct FmhaFwdSplitKVKernel
|
||||
void* lse_acc_ptr,
|
||||
void* o_acc_ptr,
|
||||
ck_tile::index_t batch,
|
||||
ck_tile::index_t max_seqlen_q,
|
||||
const void* seqstart_q_ptr,
|
||||
const void* seqstart_k_ptr,
|
||||
const void* seqlen_k_ptr,
|
||||
@@ -374,9 +374,8 @@ struct FmhaFwdSplitKVKernel
|
||||
lse_acc_ptr,
|
||||
o_acc_ptr,
|
||||
batch,
|
||||
max_seqlen_q,
|
||||
-1, // seqlen will be updated by another pointer
|
||||
-1, //
|
||||
-1, // seqlen_q will be updated by another pointer
|
||||
-1, // seqlen_k will be updated by another pointer
|
||||
hdim_q,
|
||||
hdim_v,
|
||||
num_head_q,
|
||||
@@ -496,8 +495,7 @@ struct FmhaFwdSplitKVKernel
|
||||
}
|
||||
|
||||
// get real # queries & # keys under group mode
|
||||
const auto adjusted_seqstart_q_ptr = kargs.seqstart_q_ptr + i_batch;
|
||||
kargs.seqlen_q = adjusted_seqstart_q_ptr[1] - adjusted_seqstart_q_ptr[0];
|
||||
kargs.seqlen_q = kargs.seqstart_q_ptr[i_batch + 1] - kargs.seqstart_q_ptr[i_batch];
|
||||
|
||||
// # of required blocks is different in each groups, terminate unnecessary blocks
|
||||
// earlier
|
||||
@@ -512,8 +510,7 @@ struct FmhaFwdSplitKVKernel
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto adjusted_seqstart_k_ptr = kargs.seqstart_k_ptr + i_batch;
|
||||
kargs.seqlen_k = adjusted_seqstart_k_ptr[1] - adjusted_seqstart_k_ptr[0];
|
||||
kargs.seqlen_k = kargs.seqstart_k_ptr[i_batch + 1] - kargs.seqstart_k_ptr[i_batch];
|
||||
}
|
||||
}
|
||||
else
|
||||
@@ -526,6 +523,11 @@ struct FmhaFwdSplitKVKernel
|
||||
{
|
||||
batch_offset_bias = static_cast<long_index_t>(i_batch) * kargs.batch_stride_bias;
|
||||
}
|
||||
|
||||
if(kargs.seqlen_k_ptr != nullptr)
|
||||
{
|
||||
kargs.seqlen_k = kargs.seqlen_k_ptr[i_batch];
|
||||
}
|
||||
}
|
||||
|
||||
auto k_page_block_navigator = [&, i_batch_ = i_batch, i_nhead_ = i_nhead]() {
|
||||
|
||||
@@ -27,7 +27,6 @@ struct BlockFmhaFwdAppendKVPipeline
|
||||
static constexpr index_t kK0 = Problem::kK0;
|
||||
static constexpr index_t kN1 = Problem::kN1;
|
||||
|
||||
static constexpr bool kIsGroupMode = Problem::kIsGroupMode;
|
||||
static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ;
|
||||
static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK;
|
||||
static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ;
|
||||
|
||||
@@ -139,7 +139,6 @@ template <typename QDataType_,
|
||||
index_t kK0_,
|
||||
index_t kN1_,
|
||||
bool IsVLayoutRowMajor_,
|
||||
bool kIsGroupMode_,
|
||||
typename Traits_>
|
||||
struct BlockFmhaFwdAppendKVPipelineProblem
|
||||
{
|
||||
@@ -149,7 +148,6 @@ struct BlockFmhaFwdAppendKVPipelineProblem
|
||||
using Traits = remove_cvref_t<Traits_>;
|
||||
|
||||
static constexpr index_t kBlockSize = 256;
|
||||
static constexpr bool kIsGroupMode = kIsGroupMode_;
|
||||
|
||||
static constexpr index_t kM0 = kM0_;
|
||||
static constexpr index_t kN0 = kN0_;
|
||||
|
||||
Reference in New Issue
Block a user