From bce9c22bcdeafd448f2ed4db78f3c14737001bf5 Mon Sep 17 00:00:00 2001 From: Feng Shijie Date: Fri, 4 Jul 2025 07:09:58 +0000 Subject: [PATCH] Improve gridDim calculation in persistent mode --- include/ck_tile/host/kernel_launch.hpp | 6 ++++ .../flatmm/kernel/grouped_flatmm_kernel.hpp | 34 +++++++++++++++++-- 2 files changed, 37 insertions(+), 3 deletions(-) diff --git a/include/ck_tile/host/kernel_launch.hpp b/include/ck_tile/host/kernel_launch.hpp index d159787387..8ce108aa5a 100644 --- a/include/ck_tile/host/kernel_launch.hpp +++ b/include/ck_tile/host/kernel_launch.hpp @@ -20,6 +20,12 @@ __launch_bounds__(MaxThreadPerBlock, MinBlockPerCu) Kernel{}(args...); } +template +__launch_bounds__(MaxThreadPerBlock) __global__ void kentry2(Args... args) +{ + Kernel{}(args...); +} + // // return a anonymous functor(lambda) to be called later // the KernelImpl should be a class without non-static data member, or let's say diff --git a/include/ck_tile/ops/flatmm/kernel/grouped_flatmm_kernel.hpp b/include/ck_tile/ops/flatmm/kernel/grouped_flatmm_kernel.hpp index 31a7d03798..4e30035622 100644 --- a/include/ck_tile/ops/flatmm/kernel/grouped_flatmm_kernel.hpp +++ b/include/ck_tile/ops/flatmm/kernel/grouped_flatmm_kernel.hpp @@ -53,6 +53,16 @@ struct GroupedFlatmmHostArgs index_t k_batch; }; +namespace persist { + +template +__launch_bounds__(MaxThreadPerBlock) __global__ void persist_kernel(Args... args) +{ + Kernel{}(args...); +} + +} // namespace persist + template struct GroupedFlatmmKernel : FlatmmKernel { @@ -76,14 +86,32 @@ struct GroupedFlatmmKernel : FlatmmKernel, FlatmmPipeline::GetName()); } - CK_TILE_HOST_DEVICE static constexpr auto GridSize(const GroupedFlatmmKernelArgs& kernelArgs) + CK_TILE_HOST_DEVICE static auto GridSize(const GroupedFlatmmKernelArgs& kernelArgs) { hipDeviceProp_t prop; int deviceId = 0; // default device - auto e = hipGetDeviceProperties(&prop, deviceId); + constexpr int block_size = UnderlyingGemmKernel::BlockSize().x; + int dync_smem_size = 0; + int maxActiveBlocksPerCU; - const int persistent_block_size = prop.multiProcessorCount; + [[maybe_unused]] auto e = hipGetDeviceProperties(&prop, deviceId); + + e = hipOccupancyMaxActiveBlocksPerMultiprocessor( + &maxActiveBlocksPerCU, + // reinterpret_cast(GroupedFlatmmKernel::Kernel), + reinterpret_cast( + kentry2), + block_size, + dync_smem_size); + + const int persistent_block_size = prop.multiProcessorCount * maxActiveBlocksPerCU; + + // print maxActiveBlocksPerCU and persistent_block_size + std::cout << "maxActiveBlocksPerCU: " << maxActiveBlocksPerCU + << ", persistent_block_size: " << persistent_block_size << std::endl; + + assert(kernelArgs.k_batch == 1); return dim3(persistent_block_size, 1, kernelArgs.k_batch); }