mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-07-01 20:27:42 +00:00
Improve gridDim calculation in persistent mode
This commit is contained in:
@@ -20,6 +20,12 @@ __launch_bounds__(MaxThreadPerBlock, MinBlockPerCu)
|
||||
Kernel{}(args...);
|
||||
}
|
||||
|
||||
template <int MaxThreadPerBlock, typename Kernel, typename... Args>
|
||||
__launch_bounds__(MaxThreadPerBlock) __global__ void kentry2(Args... args)
|
||||
{
|
||||
Kernel{}(args...);
|
||||
}
|
||||
|
||||
//
|
||||
// return a anonymous functor(lambda) to be called later
|
||||
// the KernelImpl should be a class without non-static data member, or let's say
|
||||
|
||||
@@ -53,6 +53,16 @@ struct GroupedFlatmmHostArgs
|
||||
index_t k_batch;
|
||||
};
|
||||
|
||||
namespace persist {
|
||||
|
||||
template <int MaxThreadPerBlock, typename Kernel, typename... Args>
|
||||
__launch_bounds__(MaxThreadPerBlock) __global__ void persist_kernel(Args... args)
|
||||
{
|
||||
Kernel{}(args...);
|
||||
}
|
||||
|
||||
} // namespace persist
|
||||
|
||||
template <typename TilePartitioner_, typename FlatmmPipeline_, typename EpiloguePipeline_>
|
||||
struct GroupedFlatmmKernel : FlatmmKernel<TilePartitioner_, FlatmmPipeline_, EpiloguePipeline_>
|
||||
{
|
||||
@@ -76,14 +86,32 @@ struct GroupedFlatmmKernel : FlatmmKernel<TilePartitioner_, FlatmmPipeline_, Epi
|
||||
'_', "grouped_flatmm", gemm_prec_str<ADataType, BDataType>, FlatmmPipeline::GetName());
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GridSize(const GroupedFlatmmKernelArgs& kernelArgs)
|
||||
CK_TILE_HOST_DEVICE static auto GridSize(const GroupedFlatmmKernelArgs& kernelArgs)
|
||||
{
|
||||
hipDeviceProp_t prop;
|
||||
int deviceId = 0; // default device
|
||||
|
||||
auto e = hipGetDeviceProperties(&prop, deviceId);
|
||||
constexpr int block_size = UnderlyingGemmKernel::BlockSize().x;
|
||||
int dync_smem_size = 0;
|
||||
int maxActiveBlocksPerCU;
|
||||
|
||||
const int persistent_block_size = prop.multiProcessorCount;
|
||||
[[maybe_unused]] auto e = hipGetDeviceProperties(&prop, deviceId);
|
||||
|
||||
e = hipOccupancyMaxActiveBlocksPerMultiprocessor(
|
||||
&maxActiveBlocksPerCU,
|
||||
// reinterpret_cast<void*>(GroupedFlatmmKernel::Kernel),
|
||||
reinterpret_cast<void*>(
|
||||
kentry2<block_size, GroupedFlatmmKernel, GroupedFlatmmKernelArgs>),
|
||||
block_size,
|
||||
dync_smem_size);
|
||||
|
||||
const int persistent_block_size = prop.multiProcessorCount * maxActiveBlocksPerCU;
|
||||
|
||||
// print maxActiveBlocksPerCU and persistent_block_size
|
||||
std::cout << "maxActiveBlocksPerCU: " << maxActiveBlocksPerCU
|
||||
<< ", persistent_block_size: " << persistent_block_size << std::endl;
|
||||
|
||||
assert(kernelArgs.k_batch == 1);
|
||||
return dim3(persistent_block_size, 1, kernelArgs.k_batch);
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user