diff --git a/include/ck_tile/ops/flatmm/kernel/grouped_flatmm_kernel.hpp b/include/ck_tile/ops/flatmm/kernel/grouped_flatmm_kernel.hpp index d67ef6f33b..c68ff0483d 100644 --- a/include/ck_tile/ops/flatmm/kernel/grouped_flatmm_kernel.hpp +++ b/include/ck_tile/ops/flatmm/kernel/grouped_flatmm_kernel.hpp @@ -115,8 +115,9 @@ struct GroupedFlatmmKernel : FlatmmKernel, FlatmmPipeline::GetName()); } - template - CK_TILE_HOST_DEVICE static auto GridSizeImpl(const KernelArgs& kernelArgs) + + CK_TILE_HOST_DEVICE static auto + GridSize([[maybe_unused]] const GroupedFlatmmHostArgs& kernelArgs) { hipDeviceProp_t prop; int deviceId = 0; // default device @@ -129,7 +130,8 @@ struct GroupedFlatmmKernel : FlatmmKernel(kentry2), + reinterpret_cast( + kentry2), block_size, dync_smem_size); @@ -142,15 +144,34 @@ struct GroupedFlatmmKernel : FlatmmKernel(kernelArgs); - } CK_TILE_HOST_DEVICE static auto GridSize([[maybe_unused]] const ContiguousGroupedFlatmmHostArgs& kernelArgs) { - return GridSizeImpl(kernelArgs); + hipDeviceProp_t prop; + int deviceId = 0; // default device + + constexpr int block_size = UnderlyingGemmKernel::BlockSize().x; + int dync_smem_size = 0; + int maxActiveBlocksPerCU; + + [[maybe_unused]] auto e = hipGetDeviceProperties(&prop, deviceId); + + e = hipOccupancyMaxActiveBlocksPerMultiprocessor( + &maxActiveBlocksPerCU, + reinterpret_cast( + kentry2), + block_size, + dync_smem_size); + + const int persistent_block_size = prop.multiProcessorCount * maxActiveBlocksPerCU; + const int total_work_tile_cnt = TilePartitioner::GridSize(kernelArgs.M, kernelArgs.N); + + std::cout << "maxActiveBlocksPerCU: " << maxActiveBlocksPerCU + << ", persistent_block_size: " << persistent_block_size + << ", total_work_tile_cnt: " << total_work_tile_cnt << std::endl; + + assert(kernelArgs.k_batch == 1); + return dim3(min(persistent_block_size, total_work_tile_cnt), 1, kernelArgs.k_batch); } CK_TILE_HOST static constexpr auto MakeKernelArgs(const GroupedFlatmmHostArgs& hostArgs)