Pass RoPE kernel args

This commit is contained in:
PoYen, Chen
2024-07-14 23:18:32 +00:00
parent b5ad1411b0
commit 391210ed9e
7 changed files with 83 additions and 25 deletions

View File

@@ -28,6 +28,7 @@ using fmha_trait_{F_idx} = ck_tile::TileFmhaFwdAppendKVTraits<{F_spad},
{F_skpad},
{F_dpad},
{F_dvpad},
{F_rope},
{F_occupancy}>;
using fmha_pipeline_problem_{F_idx} = ck_tile::BlockFmhaFwdAppendKVPipelineProblem<
@@ -50,7 +51,7 @@ using fmha_kernel_{F_idx} =
fmha_pipeline_{F_idx}>;
using trait_{F_idx} = fmha_fwd_appendkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bs}, {F_bsk}, {F_bd}, {F_bdv}, {F_vlayout},
{F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>;
{F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_rope}>;
#include <iostream>
@@ -77,8 +78,8 @@ float fmha_fwd_appendkv(fmha_fwd_appendkv_traits t, fmha_fwd_appendkv_args a, co
"""
FMHA_FWD_APPENDKV_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) &&
({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck})) {{
using trait_ = fmha_fwd_appendkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bs}, {F_bsk}, {F_bd}, {F_bdv}, {F_vlayout}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>;
({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck}) && (t.apply_rope == {F_rope})) {{
using trait_ = fmha_fwd_appendkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bs}, {F_bsk}, {F_bd}, {F_bdv}, {F_vlayout}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_rope}>;
return fmha_fwd_appendkv_<trait_>(s, a);
}}
"""
@@ -98,11 +99,12 @@ class FmhaFwdAppendKVApiTrait:
skpad : str
dpad : str
dvpad : str
rope : str # t/f, apply RoPE to Q/K or not
@property
def name(self) -> str:
return f'{self.hdim}-{self.dtype}-{self.mode}-{self.bs}-{self.bsk}-{self.bd}-{self.bdv}-'+\
f'{self.vlayout}-{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}'
f'{self.vlayout}-{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}-{self.rope}'
@property
def scheck(self) -> str:
@@ -133,6 +135,7 @@ class FmhaFwdAppendKVPipeline:
F_skpad : str #
F_dpad : str #
F_dvpad : str #
F_rope : str # t/f, apply RoPE to Q/K or not
@property
def name(self) -> str:
@@ -147,6 +150,7 @@ class FmhaFwdAppendKVPipeline:
pn = pad_name()
n = f'v{self.F_vlayout[0]}'
if pn != '' : n += f'_{pn}'
if self.F_rope == 't': n += '_rope'
return n
class FmhaFwdAppendKVApiPool:
@@ -176,7 +180,7 @@ class FmhaFwdAppendKVApiPool:
inners = inners + FMHA_FWD_APPENDKV_API_INNER_DISPATCH.format(F_if=if_k, F_mode=MODE_MAP[trait.mode], F_vlayout=LAYOUT_MAP[trait.vlayout],
F_scheck=trait.scheck, F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck,
F_spad=BOOL_MAP[trait.spad], F_skpad=BOOL_MAP[trait.skpad], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad],
F_bs=trait.bs, F_bsk=trait.bsk, F_bd=trait.bd, F_bdv=trait.bdv, F_hdim=hdim, F_dtype=DTYPE_MAP[dtype])
F_rope=BOOL_MAP[trait.rope], F_bs=trait.bs, F_bsk=trait.bsk, F_bd=trait.bd, F_bdv=trait.bdv, F_hdim=hdim, F_dtype=DTYPE_MAP[dtype])
if_j = 'if' if j == 0 else 'else if'
per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format(F_if=if_j, F_hdim=hdim, F_inner_dispatch=inners)
if_i = 'if' if i == 0 else 'else if'
@@ -222,6 +226,7 @@ class FmhaFwdAppendKVKernel:
F_skpad = BOOL_MAP[self.F_pipeline.F_skpad],
F_dpad = BOOL_MAP[self.F_pipeline.F_dpad],
F_dvpad = BOOL_MAP[self.F_pipeline.F_dvpad],
F_rope = BOOL_MAP[self.F_pipeline.F_rope],
F_occupancy = self.F_tile.F_occupancy,
F_mode = MODE_MAP[self.F_mode])
@@ -248,7 +253,8 @@ class FmhaFwdAppendKVKernel:
spad=self.F_pipeline.F_spad,
skpad=self.F_pipeline.F_skpad,
dpad=self.F_pipeline.F_dpad,
dvpad=self.F_pipeline.F_dvpad)
dvpad=self.F_pipeline.F_dvpad,
rope=self.F_pipeline.F_rope)
# TODO: design a more practical way to do it
# this is current supported tile size per hdim
@@ -280,15 +286,15 @@ def get_fwd_appendkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) ->
squant = 't' if dtype == 'fp8' else 'f'
pipelines = []
if dtype in ['fp16', 'bf16']:
# pipelines.append(FmhaFwdAppendKVPipeline('row', 'f', 'f', 'f', 'f'))
# pipelines.append(FmhaFwdAppendKVPipeline('col', 'f', 'f', 'f', 'f'))
pipelines.append(FmhaFwdAppendKVPipeline('row', 't', 't', 't', 't'))
pipelines.append(FmhaFwdAppendKVPipeline('col', 't', 't', 't', 't'))
for rope in ["t", "f"]:
# pipelines.append(FmhaFwdAppendKVPipeline('row', 'f', 'f', 'f', 'f', rope))
# pipelines.append(FmhaFwdAppendKVPipeline('col', 'f', 'f', 'f', 'f', rope))
pipelines.append(FmhaFwdAppendKVPipeline('row', 't', 't', 't', 't', rope))
pipelines.append(FmhaFwdAppendKVPipeline('col', 't', 't', 't', 't', rope))
elif dtype in ['fp8', 'bf8']:
# no need lse/dropout kernels
pipelines.append(FmhaFwdAppendKVPipeline('col', 'f', 'f', 'f', 'f'))
pipelines.append(FmhaFwdAppendKVPipeline('col', 'f', 'f', 'f', 'f', 'f'))
else:
assert False
return pipelines

