optimize grid calculation in contiguous mode

This commit is contained in:
Feng Shijie
2025-07-09 08:24:16 +00:00
parent fae4ebac66
commit ff4d6434d9

View File

@@ -115,8 +115,9 @@ struct GroupedFlatmmKernel : FlatmmKernel<TilePartitioner_, FlatmmPipeline_, Epi
return concat(
'_', "grouped_flatmm", gemm_prec_str<ADataType, BDataType>, FlatmmPipeline::GetName());
}
template <class KernelArgs>
CK_TILE_HOST_DEVICE static auto GridSizeImpl(const KernelArgs& kernelArgs)
CK_TILE_HOST_DEVICE static auto
GridSize([[maybe_unused]] const GroupedFlatmmHostArgs& kernelArgs)
{
hipDeviceProp_t prop;
int deviceId = 0; // default device
@@ -129,7 +130,8 @@ struct GroupedFlatmmKernel : FlatmmKernel<TilePartitioner_, FlatmmPipeline_, Epi
e = hipOccupancyMaxActiveBlocksPerMultiprocessor(
&maxActiveBlocksPerCU,
reinterpret_cast<void*>(kentry2<block_size, GroupedFlatmmKernel, KernelArgs>),
reinterpret_cast<void*>(
kentry2<block_size, GroupedFlatmmKernel, GroupedFlatmmHostArgs>),
block_size,
dync_smem_size);
@@ -142,15 +144,34 @@ struct GroupedFlatmmKernel : FlatmmKernel<TilePartitioner_, FlatmmPipeline_, Epi
return dim3(persistent_block_size, 1, kernelArgs.k_batch);
}
CK_TILE_HOST_DEVICE static auto
GridSize([[maybe_unused]] const GroupedFlatmmHostArgs& kernelArgs)
{
return GridSizeImpl<GroupedFlatmmHostArgs>(kernelArgs);
}
CK_TILE_HOST_DEVICE static auto
GridSize([[maybe_unused]] const ContiguousGroupedFlatmmHostArgs& kernelArgs)
{
return GridSizeImpl<ContiguousGroupedFlatmmHostArgs>(kernelArgs);
hipDeviceProp_t prop;
int deviceId = 0; // default device
constexpr int block_size = UnderlyingGemmKernel::BlockSize().x;
int dync_smem_size = 0;
int maxActiveBlocksPerCU;
[[maybe_unused]] auto e = hipGetDeviceProperties(&prop, deviceId);
e = hipOccupancyMaxActiveBlocksPerMultiprocessor(
&maxActiveBlocksPerCU,
reinterpret_cast<void*>(
kentry2<block_size, GroupedFlatmmKernel, ContiguousGroupedFlatmmHostArgs>),
block_size,
dync_smem_size);
const int persistent_block_size = prop.multiProcessorCount * maxActiveBlocksPerCU;
const int total_work_tile_cnt = TilePartitioner::GridSize(kernelArgs.M, kernelArgs.N);
std::cout << "maxActiveBlocksPerCU: " << maxActiveBlocksPerCU
<< ", persistent_block_size: " << persistent_block_size
<< ", total_work_tile_cnt: " << total_work_tile_cnt << std::endl;
assert(kernelArgs.k_batch == 1);
return dim3(min(persistent_block_size, total_work_tile_cnt), 1, kernelArgs.k_batch);
}
CK_TILE_HOST static constexpr auto MakeKernelArgs(const GroupedFlatmmHostArgs& hostArgs)