Add batch_stride_k/batch_stride_v in group mode

This commit is contained in:
PoYen, Chen
2024-08-06 08:02:43 +00:00
parent faf6b0e8ab
commit bd0d2f3975

View File

@@ -237,6 +237,9 @@ struct FmhaFwdSplitKVKernel
const int32_t* seqstart_q_ptr;
const int32_t* seqstart_k_ptr;
const int32_t* seqlen_k_ptr;
ck_tile::index_t batch_stride_k;
ck_tile::index_t batch_stride_v;
};
using Kargs = std::conditional_t<kIsGroupMode, GroupModeKargs, BatchModeKargs>;
@@ -408,6 +411,8 @@ struct FmhaFwdSplitKVKernel
ck_tile::index_t nhead_stride_randval,
ck_tile::index_t nhead_stride_lse_acc,
ck_tile::index_t nhead_stride_o_acc,
ck_tile::index_t batch_stride_k,
ck_tile::index_t batch_stride_v,
ck_tile::index_t batch_stride_lse_acc,
ck_tile::index_t batch_stride_o_acc,
ck_tile::index_t split_stride_lse_acc,
@@ -460,7 +465,9 @@ struct FmhaFwdSplitKVKernel
{}, // placeholder for dropout
reinterpret_cast<const int32_t*>(seqstart_q_ptr),
reinterpret_cast<const int32_t*>(seqstart_k_ptr),
reinterpret_cast<const int32_t*>(seqlen_k_ptr)};
reinterpret_cast<const int32_t*>(seqlen_k_ptr),
batch_stride_k,
batch_stride_v};
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
{
@@ -582,16 +589,8 @@ struct FmhaFwdSplitKVKernel
else
{
batch_offset_q = static_cast<long_index_t>(i_batch) * kargs.batch_stride_q;
if constexpr(kIsPagedKV)
{
batch_offset_k = static_cast<long_index_t>(i_batch) * kargs.batch_stride_k;
batch_offset_v = static_cast<long_index_t>(i_batch) * kargs.batch_stride_v;
}
else
{
batch_offset_k = static_cast<long_index_t>(i_batch) * kargs.batch_stride_k;
batch_offset_v = static_cast<long_index_t>(i_batch) * kargs.batch_stride_v;
}
batch_offset_k = static_cast<long_index_t>(i_batch) * kargs.batch_stride_k;
batch_offset_v = static_cast<long_index_t>(i_batch) * kargs.batch_stride_v;
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
{