mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 06:49:15 +00:00
Joye/revise wp pipeline (#3493)
* [CK_TILE] unify double and single lds implementation (#108) Unify LDS buffer management API for single and double buffering modes This change consolidates the Local Data Store (LDS) buffer management by: Merging single and double LDS buffer APIs into a unified interface Implementing ping-pong address calculation in pipeline when double LDS is enabled Computing pong buffer addresses dynamically using base address offsets --------- Co-authored-by: joye <joye@amd.com> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * update wp_pipeline * fix a c++17 issue * update for ci errors * fix ci issues * include a header to fix ci errors * fix some rebase issues * update with rebase --------- Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
This commit is contained in:
@@ -303,24 +303,15 @@ struct GroupedGemmKernel
|
||||
CDataType* c_ptr = static_cast<CDataType*>(kargs.e_ptr);
|
||||
|
||||
// allocate LDS
|
||||
__shared__ char smem_ptr_0[GetSmemSize()];
|
||||
__shared__ char smem_ptr[GetSmemSize()];
|
||||
|
||||
// TO DO:
|
||||
// Can we simplify this branching logic?
|
||||
if constexpr(GemmPipeline::DoubleSmemBuffer == true)
|
||||
{
|
||||
|
||||
__shared__ char smem_ptr_1[GemmPipeline::GetSmemSize()];
|
||||
RunGemmWithPipelineSelection2LDS(a_ptr,
|
||||
b_ptr,
|
||||
c_ptr,
|
||||
kargs.ds_ptr,
|
||||
smem_ptr_0,
|
||||
smem_ptr_1,
|
||||
kargs,
|
||||
splitk_batch_offset,
|
||||
i_m,
|
||||
i_n);
|
||||
RunGemmWithPipelineSelection2LDS(
|
||||
a_ptr, b_ptr, c_ptr, kargs.ds_ptr, smem_ptr, kargs, splitk_batch_offset, i_m, i_n);
|
||||
}
|
||||
else // SingleSmemBuffer
|
||||
{
|
||||
@@ -331,7 +322,7 @@ struct GroupedGemmKernel
|
||||
b_ptr,
|
||||
kargs.ds_ptr,
|
||||
c_ptr,
|
||||
smem_ptr_0,
|
||||
smem_ptr,
|
||||
kargs,
|
||||
splitk_batch_offset,
|
||||
i_m,
|
||||
@@ -343,7 +334,7 @@ struct GroupedGemmKernel
|
||||
{b_ptr},
|
||||
kargs.ds_ptr,
|
||||
c_ptr,
|
||||
smem_ptr_0,
|
||||
smem_ptr,
|
||||
kargs,
|
||||
splitk_batch_offset,
|
||||
i_m,
|
||||
@@ -425,9 +416,7 @@ struct GroupedGemmKernel
|
||||
* @param a_ptr input A pointer
|
||||
* @param b_ptr input B pointer
|
||||
* @param c_ptr output C pointer
|
||||
* @param ds_ptr input Ds pointer
|
||||
* @param smem_ptr_0 The starting pointer of 1st shared memory block.
|
||||
* @param smem_ptr_1 The starting pointer of 2nd shared memory block.
|
||||
* @param smem_ptr The start memory pointer of the shared memory block.
|
||||
* @param kargs GEMM kernel arguments
|
||||
* @param splitk_batch_offset Utility structure used to calculate k batch.
|
||||
* @param block_idx_m The GEMM's output M dimension tile index processed by this workgroup.
|
||||
@@ -439,8 +428,7 @@ struct GroupedGemmKernel
|
||||
const BDataType* b_ptr,
|
||||
CDataType* c_ptr,
|
||||
const std::array<const void*, NumDTensor_>& ds_ptr,
|
||||
void* __restrict__ smem_ptr_0,
|
||||
void* __restrict__ smem_ptr_1,
|
||||
void* __restrict__ smem_ptr,
|
||||
const UniversalGemmKernelArgs<1, 1, NumDTensor_>& kargs,
|
||||
const typename Base::SplitKBatchOffset& splitk_batch_offset,
|
||||
const index_t block_idx_m,
|
||||
@@ -460,8 +448,8 @@ struct GroupedGemmKernel
|
||||
amd_wave_read_first_lane(TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k));
|
||||
|
||||
// Run GEMM cooperatively by whole workgroup.
|
||||
const auto& c_block_tile = GemmPipeline{}.template operator()(
|
||||
a_block_window, b_block_window, num_loop, smem_ptr_0, smem_ptr_1);
|
||||
const auto& c_block_tile =
|
||||
GemmPipeline{}.template operator()(a_block_window, b_block_window, num_loop, smem_ptr);
|
||||
|
||||
// Run Epilogue Pipeline
|
||||
if(kargs.k_batch == 1)
|
||||
@@ -469,7 +457,7 @@ struct GroupedGemmKernel
|
||||
auto c_block_window = Base::template MakeCBlockWindows<memory_operation_enum::set>(
|
||||
c_ptr, kargs, block_idx_m, block_idx_n);
|
||||
|
||||
EpiloguePipeline{}(c_block_window, c_block_tile, d_block_window, smem_ptr_0);
|
||||
EpiloguePipeline{}(c_block_window, c_block_tile, d_block_window, smem_ptr);
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -477,7 +465,7 @@ struct GroupedGemmKernel
|
||||
Base::template MakeCBlockWindows<memory_operation_enum::atomic_add>(
|
||||
c_ptr, kargs, block_idx_m, block_idx_n);
|
||||
|
||||
EpiloguePipeline{}(c_block_window, c_block_tile, d_block_window, smem_ptr_0);
|
||||
EpiloguePipeline{}(c_block_window, c_block_tile, d_block_window, smem_ptr);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -978,7 +978,7 @@ struct UniversalGemmKernel
|
||||
* @param bs_ptr input Bs pointer
|
||||
* @param ds_ptr input Ds pointer
|
||||
* @param e_ptr output E pointer
|
||||
* @param smem_ptr_0 The start memory pointer of the shared memory block.
|
||||
* @param smem_ptr The start memory pointer of the shared memory block.
|
||||
* @param kargs GEMM kernel arguments
|
||||
* @param splitk_batch_offset splitk_batch_offset Utility structure used to calculate k batch.
|
||||
* @param block_idx_m The GEMM's output M dimension tile index processed by this workgroup.
|
||||
@@ -990,7 +990,7 @@ struct UniversalGemmKernel
|
||||
const std::array<const BDataType*, NumBTensor>& bs_ptr,
|
||||
const std::array<const void*, NumDTensor>& ds_ptr,
|
||||
EDataType* e_ptr,
|
||||
void* smem_ptr_0,
|
||||
void* smem_ptr,
|
||||
const KernelArgs& kargs,
|
||||
const SplitKBatchOffset& splitk_batch_offset,
|
||||
const index_t block_idx_m,
|
||||
@@ -1008,7 +1008,7 @@ struct UniversalGemmKernel
|
||||
|
||||
// Run GEMM cooperatively by whole workgroup.
|
||||
const auto& c_block_tile = GemmPipeline{}.template operator()(
|
||||
as_block_window, AElementWise{}, bs_block_window, BElementWise{}, num_loop, smem_ptr_0);
|
||||
as_block_window, AElementWise{}, bs_block_window, BElementWise{}, num_loop, smem_ptr);
|
||||
|
||||
const index_t k_batch = amd_wave_read_first_lane(kargs.k_batch);
|
||||
// Run Epilogue Pipeline
|
||||
@@ -1016,77 +1016,63 @@ struct UniversalGemmKernel
|
||||
{
|
||||
auto c_block_window = MakeCBlockWindows<memory_operation_enum::set>(
|
||||
e_ptr, kargs, block_idx_m, block_idx_n);
|
||||
EpiloguePipeline{}(c_block_window, c_block_tile, ds_block_window, smem_ptr_0);
|
||||
EpiloguePipeline{}(c_block_window, c_block_tile, ds_block_window, smem_ptr);
|
||||
}
|
||||
else
|
||||
{
|
||||
auto c_block_window = MakeCBlockWindows<memory_operation_enum::atomic_add>(
|
||||
e_ptr, kargs, block_idx_m, block_idx_n);
|
||||
EpiloguePipeline{}(c_block_window, c_block_tile, ds_block_window, smem_ptr_0);
|
||||
EpiloguePipeline{}(c_block_window, c_block_tile, ds_block_window, smem_ptr);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Runs single GEMM problem cooperatively by whole workgroup.
|
||||
*
|
||||
* @note RunGEMM2LDS in with two shared memory buffers using the ping pong buffer mechanism.
|
||||
*
|
||||
* @param as_ptr input As pointer
|
||||
* @param bs_ptr input Bs pointer
|
||||
* @param ds_ptr input Ds pointer
|
||||
* @param e_ptr output E pointer
|
||||
* @param smem_ptr_0 The starting pointer of 1st shared memory block.
|
||||
* @param smem_ptr_1 The starting pointer of 2nd shared memory block.
|
||||
* @param kargs GEMM kernel arguments
|
||||
* @param splitk_batch_offset Utility structure used to calculate k batch.
|
||||
* @param block_idx_m The GEMM's output M dimension tile index processed by this workgroup.
|
||||
* @param block_idx_n The GEMM's output N dimension tile index processed by this workgroup.
|
||||
*
|
||||
*/
|
||||
CK_TILE_DEVICE static void RunGemm2LDS(const std::array<const ADataType*, NumATensor>& as_ptr,
|
||||
const std::array<const BDataType*, NumBTensor>& bs_ptr,
|
||||
const std::array<const void*, NumDTensor>& ds_ptr,
|
||||
EDataType* e_ptr,
|
||||
void* __restrict__ smem_ptr_0,
|
||||
void* __restrict__ smem_ptr_1,
|
||||
const KernelArgs& kargs,
|
||||
const SplitKBatchOffset& splitk_batch_offset,
|
||||
const index_t block_idx_m,
|
||||
const index_t block_idx_n)
|
||||
CK_TILE_DEVICE static auto
|
||||
GetTileCoordinates(const KernelArgs& kargs) -> tuple<index_t, index_t>
|
||||
{
|
||||
// Create block windows using specialized methods
|
||||
const auto& as_block_window =
|
||||
MakeABlockWindows(as_ptr, kargs, splitk_batch_offset.splitted_k, block_idx_m);
|
||||
const auto& bs_block_window =
|
||||
MakeBBlockWindows(bs_ptr, kargs, splitk_batch_offset.splitted_k, block_idx_n);
|
||||
const auto& ds_block_window = MakeDBlockWindows(ds_ptr, kargs, block_idx_m, block_idx_n);
|
||||
index_t iM, iN;
|
||||
|
||||
const index_t num_loop =
|
||||
amd_wave_read_first_lane(TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k));
|
||||
// Regular launch: use 1D block indexing
|
||||
const auto blockId = amd_wave_read_first_lane(blockIdx.x);
|
||||
const auto [tile_m, tile_n] = TilePartitioner{kargs.M, kargs.N}.GetOutputTileIndex(blockId);
|
||||
iM = tile_m;
|
||||
iN = tile_n;
|
||||
|
||||
// Run GEMM cooperatively by whole workgroup.
|
||||
const auto& c_block_tile = GemmPipeline{}.template operator()(as_block_window,
|
||||
AElementWise{},
|
||||
bs_block_window,
|
||||
BElementWise{},
|
||||
num_loop,
|
||||
smem_ptr_0,
|
||||
smem_ptr_1);
|
||||
const index_t i_m = amd_wave_read_first_lane(iM * TilePartitioner::MPerBlock);
|
||||
const index_t i_n = amd_wave_read_first_lane(iN * TilePartitioner::NPerBlock);
|
||||
|
||||
// Run Epilogue Pipeline
|
||||
if(kargs.k_batch == 1)
|
||||
return make_tuple(i_m, i_n);
|
||||
}
|
||||
|
||||
// Helper functions
|
||||
CK_TILE_DEVICE static auto GetBlockId() -> index_t
|
||||
{
|
||||
// For 1D regular launch
|
||||
return amd_wave_read_first_lane(get_block_id());
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE static auto GetGridSize() -> index_t
|
||||
{
|
||||
// For 1D regular launch
|
||||
return amd_wave_read_first_lane(get_grid_size());
|
||||
}
|
||||
|
||||
// Helper to get total number of tiles, handling both dim3 and index_t return types
|
||||
template <typename... Args>
|
||||
CK_TILE_HOST_DEVICE static auto GetNumTiles(Args&&... args) -> index_t
|
||||
{
|
||||
auto grid_size = TilePartitioner::GridSize(std::forward<Args>(args)...);
|
||||
|
||||
using GridSizeType = decltype(grid_size);
|
||||
|
||||
if constexpr(std::is_same_v<GridSizeType, dim3>)
|
||||
{
|
||||
auto c_block_window = MakeCBlockWindows<memory_operation_enum::set>(
|
||||
e_ptr, kargs, block_idx_m, block_idx_n);
|
||||
|
||||
EpiloguePipeline{}(c_block_window, c_block_tile, ds_block_window, smem_ptr_0);
|
||||
// GridSize returns dim3: compute total tiles as x * y * z
|
||||
return amd_wave_read_first_lane(grid_size.x * grid_size.y * grid_size.z);
|
||||
}
|
||||
else
|
||||
{
|
||||
auto c_block_window = MakeCBlockWindows<memory_operation_enum::atomic_add>(
|
||||
e_ptr, kargs, block_idx_m, block_idx_n);
|
||||
|
||||
EpiloguePipeline{}(c_block_window, c_block_tile, ds_block_window, smem_ptr_0);
|
||||
// GridSize returns scalar (index_t): use directly
|
||||
return amd_wave_read_first_lane(grid_size);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1123,36 +1109,12 @@ struct UniversalGemmKernel
|
||||
}
|
||||
|
||||
// allocate LDS
|
||||
__shared__ char smem_ptr_0[GetSmemSize()];
|
||||
__shared__ char smem_ptr[GetSmemSize()];
|
||||
|
||||
if constexpr(GemmPipeline::DoubleSmemBuffer == true)
|
||||
{
|
||||
__shared__ char smem_ptr_1[GemmPipeline::GetSmemSize()];
|
||||
RunGemm2LDS(as_ptr,
|
||||
bs_ptr,
|
||||
kargs.ds_ptr,
|
||||
e_ptr,
|
||||
smem_ptr_0,
|
||||
smem_ptr_1,
|
||||
kargs,
|
||||
splitk_batch_offset,
|
||||
i_m,
|
||||
i_n);
|
||||
}
|
||||
else
|
||||
{
|
||||
|
||||
constexpr auto scheduler_type = (GemmPipeline::NumWaveGroups == 1);
|
||||
RunGemm<scheduler_type>(as_ptr,
|
||||
bs_ptr,
|
||||
kargs.ds_ptr,
|
||||
e_ptr,
|
||||
smem_ptr_0,
|
||||
kargs,
|
||||
splitk_batch_offset,
|
||||
i_m,
|
||||
i_n);
|
||||
}
|
||||
constexpr auto scheduler_type =
|
||||
GemmPipeline::DoubleSmemBuffer || (GemmPipeline::NumWaveGroups == 1);
|
||||
RunGemm<scheduler_type>(
|
||||
as_ptr, bs_ptr, kargs.ds_ptr, e_ptr, smem_ptr, kargs, splitk_batch_offset, i_m, i_n);
|
||||
}
|
||||
|
||||
// Persistent kernel entry point
|
||||
@@ -1199,34 +1161,19 @@ struct UniversalGemmKernel
|
||||
}
|
||||
|
||||
// allocate LDS
|
||||
__shared__ char smem_ptr_0[GetSmemSize()];
|
||||
__shared__ char smem_ptr[GetSmemSize()];
|
||||
// Run the GEMM
|
||||
if constexpr(GemmPipeline::DoubleSmemBuffer == true)
|
||||
{
|
||||
__shared__ char smem_ptr_1[GemmPipeline::GetSmemSize()];
|
||||
RunGemm2LDS(as_ptr,
|
||||
bs_ptr,
|
||||
kargs.ds_ptr,
|
||||
e_ptr,
|
||||
smem_ptr_0,
|
||||
smem_ptr_1,
|
||||
kargs,
|
||||
splitk_batch_offset,
|
||||
i_m,
|
||||
i_n);
|
||||
}
|
||||
else
|
||||
{
|
||||
RunGemm(as_ptr,
|
||||
bs_ptr,
|
||||
kargs.ds_ptr,
|
||||
e_ptr,
|
||||
smem_ptr_0,
|
||||
kargs,
|
||||
splitk_batch_offset,
|
||||
i_m,
|
||||
i_n);
|
||||
}
|
||||
|
||||
RunGemm(as_ptr,
|
||||
bs_ptr,
|
||||
kargs.ds_ptr,
|
||||
e_ptr,
|
||||
smem_ptr,
|
||||
kargs,
|
||||
splitk_batch_offset,
|
||||
i_m,
|
||||
i_n);
|
||||
|
||||
// Advance to the next work item
|
||||
block_id += grid_size;
|
||||
if(block_id >= num_work)
|
||||
|
||||
Reference in New Issue
Block a user