WIP: Double buffer implementation.

This commit is contained in:
Ville Pietilä
2026-02-10 10:03:41 -05:00
parent 59cbe19c83
commit afdd6a84a7
2 changed files with 86 additions and 65 deletions

View File

@@ -357,7 +357,8 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
#if !defined(__HIPCC_RTC__) || !defined(CK_CODE_GEN_RTC)
if(ck::get_device_name() == "gfx950")
{
return Base::GetSharedMemoryNumberOfByte(gfx950_t{});
// Double buffering -> 2 times shared memory
return 2*Base::GetSharedMemoryNumberOfByte(gfx950_t{});
}
else
#endif
@@ -755,12 +756,24 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
// LDS allocation for A and B: be careful of alignment
constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
constexpr auto b_block_space_size_aligned = math::integer_least_multiple(
b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align);
auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
// Double buffers for A and B in LDS
auto a_block_buf_0 = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<AComputeDataType*>(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<BComputeDataType*>(p_shared) + a_block_space_size_aligned,
auto a_block_buf_1 = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<AComputeDataType*>(p_shared) + a_block_space_size_aligned,
a_block_desc_ak0_m_ak1.GetElementSpaceSize());
auto b_block_buf_0 = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<BComputeDataType*>(p_shared) + 2*a_block_space_size_aligned,
b_block_desc_bk0_n_bk1.GetElementSpaceSize());
auto b_block_buf_1 = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<BComputeDataType*>(p_shared) + 2*a_block_space_size_aligned +
b_block_space_size_aligned,
b_block_desc_bk0_n_bk1.GetElementSpaceSize());
constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1, 0, 0);
@@ -778,13 +791,15 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
a_block_desc_ak0_m_ak1,
a_blockwise_copy,
a_grid_buf,
a_block_buf,
a_block_buf_0,
a_block_buf_1,
a_block_slice_copy_step,
b_grid_desc_bk0_n_bk1,
b_block_desc_bk0_n_bk1,
b_blockwise_copy,
b_grid_buf,
b_block_buf,
b_block_buf_0,
b_block_buf_1,
b_block_slice_copy_step,
blockwise_gemm,
c_thread_buf,

View File

@@ -113,10 +113,9 @@ struct GridwiseGemmPipeline_v1<2, true, true>
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
__host__ __device__ static constexpr bool IsSupported(index_t num_loop)
__host__ __device__ static constexpr bool IsSupported(index_t)
{
// TODO: improve applicability
return num_loop % 2 == 0;
return true;
}
__host__ __device__ static constexpr bool CalculateHasMainLoop(index_t num_loop)
@@ -124,6 +123,11 @@ struct GridwiseGemmPipeline_v1<2, true, true>
return (num_loop / 2) > 1;
}
__host__ __device__ static constexpr bool CalculateIsOddLoop(index_t num_loop)
{
return (num_loop % 2) == 1;
}
template <bool HasMainLoop,
typename AGridDesc,
typename ABlockDesc,
@@ -143,31 +147,33 @@ struct GridwiseGemmPipeline_v1<2, true, true>
const ABlockDesc& a_block_desc,
ABlockTransfer& a_blockwise_copy,
const AGridBuffer& a_grid_buf,
ABlockBuffer& a_block_buf,
ABlockBuffer& a_block_buf_0,
ABlockBuffer& a_block_buf_1,
const ABlockTransferStep& a_block_copy_step,
const BGridDesc& b_grid_desc,
const BBlockDesc& b_block_desc,
BBlockTransfer& b_blockwise_copy,
const BGridBuffer& b_grid_buf,
BBlockBuffer& b_block_buf,
BBlockBuffer& b_block_buf_0,
BBlockBuffer& b_block_buf_1,
const BBlockTransferStep& b_block_copy_step,
const BlockwiseGemm& blockwise_gemm,
CThreadBuffer& c_thread_buf,
index_t num_loop)
{
// preload data into LDS
// Prologue - load data into buffer 0
{
// Read 0
// Read from global mem to registers (I0)
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I0);
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, I0);
// Move
// Move source slice window for next read (I1)
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
// Read 1
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I1);
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, I1);
// Write from registers to LDS buffer 0
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf_0);
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf_0);
}
// Initialize C
@@ -180,76 +186,76 @@ struct GridwiseGemmPipeline_v1<2, true, true>
do
{
// Move
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
// Write i
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, I0);
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf, I0);
// Read i+2
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I0);
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, I0);
// Sync
block_sync_lds();
// Gemm i
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
// Sync
block_sync_lds();
// Move
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
// Write i+1
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, I1);
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf, I1);
// Read i+3
// Read from global mem to registers (I1)
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I1);
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, I1);
// Sync
// Move source slice window for next read (I0)
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
// Sync LDS to ensure buffer 0 is ready
block_sync_lds();
// Gemm i+1
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
// Run GEMM on buffer 0 while buffer 1 is loading
blockwise_gemm.Run(a_block_buf_0, b_block_buf_0, c_thread_buf);
// Sync
// Write from registers to LDS buffer 1
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf_1);
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf_1);
// Read from global mem to registers (I0)
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I0);
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, I0);
// Move source slice window for next read (I1)
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
// Sync LDS to ensure buffer 1 is ready
block_sync_lds();
// Run GEMM on buffer 1 while buffer 0 is loading
blockwise_gemm.Run(a_block_buf_1, b_block_buf_1, c_thread_buf);
// Write from registers to LDS buffer 0
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf_0);
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf_0);
i += 2;
} while(i < (num_loop - 2));
}
// tail
if (num_loop % 2 == 0)
{
// Write num_loop - 2
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, I0);
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf, I0);
// Read from global mem to registers (I1)
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I1);
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, I1);
// Sync
// Sync LDS to ensure buffer 0 is ready
block_sync_lds();
// Gemm num_loop - 2
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
// Run GEMM on buffer 0
blockwise_gemm.Run(a_block_buf_0, b_block_buf_0, c_thread_buf);
// Sync
// Write from registers to LDS buffer 1
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf_1);
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf_1);
// Sync LDS to ensure buffer 1 is ready
block_sync_lds();
// Write num_loop - 1
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, I1);
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf, I1);
// Sync
// Run GEMM on buffer 1
blockwise_gemm.Run(a_block_buf_1, b_block_buf_1, c_thread_buf);
}
else
{
// Sync LDS to ensure buffer 0 is ready
block_sync_lds();
// Gemm num_loop - 1
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
// Run GEMM on buffer 0
blockwise_gemm.Run(a_block_buf_0, b_block_buf_0, c_thread_buf);
}
}
};