[Bug Fix]Bypass launch grids for SM120 Kernel with SM90 Mainloop & SM100 TileScheduler (#2865)

* Delete unused #ifdef/#endif. Bypass sm120 case.

* Add todo.

* Fix pingpong.

* Revert "Add todo."

This reverts commit 246cb42091.

* Refine name.

Refine name again.

* Apply suggestions from code review

Skip `is_last_tile` for all sm120 kernels.

Co-authored-by: Junkai-Wu <junkaiw@nvidia.com>

* Skip early stop for sm120 kernel.

* Fix typo.

---------

Co-authored-by: Junkai-Wu <junkaiw@nvidia.com>
This commit is contained in:
Qi Yuhang
2025-12-18 08:51:38 +08:00
committed by GitHub
parent d4e16f5d4e
commit ebf3165efb
2 changed files with 31 additions and 28 deletions

View File

@@ -136,16 +136,16 @@ public:
// Detect if this is SM120 blockscaled kernel which hits high register pressure
// on smaller tiles (e.g. 256x128 registers per thread)
template <typename T>
struct is_blockscaled : cute::false_type {};
struct IsSm120BlockScaled : cute::false_type {};
template <int Stages, int SchedStages, class ClusterShape, class KernelSchedule>
struct is_blockscaled<MainloopSm120TmaWarpSpecializedBlockScaled<Stages, SchedStages, ClusterShape, KernelSchedule>>
struct IsSm120BlockScaled<MainloopSm120TmaWarpSpecializedBlockScaled<Stages, SchedStages, ClusterShape, KernelSchedule>>
: cute::true_type {};
static constexpr bool IsBlockScaled = is_blockscaled<DispatchPolicy>::value;
static constexpr bool IsSm120Family = cute::is_same_v<typename DispatchPolicy::ArchTag, arch::Sm120>;
static constexpr bool HeavyRegisterPressure =
IsBlockScaled ? (RegsPerThread >= 128) : (RegsPerThread >= 208);
IsSm120BlockScaled<DispatchPolicy>::value ? (RegsPerThread >= 128) : (RegsPerThread >= 208);
static constexpr uint32_t LoadRegisterRequirement = !HeavyRegisterPressure ? 40 : 24;
static constexpr uint32_t MmaRegisterRequirement = !HeavyRegisterPressure ? 232 : 240;
@@ -804,15 +804,16 @@ public:
// Update starting mainloop pipeline state for the next tile
mainloop_pipe_consumer_state.advance(work_k_tile_count);
}
#ifdef CUTLASS_ENABLE_GDC_FOR_SM90
if (scheduler.is_last_tile(work_tile_info)) {
// Hint on an early release of global memory resources.
// The timing of calling this function only influences performance,
// not functional correctness.
cutlass::arch::launch_dependent_grids();
if constexpr (!IsSm120Family) {
if (scheduler.is_last_tile(work_tile_info)) {
// Hint on an early release of global memory resources.
// The timing of calling this function only influences performance,
// not functional correctness.
cutlass::arch::launch_dependent_grids();
}
}
#endif
// Index of warp group within consumer warp groups
int consumer_warp_group_idx = canonical_warp_group_idx() - NumLoadWarpGroups;

View File

@@ -147,6 +147,8 @@ public:
static constexpr uint32_t LoadRegisterRequirement = !HeavyRegisterPressure ? 40 : 24;
static constexpr uint32_t MmaRegisterRequirement = !HeavyRegisterPressure ? 232 : 240;
static constexpr bool IsSm120Family = cute::is_same_v<typename DispatchPolicy::ArchTag, arch::Sm120>;
// 1 stage ordered sequence between mainloop and epilogue producer load threads
using LoadWarpOrderBarrier = cutlass::OrderedSequenceBarrier<1,2>;
@@ -800,18 +802,18 @@ public:
else if (warp_group_role == WarpGroupRole::Consumer0 || warp_group_role == WarpGroupRole::Consumer1) {
cutlass::arch::warpgroup_reg_alloc<MmaRegisterRequirement>();
#ifdef CUTLASS_ENABLE_GDC_FOR_SM90
// It is possible to have work tiles start off invalid,
// so we have to check that first.
if (not work_tile_info.is_valid()) {
// Hint on an early release of global memory resources.
// The timing of calling this function only influences performance,
// not functional correctness.
cutlass::arch::launch_dependent_grids();
if constexpr (!IsSm120Family) {
// It is possible to have work tiles start off invalid,
// so we have to check that first.
if (not work_tile_info.is_valid()) {
// Hint on an early release of global memory resources.
// The timing of calling this function only influences performance,
// not functional correctness.
cutlass::arch::launch_dependent_grids();
return;
return;
}
}
#endif
if constexpr (IsSchedDynamicPersistent) {
// Consumer0's initial tile is static. It starts consuming the 2nd tile.
@@ -868,15 +870,15 @@ public:
// Update starting mainloop pipeline state for the next tile
mainloop_pipe_consumer_state.advance(k_tile_count * NumMmaWarpGroups);
#ifdef CUTLASS_ENABLE_GDC_FOR_SM90
if (scheduler.is_last_tile(work_tile_info, NumMmaWarpGroups)) {
// Hint on an early release of global memory resources.
// The timing of calling this function only influences performance,
// not functional correctness.
cutlass::arch::launch_dependent_grids();
if constexpr (!IsSm120Family) {
if (scheduler.is_last_tile(work_tile_info, NumMmaWarpGroups)) {
// Hint on an early release of global memory resources.
// The timing of calling this function only influences performance,
// not functional correctness.
cutlass::arch::launch_dependent_grids();
}
}
#endif
// Order two Math WG's Epilogue one after the other
math_wg_order_barrier.wait();