Improve gridDim calculation in persistent mode

This commit is contained in:
Feng Shijie
2025-07-04 07:09:58 +00:00
parent 10fe5ab7b5
commit bce9c22bcd
2 changed files with 37 additions and 3 deletions

View File

@@ -20,6 +20,12 @@ __launch_bounds__(MaxThreadPerBlock, MinBlockPerCu)
Kernel{}(args...);
}
template <int MaxThreadPerBlock, typename Kernel, typename... Args>
__launch_bounds__(MaxThreadPerBlock) __global__ void kentry2(Args... args)
{
Kernel{}(args...);
}
//
// return a anonymous functor(lambda) to be called later
// the KernelImpl should be a class without non-static data member, or let's say

View File

@@ -53,6 +53,16 @@ struct GroupedFlatmmHostArgs
index_t k_batch;
};
namespace persist {
template <int MaxThreadPerBlock, typename Kernel, typename... Args>
__launch_bounds__(MaxThreadPerBlock) __global__ void persist_kernel(Args... args)
{
Kernel{}(args...);
}
} // namespace persist
template <typename TilePartitioner_, typename FlatmmPipeline_, typename EpiloguePipeline_>
struct GroupedFlatmmKernel : FlatmmKernel<TilePartitioner_, FlatmmPipeline_, EpiloguePipeline_>
{
@@ -76,14 +86,32 @@ struct GroupedFlatmmKernel : FlatmmKernel<TilePartitioner_, FlatmmPipeline_, Epi
'_', "grouped_flatmm", gemm_prec_str<ADataType, BDataType>, FlatmmPipeline::GetName());
}
CK_TILE_HOST_DEVICE static constexpr auto GridSize(const GroupedFlatmmKernelArgs& kernelArgs)
CK_TILE_HOST_DEVICE static auto GridSize(const GroupedFlatmmKernelArgs& kernelArgs)
{
hipDeviceProp_t prop;
int deviceId = 0; // default device
auto e = hipGetDeviceProperties(&prop, deviceId);
constexpr int block_size = UnderlyingGemmKernel::BlockSize().x;
int dync_smem_size = 0;
int maxActiveBlocksPerCU;
const int persistent_block_size = prop.multiProcessorCount;
[[maybe_unused]] auto e = hipGetDeviceProperties(&prop, deviceId);
e = hipOccupancyMaxActiveBlocksPerMultiprocessor(
&maxActiveBlocksPerCU,
// reinterpret_cast<void*>(GroupedFlatmmKernel::Kernel),
reinterpret_cast<void*>(
kentry2<block_size, GroupedFlatmmKernel, GroupedFlatmmKernelArgs>),
block_size,
dync_smem_size);
const int persistent_block_size = prop.multiProcessorCount * maxActiveBlocksPerCU;
// print maxActiveBlocksPerCU and persistent_block_size
std::cout << "maxActiveBlocksPerCU: " << maxActiveBlocksPerCU
<< ", persistent_block_size: " << persistent_block_size << std::endl;
assert(kernelArgs.k_batch == 1);
return dim3(persistent_block_size, 1, kernelArgs.k_batch);
}