mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-04 05:31:24 +00:00
Add support for double buffering in direct load GEMM kernel (#1052)
This PR introduces support for double buffering in LDS into GEMM kernels that use direct load instructions. Direct loads now use inline asm instead of intrinsics. Usage of intrinsics results in compiler adding additional waitcnt instructions what breaks possible load/compute overlap in case of double buffering. Usage of inline asm results in the need to use sched_barrier in order to make sure that compiler cannot incorrectly reschedule instructions since it does not know the data dependencies between global->LDS and LDS->registers.
This commit is contained in:
committed by
GitHub
parent
c7d5c7727b
commit
bc4bf9bd03
@@ -236,9 +236,10 @@ struct GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad
|
||||
constexpr auto c_block_size =
|
||||
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize();
|
||||
|
||||
return math::max(a_block_space_size_aligned * sizeof(AComputeDataType) +
|
||||
b_block_space_size_aligned * sizeof(BComputeDataType),
|
||||
c_block_size * sizeof(CShuffleDataType));
|
||||
return math::max(
|
||||
NumGemmKPrefetchStage * a_block_space_size_aligned * sizeof(AComputeDataType) +
|
||||
NumGemmKPrefetchStage * b_block_space_size_aligned * sizeof(BComputeDataType),
|
||||
c_block_size * sizeof(CShuffleDataType));
|
||||
}
|
||||
|
||||
__host__ __device__ static auto
|
||||
@@ -491,6 +492,22 @@ struct GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad
|
||||
|
||||
__device__ __host__ static constexpr auto GetMPerBlock() { return MPerBlock; }
|
||||
|
||||
template <typename DataType>
|
||||
__device__ static auto AllocateBlockBuffers(void* p_shared,
|
||||
int32_t num_elems,
|
||||
int32_t offset_elems,
|
||||
int32_t max_lds_align)
|
||||
{
|
||||
const int32_t single_buffer_offset = math::integer_least_multiple(num_elems, max_lds_align);
|
||||
return generate_tuple(
|
||||
[&](auto i) {
|
||||
const int32_t local_offset = i * single_buffer_offset;
|
||||
return make_dynamic_buffer<AddressSpaceEnum::Lds>(
|
||||
static_cast<DataType*>(p_shared) + local_offset + offset_elems, num_elems);
|
||||
},
|
||||
Number<NumGemmKPrefetchStage>{});
|
||||
}
|
||||
|
||||
template <bool HasMainKBlockLoop,
|
||||
typename AGridDesc_AK0_M_AK1,
|
||||
typename BGridDesc_BK0_N_BK1,
|
||||
@@ -624,12 +641,14 @@ struct GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad
|
||||
constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
|
||||
a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
|
||||
|
||||
auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
|
||||
static_cast<AComputeDataType*>(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
|
||||
|
||||
auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
|
||||
static_cast<BComputeDataType*>(p_shared) + a_block_space_size_aligned,
|
||||
b_block_desc_bk0_n_bk1.GetElementSpaceSize());
|
||||
auto a_block_buffers = AllocateBlockBuffers<AComputeDataType>(
|
||||
p_shared, a_block_desc_ak0_m_ak1.GetElementSpaceSize(), 0, max_lds_align);
|
||||
const auto b_buffers_offset = a_block_space_size_aligned * NumGemmKPrefetchStage;
|
||||
auto b_block_buffers =
|
||||
AllocateBlockBuffers<BComputeDataType>(p_shared,
|
||||
b_block_desc_bk0_n_bk1.GetElementSpaceSize(),
|
||||
b_buffers_offset,
|
||||
max_lds_align);
|
||||
|
||||
constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1, 0, 0);
|
||||
constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock / BK1, 0, 0);
|
||||
@@ -645,13 +664,13 @@ struct GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad
|
||||
a_block_desc_ak0_m_ak1,
|
||||
a_blockwise_copy,
|
||||
a_grid_buf,
|
||||
a_block_buf,
|
||||
a_block_buffers,
|
||||
a_block_slice_copy_step,
|
||||
b_grid_desc_bk0_n_bk1,
|
||||
b_block_desc_bk0_n_bk1,
|
||||
b_blockwise_copy,
|
||||
b_grid_buf,
|
||||
b_block_buf,
|
||||
b_block_buffers,
|
||||
b_block_slice_copy_step,
|
||||
blockwise_gemm,
|
||||
c_thread_buf,
|
||||
|
||||
@@ -7,6 +7,20 @@
|
||||
#include "ck/utility/loop_scheduler.hpp"
|
||||
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
|
||||
|
||||
namespace lds_direct_load {
|
||||
|
||||
__device__ void sched_barrier()
|
||||
{
|
||||
#if CK_USE_AMD_LDS_DIRECT_LOAD_INLINE_ASM
|
||||
// When direct loads and `waitcnt` instructions are submitted using inline asm, the usage of
|
||||
// `sched_barrier` is necessary to make sure no instructions that use the loaded memory
|
||||
// are scheduled by the compiler before the `waitcnt` instruction.
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
#endif
|
||||
}
|
||||
|
||||
} // namespace lds_direct_load
|
||||
|
||||
namespace ck {
|
||||
|
||||
template <index_t NumPrefetch>
|
||||
@@ -17,7 +31,6 @@ template <>
|
||||
struct GridwiseGemmPipeline_v4<1>
|
||||
{
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
|
||||
__host__ __device__ static constexpr bool IsSupported(index_t /* num_loop */) { return true; }
|
||||
|
||||
@@ -31,13 +44,13 @@ struct GridwiseGemmPipeline_v4<1>
|
||||
typename ABlockDesc,
|
||||
typename ABlockTransfer,
|
||||
typename AGridBuffer,
|
||||
typename ABlockBuffer,
|
||||
typename ABlockBuffers,
|
||||
typename ABlockTransferStep,
|
||||
typename BGridDesc,
|
||||
typename BBlockDesc,
|
||||
typename BBlockTransfer,
|
||||
typename BGridBuffer,
|
||||
typename BBlockBuffer,
|
||||
typename BBlockBuffers,
|
||||
typename BBlockTransferStep,
|
||||
typename BlockwiseGemm,
|
||||
typename CThreadBuffer>
|
||||
@@ -45,18 +58,22 @@ struct GridwiseGemmPipeline_v4<1>
|
||||
const ABlockDesc& a_block_desc,
|
||||
ABlockTransfer& a_blockwise_copy,
|
||||
const AGridBuffer& a_grid_buf,
|
||||
ABlockBuffer& a_block_buf,
|
||||
ABlockBuffers& a_block_bufs,
|
||||
const ABlockTransferStep& a_block_copy_step,
|
||||
const BGridDesc& b_grid_desc,
|
||||
const BBlockDesc& b_block_desc,
|
||||
BBlockTransfer& b_blockwise_copy,
|
||||
const BGridBuffer& b_grid_buf,
|
||||
BBlockBuffer& b_block_buf,
|
||||
BBlockBuffers& b_block_bufs,
|
||||
const BBlockTransferStep& b_block_copy_step,
|
||||
const BlockwiseGemm& blockwise_gemm,
|
||||
CThreadBuffer& c_thread_buf,
|
||||
index_t num_loop)
|
||||
{
|
||||
static_assert(ABlockBuffers::Size() == 1 && BBlockBuffers::Size() == 1);
|
||||
auto& a_block_buf = a_block_bufs.At(I0);
|
||||
auto& b_block_buf = b_block_bufs.At(I0);
|
||||
|
||||
a_blockwise_copy.Run(a_grid_desc, a_grid_buf, a_block_desc, a_block_buf);
|
||||
b_blockwise_copy.Run(b_grid_desc, b_grid_buf, b_block_desc, b_block_buf);
|
||||
|
||||
@@ -74,10 +91,12 @@ struct GridwiseGemmPipeline_v4<1>
|
||||
do
|
||||
{
|
||||
block_sync_lds_direct_load();
|
||||
lds_direct_load::sched_barrier();
|
||||
|
||||
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
|
||||
|
||||
block_sync_lds_direct_load();
|
||||
lds_direct_load::sched_barrier();
|
||||
|
||||
a_blockwise_copy.Run(a_grid_desc, a_grid_buf, a_block_desc, a_block_buf);
|
||||
b_blockwise_copy.Run(b_grid_desc, b_grid_buf, b_block_desc, b_block_buf);
|
||||
@@ -92,10 +111,128 @@ struct GridwiseGemmPipeline_v4<1>
|
||||
// tail
|
||||
{
|
||||
block_sync_lds_direct_load();
|
||||
lds_direct_load::sched_barrier();
|
||||
|
||||
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// 2-stages prefetch
|
||||
template <>
|
||||
struct GridwiseGemmPipeline_v4<2>
|
||||
{
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
|
||||
__host__ __device__ static constexpr bool IsSupported(index_t num_loop)
|
||||
{
|
||||
return num_loop % 2 == 0;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr bool CalculateHasMainLoop(index_t num_loop)
|
||||
{
|
||||
return (num_loop / 2) > 1;
|
||||
}
|
||||
|
||||
template <bool HasMainLoop,
|
||||
typename AGridDesc,
|
||||
typename ABlockDesc,
|
||||
typename ABlockTransfer,
|
||||
typename AGridBuffer,
|
||||
typename ABlockBuffers,
|
||||
typename ABlockTransferStep,
|
||||
typename BGridDesc,
|
||||
typename BBlockDesc,
|
||||
typename BBlockTransfer,
|
||||
typename BGridBuffer,
|
||||
typename BBlockBuffers,
|
||||
typename BBlockTransferStep,
|
||||
typename BlockwiseGemm,
|
||||
typename CThreadBuffer>
|
||||
__device__ static void Run(const AGridDesc& a_grid_desc,
|
||||
const ABlockDesc& a_block_desc,
|
||||
ABlockTransfer& a_blockwise_copy,
|
||||
const AGridBuffer& a_grid_buf,
|
||||
ABlockBuffers& a_block_bufs,
|
||||
const ABlockTransferStep& a_block_copy_step,
|
||||
const BGridDesc& b_grid_desc,
|
||||
const BBlockDesc& b_block_desc,
|
||||
BBlockTransfer& b_blockwise_copy,
|
||||
const BGridBuffer& b_grid_buf,
|
||||
BBlockBuffers& b_block_bufs,
|
||||
const BBlockTransferStep& b_block_copy_step,
|
||||
const BlockwiseGemm& blockwise_gemm,
|
||||
CThreadBuffer& c_thread_buf,
|
||||
index_t num_loop)
|
||||
{
|
||||
static_assert(ABlockBuffers::Size() == 2 && BBlockBuffers::Size() == 2);
|
||||
auto& a_block_buf1 = a_block_bufs.At(I0);
|
||||
auto& a_block_buf2 = a_block_bufs.At(I1);
|
||||
auto& b_block_buf1 = b_block_bufs.At(I0);
|
||||
auto& b_block_buf2 = b_block_bufs.At(I1);
|
||||
|
||||
a_blockwise_copy.Run(a_grid_desc, a_grid_buf, a_block_desc, a_block_buf1);
|
||||
b_blockwise_copy.Run(b_grid_desc, b_grid_buf, b_block_desc, b_block_buf1);
|
||||
|
||||
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
|
||||
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
|
||||
|
||||
// Initialize C
|
||||
c_thread_buf.Clear();
|
||||
|
||||
// main body
|
||||
if constexpr(HasMainLoop)
|
||||
{
|
||||
index_t i = 0;
|
||||
|
||||
do
|
||||
{
|
||||
block_sync_lds_direct_load();
|
||||
lds_direct_load::sched_barrier();
|
||||
|
||||
a_blockwise_copy.Run(a_grid_desc, a_grid_buf, a_block_desc, a_block_buf2);
|
||||
b_blockwise_copy.Run(b_grid_desc, b_grid_buf, b_block_desc, b_block_buf2);
|
||||
|
||||
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
|
||||
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
|
||||
|
||||
blockwise_gemm.Run(a_block_buf1, b_block_buf1, c_thread_buf);
|
||||
|
||||
block_sync_lds_direct_load();
|
||||
lds_direct_load::sched_barrier();
|
||||
|
||||
a_blockwise_copy.Run(a_grid_desc, a_grid_buf, a_block_desc, a_block_buf1);
|
||||
b_blockwise_copy.Run(b_grid_desc, b_grid_buf, b_block_desc, b_block_buf1);
|
||||
|
||||
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
|
||||
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
|
||||
|
||||
blockwise_gemm.Run(a_block_buf2, b_block_buf2, c_thread_buf);
|
||||
|
||||
i += 2;
|
||||
} while(i < (num_loop - 2));
|
||||
}
|
||||
|
||||
// tail
|
||||
{
|
||||
block_sync_lds_direct_load();
|
||||
lds_direct_load::sched_barrier();
|
||||
|
||||
a_blockwise_copy.Run(a_grid_desc, a_grid_buf, a_block_desc, a_block_buf2);
|
||||
b_blockwise_copy.Run(b_grid_desc, b_grid_buf, b_block_desc, b_block_buf2);
|
||||
|
||||
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
|
||||
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
|
||||
|
||||
blockwise_gemm.Run(a_block_buf1, b_block_buf1, c_thread_buf);
|
||||
|
||||
block_sync_lds_direct_load();
|
||||
lds_direct_load::sched_barrier();
|
||||
|
||||
blockwise_gemm.Run(a_block_buf2, b_block_buf2, c_thread_buf);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
|
||||
Reference in New Issue
Block a user