Remove group mode from appendkv kernel

This commit is contained in:
PoYen, Chen
2024-08-16 10:04:48 +00:00
parent 9de0f35ebc
commit 5805f5aa73
8 changed files with 95 additions and 281 deletions

View File

@@ -26,7 +26,6 @@ struct FmhaFwdAppendKVKernel
using VLayout = ck_tile::remove_cvref_t<typename FmhaPipeline::VLayout>;
static constexpr bool kIsGroupMode = FmhaPipeline::kIsGroupMode;
static constexpr bool kPadSeqLenQ = FmhaPipeline::kPadSeqLenQ;
static constexpr bool kPadSeqLenK = FmhaPipeline::kPadSeqLenK;
static constexpr bool kPadHeadDimQ = FmhaPipeline::kPadHeadDimQ;
@@ -58,8 +57,7 @@ struct FmhaFwdAppendKVKernel
if (kPadHeadDimV) n += "dv";
return n.empty() ? n : std::string("p") + n; }();
return
_SS_("fmha_fwd_appendkv_d") + _TS_(FmhaPipeline::kK0) + "_" + _SS_(t2s<QDataType>::name) +
"_" + (kIsGroupMode ? "group" : "batch") + "_"
_SS_("fmha_fwd_appendkv_d") + _TS_(FmhaPipeline::kK0) + "_" + _SS_(t2s<QDataType>::name) + "_"
"b" + _TS_(FmhaPipeline::kM0) + "x" + _TS_(FmhaPipeline::kN0) + "x" + _TS_(FmhaPipeline::kK0) + "x" +
_TS_(FmhaPipeline::kN1) + "_" + (kBlockPerCuInput == -1 ? "" : ("o" + _TS_(kBlockPerCu) + "_")) +
"v" + (std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor> ? "r" : "c") + (pn.empty() ? "" : "_" + pn)
@@ -79,7 +77,7 @@ struct FmhaFwdAppendKVKernel
// kargs use aggregate initializer, so no constructor will provided
// use inheritance to minimize karg size
// user need to use MakeKargs() function to create kargs.
struct CommonKargs
struct BasicKargs
{
void* q_ptr;
void* k_ptr;
@@ -87,6 +85,8 @@ struct FmhaFwdAppendKVKernel
void* v_ptr;
const void* vnew_ptr;
const int32_t* seqlen_k_ptr;
ck_tile::index_t seqlen_q;
ck_tile::index_t seqlen_k;
ck_tile::index_t seqlen_knew;
@@ -114,47 +114,32 @@ struct FmhaFwdAppendKVKernel
ck_tile::index_t nhead_stride_v;
ck_tile::index_t nhead_stride_vnew;
ck_tile::index_t batch_stride_q;
ck_tile::index_t batch_stride_k;
ck_tile::index_t batch_stride_knew;
ck_tile::index_t batch_stride_v;
ck_tile::index_t batch_stride_vnew;
};
struct CommonRoPEKargs
struct RoPEKargs
{
const void* rotary_cos_ptr;
const void* rotary_sin_ptr;
ck_tile::index_t rotary_dim;
};
struct BatchModeKargs : CommonKargs,
std::conditional_t<kApplyRoPE, CommonRoPEKargs, EmptyKargs<0>>
{
ck_tile::index_t batch_stride_q;
ck_tile::index_t batch_stride_k;
ck_tile::index_t batch_stride_v;
};
struct Kargs : BasicKargs,
std::conditional_t<kApplyRoPE, RoPEKargs, EmptyKargs<0>>
{};
struct GroupModeKargs : CommonKargs,
std::conditional_t<kApplyRoPE, CommonRoPEKargs, EmptyKargs<0>>
{
const int32_t* seqstart_q_ptr;
const int32_t* seqstart_k_ptr;
const int32_t* seqlen_k_ptr;
ck_tile::index_t batch_stride_k;
ck_tile::index_t batch_stride_v;
};
using Kargs = std::conditional_t<kIsGroupMode, GroupModeKargs, BatchModeKargs>;
template <bool Cond = !kIsGroupMode>
__host__ static constexpr std::enable_if_t<Cond, Kargs>
__host__ static constexpr Kargs
MakeKargs(void* q_ptr,
void* k_ptr,
const void* knew_ptr,
void* v_ptr,
const void* vnew_ptr,
ck_tile::index_t seqlen_q,
ck_tile::index_t seqlen_k,
const void* seqlen_k_ptr,
ck_tile::index_t seqlen_knew,
ck_tile::index_t hdim_q,
ck_tile::index_t hdim_v,
@@ -187,8 +172,9 @@ struct FmhaFwdAppendKVKernel
knew_ptr,
v_ptr,
vnew_ptr,
reinterpret_cast<const int32_t*>(seqlen_k_ptr),
seqlen_q,
seqlen_k,
-1, // seqlen_k will be updated by content of seqlen_k_ptr
seqlen_knew,
hdim_q,
hdim_v,
@@ -207,92 +193,13 @@ struct FmhaFwdAppendKVKernel
nhead_stride_knew,
nhead_stride_v,
nhead_stride_vnew,
batch_stride_q,
batch_stride_k,
batch_stride_knew,
batch_stride_v,
batch_stride_vnew}, // args for common karg
{}, // placeholder for rope
batch_stride_q,
batch_stride_k,
batch_stride_v};
if constexpr(kApplyRoPE)
{
kargs.rotary_cos_ptr = rotary_cos_ptr;
kargs.rotary_sin_ptr = rotary_sin_ptr;
kargs.rotary_dim = rotary_dim;
}
return kargs;
}
template <bool Cond = kIsGroupMode>
__host__ static constexpr std::enable_if_t<Cond, Kargs>
MakeKargs(void* q_ptr,
void* k_ptr,
const void* knew_ptr,
void* v_ptr,
const void* vnew_ptr,
const void* seqstart_q_ptr,
const void* seqstart_k_ptr,
const void* seqlen_k_ptr,
ck_tile::index_t seqlen_knew,
ck_tile::index_t hdim_q,
ck_tile::index_t hdim_v,
ck_tile::index_t num_head_q,
ck_tile::index_t nhead_ratio_qk,
const void* rotary_cos_ptr,
const void* rotary_sin_ptr,
ck_tile::index_t rotary_dim,
const void* block_table_ptr,
ck_tile::index_t batch_stride_block_table,
ck_tile::index_t page_block_size,
ck_tile::index_t stride_q,
ck_tile::index_t stride_k,
ck_tile::index_t stride_knew,
ck_tile::index_t stride_v,
ck_tile::index_t stride_vnew,
ck_tile::index_t nhead_stride_q,
ck_tile::index_t nhead_stride_k,
ck_tile::index_t nhead_stride_knew,
ck_tile::index_t nhead_stride_v,
ck_tile::index_t nhead_stride_vnew,
ck_tile::index_t batch_stride_k,
ck_tile::index_t batch_stride_knew,
ck_tile::index_t batch_stride_v,
ck_tile::index_t batch_stride_vnew)
{
Kargs kargs{{q_ptr,
k_ptr,
knew_ptr,
v_ptr,
vnew_ptr,
-1, // seqlen will be updated by another pointer
-1, //
seqlen_knew,
hdim_q,
hdim_v,
num_head_q,
nhead_ratio_qk,
block_table_ptr,
batch_stride_block_table,
page_block_size,
stride_q,
stride_k,
stride_knew,
stride_v,
stride_vnew,
nhead_stride_q,
nhead_stride_k,
nhead_stride_knew,
nhead_stride_v,
nhead_stride_vnew,
batch_stride_knew,
batch_stride_vnew}, // args for common karg
{}, // placeholder for rope
reinterpret_cast<const int32_t*>(seqstart_q_ptr),
reinterpret_cast<const int32_t*>(seqstart_k_ptr),
reinterpret_cast<const int32_t*>(seqlen_k_ptr),
batch_stride_k,
batch_stride_v};
{} // placeholder for rope
};
if constexpr(kApplyRoPE)
{
@@ -322,51 +229,15 @@ struct FmhaFwdAppendKVKernel
const index_t i_m0 = __builtin_amdgcn_readfirstlane(i_tile * FmhaPipeline::kM0);
const index_t i_n0 = __builtin_amdgcn_readfirstlane(i_tile * FmhaPipeline::kN0);
long_index_t batch_offset_q = 0;
long_index_t batch_offset_k = 0;
long_index_t batch_offset_knew =
const long_index_t batch_offset_q = static_cast<long_index_t>(i_batch) * kargs.batch_stride_q;
const long_index_t batch_offset_k = static_cast<long_index_t>(i_batch) * kargs.batch_stride_k;
const long_index_t batch_offset_knew =
static_cast<long_index_t>(i_batch) * kargs.batch_stride_knew;
long_index_t batch_offset_v = 0;
long_index_t batch_offset_vnew =
const long_index_t batch_offset_v = static_cast<long_index_t>(i_batch) * kargs.batch_stride_v;
const long_index_t batch_offset_vnew =
static_cast<long_index_t>(i_batch) * kargs.batch_stride_vnew;
if constexpr(kIsGroupMode)
{
// get starting offset for each batch
const long_index_t query_start = kargs.seqstart_q_ptr[i_batch];
const long_index_t key_start = kargs.seqstart_k_ptr[i_batch];
batch_offset_q = query_start * kargs.stride_q;
batch_offset_k = key_start * kargs.stride_k;
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
{
batch_offset_v = key_start * kargs.stride_v;
}
else
{
batch_offset_v = key_start;
}
// get real # queries & # keys under group mode
const auto adjusted_seqstart_q_ptr = kargs.seqstart_q_ptr + i_batch;
kargs.seqlen_q = adjusted_seqstart_q_ptr[1] - adjusted_seqstart_q_ptr[0];
if(kargs.seqlen_k_ptr != nullptr)
{
kargs.seqlen_k = kargs.seqlen_k_ptr[i_batch];
}
else
{
const auto adjusted_seqstart_k_ptr = kargs.seqstart_k_ptr + i_batch;
kargs.seqlen_k = adjusted_seqstart_k_ptr[1] - adjusted_seqstart_k_ptr[0];
}
}
else
{
batch_offset_q = static_cast<long_index_t>(i_batch) * kargs.batch_stride_q;
batch_offset_k = static_cast<long_index_t>(i_batch) * kargs.batch_stride_k;
batch_offset_v = static_cast<long_index_t>(i_batch) * kargs.batch_stride_v;
}
kargs.seqlen_k = kargs.seqlen_k_ptr[i_batch];
auto k_page_block_navigator = [&, i_batch_ = i_batch, i_nhead_ = i_nhead]() {
if constexpr(kIsPagedKV)

View File

@@ -19,11 +19,11 @@ struct FmhaFwdAppendKVTilePartitioner
CK_TILE_HOST static constexpr auto GridSize(ck_tile::index_t batch_size,
ck_tile::index_t nhead,
ck_tile::index_t max_seqlen_q,
ck_tile::index_t seqlen_q,
ck_tile::index_t seqlen_knew)
{
// TODO: this may need tuning
return dim3(std::max(ck_tile::integer_divide_ceil(max_seqlen_q, kM0),
return dim3(std::max(ck_tile::integer_divide_ceil(seqlen_q, kM0),
ck_tile::integer_divide_ceil(seqlen_knew, kN0)),
nhead,
batch_size);

View File

@@ -108,7 +108,6 @@ struct FmhaFwdSplitKVKernel
void* o_acc_ptr;
ck_tile::index_t batch;
ck_tile::index_t max_seqlen_q;
ck_tile::index_t seqlen_q;
ck_tile::index_t seqlen_k;
@@ -186,6 +185,8 @@ struct FmhaFwdSplitKVKernel
std::conditional_t<kHasMask, MaskKargs, EmptyKargs<1>>,
std::conditional_t<kDoFp8StaticQuant, Fp8StaticQuantKargs, EmptyKargs<2>>
{
const int32_t* seqlen_k_ptr;
ck_tile::index_t batch_stride_q;
ck_tile::index_t batch_stride_k;
ck_tile::index_t batch_stride_v;
@@ -220,9 +221,9 @@ struct FmhaFwdSplitKVKernel
void* lse_acc_ptr,
void* o_acc_ptr,
ck_tile::index_t batch,
ck_tile::index_t max_seqlen_q,
ck_tile::index_t seqlen_q,
ck_tile::index_t seqlen_k,
ck_tile::index_t seqlen_k, // only used if 'seqlen_k_ptr' is not specified
const void* seqlen_k_ptr, // only used for (paged-) kvcache
ck_tile::index_t hdim_q,
ck_tile::index_t hdim_v,
ck_tile::index_t num_head_q,
@@ -262,7 +263,6 @@ struct FmhaFwdSplitKVKernel
lse_acc_ptr,
o_acc_ptr,
batch,
max_seqlen_q,
seqlen_q,
seqlen_k,
hdim_q,
@@ -294,6 +294,7 @@ struct FmhaFwdSplitKVKernel
{}, // placeholder for bias
{}, // placeholder for mask
{}, // placeholder for fp8_static_quant args
reinterpret_cast<const int32_t*>(seqlen_k_ptr),
batch_stride_q,
batch_stride_k,
batch_stride_v};
@@ -333,7 +334,6 @@ struct FmhaFwdSplitKVKernel
void* lse_acc_ptr,
void* o_acc_ptr,
ck_tile::index_t batch,
ck_tile::index_t max_seqlen_q,
const void* seqstart_q_ptr,
const void* seqstart_k_ptr,
const void* seqlen_k_ptr,
@@ -374,9 +374,8 @@ struct FmhaFwdSplitKVKernel
lse_acc_ptr,
o_acc_ptr,
batch,
max_seqlen_q,
-1, // seqlen will be updated by another pointer
-1, //
-1, // seqlen_q will be updated by another pointer
-1, // seqlen_k will be updated by another pointer
hdim_q,
hdim_v,
num_head_q,
@@ -496,8 +495,7 @@ struct FmhaFwdSplitKVKernel
}
// get real # queries & # keys under group mode
const auto adjusted_seqstart_q_ptr = kargs.seqstart_q_ptr + i_batch;
kargs.seqlen_q = adjusted_seqstart_q_ptr[1] - adjusted_seqstart_q_ptr[0];
kargs.seqlen_q = kargs.seqstart_q_ptr[i_batch + 1] - kargs.seqstart_q_ptr[i_batch];
// # of required blocks is different in each groups, terminate unnecessary blocks
// earlier
@@ -512,8 +510,7 @@ struct FmhaFwdSplitKVKernel
}
else
{
const auto adjusted_seqstart_k_ptr = kargs.seqstart_k_ptr + i_batch;
kargs.seqlen_k = adjusted_seqstart_k_ptr[1] - adjusted_seqstart_k_ptr[0];
kargs.seqlen_k = kargs.seqstart_k_ptr[i_batch + 1] - kargs.seqstart_k_ptr[i_batch];
}
}
else
@@ -526,6 +523,11 @@ struct FmhaFwdSplitKVKernel
{
batch_offset_bias = static_cast<long_index_t>(i_batch) * kargs.batch_stride_bias;
}
if(kargs.seqlen_k_ptr != nullptr)
{
kargs.seqlen_k = kargs.seqlen_k_ptr[i_batch];
}
}
auto k_page_block_navigator = [&, i_batch_ = i_batch, i_nhead_ = i_nhead]() {

View File

@@ -27,7 +27,6 @@ struct BlockFmhaFwdAppendKVPipeline
static constexpr index_t kK0 = Problem::kK0;
static constexpr index_t kN1 = Problem::kN1;
static constexpr bool kIsGroupMode = Problem::kIsGroupMode;
static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ;
static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK;
static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ;

View File

@@ -139,7 +139,6 @@ template <typename QDataType_,
index_t kK0_,
index_t kN1_,
bool IsVLayoutRowMajor_,
bool kIsGroupMode_,
typename Traits_>
struct BlockFmhaFwdAppendKVPipelineProblem
{
@@ -149,7 +148,6 @@ struct BlockFmhaFwdAppendKVPipelineProblem
using Traits = remove_cvref_t<Traits_>;
static constexpr index_t kBlockSize = 256;
static constexpr bool kIsGroupMode = kIsGroupMode_;
static constexpr index_t kM0 = kM0_;
static constexpr index_t kN0 = kN0_;