mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-11 08:50:17 +00:00
Add batched attention special kernel instances (#424)
* sanity check * add attribution * add irrgular k tile size for batched attention * format
This commit is contained in:
@@ -649,6 +649,9 @@ struct BlockwiseGemmXdlops_v2
|
||||
static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerXDL);
|
||||
static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerXDL);
|
||||
|
||||
static_assert(KPerThread % KPack == 0,
|
||||
"Wrong KPack setting; try increasing KPerThread or decreasing KPack");
|
||||
|
||||
StaticBufferTupleOfVector<AddressSpaceEnum::Vgpr,
|
||||
FloatAcc,
|
||||
MRepeat * NRepeat,
|
||||
|
||||
@@ -881,9 +881,10 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
|
||||
FloatGemmAcc c_new =
|
||||
(running_sum[iM] * math::exp(running_max[iM] - running_max_new[iM]) * c +
|
||||
math::exp(max[iM] - running_max_new[iM]) * acc1) /
|
||||
running_sum_new[iM]; // O_new
|
||||
running_sum_new[iM]; // Formula by Dao et al.,
|
||||
// https://arxiv.org/pdf/2205.14135v2.pdf section 3.1
|
||||
|
||||
c_thread_buf(I) = c_new;
|
||||
c_thread_buf(I) = c_new; // O_new
|
||||
});
|
||||
});
|
||||
|
||||
|
||||
Reference in New Issue
Block a user