Hotfix LDS data hazard in fused attention (#360)

* avoid LDS data hazard in gemm_softmax_gemm pipeline

* trivial refactors

* comments

* shrink blockwise gemm v2 thread buffer size

* reclaim A block lds space when during 2nd gemm

* amend

* amend
This commit is contained in:
Anthony Chang
2022-08-16 01:04:20 +08:00
committed by GitHub
parent 53ea4713af
commit c961ce9226
4 changed files with 88 additions and 69 deletions

View File

@@ -701,9 +701,7 @@ struct BlockwiseGemmXdlops_v2
const auto waveId_m = wave_idx[I0];
const auto waveId_n = wave_idx[I1];
const auto tmp = xdlops_gemm.GetBeginOfThreadBlk(xdlops_i, blk_i);
const auto blk_idx =
TransposeC ? make_multi_index(tmp[I1], tmp[I0]) : make_multi_index(tmp[I0], tmp[I1]);
const auto blk_idx = xdlops_gemm.GetBeginOfThreadBlk(xdlops_i, blk_i);
constexpr auto mrepeat_mwave_mperxdl_to_m_adaptor = make_single_stage_tensor_adaptor(
make_tuple(make_unmerge_transform(make_tuple(MRepeat, MWaves, MPerXDL))),
@@ -922,13 +920,13 @@ struct BlockwiseGemmXdlops_v2
}
protected:
// A[M0, M1, M2, KPerThread]
// A[M0, M1, M2, KPack]
static constexpr auto a_thread_desc_ =
make_naive_tensor_descriptor_packed(make_tuple(I1, I1, I1, Number<KPerThread>{}));
make_naive_tensor_descriptor_packed(make_tuple(I1, I1, I1, Number<KPack>{}));
// B[N0, N1, N2, KPerThread]
// B[N0, N1, N2, KPack]
static constexpr auto b_thread_desc_ =
make_naive_tensor_descriptor_packed(make_tuple(I1, I1, I1, Number<KPerThread>{}));
make_naive_tensor_descriptor_packed(make_tuple(I1, I1, I1, Number<KPack>{}));
// C[M, N, NumRegXdlops]
static constexpr auto c_thread_desc_ = make_naive_tensor_descriptor_packed(