Remove group mode from appendkv kernel

This commit is contained in:
PoYen, Chen
2024-08-16 10:04:48 +00:00
parent 9de0f35ebc
commit 5805f5aa73
8 changed files with 95 additions and 281 deletions

View File

@@ -41,7 +41,6 @@ using fmha_pipeline_problem_{F_idx} = ck_tile::BlockFmhaFwdAppendKVPipelineProbl
{F_bd},
{F_bdv},
{F_vlayout},
{F_mode},
fmha_trait_{F_idx}>;
using fmha_pipeline_{F_idx} = ck_tile::BlockFmhaFwdAppendKVPipeline<
@@ -51,7 +50,7 @@ using fmha_kernel_{F_idx} =
ck_tile::FmhaFwdAppendKVKernel<ck_tile::FmhaFwdAppendKVTilePartitioner<{F_bs}, {F_bsk}, {F_bd}, {F_bdv}>,
fmha_pipeline_{F_idx}>;
using trait_{F_idx} = fmha_fwd_appendkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bs}, {F_bsk}, {F_bd}, {F_bdv}, {F_vlayout},
using trait_{F_idx} = fmha_fwd_appendkv_traits_<{F_hdim}, {F_dtype}, {F_bs}, {F_bsk}, {F_bd}, {F_bdv}, {F_vlayout},
{F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_rope}, {F_pagedkv}>;
#include <iostream>
@@ -78,10 +77,10 @@ float fmha_fwd_appendkv(fmha_fwd_appendkv_traits t, fmha_fwd_appendkv_args a, co
}}
"""
FMHA_FWD_APPENDKV_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) &&
FMHA_FWD_APPENDKV_API_INNER_DISPATCH=""" {F_if}((t.is_v_rowmajor == {F_vlayout}) &&
({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck}) && (t.rope_type == {F_rope_check}) &&
((a.block_table_ptr != nullptr) == {F_pagedkv})) {{
using trait_ = fmha_fwd_appendkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bs}, {F_bsk}, {F_bd}, {F_bdv}, {F_vlayout}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_rope}, {F_pagedkv}>;
using trait_ = fmha_fwd_appendkv_traits_<{F_hdim}, {F_dtype}, {F_bs}, {F_bsk}, {F_bd}, {F_bdv}, {F_vlayout}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_rope}, {F_pagedkv}>;
return fmha_fwd_appendkv_<trait_>(s, a);
}}
"""
@@ -91,7 +90,6 @@ class FmhaFwdAppendKVApiTrait:
# sync with fmha_fwd_traits<>, to generate fallback calls
hdim : str
dtype : str # data type
mode : str # value from MODE_MAP
bs : int # tile size along q seqlen
bsk : int # tile size along k seqlen
bd : int # tile size along qk gemm unroll
@@ -106,13 +104,11 @@ class FmhaFwdAppendKVApiTrait:
@property
def name(self) -> str:
return f'{self.hdim}-{self.dtype}-{self.mode}-{self.bs}-{self.bsk}-{self.bd}-{self.bdv}-'+\
f'{self.vlayout}-{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}-{self.rope}-'+\
f'{self.pagedkv}'
return f'{self.hdim}-{self.dtype}-{self.bs}-{self.bsk}-{self.bd}-{self.bdv}-{self.vlayout}-'+\
f'{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}-{self.rope}-{self.pagedkv}'
@property
def scheck(self) -> str:
if self.mode == 'group': return 'true/*group mode spad always true*/' # group mode only generate spad/skpad == true
if self.spad == 't' : return f'true /*a.seqlen_q % {self.bs} != 0*/'
else : return f'a.seqlen_q % {self.bs} == 0'
@@ -183,7 +179,7 @@ class FmhaFwdAppendKVApiPool:
inners=str()
for k, trait in enumerate(traits):
if_k = 'if' if k == 0 else 'else if'
inners = inners + FMHA_FWD_APPENDKV_API_INNER_DISPATCH.format(F_if=if_k, F_mode=MODE_MAP[trait.mode], F_vlayout=LAYOUT_MAP[trait.vlayout],
inners = inners + FMHA_FWD_APPENDKV_API_INNER_DISPATCH.format(F_if=if_k, F_vlayout=LAYOUT_MAP[trait.vlayout],
F_scheck=trait.scheck, F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck, F_rope_check=ROPE_CHECK_MAP[trait.rope],
F_pagedkv=BOOL_MAP[trait.pagedkv], F_spad=BOOL_MAP[trait.spad], F_skpad=BOOL_MAP[trait.skpad], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad],
F_rope=ROPE_MAP[trait.rope], F_bs=trait.bs, F_bsk=trait.bsk, F_bd=trait.bd, F_bdv=trait.bdv, F_hdim=hdim, F_dtype=DTYPE_MAP[dtype])
@@ -210,7 +206,6 @@ class FmhaFwdAppendKVKernel:
F_idx : int # this is not a tunable, but a counter to differentiate symbol
F_hdim : int # hdim
F_dtype : str # data type
F_mode : str # value from MODE_MAP
F_tile : FmhaFwdAppendKVTileSize
F_pipeline : FmhaFwdAppendKVPipeline
mask_impl : str
@@ -234,13 +229,12 @@ class FmhaFwdAppendKVKernel:
F_dvpad = BOOL_MAP[self.F_pipeline.F_dvpad],
F_rope = ROPE_MAP[self.F_pipeline.F_rope],
F_pagedkv = BOOL_MAP[self.F_pipeline.F_pagedkv],
F_occupancy = self.F_tile.F_occupancy,
F_mode = MODE_MAP[self.F_mode])
F_occupancy = self.F_tile.F_occupancy)
@property
def name(self) -> str:
# TODO: we don't encode idx here
return f"fmha_fwd_appendkv_d{self.F_hdim}_{self.F_dtype}_{self.F_mode}_" + \
return f"fmha_fwd_appendkv_d{self.F_hdim}_{self.F_dtype}_" + \
self.F_tile.name + '_' + self.F_pipeline.name
@property
@@ -251,7 +245,6 @@ class FmhaFwdAppendKVKernel:
return FmhaFwdAppendKVApiTrait(
hdim=str(self.F_hdim),
dtype=self.F_dtype,
mode=self.F_mode,
bs=self.F_tile.F_bs,
bsk=self.F_tile.F_bsk,
bd=self.F_tile.F_bd,
@@ -320,19 +313,13 @@ def get_fwd_appendkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) ->
d = get_fmha_fwd_appendkv_tile_dict_from_dtype(dtype)
if d == None:
continue
#for hdim_str, mode, mask, bias, lse in itertools.product(d.keys(), MODE_MAP.keys(), MASK_MAP.keys(), ["t", "f"], ["t", "f"]):
for hdim_str, mode in itertools.product(d.keys(), MODE_MAP.keys()):
for hdim_str in d.keys():
tile = d[hdim_str]
hdim = int(hdim_str)
for pipeline in get_pipelines(dtype, hdim):
if mode == "group":
if pipeline.F_spad != 't' or pipeline.F_skpad != 't':
# in group mode, spad/skpad must be true, since we can't predict if seqlen of current batch need pad or not
continue
k = FmhaFwdAppendKVKernel(F_idx=0,
F_hdim=hdim,
F_dtype=dtype,
F_mode=mode,
F_tile=tile,
F_pipeline=pipeline,
mask_impl=mask_impl)

View File

@@ -258,7 +258,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
{
std::string data_type = arg_parser.get_str("prec");
int do_validation = arg_parser.get_int("v");
auto mode = static_cast<mode_enum>(arg_parser.get_uint32("mode"));
ck_tile::index_t batch = arg_parser.get_int("b");
ck_tile::index_t nhead = arg_parser.get_int("h");
ck_tile::index_t nhead_k = arg_parser.get_int("h_k");
@@ -278,7 +278,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
}
ck_tile::index_t seqlen_knew = arg_parser.get_int("s_knew");
#if !CK_TILE_FMHA_FWD_APPENDKV_API
#if !CK_TILE_FMHA_FWD_APPENDKV_API || !CK_TILE_FMHA_FWD_SPLITKV_API
if(seqlen_knew != 0)
{
std::cerr << "kvcache is not supported. ignoring the 's_knew' option" << std::endl;
@@ -290,6 +290,29 @@ bool run(const ck_tile::ArgParser& arg_parser)
seqlen_knew = randint<ck_tile::index_t>(1, arg_parser.get_int("s"), seed);
}
ck_tile::index_t page_block_size = arg_parser.get_int("page_block_size");
#if !CK_TILE_FMHA_FWD_SPLITKV_API
if(0 < page_block_size)
{
std::cerr << "paged-kvcache is not supported. ignoring the 'page_block_size' option"
<< std::endl;
page_block_size = 0;
}
#endif
if(!(page_block_size % 128 == 0))
{
std::cerr << "only paged-kvcache block size divisible by 128 are currently supported"
<< std::endl;
return false;
}
auto mode = static_cast<mode_enum>(arg_parser.get_uint32("mode"));
if ((0 < seqlen_knew || 0 < page_block_size) && mode != mode_enum::batch) {
std::cerr << "kvcache enabled. ignoring the 'mode' option"
<< std::endl;
mode = mode_enum::batch;
}
auto [seqlen_qs, seqlen_ks, seqlen_kpads] = decode_seqlen(mode,
batch,
arg_parser.get_str("s"),
@@ -420,21 +443,6 @@ bool run(const ck_tile::ArgParser& arg_parser)
num_splits = 1;
}
#endif
ck_tile::index_t page_block_size = arg_parser.get_int("page_block_size");
#if !CK_TILE_FMHA_FWD_SPLITKV_API
if(0 < page_block_size)
{
std::cerr << "paged-kvcache is not supported. ignoring the 'page_block_size' option"
<< std::endl;
page_block_size = 0;
}
#endif
if(!(page_block_size % 128 == 0))
{
std::cerr << "only paged-kvcache block size divisible by 128 are currently supported"
<< std::endl;
return false;
}
int stream_warmup = arg_parser.get_int("warmup");
int stream_repeat = arg_parser.get_int("repeat");
@@ -581,11 +589,11 @@ bool run(const ck_tile::ArgParser& arg_parser)
generate_rotary_cos_sin<KDataType>(shape_seqlen_k, rotary_dim, seed);
ck_tile::HostTensor<LSEDataType> lse_acc_host(
1 < num_splits || 0 < page_block_size
1 < num_splits || 0 < seqlen_knew || 0 < page_block_size
? std::array<ck_tile::index_t, 4>{num_splits, batch, nhead, max_seqlen_q}
: std::array<ck_tile::index_t, 4>{1, 1, 1, 1});
ck_tile::HostTensor<OaccDataType> o_acc_host(
1 < num_splits || 0 < page_block_size
1 < num_splits || 0 < seqlen_knew || 0 < page_block_size
? std::array<ck_tile::index_t, 5>{num_splits, batch, nhead, max_seqlen_q, hdim_v}
: std::array<ck_tile::index_t, 5>{1, 1, 1, 1, 1});
@@ -762,7 +770,6 @@ bool run(const ck_tile::ArgParser& arg_parser)
traits.hdim_q = hdim_q;
traits.hdim_v = hdim_v;
traits.data_type = data_type;
traits.is_group_mode = (mode == mode_enum::group);
traits.is_v_rowmajor = is_v_rowmajor;
if constexpr(std::is_same_v<fmha_fwd_appendkv_traits, std::decay_t<decltype(traits)>>)
@@ -773,6 +780,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
}
else // fmha_fwd_traits or fmha_splitkv_traits
{
traits.is_group_mode = (mode == mode_enum::group);
traits.mask_type = mask.type;
traits.bias_type = bias.type;
traits.has_lse = lse;
@@ -863,12 +871,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
args.k_ptr = k_buf.GetDeviceBuffer();
args.v_ptr = v_buf.GetDeviceBuffer();
args.seqstart_q_ptr = seqstart_q.GetDeviceBuffer();
args.seqstart_k_ptr = seqstart_k.GetDeviceBuffer();
args.seqlen_q = shape_seqlen_q;
args.batch = batch;
args.max_seqlen_q = max_seqlen_q;
args.seqlen_q = shape_seqlen_q;
args.hdim_q = hdim_q;
args.hdim_v = hdim_v;
args.nhead_q = nhead;
@@ -891,7 +895,6 @@ bool run(const ck_tile::ArgParser& arg_parser)
args.seqlen_knew = seqlen_knew;
args.seqlen_k_ptr = cache_seqlen_k_buf.GetDeviceBuffer();
args.seqlen_k = shape_seqlen_k - seqlen_knew; // kvcache seqlen for batch mode
args.rotary_cos_ptr = rotary_cos_buf.GetDeviceBuffer();
args.rotary_sin_ptr = rotary_sin_buf.GetDeviceBuffer();
@@ -916,8 +919,12 @@ bool run(const ck_tile::ArgParser& arg_parser)
args.lse_ptr = lse_buf.GetDeviceBuffer();
args.o_ptr = o_buf.GetDeviceBuffer();
args.seqlen_k_ptr = k_paddings_[0] < 0 ? nullptr : seqlen_k_buf.GetDeviceBuffer();
args.seqlen_k = shape_seqlen_k;
args.seqstart_q_ptr = (mode == mode_enum::group ? seqstart_q.GetDeviceBuffer() : nullptr);
args.seqstart_k_ptr = (mode == mode_enum::group ? seqstart_k.GetDeviceBuffer() : nullptr);
args.seqlen_k_ptr = (0 < seqlen_knew || 0 < page_block_size || 0 <= k_paddings_[0] ? seqlen_k_buf.GetDeviceBuffer() : nullptr);
args.seqlen_k = (args.seqlen_k_ptr == nullptr ? shape_seqlen_k : -1);
args.max_seqlen_q = max_seqlen_q;
args.scale_s = scale_s;
args.scale_p = scale_p;
@@ -990,7 +997,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
const float fwd_ave_time = [&] {
#if CK_TILE_FMHA_FWD_SPLITKV_API
if(1 < num_splits || 0 < page_block_size)
if(1 < num_splits || 0 < seqlen_knew || 0 < page_block_size)
{
fmha_fwd_splitkv_traits fmha_splitkv_traits;
init_traits(fmha_splitkv_traits);

View File

@@ -165,10 +165,11 @@ struct fmha_fwd_splitkv_args
const void* seqstart_q_ptr;
const void* seqstart_k_ptr;
const void*
seqlen_k_ptr; // only used if both 'seqstart_q_ptr' & 'seqstart_k_ptr' are not nullptr
seqlen_k_ptr; // only used if both 'seqstart_q_ptr' & 'seqstart_k_ptr' are not nullptr, or
// kvcache is used
ck_tile::index_t seqlen_q;
ck_tile::index_t seqlen_k;
ck_tile::index_t seqlen_k; // only used if 'seqlen_k_ptr' is nullptr
ck_tile::index_t batch;
ck_tile::index_t max_seqlen_q;
ck_tile::index_t hdim_q;
@@ -219,16 +220,11 @@ struct fmha_fwd_appendkv_args
void* v_ptr;
const void* vnew_ptr;
const void* seqstart_q_ptr;
const void* seqstart_k_ptr;
const void*
seqlen_k_ptr; // only used if both 'seqstart_q_ptr' & 'seqstart_k_ptr' are not nullptr
const void* seqlen_k_ptr;
ck_tile::index_t seqlen_q;
ck_tile::index_t seqlen_k;
ck_tile::index_t seqlen_knew;
ck_tile::index_t batch;
ck_tile::index_t max_seqlen_q;
ck_tile::index_t hdim_q;
ck_tile::index_t hdim_v;
ck_tile::index_t nhead_q;
@@ -371,7 +367,6 @@ auto fmha_fwd_splitkv_create_kargs_and_grids(fmha_fwd_splitkv_args args)
args.lse_acc_ptr,
args.o_acc_ptr,
args.batch,
args.max_seqlen_q,
args.seqstart_q_ptr,
args.seqstart_k_ptr,
args.seqlen_k_ptr,
@@ -415,9 +410,9 @@ auto fmha_fwd_splitkv_create_kargs_and_grids(fmha_fwd_splitkv_args args)
args.lse_acc_ptr,
args.o_acc_ptr,
args.batch,
args.max_seqlen_q,
args.seqlen_q,
args.seqlen_k,
args.seqlen_k_ptr,
args.hdim_q,
args.hdim_v,
args.nhead_q,
@@ -525,53 +520,13 @@ template <typename Kernel>
auto fmha_fwd_appendkv_create_kargs_and_grids(fmha_fwd_appendkv_args args)
{
assert(args.nhead_q % args.nhead_k == 0);
auto kargs = [&] {
// create group mode kernel arguments
if constexpr(Kernel::kIsGroupMode)
{
return Kernel::MakeKargs(args.q_ptr,
args.k_ptr,
args.knew_ptr,
args.v_ptr,
args.vnew_ptr,
args.seqstart_q_ptr,
args.seqstart_k_ptr,
args.seqlen_k_ptr,
args.seqlen_knew,
args.hdim_q,
args.hdim_v,
args.nhead_q,
args.nhead_q / args.nhead_k,
args.rotary_cos_ptr,
args.rotary_sin_ptr,
args.rotary_dim,
args.block_table_ptr,
args.batch_stride_block_table,
args.page_block_size,
args.stride_q,
args.stride_k,
args.stride_knew,
args.stride_v,
args.stride_vnew,
args.nhead_stride_q,
args.nhead_stride_k,
args.nhead_stride_knew,
args.nhead_stride_v,
args.nhead_stride_vnew,
args.batch_stride_k,
args.batch_stride_knew,
args.batch_stride_v,
args.batch_stride_vnew);
}
else
{ // create batch mode kernel arguments
return Kernel::MakeKargs(args.q_ptr,
auto kargs = Kernel::MakeKargs(args.q_ptr,
args.k_ptr,
args.knew_ptr,
args.v_ptr,
args.vnew_ptr,
args.seqlen_q,
args.seqlen_k,
args.seqlen_k_ptr,
args.seqlen_knew,
args.hdim_q,
args.hdim_v,
@@ -598,10 +553,8 @@ auto fmha_fwd_appendkv_create_kargs_and_grids(fmha_fwd_appendkv_args args)
args.batch_stride_knew,
args.batch_stride_v,
args.batch_stride_vnew);
}
}();
dim3 grids = Kernel::GridSize(args.batch, args.nhead_q, args.max_seqlen_q, args.seqlen_knew);
dim3 grids = Kernel::GridSize(args.batch, args.nhead_q, args.seqlen_q, args.seqlen_knew);
return ck_tile::make_tuple(kargs, grids);
}
@@ -735,7 +688,6 @@ std::string fmha_fwd_splitkv_combine_get_name_();
// this is used to pattern-match internl kernel implementation, not to instantiate kernel
template <ck_tile::index_t HDim_,
typename DataType_,
bool kIsGroupMode_,
ck_tile::index_t kTileSizeS_,
ck_tile::index_t kTileSizeSk_,
ck_tile::index_t kTileSizeD_,
@@ -751,7 +703,6 @@ struct fmha_fwd_appendkv_traits_
{
static constexpr ck_tile::index_t HDim = HDim_;
using DataType = ck_tile::remove_cvref_t<DataType_>;
static constexpr bool kIsGroupMode = kIsGroupMode_;
static constexpr ck_tile::index_t kTileSizeS = kTileSizeS_;
static constexpr ck_tile::index_t kTileSizeSk = kTileSizeSk_;
static constexpr ck_tile::index_t kTileSizeD = kTileSizeD_;
@@ -807,7 +758,6 @@ struct fmha_fwd_appendkv_traits
int hdim_q;
int hdim_v;
std::string data_type;
bool is_group_mode;
bool is_v_rowmajor;
rope_enum rope_type;
};

View File

@@ -26,7 +26,6 @@ struct FmhaFwdAppendKVKernel
using VLayout = ck_tile::remove_cvref_t<typename FmhaPipeline::VLayout>;
static constexpr bool kIsGroupMode = FmhaPipeline::kIsGroupMode;
static constexpr bool kPadSeqLenQ = FmhaPipeline::kPadSeqLenQ;
static constexpr bool kPadSeqLenK = FmhaPipeline::kPadSeqLenK;
static constexpr bool kPadHeadDimQ = FmhaPipeline::kPadHeadDimQ;
@@ -58,8 +57,7 @@ struct FmhaFwdAppendKVKernel
if (kPadHeadDimV) n += "dv";
return n.empty() ? n : std::string("p") + n; }();
return
_SS_("fmha_fwd_appendkv_d") + _TS_(FmhaPipeline::kK0) + "_" + _SS_(t2s<QDataType>::name) +
"_" + (kIsGroupMode ? "group" : "batch") + "_"
_SS_("fmha_fwd_appendkv_d") + _TS_(FmhaPipeline::kK0) + "_" + _SS_(t2s<QDataType>::name) + "_"
"b" + _TS_(FmhaPipeline::kM0) + "x" + _TS_(FmhaPipeline::kN0) + "x" + _TS_(FmhaPipeline::kK0) + "x" +
_TS_(FmhaPipeline::kN1) + "_" + (kBlockPerCuInput == -1 ? "" : ("o" + _TS_(kBlockPerCu) + "_")) +
"v" + (std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor> ? "r" : "c") + (pn.empty() ? "" : "_" + pn)
@@ -79,7 +77,7 @@ struct FmhaFwdAppendKVKernel
// kargs use aggregate initializer, so no constructor will provided
// use inheritance to minimize karg size
// user need to use MakeKargs() function to create kargs.
struct CommonKargs
struct BasicKargs
{
void* q_ptr;
void* k_ptr;
@@ -87,6 +85,8 @@ struct FmhaFwdAppendKVKernel
void* v_ptr;
const void* vnew_ptr;
const int32_t* seqlen_k_ptr;
ck_tile::index_t seqlen_q;
ck_tile::index_t seqlen_k;
ck_tile::index_t seqlen_knew;
@@ -114,47 +114,32 @@ struct FmhaFwdAppendKVKernel
ck_tile::index_t nhead_stride_v;
ck_tile::index_t nhead_stride_vnew;
ck_tile::index_t batch_stride_q;
ck_tile::index_t batch_stride_k;
ck_tile::index_t batch_stride_knew;
ck_tile::index_t batch_stride_v;
ck_tile::index_t batch_stride_vnew;
};
struct CommonRoPEKargs
struct RoPEKargs
{
const void* rotary_cos_ptr;
const void* rotary_sin_ptr;
ck_tile::index_t rotary_dim;
};
struct BatchModeKargs : CommonKargs,
std::conditional_t<kApplyRoPE, CommonRoPEKargs, EmptyKargs<0>>
{
ck_tile::index_t batch_stride_q;
ck_tile::index_t batch_stride_k;
ck_tile::index_t batch_stride_v;
};
struct Kargs : BasicKargs,
std::conditional_t<kApplyRoPE, RoPEKargs, EmptyKargs<0>>
{};
struct GroupModeKargs : CommonKargs,
std::conditional_t<kApplyRoPE, CommonRoPEKargs, EmptyKargs<0>>
{
const int32_t* seqstart_q_ptr;
const int32_t* seqstart_k_ptr;
const int32_t* seqlen_k_ptr;
ck_tile::index_t batch_stride_k;
ck_tile::index_t batch_stride_v;
};
using Kargs = std::conditional_t<kIsGroupMode, GroupModeKargs, BatchModeKargs>;
template <bool Cond = !kIsGroupMode>
__host__ static constexpr std::enable_if_t<Cond, Kargs>
__host__ static constexpr Kargs
MakeKargs(void* q_ptr,
void* k_ptr,
const void* knew_ptr,
void* v_ptr,
const void* vnew_ptr,
ck_tile::index_t seqlen_q,
ck_tile::index_t seqlen_k,
const void* seqlen_k_ptr,
ck_tile::index_t seqlen_knew,
ck_tile::index_t hdim_q,
ck_tile::index_t hdim_v,
@@ -187,8 +172,9 @@ struct FmhaFwdAppendKVKernel
knew_ptr,
v_ptr,
vnew_ptr,
reinterpret_cast<const int32_t*>(seqlen_k_ptr),
seqlen_q,
seqlen_k,
-1, // seqlen_k will be updated by content of seqlen_k_ptr
seqlen_knew,
hdim_q,
hdim_v,
@@ -207,92 +193,13 @@ struct FmhaFwdAppendKVKernel
nhead_stride_knew,
nhead_stride_v,
nhead_stride_vnew,
batch_stride_q,
batch_stride_k,
batch_stride_knew,
batch_stride_v,
batch_stride_vnew}, // args for common karg
{}, // placeholder for rope
batch_stride_q,
batch_stride_k,
batch_stride_v};
if constexpr(kApplyRoPE)
{
kargs.rotary_cos_ptr = rotary_cos_ptr;
kargs.rotary_sin_ptr = rotary_sin_ptr;
kargs.rotary_dim = rotary_dim;
}
return kargs;
}
template <bool Cond = kIsGroupMode>
__host__ static constexpr std::enable_if_t<Cond, Kargs>
MakeKargs(void* q_ptr,
void* k_ptr,
const void* knew_ptr,
void* v_ptr,
const void* vnew_ptr,
const void* seqstart_q_ptr,
const void* seqstart_k_ptr,
const void* seqlen_k_ptr,
ck_tile::index_t seqlen_knew,
ck_tile::index_t hdim_q,
ck_tile::index_t hdim_v,
ck_tile::index_t num_head_q,
ck_tile::index_t nhead_ratio_qk,
const void* rotary_cos_ptr,
const void* rotary_sin_ptr,
ck_tile::index_t rotary_dim,
const void* block_table_ptr,
ck_tile::index_t batch_stride_block_table,
ck_tile::index_t page_block_size,
ck_tile::index_t stride_q,
ck_tile::index_t stride_k,
ck_tile::index_t stride_knew,
ck_tile::index_t stride_v,
ck_tile::index_t stride_vnew,
ck_tile::index_t nhead_stride_q,
ck_tile::index_t nhead_stride_k,
ck_tile::index_t nhead_stride_knew,
ck_tile::index_t nhead_stride_v,
ck_tile::index_t nhead_stride_vnew,
ck_tile::index_t batch_stride_k,
ck_tile::index_t batch_stride_knew,
ck_tile::index_t batch_stride_v,
ck_tile::index_t batch_stride_vnew)
{
Kargs kargs{{q_ptr,
k_ptr,
knew_ptr,
v_ptr,
vnew_ptr,
-1, // seqlen will be updated by another pointer
-1, //
seqlen_knew,
hdim_q,
hdim_v,
num_head_q,
nhead_ratio_qk,
block_table_ptr,
batch_stride_block_table,
page_block_size,
stride_q,
stride_k,
stride_knew,
stride_v,
stride_vnew,
nhead_stride_q,
nhead_stride_k,
nhead_stride_knew,
nhead_stride_v,
nhead_stride_vnew,
batch_stride_knew,
batch_stride_vnew}, // args for common karg
{}, // placeholder for rope
reinterpret_cast<const int32_t*>(seqstart_q_ptr),
reinterpret_cast<const int32_t*>(seqstart_k_ptr),
reinterpret_cast<const int32_t*>(seqlen_k_ptr),
batch_stride_k,
batch_stride_v};
{} // placeholder for rope
};
if constexpr(kApplyRoPE)
{
@@ -322,51 +229,15 @@ struct FmhaFwdAppendKVKernel
const index_t i_m0 = __builtin_amdgcn_readfirstlane(i_tile * FmhaPipeline::kM0);
const index_t i_n0 = __builtin_amdgcn_readfirstlane(i_tile * FmhaPipeline::kN0);
long_index_t batch_offset_q = 0;
long_index_t batch_offset_k = 0;
long_index_t batch_offset_knew =
const long_index_t batch_offset_q = static_cast<long_index_t>(i_batch) * kargs.batch_stride_q;
const long_index_t batch_offset_k = static_cast<long_index_t>(i_batch) * kargs.batch_stride_k;
const long_index_t batch_offset_knew =
static_cast<long_index_t>(i_batch) * kargs.batch_stride_knew;
long_index_t batch_offset_v = 0;
long_index_t batch_offset_vnew =
const long_index_t batch_offset_v = static_cast<long_index_t>(i_batch) * kargs.batch_stride_v;
const long_index_t batch_offset_vnew =
static_cast<long_index_t>(i_batch) * kargs.batch_stride_vnew;
if constexpr(kIsGroupMode)
{
// get starting offset for each batch
const long_index_t query_start = kargs.seqstart_q_ptr[i_batch];
const long_index_t key_start = kargs.seqstart_k_ptr[i_batch];
batch_offset_q = query_start * kargs.stride_q;
batch_offset_k = key_start * kargs.stride_k;
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
{
batch_offset_v = key_start * kargs.stride_v;
}
else
{
batch_offset_v = key_start;
}
// get real # queries & # keys under group mode
const auto adjusted_seqstart_q_ptr = kargs.seqstart_q_ptr + i_batch;
kargs.seqlen_q = adjusted_seqstart_q_ptr[1] - adjusted_seqstart_q_ptr[0];
if(kargs.seqlen_k_ptr != nullptr)
{
kargs.seqlen_k = kargs.seqlen_k_ptr[i_batch];
}
else
{
const auto adjusted_seqstart_k_ptr = kargs.seqstart_k_ptr + i_batch;
kargs.seqlen_k = adjusted_seqstart_k_ptr[1] - adjusted_seqstart_k_ptr[0];
}
}
else
{
batch_offset_q = static_cast<long_index_t>(i_batch) * kargs.batch_stride_q;
batch_offset_k = static_cast<long_index_t>(i_batch) * kargs.batch_stride_k;
batch_offset_v = static_cast<long_index_t>(i_batch) * kargs.batch_stride_v;
}
kargs.seqlen_k = kargs.seqlen_k_ptr[i_batch];
auto k_page_block_navigator = [&, i_batch_ = i_batch, i_nhead_ = i_nhead]() {
if constexpr(kIsPagedKV)

View File

@@ -19,11 +19,11 @@ struct FmhaFwdAppendKVTilePartitioner
CK_TILE_HOST static constexpr auto GridSize(ck_tile::index_t batch_size,
ck_tile::index_t nhead,
ck_tile::index_t max_seqlen_q,
ck_tile::index_t seqlen_q,
ck_tile::index_t seqlen_knew)
{
// TODO: this may need tuning
return dim3(std::max(ck_tile::integer_divide_ceil(max_seqlen_q, kM0),
return dim3(std::max(ck_tile::integer_divide_ceil(seqlen_q, kM0),
ck_tile::integer_divide_ceil(seqlen_knew, kN0)),
nhead,
batch_size);

View File

@@ -108,7 +108,6 @@ struct FmhaFwdSplitKVKernel
void* o_acc_ptr;
ck_tile::index_t batch;
ck_tile::index_t max_seqlen_q;
ck_tile::index_t seqlen_q;
ck_tile::index_t seqlen_k;
@@ -186,6 +185,8 @@ struct FmhaFwdSplitKVKernel
std::conditional_t<kHasMask, MaskKargs, EmptyKargs<1>>,
std::conditional_t<kDoFp8StaticQuant, Fp8StaticQuantKargs, EmptyKargs<2>>
{
const int32_t* seqlen_k_ptr;
ck_tile::index_t batch_stride_q;
ck_tile::index_t batch_stride_k;
ck_tile::index_t batch_stride_v;
@@ -220,9 +221,9 @@ struct FmhaFwdSplitKVKernel
void* lse_acc_ptr,
void* o_acc_ptr,
ck_tile::index_t batch,
ck_tile::index_t max_seqlen_q,
ck_tile::index_t seqlen_q,
ck_tile::index_t seqlen_k,
ck_tile::index_t seqlen_k, // only used if 'seqlen_k_ptr' is not specified
const void* seqlen_k_ptr, // only used for (paged-) kvcache
ck_tile::index_t hdim_q,
ck_tile::index_t hdim_v,
ck_tile::index_t num_head_q,
@@ -262,7 +263,6 @@ struct FmhaFwdSplitKVKernel
lse_acc_ptr,
o_acc_ptr,
batch,
max_seqlen_q,
seqlen_q,
seqlen_k,
hdim_q,
@@ -294,6 +294,7 @@ struct FmhaFwdSplitKVKernel
{}, // placeholder for bias
{}, // placeholder for mask
{}, // placeholder for fp8_static_quant args
reinterpret_cast<const int32_t*>(seqlen_k_ptr),
batch_stride_q,
batch_stride_k,
batch_stride_v};
@@ -333,7 +334,6 @@ struct FmhaFwdSplitKVKernel
void* lse_acc_ptr,
void* o_acc_ptr,
ck_tile::index_t batch,
ck_tile::index_t max_seqlen_q,
const void* seqstart_q_ptr,
const void* seqstart_k_ptr,
const void* seqlen_k_ptr,
@@ -374,9 +374,8 @@ struct FmhaFwdSplitKVKernel
lse_acc_ptr,
o_acc_ptr,
batch,
max_seqlen_q,
-1, // seqlen will be updated by another pointer
-1, //
-1, // seqlen_q will be updated by another pointer
-1, // seqlen_k will be updated by another pointer
hdim_q,
hdim_v,
num_head_q,
@@ -496,8 +495,7 @@ struct FmhaFwdSplitKVKernel
}
// get real # queries & # keys under group mode
const auto adjusted_seqstart_q_ptr = kargs.seqstart_q_ptr + i_batch;
kargs.seqlen_q = adjusted_seqstart_q_ptr[1] - adjusted_seqstart_q_ptr[0];
kargs.seqlen_q = kargs.seqstart_q_ptr[i_batch + 1] - kargs.seqstart_q_ptr[i_batch];
// # of required blocks is different in each groups, terminate unnecessary blocks
// earlier
@@ -512,8 +510,7 @@ struct FmhaFwdSplitKVKernel
}
else
{
const auto adjusted_seqstart_k_ptr = kargs.seqstart_k_ptr + i_batch;
kargs.seqlen_k = adjusted_seqstart_k_ptr[1] - adjusted_seqstart_k_ptr[0];
kargs.seqlen_k = kargs.seqstart_k_ptr[i_batch + 1] - kargs.seqstart_k_ptr[i_batch];
}
}
else
@@ -526,6 +523,11 @@ struct FmhaFwdSplitKVKernel
{
batch_offset_bias = static_cast<long_index_t>(i_batch) * kargs.batch_stride_bias;
}
if(kargs.seqlen_k_ptr != nullptr)
{
kargs.seqlen_k = kargs.seqlen_k_ptr[i_batch];
}
}
auto k_page_block_navigator = [&, i_batch_ = i_batch, i_nhead_ = i_nhead]() {

View File

@@ -27,7 +27,6 @@ struct BlockFmhaFwdAppendKVPipeline
static constexpr index_t kK0 = Problem::kK0;
static constexpr index_t kN1 = Problem::kN1;
static constexpr bool kIsGroupMode = Problem::kIsGroupMode;
static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ;
static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK;
static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ;

View File

@@ -139,7 +139,6 @@ template <typename QDataType_,
index_t kK0_,
index_t kN1_,
bool IsVLayoutRowMajor_,
bool kIsGroupMode_,
typename Traits_>
struct BlockFmhaFwdAppendKVPipelineProblem
{
@@ -149,7 +148,6 @@ struct BlockFmhaFwdAppendKVPipelineProblem
using Traits = remove_cvref_t<Traits_>;
static constexpr index_t kBlockSize = 256;
static constexpr bool kIsGroupMode = kIsGroupMode_;
static constexpr index_t kM0 = kM0_;
static constexpr index_t kN0 = kN0_;