Joye/revise wp pipeline (#3493)

* [CK_TILE] unify double and single lds implementation (#108)

Unify LDS buffer management API for single and double buffering modes

This change consolidates the Local Data Store (LDS) buffer management by:

Merging single and double LDS buffer APIs into a unified interface
Implementing ping-pong address calculation in pipeline when double LDS is enabled
Computing pong buffer addresses dynamically using base address offsets

---------

Co-authored-by: joye <joye@amd.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* update wp_pipeline

* fix a c++17 issue

* update for ci errors

* fix ci issues

* include a header to fix ci errors

* fix some rebase issues

* update with rebase

---------

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
This commit is contained in:
joyeamd
2026-01-06 05:49:26 +08:00
committed by GitHub
parent 1224bc0a82
commit 2b563ad048
13 changed files with 766 additions and 929 deletions

View File

@@ -303,24 +303,15 @@ struct GroupedGemmKernel
CDataType* c_ptr = static_cast<CDataType*>(kargs.e_ptr);
// allocate LDS
__shared__ char smem_ptr_0[GetSmemSize()];
__shared__ char smem_ptr[GetSmemSize()];
// TO DO:
// Can we simplify this branching logic?
if constexpr(GemmPipeline::DoubleSmemBuffer == true)
{
__shared__ char smem_ptr_1[GemmPipeline::GetSmemSize()];
RunGemmWithPipelineSelection2LDS(a_ptr,
b_ptr,
c_ptr,
kargs.ds_ptr,
smem_ptr_0,
smem_ptr_1,
kargs,
splitk_batch_offset,
i_m,
i_n);
RunGemmWithPipelineSelection2LDS(
a_ptr, b_ptr, c_ptr, kargs.ds_ptr, smem_ptr, kargs, splitk_batch_offset, i_m, i_n);
}
else // SingleSmemBuffer
{
@@ -331,7 +322,7 @@ struct GroupedGemmKernel
b_ptr,
kargs.ds_ptr,
c_ptr,
smem_ptr_0,
smem_ptr,
kargs,
splitk_batch_offset,
i_m,
@@ -343,7 +334,7 @@ struct GroupedGemmKernel
{b_ptr},
kargs.ds_ptr,
c_ptr,
smem_ptr_0,
smem_ptr,
kargs,
splitk_batch_offset,
i_m,
@@ -425,9 +416,7 @@ struct GroupedGemmKernel
* @param a_ptr input A pointer
* @param b_ptr input B pointer
* @param c_ptr output C pointer
* @param ds_ptr input Ds pointer
* @param smem_ptr_0 The starting pointer of 1st shared memory block.
* @param smem_ptr_1 The starting pointer of 2nd shared memory block.
* @param smem_ptr The start memory pointer of the shared memory block.
* @param kargs GEMM kernel arguments
* @param splitk_batch_offset Utility structure used to calculate k batch.
* @param block_idx_m The GEMM's output M dimension tile index processed by this workgroup.
@@ -439,8 +428,7 @@ struct GroupedGemmKernel
const BDataType* b_ptr,
CDataType* c_ptr,
const std::array<const void*, NumDTensor_>& ds_ptr,
void* __restrict__ smem_ptr_0,
void* __restrict__ smem_ptr_1,
void* __restrict__ smem_ptr,
const UniversalGemmKernelArgs<1, 1, NumDTensor_>& kargs,
const typename Base::SplitKBatchOffset& splitk_batch_offset,
const index_t block_idx_m,
@@ -460,8 +448,8 @@ struct GroupedGemmKernel
amd_wave_read_first_lane(TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k));
// Run GEMM cooperatively by whole workgroup.
const auto& c_block_tile = GemmPipeline{}.template operator()(
a_block_window, b_block_window, num_loop, smem_ptr_0, smem_ptr_1);
const auto& c_block_tile =
GemmPipeline{}.template operator()(a_block_window, b_block_window, num_loop, smem_ptr);
// Run Epilogue Pipeline
if(kargs.k_batch == 1)
@@ -469,7 +457,7 @@ struct GroupedGemmKernel
auto c_block_window = Base::template MakeCBlockWindows<memory_operation_enum::set>(
c_ptr, kargs, block_idx_m, block_idx_n);
EpiloguePipeline{}(c_block_window, c_block_tile, d_block_window, smem_ptr_0);
EpiloguePipeline{}(c_block_window, c_block_tile, d_block_window, smem_ptr);
}
else
{
@@ -477,7 +465,7 @@ struct GroupedGemmKernel
Base::template MakeCBlockWindows<memory_operation_enum::atomic_add>(
c_ptr, kargs, block_idx_m, block_idx_n);
EpiloguePipeline{}(c_block_window, c_block_tile, d_block_window, smem_ptr_0);
EpiloguePipeline{}(c_block_window, c_block_tile, d_block_window, smem_ptr);
}
}

