mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-30 03:37:38 +00:00
optimize grid calculation in contiguous mode
This commit is contained in:
@@ -115,8 +115,9 @@ struct GroupedFlatmmKernel : FlatmmKernel<TilePartitioner_, FlatmmPipeline_, Epi
|
||||
return concat(
|
||||
'_', "grouped_flatmm", gemm_prec_str<ADataType, BDataType>, FlatmmPipeline::GetName());
|
||||
}
|
||||
template <class KernelArgs>
|
||||
CK_TILE_HOST_DEVICE static auto GridSizeImpl(const KernelArgs& kernelArgs)
|
||||
|
||||
CK_TILE_HOST_DEVICE static auto
|
||||
GridSize([[maybe_unused]] const GroupedFlatmmHostArgs& kernelArgs)
|
||||
{
|
||||
hipDeviceProp_t prop;
|
||||
int deviceId = 0; // default device
|
||||
@@ -129,7 +130,8 @@ struct GroupedFlatmmKernel : FlatmmKernel<TilePartitioner_, FlatmmPipeline_, Epi
|
||||
|
||||
e = hipOccupancyMaxActiveBlocksPerMultiprocessor(
|
||||
&maxActiveBlocksPerCU,
|
||||
reinterpret_cast<void*>(kentry2<block_size, GroupedFlatmmKernel, KernelArgs>),
|
||||
reinterpret_cast<void*>(
|
||||
kentry2<block_size, GroupedFlatmmKernel, GroupedFlatmmHostArgs>),
|
||||
block_size,
|
||||
dync_smem_size);
|
||||
|
||||
@@ -142,15 +144,34 @@ struct GroupedFlatmmKernel : FlatmmKernel<TilePartitioner_, FlatmmPipeline_, Epi
|
||||
return dim3(persistent_block_size, 1, kernelArgs.k_batch);
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static auto
|
||||
GridSize([[maybe_unused]] const GroupedFlatmmHostArgs& kernelArgs)
|
||||
{
|
||||
return GridSizeImpl<GroupedFlatmmHostArgs>(kernelArgs);
|
||||
}
|
||||
CK_TILE_HOST_DEVICE static auto
|
||||
GridSize([[maybe_unused]] const ContiguousGroupedFlatmmHostArgs& kernelArgs)
|
||||
{
|
||||
return GridSizeImpl<ContiguousGroupedFlatmmHostArgs>(kernelArgs);
|
||||
hipDeviceProp_t prop;
|
||||
int deviceId = 0; // default device
|
||||
|
||||
constexpr int block_size = UnderlyingGemmKernel::BlockSize().x;
|
||||
int dync_smem_size = 0;
|
||||
int maxActiveBlocksPerCU;
|
||||
|
||||
[[maybe_unused]] auto e = hipGetDeviceProperties(&prop, deviceId);
|
||||
|
||||
e = hipOccupancyMaxActiveBlocksPerMultiprocessor(
|
||||
&maxActiveBlocksPerCU,
|
||||
reinterpret_cast<void*>(
|
||||
kentry2<block_size, GroupedFlatmmKernel, ContiguousGroupedFlatmmHostArgs>),
|
||||
block_size,
|
||||
dync_smem_size);
|
||||
|
||||
const int persistent_block_size = prop.multiProcessorCount * maxActiveBlocksPerCU;
|
||||
const int total_work_tile_cnt = TilePartitioner::GridSize(kernelArgs.M, kernelArgs.N);
|
||||
|
||||
std::cout << "maxActiveBlocksPerCU: " << maxActiveBlocksPerCU
|
||||
<< ", persistent_block_size: " << persistent_block_size
|
||||
<< ", total_work_tile_cnt: " << total_work_tile_cnt << std::endl;
|
||||
|
||||
assert(kernelArgs.k_batch == 1);
|
||||
return dim3(min(persistent_block_size, total_work_tile_cnt), 1, kernelArgs.k_batch);
|
||||
}
|
||||
|
||||
CK_TILE_HOST static constexpr auto MakeKernelArgs(const GroupedFlatmmHostArgs& hostArgs)
|
||||
|
||||
Reference in New Issue
Block a user