Code drop for 2 warp ping pong scheduler along K dimension. (#2276)

* Code drop for 2 warp ping pong scheduler along K dimension.

* Addressing code review comments.

* Addressing Clang formatting issues.

* Addressing build issues.

* Addressing build issues of other GEMM pipelines with ping pong scheduler code drop.

* Fix for LDS memory size for GEMM pipelines.

* Addressing code review feedback comments.

* Change log update.

* Addressing code review comments and build issues.

* Added new policy for pipeline specific logic about LDS needs.

* Clang Fix during build.
This commit is contained in:
kylasa
2025-06-12 18:24:02 -07:00
committed by GitHub
parent e5ece14467
commit 5f1ad09b61
17 changed files with 727 additions and 110 deletions

View File

@@ -644,6 +644,7 @@ struct GemmKernel
* @param block_idx_n The GEMM's output N dimension tile index processed by this workgroup.
*
*/
template <bool UseDefaultScheduler = true>
CK_TILE_DEVICE static void RunGemm(const ADataType* a_ptr,
const BDataType* b_ptr,
CDataType* c_ptr,
@@ -671,11 +672,15 @@ struct GemmKernel
const auto& c_block_tile = GemmPipeline{}.template operator()(
a_block_window, b_block_window, num_loop, smem_ptr_0);
// Run Epilogue Pipeline
auto& c_block_window = gemm_tile_windows.at(I2);
if(UseDefaultScheduler || (get_warp_id() == 0))
{
// Run Epilogue Pipeline
auto& c_block_window = gemm_tile_windows.at(I2);
EpiloguePipeline{}.template operator()<decltype(c_block_window), decltype(c_block_tile)>(
c_block_window, c_block_tile, smem_ptr_0);
EpiloguePipeline{}
.template operator()<decltype(c_block_window), decltype(c_block_tile)>(
c_block_window, c_block_tile, smem_ptr_0);
}
}
/**
@@ -772,7 +777,9 @@ struct GemmKernel
EpiloguePipeline::GetVectorSizeC() % 2 != 0 &&
is_any_of<CDataType, fp16_t, bf16_t>::value))
{
RunGemm(a_ptr, b_ptr, c_ptr, smem_ptr_0, kargs, splitk_batch_offset, i_m, i_n);
constexpr auto scheduler_type = (GemmPipeline::NumWaveGroups == 1);
RunGemm<scheduler_type>(
a_ptr, b_ptr, c_ptr, smem_ptr_0, kargs, splitk_batch_offset, i_m, i_n);
}
}
}