mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-29 19:28:33 +00:00
WIP: Double buffer implementation.
This commit is contained in:
@@ -357,7 +357,8 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
|
||||
#if !defined(__HIPCC_RTC__) || !defined(CK_CODE_GEN_RTC)
|
||||
if(ck::get_device_name() == "gfx950")
|
||||
{
|
||||
return Base::GetSharedMemoryNumberOfByte(gfx950_t{});
|
||||
// Double buffering -> 2 times shared memory
|
||||
return 2*Base::GetSharedMemoryNumberOfByte(gfx950_t{});
|
||||
}
|
||||
else
|
||||
#endif
|
||||
@@ -755,12 +756,24 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
|
||||
// LDS allocation for A and B: be careful of alignment
|
||||
constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
|
||||
a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
|
||||
constexpr auto b_block_space_size_aligned = math::integer_least_multiple(
|
||||
b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align);
|
||||
|
||||
auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
|
||||
// Double buffers for A and B in LDS
|
||||
auto a_block_buf_0 = make_dynamic_buffer<AddressSpaceEnum::Lds>(
|
||||
static_cast<AComputeDataType*>(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
|
||||
|
||||
auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
|
||||
static_cast<BComputeDataType*>(p_shared) + a_block_space_size_aligned,
|
||||
auto a_block_buf_1 = make_dynamic_buffer<AddressSpaceEnum::Lds>(
|
||||
static_cast<AComputeDataType*>(p_shared) + a_block_space_size_aligned,
|
||||
a_block_desc_ak0_m_ak1.GetElementSpaceSize());
|
||||
|
||||
auto b_block_buf_0 = make_dynamic_buffer<AddressSpaceEnum::Lds>(
|
||||
static_cast<BComputeDataType*>(p_shared) + 2*a_block_space_size_aligned,
|
||||
b_block_desc_bk0_n_bk1.GetElementSpaceSize());
|
||||
|
||||
auto b_block_buf_1 = make_dynamic_buffer<AddressSpaceEnum::Lds>(
|
||||
static_cast<BComputeDataType*>(p_shared) + 2*a_block_space_size_aligned +
|
||||
b_block_space_size_aligned,
|
||||
b_block_desc_bk0_n_bk1.GetElementSpaceSize());
|
||||
|
||||
constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1, 0, 0);
|
||||
@@ -778,13 +791,15 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
|
||||
a_block_desc_ak0_m_ak1,
|
||||
a_blockwise_copy,
|
||||
a_grid_buf,
|
||||
a_block_buf,
|
||||
a_block_buf_0,
|
||||
a_block_buf_1,
|
||||
a_block_slice_copy_step,
|
||||
b_grid_desc_bk0_n_bk1,
|
||||
b_block_desc_bk0_n_bk1,
|
||||
b_blockwise_copy,
|
||||
b_grid_buf,
|
||||
b_block_buf,
|
||||
b_block_buf_0,
|
||||
b_block_buf_1,
|
||||
b_block_slice_copy_step,
|
||||
blockwise_gemm,
|
||||
c_thread_buf,
|
||||
|
||||
@@ -113,10 +113,9 @@ struct GridwiseGemmPipeline_v1<2, true, true>
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
|
||||
__host__ __device__ static constexpr bool IsSupported(index_t num_loop)
|
||||
__host__ __device__ static constexpr bool IsSupported(index_t)
|
||||
{
|
||||
// TODO: improve applicability
|
||||
return num_loop % 2 == 0;
|
||||
return true;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr bool CalculateHasMainLoop(index_t num_loop)
|
||||
@@ -124,6 +123,11 @@ struct GridwiseGemmPipeline_v1<2, true, true>
|
||||
return (num_loop / 2) > 1;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr bool CalculateIsOddLoop(index_t num_loop)
|
||||
{
|
||||
return (num_loop % 2) == 1;
|
||||
}
|
||||
|
||||
template <bool HasMainLoop,
|
||||
typename AGridDesc,
|
||||
typename ABlockDesc,
|
||||
@@ -143,31 +147,33 @@ struct GridwiseGemmPipeline_v1<2, true, true>
|
||||
const ABlockDesc& a_block_desc,
|
||||
ABlockTransfer& a_blockwise_copy,
|
||||
const AGridBuffer& a_grid_buf,
|
||||
ABlockBuffer& a_block_buf,
|
||||
ABlockBuffer& a_block_buf_0,
|
||||
ABlockBuffer& a_block_buf_1,
|
||||
const ABlockTransferStep& a_block_copy_step,
|
||||
const BGridDesc& b_grid_desc,
|
||||
const BBlockDesc& b_block_desc,
|
||||
BBlockTransfer& b_blockwise_copy,
|
||||
const BGridBuffer& b_grid_buf,
|
||||
BBlockBuffer& b_block_buf,
|
||||
BBlockBuffer& b_block_buf_0,
|
||||
BBlockBuffer& b_block_buf_1,
|
||||
const BBlockTransferStep& b_block_copy_step,
|
||||
const BlockwiseGemm& blockwise_gemm,
|
||||
CThreadBuffer& c_thread_buf,
|
||||
index_t num_loop)
|
||||
{
|
||||
// preload data into LDS
|
||||
// Prologue - load data into buffer 0
|
||||
{
|
||||
// Read 0
|
||||
// Read from global mem to registers (I0)
|
||||
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I0);
|
||||
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, I0);
|
||||
|
||||
// Move
|
||||
// Move source slice window for next read (I1)
|
||||
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
|
||||
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
|
||||
|
||||
// Read 1
|
||||
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I1);
|
||||
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, I1);
|
||||
// Write from registers to LDS buffer 0
|
||||
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf_0);
|
||||
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf_0);
|
||||
}
|
||||
|
||||
// Initialize C
|
||||
@@ -180,76 +186,76 @@ struct GridwiseGemmPipeline_v1<2, true, true>
|
||||
|
||||
do
|
||||
{
|
||||
// Move
|
||||
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
|
||||
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
|
||||
|
||||
// Write i
|
||||
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, I0);
|
||||
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf, I0);
|
||||
|
||||
// Read i+2
|
||||
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I0);
|
||||
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, I0);
|
||||
|
||||
// Sync
|
||||
block_sync_lds();
|
||||
|
||||
// Gemm i
|
||||
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
|
||||
|
||||
// Sync
|
||||
block_sync_lds();
|
||||
|
||||
// Move
|
||||
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
|
||||
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
|
||||
|
||||
// Write i+1
|
||||
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, I1);
|
||||
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf, I1);
|
||||
|
||||
// Read i+3
|
||||
// Read from global mem to registers (I1)
|
||||
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I1);
|
||||
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, I1);
|
||||
|
||||
// Sync
|
||||
// Move source slice window for next read (I0)
|
||||
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
|
||||
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
|
||||
|
||||
// Sync LDS to ensure buffer 0 is ready
|
||||
block_sync_lds();
|
||||
|
||||
// Gemm i+1
|
||||
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
|
||||
// Run GEMM on buffer 0 while buffer 1 is loading
|
||||
blockwise_gemm.Run(a_block_buf_0, b_block_buf_0, c_thread_buf);
|
||||
|
||||
// Sync
|
||||
// Write from registers to LDS buffer 1
|
||||
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf_1);
|
||||
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf_1);
|
||||
|
||||
// Read from global mem to registers (I0)
|
||||
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I0);
|
||||
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, I0);
|
||||
|
||||
// Move source slice window for next read (I1)
|
||||
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
|
||||
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
|
||||
|
||||
// Sync LDS to ensure buffer 1 is ready
|
||||
block_sync_lds();
|
||||
|
||||
// Run GEMM on buffer 1 while buffer 0 is loading
|
||||
blockwise_gemm.Run(a_block_buf_1, b_block_buf_1, c_thread_buf);
|
||||
|
||||
// Write from registers to LDS buffer 0
|
||||
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf_0);
|
||||
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf_0);
|
||||
|
||||
i += 2;
|
||||
} while(i < (num_loop - 2));
|
||||
}
|
||||
|
||||
// tail
|
||||
if (num_loop % 2 == 0)
|
||||
{
|
||||
// Write num_loop - 2
|
||||
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, I0);
|
||||
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf, I0);
|
||||
// Read from global mem to registers (I1)
|
||||
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I1);
|
||||
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, I1);
|
||||
|
||||
// Sync
|
||||
// Sync LDS to ensure buffer 0 is ready
|
||||
block_sync_lds();
|
||||
|
||||
// Gemm num_loop - 2
|
||||
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
|
||||
// Run GEMM on buffer 0
|
||||
blockwise_gemm.Run(a_block_buf_0, b_block_buf_0, c_thread_buf);
|
||||
|
||||
// Sync
|
||||
// Write from registers to LDS buffer 1
|
||||
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf_1);
|
||||
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf_1);
|
||||
|
||||
// Sync LDS to ensure buffer 1 is ready
|
||||
block_sync_lds();
|
||||
|
||||
// Write num_loop - 1
|
||||
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, I1);
|
||||
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf, I1);
|
||||
|
||||
// Sync
|
||||
// Run GEMM on buffer 1
|
||||
blockwise_gemm.Run(a_block_buf_1, b_block_buf_1, c_thread_buf);
|
||||
}
|
||||
else
|
||||
{
|
||||
// Sync LDS to ensure buffer 0 is ready
|
||||
block_sync_lds();
|
||||
|
||||
// Gemm num_loop - 1
|
||||
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
|
||||
// Run GEMM on buffer 0
|
||||
blockwise_gemm.Run(a_block_buf_0, b_block_buf_0, c_thread_buf);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
Reference in New Issue
Block a user