Add batched attention special kernel instances (#424)

* sanity check

* add attribution

* add irrgular k tile size for batched attention

* format
This commit is contained in:
Anthony Chang
2022-09-20 08:20:54 +08:00
committed by GitHub
parent c6b8b472a7
commit 7c788e10ce
5 changed files with 51 additions and 10 deletions

View File

@@ -881,9 +881,10 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
FloatGemmAcc c_new =
(running_sum[iM] * math::exp(running_max[iM] - running_max_new[iM]) * c +
math::exp(max[iM] - running_max_new[iM]) * acc1) /
running_sum_new[iM]; // O_new
running_sum_new[iM]; // Formula by Dao et al.,
// https://arxiv.org/pdf/2205.14135v2.pdf section 3.1
c_thread_buf(I) = c_new;
c_thread_buf(I) = c_new; // O_new
});
});