mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-06 15:54:31 +00:00
Pass RoPE kernel args
This commit is contained in:
@@ -28,6 +28,7 @@ using fmha_trait_{F_idx} = ck_tile::TileFmhaFwdAppendKVTraits<{F_spad},
|
||||
{F_skpad},
|
||||
{F_dpad},
|
||||
{F_dvpad},
|
||||
{F_rope},
|
||||
{F_occupancy}>;
|
||||
|
||||
using fmha_pipeline_problem_{F_idx} = ck_tile::BlockFmhaFwdAppendKVPipelineProblem<
|
||||
@@ -50,7 +51,7 @@ using fmha_kernel_{F_idx} =
|
||||
fmha_pipeline_{F_idx}>;
|
||||
|
||||
using trait_{F_idx} = fmha_fwd_appendkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bs}, {F_bsk}, {F_bd}, {F_bdv}, {F_vlayout},
|
||||
{F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>;
|
||||
{F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_rope}>;
|
||||
|
||||
#include <iostream>
|
||||
|
||||
@@ -77,8 +78,8 @@ float fmha_fwd_appendkv(fmha_fwd_appendkv_traits t, fmha_fwd_appendkv_args a, co
|
||||
"""
|
||||
|
||||
FMHA_FWD_APPENDKV_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) &&
|
||||
({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck})) {{
|
||||
using trait_ = fmha_fwd_appendkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bs}, {F_bsk}, {F_bd}, {F_bdv}, {F_vlayout}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>;
|
||||
({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck}) && (t.apply_rope == {F_rope})) {{
|
||||
using trait_ = fmha_fwd_appendkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bs}, {F_bsk}, {F_bd}, {F_bdv}, {F_vlayout}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_rope}>;
|
||||
return fmha_fwd_appendkv_<trait_>(s, a);
|
||||
}}
|
||||
"""
|
||||
@@ -98,11 +99,12 @@ class FmhaFwdAppendKVApiTrait:
|
||||
skpad : str
|
||||
dpad : str
|
||||
dvpad : str
|
||||
rope : str # t/f, apply RoPE to Q/K or not
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return f'{self.hdim}-{self.dtype}-{self.mode}-{self.bs}-{self.bsk}-{self.bd}-{self.bdv}-'+\
|
||||
f'{self.vlayout}-{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}'
|
||||
f'{self.vlayout}-{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}-{self.rope}'
|
||||
|
||||
@property
|
||||
def scheck(self) -> str:
|
||||
@@ -133,6 +135,7 @@ class FmhaFwdAppendKVPipeline:
|
||||
F_skpad : str #
|
||||
F_dpad : str #
|
||||
F_dvpad : str #
|
||||
F_rope : str # t/f, apply RoPE to Q/K or not
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
@@ -147,6 +150,7 @@ class FmhaFwdAppendKVPipeline:
|
||||
pn = pad_name()
|
||||
n = f'v{self.F_vlayout[0]}'
|
||||
if pn != '' : n += f'_{pn}'
|
||||
if self.F_rope == 't': n += '_rope'
|
||||
return n
|
||||
|
||||
class FmhaFwdAppendKVApiPool:
|
||||
@@ -176,7 +180,7 @@ class FmhaFwdAppendKVApiPool:
|
||||
inners = inners + FMHA_FWD_APPENDKV_API_INNER_DISPATCH.format(F_if=if_k, F_mode=MODE_MAP[trait.mode], F_vlayout=LAYOUT_MAP[trait.vlayout],
|
||||
F_scheck=trait.scheck, F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck,
|
||||
F_spad=BOOL_MAP[trait.spad], F_skpad=BOOL_MAP[trait.skpad], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad],
|
||||
F_bs=trait.bs, F_bsk=trait.bsk, F_bd=trait.bd, F_bdv=trait.bdv, F_hdim=hdim, F_dtype=DTYPE_MAP[dtype])
|
||||
F_rope=BOOL_MAP[trait.rope], F_bs=trait.bs, F_bsk=trait.bsk, F_bd=trait.bd, F_bdv=trait.bdv, F_hdim=hdim, F_dtype=DTYPE_MAP[dtype])
|
||||
if_j = 'if' if j == 0 else 'else if'
|
||||
per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format(F_if=if_j, F_hdim=hdim, F_inner_dispatch=inners)
|
||||
if_i = 'if' if i == 0 else 'else if'
|
||||
@@ -222,6 +226,7 @@ class FmhaFwdAppendKVKernel:
|
||||
F_skpad = BOOL_MAP[self.F_pipeline.F_skpad],
|
||||
F_dpad = BOOL_MAP[self.F_pipeline.F_dpad],
|
||||
F_dvpad = BOOL_MAP[self.F_pipeline.F_dvpad],
|
||||
F_rope = BOOL_MAP[self.F_pipeline.F_rope],
|
||||
F_occupancy = self.F_tile.F_occupancy,
|
||||
F_mode = MODE_MAP[self.F_mode])
|
||||
|
||||
@@ -248,7 +253,8 @@ class FmhaFwdAppendKVKernel:
|
||||
spad=self.F_pipeline.F_spad,
|
||||
skpad=self.F_pipeline.F_skpad,
|
||||
dpad=self.F_pipeline.F_dpad,
|
||||
dvpad=self.F_pipeline.F_dvpad)
|
||||
dvpad=self.F_pipeline.F_dvpad,
|
||||
rope=self.F_pipeline.F_rope)
|
||||
|
||||
# TODO: design a more practical way to do it
|
||||
# this is current supported tile size per hdim
|
||||
@@ -280,15 +286,15 @@ def get_fwd_appendkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) ->
|
||||
squant = 't' if dtype == 'fp8' else 'f'
|
||||
pipelines = []
|
||||
if dtype in ['fp16', 'bf16']:
|
||||
# pipelines.append(FmhaFwdAppendKVPipeline('row', 'f', 'f', 'f', 'f'))
|
||||
# pipelines.append(FmhaFwdAppendKVPipeline('col', 'f', 'f', 'f', 'f'))
|
||||
|
||||
pipelines.append(FmhaFwdAppendKVPipeline('row', 't', 't', 't', 't'))
|
||||
pipelines.append(FmhaFwdAppendKVPipeline('col', 't', 't', 't', 't'))
|
||||
for rope in ["t", "f"]:
|
||||
# pipelines.append(FmhaFwdAppendKVPipeline('row', 'f', 'f', 'f', 'f', rope))
|
||||
# pipelines.append(FmhaFwdAppendKVPipeline('col', 'f', 'f', 'f', 'f', rope))
|
||||
|
||||
pipelines.append(FmhaFwdAppendKVPipeline('row', 't', 't', 't', 't', rope))
|
||||
pipelines.append(FmhaFwdAppendKVPipeline('col', 't', 't', 't', 't', rope))
|
||||
elif dtype in ['fp8', 'bf8']:
|
||||
# no need lse/dropout kernels
|
||||
pipelines.append(FmhaFwdAppendKVPipeline('col', 'f', 'f', 'f', 'f'))
|
||||
pipelines.append(FmhaFwdAppendKVPipeline('col', 'f', 'f', 'f', 'f', 'f'))
|
||||
else:
|
||||
assert False
|
||||
return pipelines
|
||||
|
||||
@@ -676,7 +676,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
if(0 < seqlen_knew)
|
||||
{
|
||||
auto appendkv_traits = fmha_fwd_appendkv_traits{
|
||||
hdim_q, hdim_v, data_type, mode == mode_enum::group, is_v_rowmajor};
|
||||
hdim_q, hdim_v, data_type, mode == mode_enum::group, is_v_rowmajor, 0 < rotary_dim};
|
||||
|
||||
auto appendkv_args = [&, k_paddings_ = seqlen_kpads]() {
|
||||
// setup stride_* arguments
|
||||
|
||||
@@ -479,6 +479,10 @@ auto fmha_fwd_appendkv_create_kargs_and_grids(fmha_fwd_appendkv_args args)
|
||||
args.hdim_v,
|
||||
args.nhead_q,
|
||||
args.nhead_q / args.nhead_k,
|
||||
args.rotary_cos_ptr,
|
||||
args.rotary_sin_ptr,
|
||||
args.rotary_dim,
|
||||
args.is_rotary_interleaved,
|
||||
args.stride_q,
|
||||
args.stride_k,
|
||||
args.stride_knew,
|
||||
@@ -506,6 +510,10 @@ auto fmha_fwd_appendkv_create_kargs_and_grids(fmha_fwd_appendkv_args args)
|
||||
args.hdim_v,
|
||||
args.nhead_q,
|
||||
args.nhead_q / args.nhead_k,
|
||||
args.rotary_cos_ptr,
|
||||
args.rotary_sin_ptr,
|
||||
args.rotary_dim,
|
||||
args.is_rotary_interleaved,
|
||||
args.stride_q,
|
||||
args.stride_k,
|
||||
args.stride_knew,
|
||||
@@ -623,11 +631,11 @@ template <ck_tile::index_t HDim_,
|
||||
ck_tile::index_t kTileSizeD_,
|
||||
ck_tile::index_t kTileSizeDv_,
|
||||
bool kIsVLayoutRowMajor_,
|
||||
// bool kApplyRotray_,
|
||||
bool kPadS_,
|
||||
bool kPadSk_,
|
||||
bool kPadD_,
|
||||
bool kPadDv_>
|
||||
bool kPadDv_,
|
||||
bool kApplyRoPE_>
|
||||
struct fmha_fwd_appendkv_traits_
|
||||
{
|
||||
static constexpr ck_tile::index_t HDim = HDim_;
|
||||
@@ -638,11 +646,11 @@ struct fmha_fwd_appendkv_traits_
|
||||
static constexpr ck_tile::index_t kTileSizeD = kTileSizeD_;
|
||||
static constexpr ck_tile::index_t kTileSizeDv = kTileSizeDv_;
|
||||
static constexpr bool kIsVLayoutRowMajor = kIsVLayoutRowMajor_;
|
||||
// static constexpr bool kApplyRotray = kApplyRotray_;
|
||||
static constexpr bool kPadS = kPadS_;
|
||||
static constexpr bool kPadSk = kPadSk_;
|
||||
static constexpr bool kPadD = kPadD_;
|
||||
static constexpr bool kPadDv = kPadDv_;
|
||||
static constexpr bool kPadS = kPadS_;
|
||||
static constexpr bool kPadSk = kPadSk_;
|
||||
static constexpr bool kPadD = kPadD_;
|
||||
static constexpr bool kPadDv = kPadDv_;
|
||||
static constexpr bool kApplyRoPE = kApplyRoPE_;
|
||||
};
|
||||
|
||||
template <typename Traits_>
|
||||
@@ -673,6 +681,7 @@ struct fmha_fwd_appendkv_traits
|
||||
std::string data_type;
|
||||
bool is_group_mode;
|
||||
bool is_v_rowmajor;
|
||||
bool apply_rope;
|
||||
};
|
||||
float fmha_fwd_appendkv(fmha_fwd_appendkv_traits,
|
||||
fmha_fwd_appendkv_args,
|
||||
|
||||
@@ -31,6 +31,7 @@ struct FmhaFwdAppendKVKernel
|
||||
static constexpr bool kPadSeqLenK = FmhaPipeline::kPadSeqLenK;
|
||||
static constexpr bool kPadHeadDimQ = FmhaPipeline::kPadHeadDimQ;
|
||||
static constexpr bool kPadHeadDimV = FmhaPipeline::kPadHeadDimV;
|
||||
static constexpr bool kApplyRoPE = FmhaPipeline::kApplyRoPE;
|
||||
|
||||
// clang-format off
|
||||
template <typename T> struct t2s;
|
||||
@@ -43,8 +44,9 @@ struct FmhaFwdAppendKVKernel
|
||||
|
||||
__host__ static std::string GetName()
|
||||
{
|
||||
// sync with generate.py
|
||||
// clang-format off
|
||||
// sync with generate.py
|
||||
// clang-format off
|
||||
|
||||
#define _SS_ std::string
|
||||
#define _TS_ std::to_string
|
||||
auto pn = [&] () {
|
||||
@@ -59,7 +61,8 @@ struct FmhaFwdAppendKVKernel
|
||||
"_" + (kIsGroupMode ? "group" : "batch") + "_"
|
||||
"b" + _TS_(FmhaPipeline::kTileSizeS) + "x" + _TS_(FmhaPipeline::kTileSizeSk) + "x" + _TS_(FmhaPipeline::kTileSizeD) + "x" +
|
||||
_TS_(FmhaPipeline::kTileSizeDv) + "_" + (kBlockPerCuInput == -1 ? "" : ("o" + _TS_(kBlockPerCu) + "_")) +
|
||||
"v" + (std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor> ? "r" : "c") + (pn.empty() ? "" : "_" + pn);
|
||||
"v" + (std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor> ? "r" : "c") + (pn.empty() ? "" : "_" + pn)
|
||||
+ (kApplyRoPE ? "_rope" : "");
|
||||
#undef _SS_
|
||||
#undef _TS_
|
||||
// clang-format on
|
||||
@@ -109,14 +112,24 @@ struct FmhaFwdAppendKVKernel
|
||||
ck_tile::index_t batch_stride_vnew;
|
||||
};
|
||||
|
||||
struct BatchModeKargs : CommonKargs
|
||||
struct CommonRoPEKargs
|
||||
{
|
||||
const void* rotary_cos_ptr;
|
||||
const void* rotary_sin_ptr;
|
||||
ck_tile::index_t rotary_dim;
|
||||
bool is_rotary_interleaved;
|
||||
};
|
||||
|
||||
struct BatchModeKargs : CommonKargs,
|
||||
std::conditional_t<kApplyRoPE, CommonRoPEKargs, EmptyKargs<0>>
|
||||
{
|
||||
ck_tile::index_t batch_stride_q;
|
||||
ck_tile::index_t batch_stride_k;
|
||||
ck_tile::index_t batch_stride_v;
|
||||
};
|
||||
|
||||
struct GroupModeKargs : CommonKargs
|
||||
struct GroupModeKargs : CommonKargs,
|
||||
std::conditional_t<kApplyRoPE, CommonRoPEKargs, EmptyKargs<0>>
|
||||
{
|
||||
const int32_t* seqstart_q_ptr;
|
||||
const int32_t* seqstart_k_ptr;
|
||||
@@ -139,6 +152,10 @@ struct FmhaFwdAppendKVKernel
|
||||
ck_tile::index_t hdim_v,
|
||||
ck_tile::index_t num_head_q,
|
||||
ck_tile::index_t nhead_ratio_qk,
|
||||
const void* rotary_cos_ptr,
|
||||
const void* rotary_sin_ptr,
|
||||
ck_tile::index_t rotary_dim,
|
||||
bool is_rotary_interleaved,
|
||||
ck_tile::index_t stride_q,
|
||||
ck_tile::index_t stride_k,
|
||||
ck_tile::index_t stride_knew,
|
||||
@@ -179,10 +196,19 @@ struct FmhaFwdAppendKVKernel
|
||||
nhead_stride_vnew,
|
||||
batch_stride_knew,
|
||||
batch_stride_vnew}, // args for common karg
|
||||
{}, // placeholder for rope
|
||||
batch_stride_q,
|
||||
batch_stride_k,
|
||||
batch_stride_v};
|
||||
|
||||
if constexpr(kApplyRoPE)
|
||||
{
|
||||
kargs.rotary_cos_ptr = rotary_cos_ptr;
|
||||
kargs.rotary_sin_ptr = rotary_sin_ptr;
|
||||
kargs.rotary_dim = rotary_dim;
|
||||
kargs.is_rotary_interleaved = is_rotary_interleaved;
|
||||
}
|
||||
|
||||
return kargs;
|
||||
}
|
||||
|
||||
@@ -201,6 +227,10 @@ struct FmhaFwdAppendKVKernel
|
||||
ck_tile::index_t hdim_v,
|
||||
ck_tile::index_t num_head_q,
|
||||
ck_tile::index_t nhead_ratio_qk,
|
||||
const void* rotary_cos_ptr,
|
||||
const void* rotary_sin_ptr,
|
||||
ck_tile::index_t rotary_dim,
|
||||
bool is_rotary_interleaved,
|
||||
ck_tile::index_t stride_q,
|
||||
ck_tile::index_t stride_k,
|
||||
ck_tile::index_t stride_knew,
|
||||
@@ -238,10 +268,19 @@ struct FmhaFwdAppendKVKernel
|
||||
nhead_stride_vnew,
|
||||
batch_stride_knew,
|
||||
batch_stride_vnew}, // args for common karg
|
||||
{}, // placeholder for rope
|
||||
reinterpret_cast<const int32_t*>(seqstart_q_ptr),
|
||||
reinterpret_cast<const int32_t*>(seqstart_k_ptr),
|
||||
reinterpret_cast<const int32_t*>(seqlen_k_ptr)};
|
||||
|
||||
if constexpr(kApplyRoPE)
|
||||
{
|
||||
kargs.rotary_cos_ptr = rotary_cos_ptr;
|
||||
kargs.rotary_sin_ptr = rotary_sin_ptr;
|
||||
kargs.rotary_dim = rotary_dim;
|
||||
kargs.is_rotary_interleaved = is_rotary_interleaved;
|
||||
}
|
||||
|
||||
return kargs;
|
||||
}
|
||||
|
||||
|
||||
@@ -31,6 +31,7 @@ struct BlockFmhaFwdAppendKVPipeline
|
||||
static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK;
|
||||
static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ;
|
||||
static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV;
|
||||
static constexpr bool kApplyRoPE = Problem::kApplyRoPE;
|
||||
|
||||
// last dimension vector length used to create tensor view(and decide buffer_load vector length)
|
||||
// ... together with tensor distribution. tensor dist should able to overwrite this
|
||||
|
||||
@@ -41,6 +41,7 @@ struct BlockFmhaFwdAppendKVPipelineProblem
|
||||
static constexpr bool kPadSeqLenK = Traits::kPadSeqLenK;
|
||||
static constexpr bool kPadHeadDimQ = Traits::kPadHeadDimQ;
|
||||
static constexpr bool kPadHeadDimV = Traits::kPadHeadDimV;
|
||||
static constexpr bool kApplyRoPE = Traits::kApplyRoPE;
|
||||
static constexpr index_t kBlockPerCu = Traits::kBlockPerCu;
|
||||
};
|
||||
|
||||
|
||||
@@ -80,6 +80,7 @@ template <bool kPadSeqLenQ_ /* padding for seqlen_q */,
|
||||
bool kPadSeqLenK_ /* padding for seqlen_k */,
|
||||
bool kPadHeadDimQ_ /* paddding for hdim_q */,
|
||||
bool kPadHeadDimV_ /* paddding for hdim_v */,
|
||||
bool kApplyRoPE_ /* apply RoPE to Q/K or not */,
|
||||
index_t kBlockPerCu_ = -1 /* overwrite occupancy if not -1 */>
|
||||
struct TileFmhaFwdAppendKVTraits
|
||||
{
|
||||
@@ -87,6 +88,7 @@ struct TileFmhaFwdAppendKVTraits
|
||||
static constexpr bool kPadSeqLenK = kPadSeqLenK_;
|
||||
static constexpr bool kPadHeadDimQ = kPadHeadDimQ_;
|
||||
static constexpr bool kPadHeadDimV = kPadHeadDimV_;
|
||||
static constexpr bool kApplyRoPE = kApplyRoPE_;
|
||||
static constexpr index_t kBlockPerCu = kBlockPerCu_;
|
||||
};
|
||||
|
||||
|
||||
Reference in New Issue
Block a user