mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-04 13:41:24 +00:00
Hotfix LDS data hazard in fused attention (#360)
* avoid LDS data hazard in gemm_softmax_gemm pipeline * trivial refactors * comments * shrink blockwise gemm v2 thread buffer size * reclaim A block lds space when during 2nd gemm * amend * amend
This commit is contained in:
@@ -701,9 +701,7 @@ struct BlockwiseGemmXdlops_v2
|
||||
const auto waveId_m = wave_idx[I0];
|
||||
const auto waveId_n = wave_idx[I1];
|
||||
|
||||
const auto tmp = xdlops_gemm.GetBeginOfThreadBlk(xdlops_i, blk_i);
|
||||
const auto blk_idx =
|
||||
TransposeC ? make_multi_index(tmp[I1], tmp[I0]) : make_multi_index(tmp[I0], tmp[I1]);
|
||||
const auto blk_idx = xdlops_gemm.GetBeginOfThreadBlk(xdlops_i, blk_i);
|
||||
|
||||
constexpr auto mrepeat_mwave_mperxdl_to_m_adaptor = make_single_stage_tensor_adaptor(
|
||||
make_tuple(make_unmerge_transform(make_tuple(MRepeat, MWaves, MPerXDL))),
|
||||
@@ -922,13 +920,13 @@ struct BlockwiseGemmXdlops_v2
|
||||
}
|
||||
|
||||
protected:
|
||||
// A[M0, M1, M2, KPerThread]
|
||||
// A[M0, M1, M2, KPack]
|
||||
static constexpr auto a_thread_desc_ =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(I1, I1, I1, Number<KPerThread>{}));
|
||||
make_naive_tensor_descriptor_packed(make_tuple(I1, I1, I1, Number<KPack>{}));
|
||||
|
||||
// B[N0, N1, N2, KPerThread]
|
||||
// B[N0, N1, N2, KPack]
|
||||
static constexpr auto b_thread_desc_ =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(I1, I1, I1, Number<KPerThread>{}));
|
||||
make_naive_tensor_descriptor_packed(make_tuple(I1, I1, I1, Number<KPack>{}));
|
||||
|
||||
// C[M, N, NumRegXdlops]
|
||||
static constexpr auto c_thread_desc_ = make_naive_tensor_descriptor_packed(
|
||||
|
||||
Reference in New Issue
Block a user