View File

@@ -676,7 +676,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
if(0 < seqlen_knew)
{
auto appendkv_traits = fmha_fwd_appendkv_traits{
hdim_q, hdim_v, data_type, mode == mode_enum::group, is_v_rowmajor};
hdim_q, hdim_v, data_type, mode == mode_enum::group, is_v_rowmajor, 0 < rotary_dim};
auto appendkv_args = [&, k_paddings_ = seqlen_kpads]() {
// setup stride_* arguments

View File

@@ -479,6 +479,10 @@ auto fmha_fwd_appendkv_create_kargs_and_grids(fmha_fwd_appendkv_args args)
args.hdim_v,
args.nhead_q,
args.nhead_q / args.nhead_k,
args.rotary_cos_ptr,
args.rotary_sin_ptr,
args.rotary_dim,
args.is_rotary_interleaved,
args.stride_q,
args.stride_k,
args.stride_knew,
@@ -506,6 +510,10 @@ auto fmha_fwd_appendkv_create_kargs_and_grids(fmha_fwd_appendkv_args args)
args.hdim_v,
args.nhead_q,
args.nhead_q / args.nhead_k,
args.rotary_cos_ptr,
args.rotary_sin_ptr,
args.rotary_dim,
args.is_rotary_interleaved,
args.stride_q,
args.stride_k,
args.stride_knew,
@@ -623,11 +631,11 @@ template <ck_tile::index_t HDim_,
ck_tile::index_t kTileSizeD_,
ck_tile::index_t kTileSizeDv_,
bool kIsVLayoutRowMajor_,
// bool kApplyRotray_,
bool kPadS_,
bool kPadSk_,
bool kPadD_,
bool kPadDv_>
bool kPadDv_,
bool kApplyRoPE_>
struct fmha_fwd_appendkv_traits_
{
static constexpr ck_tile::index_t HDim = HDim_;
@@ -638,11 +646,11 @@ struct fmha_fwd_appendkv_traits_
static constexpr ck_tile::index_t kTileSizeD = kTileSizeD_;
static constexpr ck_tile::index_t kTileSizeDv = kTileSizeDv_;
static constexpr bool kIsVLayoutRowMajor = kIsVLayoutRowMajor_;
// static constexpr bool kApplyRotray = kApplyRotray_;
static constexpr bool kPadS = kPadS_;
static constexpr bool kPadSk = kPadSk_;
static constexpr bool kPadD = kPadD_;
static constexpr bool kPadDv = kPadDv_;
static constexpr bool kPadS = kPadS_;
static constexpr bool kPadSk = kPadSk_;
static constexpr bool kPadD = kPadD_;
static constexpr bool kPadDv = kPadDv_;
static constexpr bool kApplyRoPE = kApplyRoPE_;
};
template <typename Traits_>
@@ -673,6 +681,7 @@ struct fmha_fwd_appendkv_traits
std::string data_type;
bool is_group_mode;
bool is_v_rowmajor;
bool apply_rope;
};
float fmha_fwd_appendkv(fmha_fwd_appendkv_traits,
fmha_fwd_appendkv_args,