View File

@@ -978,7 +978,7 @@ struct UniversalGemmKernel
* @param bs_ptr input Bs pointer
* @param ds_ptr input Ds pointer
* @param e_ptr output E pointer
* @param smem_ptr_0 The start memory pointer of the shared memory block.
* @param smem_ptr The start memory pointer of the shared memory block.
* @param kargs GEMM kernel arguments
* @param splitk_batch_offset splitk_batch_offset Utility structure used to calculate k batch.
* @param block_idx_m The GEMM's output M dimension tile index processed by this workgroup.
@@ -990,7 +990,7 @@ struct UniversalGemmKernel
const std::array<const BDataType*, NumBTensor>& bs_ptr,
const std::array<const void*, NumDTensor>& ds_ptr,
EDataType* e_ptr,
void* smem_ptr_0,
void* smem_ptr,
const KernelArgs& kargs,
const SplitKBatchOffset& splitk_batch_offset,
const index_t block_idx_m,
@@ -1008,7 +1008,7 @@ struct UniversalGemmKernel
// Run GEMM cooperatively by whole workgroup.
const auto& c_block_tile = GemmPipeline{}.template operator()(
as_block_window, AElementWise{}, bs_block_window, BElementWise{}, num_loop, smem_ptr_0);
as_block_window, AElementWise{}, bs_block_window, BElementWise{}, num_loop, smem_ptr);
const index_t k_batch = amd_wave_read_first_lane(kargs.k_batch);
// Run Epilogue Pipeline
@@ -1016,77 +1016,63 @@ struct UniversalGemmKernel
{
auto c_block_window = MakeCBlockWindows<memory_operation_enum::set>(
e_ptr, kargs, block_idx_m, block_idx_n);
EpiloguePipeline{}(c_block_window, c_block_tile, ds_block_window, smem_ptr_0);
EpiloguePipeline{}(c_block_window, c_block_tile, ds_block_window, smem_ptr);
}
else
{
auto c_block_window = MakeCBlockWindows<memory_operation_enum::atomic_add>(
e_ptr, kargs, block_idx_m, block_idx_n);
EpiloguePipeline{}(c_block_window, c_block_tile, ds_block_window, smem_ptr_0);
EpiloguePipeline{}(c_block_window, c_block_tile, ds_block_window, smem_ptr);
}
}
/**
* @brief Runs single GEMM problem cooperatively by whole workgroup.
*
* @note RunGEMM2LDS in with two shared memory buffers using the ping pong buffer mechanism.
*
* @param as_ptr input As pointer
* @param bs_ptr input Bs pointer
* @param ds_ptr input Ds pointer
* @param e_ptr output E pointer
* @param smem_ptr_0 The starting pointer of 1st shared memory block.
* @param smem_ptr_1 The starting pointer of 2nd shared memory block.
* @param kargs GEMM kernel arguments
* @param splitk_batch_offset Utility structure used to calculate k batch.
* @param block_idx_m The GEMM's output M dimension tile index processed by this workgroup.
* @param block_idx_n The GEMM's output N dimension tile index processed by this workgroup.
*
*/
CK_TILE_DEVICE static void RunGemm2LDS(const std::array<const ADataType*, NumATensor>& as_ptr,
const std::array<const BDataType*, NumBTensor>& bs_ptr,
const std::array<const void*, NumDTensor>& ds_ptr,
EDataType* e_ptr,
void* __restrict__ smem_ptr_0,
void* __restrict__ smem_ptr_1,
const KernelArgs& kargs,
const SplitKBatchOffset& splitk_batch_offset,
const index_t block_idx_m,
const index_t block_idx_n)
CK_TILE_DEVICE static auto
GetTileCoordinates(const KernelArgs& kargs) -> tuple<index_t, index_t>
{
// Create block windows using specialized methods
const auto& as_block_window =
MakeABlockWindows(as_ptr, kargs, splitk_batch_offset.splitted_k, block_idx_m);
const auto& bs_block_window =
MakeBBlockWindows(bs_ptr, kargs, splitk_batch_offset.splitted_k, block_idx_n);
const auto& ds_block_window = MakeDBlockWindows(ds_ptr, kargs, block_idx_m, block_idx_n);
index_t iM, iN;
const index_t num_loop =
amd_wave_read_first_lane(TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k));
// Regular launch: use 1D block indexing
const auto blockId = amd_wave_read_first_lane(blockIdx.x);
const auto [tile_m, tile_n] = TilePartitioner{kargs.M, kargs.N}.GetOutputTileIndex(blockId);
iM = tile_m;
iN = tile_n;
// Run GEMM cooperatively by whole workgroup.
const auto& c_block_tile = GemmPipeline{}.template operator()(as_block_window,
AElementWise{},
bs_block_window,
BElementWise{},
num_loop,
smem_ptr_0,
smem_ptr_1);
const index_t i_m = amd_wave_read_first_lane(iM * TilePartitioner::MPerBlock);
const index_t i_n = amd_wave_read_first_lane(iN * TilePartitioner::NPerBlock);
// Run Epilogue Pipeline
if(kargs.k_batch == 1)
return make_tuple(i_m, i_n);
}
// Helper functions
CK_TILE_DEVICE static auto GetBlockId() -> index_t
{
// For 1D regular launch
return amd_wave_read_first_lane(get_block_id());
}
CK_TILE_DEVICE static auto GetGridSize() -> index_t
{
// For 1D regular launch
return amd_wave_read_first_lane(get_grid_size());
}
// Helper to get total number of tiles, handling both dim3 and index_t return types
template <typename... Args>
CK_TILE_HOST_DEVICE static auto GetNumTiles(Args&&... args) -> index_t
{
auto grid_size = TilePartitioner::GridSize(std::forward<Args>(args)...);
using GridSizeType = decltype(grid_size);
if constexpr(std::is_same_v<GridSizeType, dim3>)
{
auto c_block_window = MakeCBlockWindows<memory_operation_enum::set>(
e_ptr, kargs, block_idx_m, block_idx_n);
EpiloguePipeline{}(c_block_window, c_block_tile, ds_block_window, smem_ptr_0);
// GridSize returns dim3: compute total tiles as x * y * z
return amd_wave_read_first_lane(grid_size.x * grid_size.y * grid_size.z);
}
else
{
auto c_block_window = MakeCBlockWindows<memory_operation_enum::atomic_add>(
e_ptr, kargs, block_idx_m, block_idx_n);
EpiloguePipeline{}(c_block_window, c_block_tile, ds_block_window, smem_ptr_0);
// GridSize returns scalar (index_t): use directly
return amd_wave_read_first_lane(grid_size);
}
}
@@ -1123,36 +1109,12 @@ struct UniversalGemmKernel
}
// allocate LDS
__shared__ char smem_ptr_0[GetSmemSize()];
__shared__ char smem_ptr[GetSmemSize()];
if constexpr(GemmPipeline::DoubleSmemBuffer == true)
{
__shared__ char smem_ptr_1[GemmPipeline::GetSmemSize()];
RunGemm2LDS(as_ptr,
bs_ptr,
kargs.ds_ptr,
e_ptr,
smem_ptr_0,
smem_ptr_1,
kargs,
splitk_batch_offset,
i_m,
i_n);
}
else
{
constexpr auto scheduler_type = (GemmPipeline::NumWaveGroups == 1);
RunGemm<scheduler_type>(as_ptr,
bs_ptr,
kargs.ds_ptr,
e_ptr,
smem_ptr_0,
kargs,
splitk_batch_offset,
i_m,
i_n);
}
constexpr auto scheduler_type =
GemmPipeline::DoubleSmemBuffer || (GemmPipeline::NumWaveGroups == 1);
RunGemm<scheduler_type>(
as_ptr, bs_ptr, kargs.ds_ptr, e_ptr, smem_ptr, kargs, splitk_batch_offset, i_m, i_n);
}
// Persistent kernel entry point
@@ -1199,34 +1161,19 @@ struct UniversalGemmKernel
}
// allocate LDS
__shared__ char smem_ptr_0[GetSmemSize()];
__shared__ char smem_ptr[GetSmemSize()];
// Run the GEMM
if constexpr(GemmPipeline::DoubleSmemBuffer == true)
{
__shared__ char smem_ptr_1[GemmPipeline::GetSmemSize()];
RunGemm2LDS(as_ptr,
bs_ptr,
kargs.ds_ptr,
e_ptr,
smem_ptr_0,
smem_ptr_1,
kargs,
splitk_batch_offset,
i_m,
i_n);
}
else
{
RunGemm(as_ptr,
bs_ptr,
kargs.ds_ptr,
e_ptr,
smem_ptr_0,
kargs,
splitk_batch_offset,
i_m,
i_n);
}
RunGemm(as_ptr,
bs_ptr,
kargs.ds_ptr,
e_ptr,
smem_ptr,
kargs,
splitk_batch_offset,
i_m,
i_n);
// Advance to the next work item
block_id += grid_size;
if(block_id >= num_work)