From 7d211aa652d1410a6cc4956d3771e15783e21d72 Mon Sep 17 00:00:00 2001 From: Anthony Chang Date: Tue, 16 Aug 2022 01:04:20 +0800 Subject: [PATCH] Hotfix LDS data hazard in fused attention (#360) * avoid LDS data hazard in gemm_softmax_gemm pipeline * trivial refactors * comments * shrink blockwise gemm v2 thread buffer size * reclaim A block lds space when during 2nd gemm * amend * amend [ROCm/composable_kernel commit: c961ce9226dd263af1d898c02c0afae0ed702f7d] --- .../gpu/block/blockwise_gemm_xdlops.hpp | 12 ++- ...wise_batched_gemm_gemm_xdl_cshuffle_v1.hpp | 87 ++++++++++--------- ...ched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp | 56 +++++++----- .../tensor_operation/gpu/warp/xdlops_gemm.hpp | 2 +- 4 files changed, 88 insertions(+), 69 deletions(-) diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp index 69a00c8e54..67332929ff 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp @@ -701,9 +701,7 @@ struct BlockwiseGemmXdlops_v2 const auto waveId_m = wave_idx[I0]; const auto waveId_n = wave_idx[I1]; - const auto tmp = xdlops_gemm.GetBeginOfThreadBlk(xdlops_i, blk_i); - const auto blk_idx = - TransposeC ? make_multi_index(tmp[I1], tmp[I0]) : make_multi_index(tmp[I0], tmp[I1]); + const auto blk_idx = xdlops_gemm.GetBeginOfThreadBlk(xdlops_i, blk_i); constexpr auto mrepeat_mwave_mperxdl_to_m_adaptor = make_single_stage_tensor_adaptor( make_tuple(make_unmerge_transform(make_tuple(MRepeat, MWaves, MPerXDL))), @@ -922,13 +920,13 @@ struct BlockwiseGemmXdlops_v2 } protected: - // A[M0, M1, M2, KPerThread] + // A[M0, M1, M2, KPack] static constexpr auto a_thread_desc_ = - make_naive_tensor_descriptor_packed(make_tuple(I1, I1, I1, Number{})); + make_naive_tensor_descriptor_packed(make_tuple(I1, I1, I1, Number{})); - // B[N0, N1, N2, KPerThread] + // B[N0, N1, N2, KPack] static constexpr auto b_thread_desc_ = - make_naive_tensor_descriptor_packed(make_tuple(I1, I1, I1, Number{})); + make_naive_tensor_descriptor_packed(make_tuple(I1, I1, I1, Number{})); // C[M, N, NumRegXdlops] static constexpr auto c_thread_desc_ = make_naive_tensor_descriptor_packed( diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_gemm_xdl_cshuffle_v1.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_gemm_xdl_cshuffle_v1.hpp index 0ab92e8fac..4fbf576f99 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_gemm_xdl_cshuffle_v1.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_gemm_xdl_cshuffle_v1.hpp @@ -181,36 +181,16 @@ struct GridwiseBatchedGemmGemm_Xdl_CShuffle __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() { - // LDS allocation for A and B: be careful of alignment - constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); - constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); - constexpr auto b1_block_desc_bk0_n_bk1 = GetB1BlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + const index_t gemm0_bytes_end = (SharedMemTrait::a_block_space_size_aligned + + SharedMemTrait::b_block_space_size_aligned) * + sizeof(FloatAB); + const index_t gemm1_bytes_end = + (SharedMemTrait::b1_block_space_offset + SharedMemTrait::b1_block_space_size_aligned) * + sizeof(FloatAB); + const index_t c_block_bytes_end = + SharedMemTrait::c_block_space_size * sizeof(FloatCShuffle); - // lds max alignment - constexpr auto max_lds_align = math::lcm(math::lcm(AK1, BK1), B1K1); - - constexpr auto a_block_space_size_aligned = math::integer_least_multiple( - a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); - - constexpr auto b0_block_space_size_aligned = math::integer_least_multiple( - b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align); - - constexpr auto b1_block_space_size_aligned = math::integer_least_multiple( - b1_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align); - - constexpr auto b_block_space_size_aligned = - math::max(b0_block_space_size_aligned.value, b1_block_space_size_aligned.value); - - // LDS allocation for C shuffle in LDS - constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = - GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); - - constexpr auto c_block_size = - c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize(); - - return math::max((a_block_space_size_aligned + b_block_space_size_aligned) * - sizeof(FloatAB), - c_block_size * sizeof(FloatCShuffle)); + return math::max(gemm0_bytes_end, gemm1_bytes_end, c_block_bytes_end); } // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} @@ -312,6 +292,36 @@ struct GridwiseBatchedGemmGemm_Xdl_CShuffle using DefaultBlock2CTileMap = remove_cvref_t; + struct SharedMemTrait + { + // LDS allocation for A and B: be careful of alignment + static constexpr auto a_block_desc_ak0_m_ak1 = + GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + static constexpr auto b_block_desc_bk0_n_bk1 = + GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + static constexpr auto b1_block_desc_bk0_n_bk1 = + GetB1BlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + + static constexpr auto max_lds_align = math::lcm(math::lcm(AK1, BK1), B1K1); + + static constexpr auto a_block_space_size_aligned = math::integer_least_multiple( + a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); + static constexpr auto b_block_space_size_aligned = math::integer_least_multiple( + b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align); + static constexpr auto b1_block_space_size_aligned = math::integer_least_multiple( + b1_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align); + + static constexpr auto a_block_space_offset = 0; + static constexpr auto b_block_space_offset = a_block_space_size_aligned.value; + static constexpr auto b1_block_space_offset = 0; + + // LDS allocation for C shuffle in LDS + static constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); + static constexpr auto c_block_space_size = + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize(); + }; + template __device__ static void Run(const FloatAB* __restrict__ p_a_grid, const FloatAB* __restrict__ p_b_grid, @@ -358,9 +368,6 @@ struct GridwiseBatchedGemmGemm_Xdl_CShuffle const index_t n_block_data_idx_on_grid = __builtin_amdgcn_readfirstlane(block_work_idx[I1] * Gemm1NPerBlock); - // lds max alignment - constexpr auto max_lds_align = math::lcm(math::lcm(AK1, BK1), B1K1); - // A matrix in LDS memory, dst of blockwise copy constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); @@ -464,14 +471,12 @@ struct GridwiseBatchedGemmGemm_Xdl_CShuffle auto acc_thread_buf = blockwise_gemm.GetCThreadBuffer(); // LDS allocation for A and B: be careful of alignment - constexpr auto a_block_space_size_aligned = math::integer_least_multiple( - a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); - auto a_block_buf = make_dynamic_buffer( - static_cast(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize()); + static_cast(p_shared) + SharedMemTrait::a_block_space_offset, + a_block_desc_ak0_m_ak1.GetElementSpaceSize()); auto b_block_buf = make_dynamic_buffer( - static_cast(p_shared) + a_block_space_size_aligned, + static_cast(p_shared) + SharedMemTrait::b_block_space_offset, b_block_desc_bk0_n_bk1.GetElementSpaceSize()); constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1, 0, 0); @@ -588,7 +593,7 @@ struct GridwiseBatchedGemmGemm_Xdl_CShuffle // reuse LDS space for gemm0's b_block_buf auto b1_block_buf = make_dynamic_buffer( - static_cast(p_shared) + a_block_space_size_aligned, + static_cast(p_shared) + SharedMemTrait::b1_block_space_offset, b1_block_desc_bk0_n_bk1.GetElementSpaceSize()); constexpr index_t Gemm1KPack = math::max( @@ -611,10 +616,11 @@ struct GridwiseBatchedGemmGemm_Xdl_CShuffle MXdlPerWave, Gemm1NXdlPerWave, Gemm1KPack, - false, + false, // TransposeC Gemm1KPack, // AMmaKStride Gemm1KPack * XdlopsGemm{}.K0PerXdlops>{ - make_tuple(0, 0, 0, 0)}; // TransposeC + // BMmaKStride + make_tuple(0, 0, 0, 0)}; // A_origin auto c_thread_buf = gemm1_blockwise_gemm.GetCThreadBuffer(); @@ -699,6 +705,7 @@ struct GridwiseBatchedGemmGemm_Xdl_CShuffle a1_thread_desc_k0_m_k1, make_tuple(I0, I0, I0), a1_thread_buf); + block_sync_lds(); gemm1_blockwise_gemm.Run(a1_thread_buf, b1_block_buf, c_thread_buf); diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp index 7e0fbb7989..db6f7cbb50 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp @@ -182,11 +182,19 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() { - return math::max((SharedMemTrait::a_block_space_size_aligned + - SharedMemTrait::b_block_space_size_aligned) * - sizeof(FloatAB) + - SharedMemTrait::reduction_workspace * sizeof(FloatGemmAcc), - SharedMemTrait::c_block_size * sizeof(FloatCShuffle)); + const index_t gemm0_bytes_end = (SharedMemTrait::a_block_space_size_aligned + + SharedMemTrait::b_block_space_size_aligned) * + sizeof(FloatAB); + const index_t gemm1_bytes_end = + (SharedMemTrait::b1_block_space_offset + SharedMemTrait::b1_block_space_size_aligned) * + sizeof(FloatAB); + const index_t softmax_bytes_end = (SharedMemTrait::reduction_space_offset + + SharedMemTrait::reduction_space_size_aligned) * + sizeof(FloatGemmAcc); + const index_t c_block_bytes_end = + SharedMemTrait::c_block_space_size * sizeof(FloatCShuffle); + + return math::max(gemm0_bytes_end, gemm1_bytes_end, softmax_bytes_end, c_block_bytes_end); } // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} @@ -302,22 +310,25 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle static constexpr auto a_block_space_size_aligned = math::integer_least_multiple( a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); - static constexpr auto b0_block_space_size_aligned = math::integer_least_multiple( + static constexpr auto b_block_space_size_aligned = math::integer_least_multiple( b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align); static constexpr auto b1_block_space_size_aligned = math::integer_least_multiple( b1_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align); - // B1 can reuse B's LDS - static constexpr auto b_block_space_size_aligned = - math::max(b0_block_space_size_aligned.value, b1_block_space_size_aligned.value); + static constexpr auto a_block_space_offset = 0; + static constexpr auto b_block_space_offset = a_block_space_size_aligned.value; + static constexpr auto b1_block_space_offset = 0; // LDS allocation for reduction - static constexpr index_t reduction_workspace = BlockSize; + static constexpr index_t reduction_space_size_aligned = + math::integer_least_multiple(BlockSize, max_lds_align); + + static constexpr auto reduction_space_offset = 0; // LDS allocation for C shuffle in LDS static constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); - static constexpr auto c_block_size = + static constexpr auto c_block_space_size = c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize(); }; @@ -471,10 +482,11 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle // LDS allocation for A and B: be careful of alignment auto a_block_buf = make_dynamic_buffer( - static_cast(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize()); + static_cast(p_shared) + SharedMemTrait::a_block_space_offset, + a_block_desc_ak0_m_ak1.GetElementSpaceSize()); auto b_block_buf = make_dynamic_buffer( - static_cast(p_shared) + SharedMemTrait::a_block_space_size_aligned, + static_cast(p_shared) + SharedMemTrait::b_block_space_offset, b_block_desc_bk0_n_bk1.GetElementSpaceSize()); constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1, 0, 0); @@ -591,7 +603,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle // reuse LDS space for gemm0's b_block_buf auto b1_block_buf = make_dynamic_buffer( - static_cast(p_shared) + SharedMemTrait::a_block_space_size_aligned, + static_cast(p_shared) + SharedMemTrait::b1_block_space_offset, b1_block_desc_bk0_n_bk1.GetElementSpaceSize()); constexpr index_t Gemm1KPack = math::max( @@ -617,7 +629,8 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle true, // TransposeC Gemm1KPack, // AMmaKStride Gemm1KPack * XdlopsGemm{}.K0PerXdlops>{ - make_tuple(0, 0, 0, 0)}; // TransposeC + // BMmaKStride + make_tuple(0, 0, 0, 0)}; // A_origin auto acc1_thread_buf = gemm1_blockwise_gemm.GetCThreadBuffer(); @@ -625,10 +638,8 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle // Blockwise softmax // auto workspace_buf = make_dynamic_buffer( - static_cast(p_shared) + - SharedMemTrait::a_block_space_size_aligned * sizeof(FloatAB) / 4 + - SharedMemTrait::b_block_space_size_aligned * sizeof(FloatAB) / 4, - SharedMemTrait::reduction_workspace); + static_cast(p_shared) + SharedMemTrait::reduction_space_offset, + SharedMemTrait::reduction_space_size_aligned); // get acc0 8D thread cluster constexpr auto thread_cluster_m0_n0_m1_n1_m2_n2_n3_n4 = @@ -717,7 +728,6 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle running_sum_new = mathext::exp(running_max - running_max_new) * running_sum + mathext::exp(max - running_max_new) * sum; - block_sync_lds(); // gemm1 { // TODO: explore using dynamic buffer for a1 thread buffer @@ -736,12 +746,13 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle b1_blockwise_copy.MoveSrcSliceWindow(b1_grid_desc_bk0_n_bk1, b1_block_slice_copy_step); + block_sync_lds(); // wait for reduction LDS read + b1_blockwise_copy.RunWrite(b1_block_desc_bk0_n_bk1, b1_block_buf); // main body if constexpr(num_gemm1_k_block_inner_loop > 1) { - static_for<0, num_gemm1_k_block_inner_loop - 1, 1>{}([&](auto i) { a1_blockwise_copy.Run(acc_thread_desc_k0_m_k1, make_tuple(Number{}, I0, I0), @@ -749,6 +760,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle a1_thread_desc_k0_m_k1, make_tuple(I0, I0, I0), a1_thread_buf); + b1_blockwise_copy.RunRead(b1_grid_desc_bk0_n_bk1, b1_grid_buf); block_sync_lds(); @@ -773,6 +785,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle a1_thread_desc_k0_m_k1, make_tuple(I0, I0, I0), a1_thread_buf); + block_sync_lds(); gemm1_blockwise_gemm.Run(a1_thread_buf, b1_block_buf, acc1_thread_buf); @@ -817,6 +830,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle running_max = running_max_new; running_sum = running_sum_new; + block_sync_lds(); // wait for gemm1 LDS read } while(++gemm1_k_block_outer_index < num_gemm1_k_block_outer_loop); // end j loop // shuffle C and write out diff --git a/include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp b/include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp index b4885ad3fc..0748ffbce5 100644 --- a/include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp +++ b/include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp @@ -819,7 +819,7 @@ struct XdlopsGemm index_t n_offset = blk_i * mfma_instr.n_per_blk + blk_td; index_t m_offset = xdlops_i * mfma_instr.m_per_blk + blk_id * mfma_instr.group_size; - return CIndex{m_offset, n_offset}; + return TransposeC ? CIndex{n_offset, m_offset} : CIndex{m_offset, n_offset}; } static constexpr auto mfma = MfmaSelector{};