mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-05 06:01:23 +00:00
[CK_TILE] fmha fwd splitkv optimization for decode (seqlen_q=1) (#1789)
* Update license year * Add initial code to override decode problem * Fix splitkv traits/args overriding error * Reshape and transpose lse for decode * Remove debug code * Prettify example code * Use better function name * Add kMergeNumHeadGroupsSeqLenQ flag Kernel user can use this switch to turn on/off optimization for some problem sizes * Add missing flag declarations * Default turn off kMergeNumHeadGroupsSeqLenQ in codegen * Group similar statements together * Remove assumption of seqlen_q=1 * Remove kMergeNumHeadGroupsSeqLenQ from splitkv combine kernel * Support kMergeNumHeadGroupsSeqLenQ=true in fmha splitkv kernel * Run kMergeNumHeadGroupsSeqLenQ=true kernels when need * Fix group mode block skip logics * Undo changes of normal fwd kernel * Update in GridSize() and using GridSize() for splitkv kernel (#1799) --------- Co-authored-by: Qianfeng <qianfeng.zhang@amd.com>
This commit is contained in:
@@ -47,10 +47,16 @@ struct FmhaFwdSplitKVKernel
|
||||
static constexpr bool kStoreLSE = FmhaPipeline::kStoreLSE;
|
||||
static constexpr bool kDoFp8StaticQuant = FmhaPipeline::Problem::kDoFp8StaticQuant;
|
||||
static constexpr bool kIsPagedKV = FmhaPipeline::Problem::kIsPagedKV;
|
||||
static constexpr bool kMergeNumHeadGroupsSeqLenQ =
|
||||
FmhaPipeline::Problem::kMergeNumHeadGroupsSeqLenQ;
|
||||
|
||||
using FmhaMask = ck_tile::remove_cvref_t<typename FmhaPipeline::FmhaMask>;
|
||||
static constexpr bool kHasMask = FmhaMask::IsMasking;
|
||||
|
||||
static_assert(!kMergeNumHeadGroupsSeqLenQ ||
|
||||
(kMergeNumHeadGroupsSeqLenQ && BiasEnum == BlockAttentionBiasEnum::NO_BIAS &&
|
||||
!kHasMask));
|
||||
|
||||
// clang-format off
|
||||
template <typename T> struct t2s;
|
||||
template <> struct t2s<float> { static constexpr const char * name = "fp32"; };
|
||||
@@ -476,15 +482,20 @@ struct FmhaFwdSplitKVKernel
|
||||
}
|
||||
|
||||
CK_TILE_HOST static constexpr auto GridSize(ck_tile::index_t batch_size,
|
||||
ck_tile::index_t nhead,
|
||||
ck_tile::index_t nhead_q,
|
||||
ck_tile::index_t nhead_kv,
|
||||
ck_tile::index_t max_seqlen_q,
|
||||
ck_tile::index_t hdim_v,
|
||||
ck_tile::index_t num_splits)
|
||||
{
|
||||
ck_tile::index_t nhead_ = kMergeNumHeadGroupsSeqLenQ ? nhead_kv : nhead_q;
|
||||
ck_tile::index_t max_seqlen_q_ =
|
||||
max_seqlen_q * (kMergeNumHeadGroupsSeqLenQ ? nhead_q / nhead_kv : 1);
|
||||
|
||||
// TODO: this may need tuning
|
||||
return dim3(ck_tile::integer_divide_ceil(max_seqlen_q, FmhaPipeline::kM0) *
|
||||
return dim3(ck_tile::integer_divide_ceil(max_seqlen_q_, FmhaPipeline::kM0) *
|
||||
ck_tile::integer_divide_ceil(hdim_v, FmhaPipeline::kN1) * num_splits,
|
||||
nhead,
|
||||
nhead_,
|
||||
batch_size);
|
||||
}
|
||||
|
||||
@@ -562,7 +573,7 @@ struct FmhaFwdSplitKVKernel
|
||||
|
||||
// # of required blocks is different in each groups, terminate unnecessary blocks
|
||||
// earlier
|
||||
if(kargs.seqlen_q <= i_m0)
|
||||
if(kargs.seqlen_q * (kMergeNumHeadGroupsSeqLenQ ? kargs.nhead_ratio_qk : 1) <= i_m0)
|
||||
{
|
||||
return;
|
||||
}
|
||||
@@ -617,30 +628,60 @@ struct FmhaFwdSplitKVKernel
|
||||
}
|
||||
|
||||
// for simplicity, batch stride we just modify the pointer
|
||||
const index_t i_nhead_k =
|
||||
(kMergeNumHeadGroupsSeqLenQ ? i_nhead : i_nhead / kargs.nhead_ratio_qk);
|
||||
|
||||
const QDataType* q_ptr = reinterpret_cast<const QDataType*>(kargs.q_ptr) +
|
||||
static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_q +
|
||||
static_cast<long_index_t>(i_nhead) *
|
||||
(kMergeNumHeadGroupsSeqLenQ ? kargs.nhead_ratio_qk : 1) *
|
||||
kargs.nhead_stride_q +
|
||||
batch_offset_q;
|
||||
const KDataType* k_ptr =
|
||||
reinterpret_cast<const KDataType*>(kargs.k_ptr) +
|
||||
static_cast<long_index_t>(i_nhead / kargs.nhead_ratio_qk) * kargs.nhead_stride_k +
|
||||
batch_offset_k;
|
||||
const VDataType* v_ptr =
|
||||
reinterpret_cast<const VDataType*>(kargs.v_ptr) +
|
||||
static_cast<long_index_t>(i_nhead / kargs.nhead_ratio_qk) * kargs.nhead_stride_v +
|
||||
batch_offset_v;
|
||||
const KDataType* k_ptr = reinterpret_cast<const KDataType*>(kargs.k_ptr) +
|
||||
static_cast<long_index_t>(i_nhead_k) * kargs.nhead_stride_k +
|
||||
batch_offset_k;
|
||||
const VDataType* v_ptr = reinterpret_cast<const VDataType*>(kargs.v_ptr) +
|
||||
static_cast<long_index_t>(i_nhead_k) * kargs.nhead_stride_v +
|
||||
batch_offset_v;
|
||||
|
||||
ODataType* o_acc_ptr = reinterpret_cast<ODataType*>(kargs.o_acc_ptr) +
|
||||
static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_o_acc +
|
||||
static_cast<long_index_t>(i_nhead) *
|
||||
(kMergeNumHeadGroupsSeqLenQ ? kargs.nhead_ratio_qk : 1) *
|
||||
kargs.nhead_stride_o_acc +
|
||||
batch_offset_o_acc + i_split * kargs.split_stride_o_acc;
|
||||
|
||||
// Q/K/V DRAM and DRAM window
|
||||
const auto q_dram = [&]() {
|
||||
const auto q_dram_naive = make_naive_tensor_view<address_space_enum::global>(
|
||||
q_ptr,
|
||||
make_tuple(kargs.seqlen_q, kargs.hdim_q),
|
||||
make_tuple(kargs.stride_q, 1),
|
||||
number<FmhaPipeline::kAlignmentQ>{},
|
||||
number<1>{});
|
||||
const auto q_dram = [&] {
|
||||
const auto q_dram_naive = [&] {
|
||||
if constexpr(kMergeNumHeadGroupsSeqLenQ)
|
||||
{
|
||||
// reshape: (nhead_ratio_qk, seqlen_q, hdim_q) -> (nhead_ratio_qk * seqlen_q,
|
||||
// hdim_q)
|
||||
const auto view = make_naive_tensor_view<address_space_enum::global>(
|
||||
q_ptr,
|
||||
make_tuple(kargs.nhead_ratio_qk, kargs.seqlen_q, kargs.hdim_q),
|
||||
make_tuple(kargs.nhead_stride_q, kargs.stride_q, 1),
|
||||
number<FmhaPipeline::kAlignmentQ>{},
|
||||
number<1>{});
|
||||
|
||||
return transform_tensor_view(
|
||||
view,
|
||||
make_tuple(
|
||||
make_merge_transform(make_tuple(kargs.nhead_ratio_qk, kargs.seqlen_q)),
|
||||
make_pass_through_transform(kargs.hdim_q)),
|
||||
make_tuple(sequence<0, 1>{}, sequence<2>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_naive_tensor_view<address_space_enum::global>(
|
||||
q_ptr,
|
||||
make_tuple(kargs.seqlen_q, kargs.hdim_q),
|
||||
make_tuple(kargs.stride_q, 1),
|
||||
number<FmhaPipeline::kAlignmentQ>{},
|
||||
number<1>{});
|
||||
}
|
||||
}();
|
||||
|
||||
if constexpr(FmhaPipeline::kQLoadOnce)
|
||||
{
|
||||
return pad_tensor_view(
|
||||
@@ -729,7 +770,7 @@ struct FmhaFwdSplitKVKernel
|
||||
}
|
||||
}();
|
||||
|
||||
auto k_page_block_navigator = [&, i_batch_ = i_batch, i_nhead_ = i_nhead]() {
|
||||
auto k_page_block_navigator = [&, i_batch_ = i_batch]() {
|
||||
if constexpr(kIsPagedKV)
|
||||
{
|
||||
const auto* block_indices =
|
||||
@@ -739,8 +780,7 @@ struct FmhaFwdSplitKVKernel
|
||||
integer_divide_ceil(kv_l2p_offset + kargs.seqlen_k, kargs.page_block_size);
|
||||
|
||||
const long_index_t fixed_offset =
|
||||
static_cast<long_index_t>(i_nhead_ / kargs.nhead_ratio_qk) *
|
||||
kargs.nhead_stride_k;
|
||||
static_cast<long_index_t>(i_nhead_k) * kargs.nhead_stride_k;
|
||||
|
||||
return make_page_block_navigator<const KDataType, 0>(
|
||||
kargs.k_ptr,
|
||||
@@ -760,7 +800,7 @@ struct FmhaFwdSplitKVKernel
|
||||
}
|
||||
}();
|
||||
|
||||
auto v_page_block_navigator = [&, i_batch_ = i_batch, i_nhead_ = i_nhead]() {
|
||||
auto v_page_block_navigator = [&, i_batch_ = i_batch]() {
|
||||
if constexpr(kIsPagedKV)
|
||||
{
|
||||
const auto* block_indices =
|
||||
@@ -770,8 +810,7 @@ struct FmhaFwdSplitKVKernel
|
||||
integer_divide_ceil(kv_l2p_offset + kargs.seqlen_k, kargs.page_block_size);
|
||||
|
||||
const long_index_t fixed_offset =
|
||||
static_cast<long_index_t>(i_nhead_ / kargs.nhead_ratio_qk) *
|
||||
kargs.nhead_stride_v;
|
||||
static_cast<long_index_t>(i_nhead_k) * kargs.nhead_stride_v;
|
||||
|
||||
return make_page_block_navigator<const VDataType, 1>(
|
||||
kargs.v_ptr,
|
||||
@@ -842,19 +881,40 @@ struct FmhaFwdSplitKVKernel
|
||||
// lse acc
|
||||
auto lse_acc_dram_window = [&, i_nhead_ = i_nhead, i_split_ = i_split]() {
|
||||
constexpr auto lse_acc_dram_window_lengths = make_tuple(number<FmhaPipeline::kM0>{});
|
||||
LSEDataType* lse_acc_ptr =
|
||||
reinterpret_cast<LSEDataType*>(kargs.lse_acc_ptr) +
|
||||
static_cast<long_index_t>(i_nhead_) * kargs.nhead_stride_lse_acc +
|
||||
batch_offset_lse_acc + i_split_ * kargs.split_stride_lse_acc;
|
||||
LSEDataType* lse_acc_ptr = reinterpret_cast<LSEDataType*>(kargs.lse_acc_ptr) +
|
||||
static_cast<long_index_t>(i_nhead_) *
|
||||
(kMergeNumHeadGroupsSeqLenQ ? kargs.nhead_ratio_qk : 1) *
|
||||
kargs.nhead_stride_lse_acc +
|
||||
batch_offset_lse_acc + i_split_ * kargs.split_stride_lse_acc;
|
||||
|
||||
const auto lse_acc_dram = [&]() {
|
||||
const auto lse_acc_dram_naive =
|
||||
make_naive_tensor_view<address_space_enum::global>(lse_acc_ptr,
|
||||
make_tuple(kargs.seqlen_q),
|
||||
make_tuple(1),
|
||||
number<1>{},
|
||||
number<1>{});
|
||||
const auto lse_acc_dram = [&] {
|
||||
const auto lse_acc_dram_naive = [&] {
|
||||
if constexpr(kMergeNumHeadGroupsSeqLenQ)
|
||||
{
|
||||
// reshape: (nhead_ratio_qk, seqlen_q) -> (nhead_ratio_qk * seqlen_q)
|
||||
const auto view = make_naive_tensor_view<address_space_enum::global>(
|
||||
lse_acc_ptr,
|
||||
make_tuple(kargs.nhead_ratio_qk, kargs.seqlen_q),
|
||||
make_tuple(kargs.nhead_stride_lse_acc, 1),
|
||||
number<1>{},
|
||||
number<1>{});
|
||||
|
||||
return transform_tensor_view(view,
|
||||
make_tuple(make_merge_transform(make_tuple(
|
||||
kargs.nhead_ratio_qk, kargs.seqlen_q))),
|
||||
make_tuple(sequence<0, 1>{}),
|
||||
make_tuple(sequence<0>{}));
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_naive_tensor_view<address_space_enum::global>(
|
||||
lse_acc_ptr,
|
||||
make_tuple(kargs.seqlen_q),
|
||||
make_tuple(1),
|
||||
number<1>{},
|
||||
number<1>{});
|
||||
}
|
||||
}();
|
||||
return pad_tensor_view(
|
||||
lse_acc_dram_naive, lse_acc_dram_window_lengths, sequence<kPadSeqLenQ>{});
|
||||
}();
|
||||
@@ -953,13 +1013,37 @@ struct FmhaFwdSplitKVKernel
|
||||
}();
|
||||
|
||||
// Oacc DRAM and Oacc DRAM window
|
||||
auto o_acc_dram = [&]() {
|
||||
const auto o_acc_dram_naive = make_naive_tensor_view<address_space_enum::global>(
|
||||
o_acc_ptr,
|
||||
make_tuple(kargs.seqlen_q, kargs.hdim_v),
|
||||
make_tuple(kargs.stride_o_acc, 1),
|
||||
number<FmhaPipeline::kAlignmentOacc>{},
|
||||
number<1>{});
|
||||
auto o_acc_dram = [&] {
|
||||
const auto o_acc_dram_naive = [&] {
|
||||
if constexpr(kMergeNumHeadGroupsSeqLenQ)
|
||||
{
|
||||
// reshape: (nhead_ratio_qk, seqlen_q, hdim_v) -> (nhead_ratio_qk * seqlen_q,
|
||||
// hdim_v)
|
||||
const auto view = make_naive_tensor_view<address_space_enum::global>(
|
||||
o_acc_ptr,
|
||||
make_tuple(kargs.nhead_ratio_qk, kargs.seqlen_q, kargs.hdim_v),
|
||||
make_tuple(kargs.nhead_stride_o_acc, kargs.stride_o_acc, 1),
|
||||
number<FmhaPipeline::kAlignmentOacc>{},
|
||||
number<1>{});
|
||||
|
||||
return transform_tensor_view(
|
||||
view,
|
||||
make_tuple(
|
||||
make_merge_transform(make_tuple(kargs.nhead_ratio_qk, kargs.seqlen_q)),
|
||||
make_pass_through_transform(kargs.hdim_v)),
|
||||
make_tuple(sequence<0, 1>{}, sequence<2>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_naive_tensor_view<address_space_enum::global>(
|
||||
o_acc_ptr,
|
||||
make_tuple(kargs.seqlen_q, kargs.hdim_v),
|
||||
make_tuple(kargs.stride_o_acc, 1),
|
||||
number<FmhaPipeline::kAlignmentOacc>{},
|
||||
number<1>{});
|
||||
}
|
||||
}();
|
||||
|
||||
return pad_tensor_view(
|
||||
o_acc_dram_naive,
|
||||
|
||||
Reference in New Issue
Block a user