mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 14:59:17 +00:00
Code drop for 2 warp ping pong scheduler along K dimension. (#2276)
* Code drop for 2 warp ping pong scheduler along K dimension. * Addressing code review comments. * Addressing Clang formatting issues. * Addressing build issues. * Addressing build issues of other GEMM pipelines with ping pong scheduler code drop. * Fix for LDS memory size for GEMM pipelines. * Addressing code review feedback comments. * Change log update. * Addressing code review comments and build issues. * Added new policy for pipeline specific logic about LDS needs. * Clang Fix during build.
This commit is contained in:
@@ -644,6 +644,7 @@ struct GemmKernel
|
||||
* @param block_idx_n The GEMM's output N dimension tile index processed by this workgroup.
|
||||
*
|
||||
*/
|
||||
template <bool UseDefaultScheduler = true>
|
||||
CK_TILE_DEVICE static void RunGemm(const ADataType* a_ptr,
|
||||
const BDataType* b_ptr,
|
||||
CDataType* c_ptr,
|
||||
@@ -671,11 +672,15 @@ struct GemmKernel
|
||||
const auto& c_block_tile = GemmPipeline{}.template operator()(
|
||||
a_block_window, b_block_window, num_loop, smem_ptr_0);
|
||||
|
||||
// Run Epilogue Pipeline
|
||||
auto& c_block_window = gemm_tile_windows.at(I2);
|
||||
if(UseDefaultScheduler || (get_warp_id() == 0))
|
||||
{
|
||||
// Run Epilogue Pipeline
|
||||
auto& c_block_window = gemm_tile_windows.at(I2);
|
||||
|
||||
EpiloguePipeline{}.template operator()<decltype(c_block_window), decltype(c_block_tile)>(
|
||||
c_block_window, c_block_tile, smem_ptr_0);
|
||||
EpiloguePipeline{}
|
||||
.template operator()<decltype(c_block_window), decltype(c_block_tile)>(
|
||||
c_block_window, c_block_tile, smem_ptr_0);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -772,7 +777,9 @@ struct GemmKernel
|
||||
EpiloguePipeline::GetVectorSizeC() % 2 != 0 &&
|
||||
is_any_of<CDataType, fp16_t, bf16_t>::value))
|
||||
{
|
||||
RunGemm(a_ptr, b_ptr, c_ptr, smem_ptr_0, kargs, splitk_batch_offset, i_m, i_n);
|
||||
constexpr auto scheduler_type = (GemmPipeline::NumWaveGroups == 1);
|
||||
RunGemm<scheduler_type>(
|
||||
a_ptr, b_ptr, c_ptr, smem_ptr_0, kargs, splitk_batch_offset, i_m, i_n);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user