Add batched attention special kernel instances (#424)

* sanity check

* add attribution

* add irrgular k tile size for batched attention

* format
This commit is contained in:
Anthony Chang
2022-09-20 08:20:54 +08:00
committed by GitHub
parent c6b8b472a7
commit 7c788e10ce
5 changed files with 51 additions and 10 deletions

View File

@@ -649,6 +649,9 @@ struct BlockwiseGemmXdlops_v2
static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerXDL);
static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerXDL);
static_assert(KPerThread % KPack == 0,
"Wrong KPack setting; try increasing KPerThread or decreasing KPack");
StaticBufferTupleOfVector<AddressSpaceEnum::Vgpr,
FloatAcc,
MRepeat * NRepeat,

View File

@@ -881,9 +881,10 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
FloatGemmAcc c_new =
(running_sum[iM] * math::exp(running_max[iM] - running_max_new[iM]) * c +
math::exp(max[iM] - running_max_new[iM]) * acc1) /
running_sum_new[iM]; // O_new
running_sum_new[iM]; // Formula by Dao et al.,
// https://arxiv.org/pdf/2205.14135v2.pdf section 3.1
c_thread_buf(I) = c_new;
c_thread_buf(I) = c_new; // O_new
});
});