mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 22:39:03 +00:00
Hotfix LDS data hazard in fused attention (#360)
* avoid LDS data hazard in gemm_softmax_gemm pipeline * trivial refactors * comments * shrink blockwise gemm v2 thread buffer size * reclaim A block lds space when during 2nd gemm * amend * amend
This commit is contained in:
@@ -701,9 +701,7 @@ struct BlockwiseGemmXdlops_v2
|
||||
const auto waveId_m = wave_idx[I0];
|
||||
const auto waveId_n = wave_idx[I1];
|
||||
|
||||
const auto tmp = xdlops_gemm.GetBeginOfThreadBlk(xdlops_i, blk_i);
|
||||
const auto blk_idx =
|
||||
TransposeC ? make_multi_index(tmp[I1], tmp[I0]) : make_multi_index(tmp[I0], tmp[I1]);
|
||||
const auto blk_idx = xdlops_gemm.GetBeginOfThreadBlk(xdlops_i, blk_i);
|
||||
|
||||
constexpr auto mrepeat_mwave_mperxdl_to_m_adaptor = make_single_stage_tensor_adaptor(
|
||||
make_tuple(make_unmerge_transform(make_tuple(MRepeat, MWaves, MPerXDL))),
|
||||
@@ -922,13 +920,13 @@ struct BlockwiseGemmXdlops_v2
|
||||
}
|
||||
|
||||
protected:
|
||||
// A[M0, M1, M2, KPerThread]
|
||||
// A[M0, M1, M2, KPack]
|
||||
static constexpr auto a_thread_desc_ =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(I1, I1, I1, Number<KPerThread>{}));
|
||||
make_naive_tensor_descriptor_packed(make_tuple(I1, I1, I1, Number<KPack>{}));
|
||||
|
||||
// B[N0, N1, N2, KPerThread]
|
||||
// B[N0, N1, N2, KPack]
|
||||
static constexpr auto b_thread_desc_ =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(I1, I1, I1, Number<KPerThread>{}));
|
||||
make_naive_tensor_descriptor_packed(make_tuple(I1, I1, I1, Number<KPack>{}));
|
||||
|
||||
// C[M, N, NumRegXdlops]
|
||||
static constexpr auto c_thread_desc_ = make_naive_tensor_descriptor_packed(
|
||||
|
||||
@@ -181,36 +181,16 @@ struct GridwiseBatchedGemmGemm_Xdl_CShuffle
|
||||
|
||||
__host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
|
||||
{
|
||||
// LDS allocation for A and B: be careful of alignment
|
||||
constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
|
||||
constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
|
||||
constexpr auto b1_block_desc_bk0_n_bk1 = GetB1BlockDescriptor_BK0PerBlock_NPerBlock_BK1();
|
||||
const index_t gemm0_bytes_end = (SharedMemTrait::a_block_space_size_aligned +
|
||||
SharedMemTrait::b_block_space_size_aligned) *
|
||||
sizeof(FloatAB);
|
||||
const index_t gemm1_bytes_end =
|
||||
(SharedMemTrait::b1_block_space_offset + SharedMemTrait::b1_block_space_size_aligned) *
|
||||
sizeof(FloatAB);
|
||||
const index_t c_block_bytes_end =
|
||||
SharedMemTrait::c_block_space_size * sizeof(FloatCShuffle);
|
||||
|
||||
// lds max alignment
|
||||
constexpr auto max_lds_align = math::lcm(math::lcm(AK1, BK1), B1K1);
|
||||
|
||||
constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
|
||||
a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
|
||||
|
||||
constexpr auto b0_block_space_size_aligned = math::integer_least_multiple(
|
||||
b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align);
|
||||
|
||||
constexpr auto b1_block_space_size_aligned = math::integer_least_multiple(
|
||||
b1_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align);
|
||||
|
||||
constexpr auto b_block_space_size_aligned =
|
||||
math::max(b0_block_space_size_aligned.value, b1_block_space_size_aligned.value);
|
||||
|
||||
// LDS allocation for C shuffle in LDS
|
||||
constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
|
||||
GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock();
|
||||
|
||||
constexpr auto c_block_size =
|
||||
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize();
|
||||
|
||||
return math::max((a_block_space_size_aligned + b_block_space_size_aligned) *
|
||||
sizeof(FloatAB),
|
||||
c_block_size * sizeof(FloatCShuffle));
|
||||
return math::max(gemm0_bytes_end, gemm1_bytes_end, c_block_bytes_end);
|
||||
}
|
||||
|
||||
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
|
||||
@@ -312,6 +292,36 @@ struct GridwiseBatchedGemmGemm_Xdl_CShuffle
|
||||
using DefaultBlock2CTileMap =
|
||||
remove_cvref_t<decltype(MakeDefaultBlock2CTileMap(CGridDesc_M_N{}))>;
|
||||
|
||||
struct SharedMemTrait
|
||||
{
|
||||
// LDS allocation for A and B: be careful of alignment
|
||||
static constexpr auto a_block_desc_ak0_m_ak1 =
|
||||
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
|
||||
static constexpr auto b_block_desc_bk0_n_bk1 =
|
||||
GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
|
||||
static constexpr auto b1_block_desc_bk0_n_bk1 =
|
||||
GetB1BlockDescriptor_BK0PerBlock_NPerBlock_BK1();
|
||||
|
||||
static constexpr auto max_lds_align = math::lcm(math::lcm(AK1, BK1), B1K1);
|
||||
|
||||
static constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
|
||||
a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
|
||||
static constexpr auto b_block_space_size_aligned = math::integer_least_multiple(
|
||||
b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align);
|
||||
static constexpr auto b1_block_space_size_aligned = math::integer_least_multiple(
|
||||
b1_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align);
|
||||
|
||||
static constexpr auto a_block_space_offset = 0;
|
||||
static constexpr auto b_block_space_offset = a_block_space_size_aligned.value;
|
||||
static constexpr auto b1_block_space_offset = 0;
|
||||
|
||||
// LDS allocation for C shuffle in LDS
|
||||
static constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
|
||||
GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock();
|
||||
static constexpr auto c_block_space_size =
|
||||
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize();
|
||||
};
|
||||
|
||||
template <bool HasMainKBlockLoop, typename Block2CTileMap>
|
||||
__device__ static void Run(const FloatAB* __restrict__ p_a_grid,
|
||||
const FloatAB* __restrict__ p_b_grid,
|
||||
@@ -358,9 +368,6 @@ struct GridwiseBatchedGemmGemm_Xdl_CShuffle
|
||||
const index_t n_block_data_idx_on_grid =
|
||||
__builtin_amdgcn_readfirstlane(block_work_idx[I1] * Gemm1NPerBlock);
|
||||
|
||||
// lds max alignment
|
||||
constexpr auto max_lds_align = math::lcm(math::lcm(AK1, BK1), B1K1);
|
||||
|
||||
// A matrix in LDS memory, dst of blockwise copy
|
||||
constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
|
||||
|
||||
@@ -464,14 +471,12 @@ struct GridwiseBatchedGemmGemm_Xdl_CShuffle
|
||||
auto acc_thread_buf = blockwise_gemm.GetCThreadBuffer();
|
||||
|
||||
// LDS allocation for A and B: be careful of alignment
|
||||
constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
|
||||
a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
|
||||
|
||||
auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
|
||||
static_cast<FloatAB*>(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
|
||||
static_cast<FloatAB*>(p_shared) + SharedMemTrait::a_block_space_offset,
|
||||
a_block_desc_ak0_m_ak1.GetElementSpaceSize());
|
||||
|
||||
auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
|
||||
static_cast<FloatAB*>(p_shared) + a_block_space_size_aligned,
|
||||
static_cast<FloatAB*>(p_shared) + SharedMemTrait::b_block_space_offset,
|
||||
b_block_desc_bk0_n_bk1.GetElementSpaceSize());
|
||||
|
||||
constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1, 0, 0);
|
||||
@@ -588,7 +593,7 @@ struct GridwiseBatchedGemmGemm_Xdl_CShuffle
|
||||
|
||||
// reuse LDS space for gemm0's b_block_buf
|
||||
auto b1_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
|
||||
static_cast<FloatAB*>(p_shared) + a_block_space_size_aligned,
|
||||
static_cast<FloatAB*>(p_shared) + SharedMemTrait::b1_block_space_offset,
|
||||
b1_block_desc_bk0_n_bk1.GetElementSpaceSize());
|
||||
|
||||
constexpr index_t Gemm1KPack = math::max(
|
||||
@@ -611,10 +616,11 @@ struct GridwiseBatchedGemmGemm_Xdl_CShuffle
|
||||
MXdlPerWave,
|
||||
Gemm1NXdlPerWave,
|
||||
Gemm1KPack,
|
||||
false,
|
||||
false, // TransposeC
|
||||
Gemm1KPack, // AMmaKStride
|
||||
Gemm1KPack * XdlopsGemm<FloatAB, MPerXdl, NPerXdl, Gemm1KPack, false>{}.K0PerXdlops>{
|
||||
make_tuple(0, 0, 0, 0)}; // TransposeC
|
||||
// BMmaKStride
|
||||
make_tuple(0, 0, 0, 0)}; // A_origin
|
||||
|
||||
auto c_thread_buf = gemm1_blockwise_gemm.GetCThreadBuffer();
|
||||
|
||||
@@ -699,6 +705,7 @@ struct GridwiseBatchedGemmGemm_Xdl_CShuffle
|
||||
a1_thread_desc_k0_m_k1,
|
||||
make_tuple(I0, I0, I0),
|
||||
a1_thread_buf);
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
gemm1_blockwise_gemm.Run(a1_thread_buf, b1_block_buf, c_thread_buf);
|
||||
|
||||
@@ -182,11 +182,19 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
|
||||
|
||||
__host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
|
||||
{
|
||||
return math::max((SharedMemTrait::a_block_space_size_aligned +
|
||||
SharedMemTrait::b_block_space_size_aligned) *
|
||||
sizeof(FloatAB) +
|
||||
SharedMemTrait::reduction_workspace * sizeof(FloatGemmAcc),
|
||||
SharedMemTrait::c_block_size * sizeof(FloatCShuffle));
|
||||
const index_t gemm0_bytes_end = (SharedMemTrait::a_block_space_size_aligned +
|
||||
SharedMemTrait::b_block_space_size_aligned) *
|
||||
sizeof(FloatAB);
|
||||
const index_t gemm1_bytes_end =
|
||||
(SharedMemTrait::b1_block_space_offset + SharedMemTrait::b1_block_space_size_aligned) *
|
||||
sizeof(FloatAB);
|
||||
const index_t softmax_bytes_end = (SharedMemTrait::reduction_space_offset +
|
||||
SharedMemTrait::reduction_space_size_aligned) *
|
||||
sizeof(FloatGemmAcc);
|
||||
const index_t c_block_bytes_end =
|
||||
SharedMemTrait::c_block_space_size * sizeof(FloatCShuffle);
|
||||
|
||||
return math::max(gemm0_bytes_end, gemm1_bytes_end, softmax_bytes_end, c_block_bytes_end);
|
||||
}
|
||||
|
||||
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
|
||||
@@ -302,22 +310,25 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
|
||||
|
||||
static constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
|
||||
a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
|
||||
static constexpr auto b0_block_space_size_aligned = math::integer_least_multiple(
|
||||
static constexpr auto b_block_space_size_aligned = math::integer_least_multiple(
|
||||
b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align);
|
||||
static constexpr auto b1_block_space_size_aligned = math::integer_least_multiple(
|
||||
b1_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align);
|
||||
|
||||
// B1 can reuse B's LDS
|
||||
static constexpr auto b_block_space_size_aligned =
|
||||
math::max(b0_block_space_size_aligned.value, b1_block_space_size_aligned.value);
|
||||
static constexpr auto a_block_space_offset = 0;
|
||||
static constexpr auto b_block_space_offset = a_block_space_size_aligned.value;
|
||||
static constexpr auto b1_block_space_offset = 0;
|
||||
|
||||
// LDS allocation for reduction
|
||||
static constexpr index_t reduction_workspace = BlockSize;
|
||||
static constexpr index_t reduction_space_size_aligned =
|
||||
math::integer_least_multiple(BlockSize, max_lds_align);
|
||||
|
||||
static constexpr auto reduction_space_offset = 0;
|
||||
|
||||
// LDS allocation for C shuffle in LDS
|
||||
static constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
|
||||
GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock();
|
||||
static constexpr auto c_block_size =
|
||||
static constexpr auto c_block_space_size =
|
||||
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize();
|
||||
};
|
||||
|
||||
@@ -471,10 +482,11 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
|
||||
|
||||
// LDS allocation for A and B: be careful of alignment
|
||||
auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
|
||||
static_cast<FloatAB*>(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
|
||||
static_cast<FloatAB*>(p_shared) + SharedMemTrait::a_block_space_offset,
|
||||
a_block_desc_ak0_m_ak1.GetElementSpaceSize());
|
||||
|
||||
auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
|
||||
static_cast<FloatAB*>(p_shared) + SharedMemTrait::a_block_space_size_aligned,
|
||||
static_cast<FloatAB*>(p_shared) + SharedMemTrait::b_block_space_offset,
|
||||
b_block_desc_bk0_n_bk1.GetElementSpaceSize());
|
||||
|
||||
constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1, 0, 0);
|
||||
@@ -591,7 +603,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
|
||||
|
||||
// reuse LDS space for gemm0's b_block_buf
|
||||
auto b1_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
|
||||
static_cast<FloatAB*>(p_shared) + SharedMemTrait::a_block_space_size_aligned,
|
||||
static_cast<FloatAB*>(p_shared) + SharedMemTrait::b1_block_space_offset,
|
||||
b1_block_desc_bk0_n_bk1.GetElementSpaceSize());
|
||||
|
||||
constexpr index_t Gemm1KPack = math::max(
|
||||
@@ -617,7 +629,8 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
|
||||
true, // TransposeC
|
||||
Gemm1KPack, // AMmaKStride
|
||||
Gemm1KPack * XdlopsGemm<FloatAB, MPerXdl, NPerXdl, Gemm1KPack, false>{}.K0PerXdlops>{
|
||||
make_tuple(0, 0, 0, 0)}; // TransposeC
|
||||
// BMmaKStride
|
||||
make_tuple(0, 0, 0, 0)}; // A_origin
|
||||
|
||||
auto acc1_thread_buf = gemm1_blockwise_gemm.GetCThreadBuffer();
|
||||
|
||||
@@ -625,10 +638,8 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
|
||||
// Blockwise softmax
|
||||
//
|
||||
auto workspace_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
|
||||
static_cast<FloatGemmAcc*>(p_shared) +
|
||||
SharedMemTrait::a_block_space_size_aligned * sizeof(FloatAB) / 4 +
|
||||
SharedMemTrait::b_block_space_size_aligned * sizeof(FloatAB) / 4,
|
||||
SharedMemTrait::reduction_workspace);
|
||||
static_cast<FloatGemmAcc*>(p_shared) + SharedMemTrait::reduction_space_offset,
|
||||
SharedMemTrait::reduction_space_size_aligned);
|
||||
|
||||
// get acc0 8D thread cluster
|
||||
constexpr auto thread_cluster_m0_n0_m1_n1_m2_n2_n3_n4 =
|
||||
@@ -717,7 +728,6 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
|
||||
running_sum_new = mathext::exp(running_max - running_max_new) * running_sum +
|
||||
mathext::exp(max - running_max_new) * sum;
|
||||
|
||||
block_sync_lds();
|
||||
// gemm1
|
||||
{
|
||||
// TODO: explore using dynamic buffer for a1 thread buffer
|
||||
@@ -736,12 +746,13 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
|
||||
b1_blockwise_copy.MoveSrcSliceWindow(b1_grid_desc_bk0_n_bk1,
|
||||
b1_block_slice_copy_step);
|
||||
|
||||
block_sync_lds(); // wait for reduction LDS read
|
||||
|
||||
b1_blockwise_copy.RunWrite(b1_block_desc_bk0_n_bk1, b1_block_buf);
|
||||
|
||||
// main body
|
||||
if constexpr(num_gemm1_k_block_inner_loop > 1)
|
||||
{
|
||||
|
||||
static_for<0, num_gemm1_k_block_inner_loop - 1, 1>{}([&](auto i) {
|
||||
a1_blockwise_copy.Run(acc_thread_desc_k0_m_k1,
|
||||
make_tuple(Number<i * A1ThreadSliceK0>{}, I0, I0),
|
||||
@@ -749,6 +760,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
|
||||
a1_thread_desc_k0_m_k1,
|
||||
make_tuple(I0, I0, I0),
|
||||
a1_thread_buf);
|
||||
|
||||
b1_blockwise_copy.RunRead(b1_grid_desc_bk0_n_bk1, b1_grid_buf);
|
||||
|
||||
block_sync_lds();
|
||||
@@ -773,6 +785,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
|
||||
a1_thread_desc_k0_m_k1,
|
||||
make_tuple(I0, I0, I0),
|
||||
a1_thread_buf);
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
gemm1_blockwise_gemm.Run(a1_thread_buf, b1_block_buf, acc1_thread_buf);
|
||||
@@ -817,6 +830,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
|
||||
running_max = running_max_new;
|
||||
running_sum = running_sum_new;
|
||||
|
||||
block_sync_lds(); // wait for gemm1 LDS read
|
||||
} while(++gemm1_k_block_outer_index < num_gemm1_k_block_outer_loop); // end j loop
|
||||
|
||||
// shuffle C and write out
|
||||
|
||||
@@ -819,7 +819,7 @@ struct XdlopsGemm
|
||||
index_t n_offset = blk_i * mfma_instr.n_per_blk + blk_td;
|
||||
index_t m_offset = xdlops_i * mfma_instr.m_per_blk + blk_id * mfma_instr.group_size;
|
||||
|
||||
return CIndex{m_offset, n_offset};
|
||||
return TransposeC ? CIndex{n_offset, m_offset} : CIndex{m_offset, n_offset};
|
||||
}
|
||||
|
||||
static constexpr auto mfma = MfmaSelector<base_type, MPerXdlops, NPerXdlops>{};
|
||||
|
||||
Reference in New Issue
Block a user