mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 06:49:15 +00:00
[rocm-libraries] ROCm/rocm-libraries#5050 (commit 033dad7)
[CK TILE] Skip work if any of Grouped GEMM groups M/N/K are zero. (#5050) ## Motivation It's common in MoE workloads that some experts receive zero tokens, which would result in some of the dimensions equal to zero. Currently we handle such case only for non-persistent kernels where we have all GEMMs information beforehand on host - we validate this during creation of kernel arguments. However for the "dynamic" input path (persistent kernel) this information is not available before kernel launch. Thus we have to validate this during kernel execution. The goal is to add this validation. ## Technical Details Skip work if any of Grouped GEMM groups M/N/K are zero for persistent kernel path. ## Test Plan Add unit-tests which cover "dynamic" inputs with zero dims for persistent kernel execution path. ## Test Result All tests pass. ## Submission Checklist - [ x] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
This commit is contained in:
committed by
assistant-librarian[bot]
parent
2c3f9bfa52
commit
b09ce811d5
@@ -507,6 +507,12 @@ struct GroupedGemmKernel
|
||||
const index_t group_id = FindGroupId(gemm_desc_ptr, block_id, group_count);
|
||||
const auto& kargs = gemm_desc_ptr[group_id];
|
||||
|
||||
// Early exit if no work to do.
|
||||
if(kargs.group_karg.M == 0 || kargs.group_karg.N == 0 || kargs.group_karg.K == 0)
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
const auto grid_size_2d = TilePartitioner::GridSize(kargs.group_karg.M, kargs.group_karg.N);
|
||||
const auto block_idx_2d = OffsetTile1DPartitioner::GetOffsetedTileIndex(
|
||||
0,
|
||||
@@ -534,6 +540,22 @@ struct GroupedGemmKernel
|
||||
const auto& k_batch = kargs.k_batch;
|
||||
const auto block_start = cum_grid_size;
|
||||
cum_grid_size += TilePartitioner::GridSize(kargs.M, kargs.N) * k_batch;
|
||||
|
||||
// Early exit if no work to do.
|
||||
// If M or N is zero, TilePartitioner::GridSize(kargs.M, kargs.N) returns zero,
|
||||
// so this group contributes no blocks and cum_grid_size is unchanged. The group
|
||||
// is naturally skipped by the block_id < cum_grid_size check below.
|
||||
if(kargs.K == 0)
|
||||
{
|
||||
// Advance only if this workgroup was assigned to this group's range,
|
||||
// matching the pattern of the normal while loop below.
|
||||
while(block_id < cum_grid_size)
|
||||
{
|
||||
block_id += grid_size;
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
while(block_id < cum_grid_size)
|
||||
{
|
||||
const auto grid_size_2d = TilePartitioner::GridSize(kargs.M, kargs.N);
|
||||
|
||||
Reference in New Issue
Block a user