mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-08 15:30:23 +00:00
Support cache_batch_idx in example
This commit is contained in:
@@ -249,14 +249,26 @@ struct FmhaFwdAppendKVKernel
|
||||
const index_t i_m0 = __builtin_amdgcn_readfirstlane(i_tile * FmhaPipeline::kM0);
|
||||
const index_t i_n0 = __builtin_amdgcn_readfirstlane(i_tile * FmhaPipeline::kN0);
|
||||
|
||||
const index_t i_cache_batch = [&, i_batch_ = i_batch] {
|
||||
if constexpr(kIsPagedKV)
|
||||
{
|
||||
return i_batch_;
|
||||
}
|
||||
else
|
||||
{
|
||||
return (kargs.cache_batch_idx != nullptr ? kargs.cache_batch_idx[i_batch_]
|
||||
: i_batch_);
|
||||
}
|
||||
}();
|
||||
|
||||
const long_index_t batch_offset_q =
|
||||
static_cast<long_index_t>(i_batch) * kargs.batch_stride_q;
|
||||
const long_index_t batch_offset_k =
|
||||
static_cast<long_index_t>(i_batch) * kargs.batch_stride_k;
|
||||
static_cast<long_index_t>(i_cache_batch) * kargs.batch_stride_k;
|
||||
const long_index_t batch_offset_knew =
|
||||
static_cast<long_index_t>(i_batch) * kargs.batch_stride_knew;
|
||||
const long_index_t batch_offset_v =
|
||||
static_cast<long_index_t>(i_batch) * kargs.batch_stride_v;
|
||||
static_cast<long_index_t>(i_cache_batch) * kargs.batch_stride_v;
|
||||
const long_index_t batch_offset_vnew =
|
||||
static_cast<long_index_t>(i_batch) * kargs.batch_stride_vnew;
|
||||
|
||||
|
||||
@@ -529,9 +529,21 @@ struct FmhaFwdSplitKVKernel
|
||||
}
|
||||
else
|
||||
{
|
||||
const index_t i_cache_batch = [&, i_batch_ = i_batch] {
|
||||
if constexpr(kIsPagedKV)
|
||||
{
|
||||
return i_batch_;
|
||||
}
|
||||
else
|
||||
{
|
||||
return (kargs.cache_batch_idx != nullptr ? kargs.cache_batch_idx[i_batch_]
|
||||
: i_batch_);
|
||||
}
|
||||
}();
|
||||
|
||||
batch_offset_q = static_cast<long_index_t>(i_batch) * kargs.batch_stride_q;
|
||||
batch_offset_k = static_cast<long_index_t>(i_batch) * kargs.batch_stride_k;
|
||||
batch_offset_v = static_cast<long_index_t>(i_batch) * kargs.batch_stride_v;
|
||||
batch_offset_k = static_cast<long_index_t>(i_cache_batch) * kargs.batch_stride_k;
|
||||
batch_offset_v = static_cast<long_index_t>(i_cache_batch) * kargs.batch_stride_v;
|
||||
|
||||
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user