From ebf3165efba92012b4f0eff001bc22d9a4e2af39 Mon Sep 17 00:00:00 2001 From: Qi Yuhang <45795032+HydraQYH@users.noreply.github.com> Date: Thu, 18 Dec 2025 08:51:38 +0800 Subject: [PATCH] [Bug Fix]Bypass launch grids for SM120 Kernel with SM90 Mainloop & SM100 TileScheduler (#2865) * Delete unused #ifdef/#endif. Bypass sm120 case. * Add todo. * Fix pingpong. * Revert "Add todo." This reverts commit 246cb42091b1ed1b89c1eac2312eed6625a754fd. * Refine name. Refine name again. * Apply suggestions from code review Skip `is_last_tile` for all sm120 kernels. Co-authored-by: Junkai-Wu * Skip early stop for sm120 kernel. * Fix typo. --------- Co-authored-by: Junkai-Wu --- ...0_gemm_tma_warpspecialized_cooperative.hpp | 23 ++++++------ ...sm90_gemm_tma_warpspecialized_pingpong.hpp | 36 ++++++++++--------- 2 files changed, 31 insertions(+), 28 deletions(-) diff --git a/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_cooperative.hpp b/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_cooperative.hpp index 5c04259bd..255eb8ef7 100644 --- a/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_cooperative.hpp +++ b/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_cooperative.hpp @@ -136,16 +136,16 @@ public: // Detect if this is SM120 blockscaled kernel which hits high register pressure // on smaller tiles (e.g. 256x128 registers per thread) template - struct is_blockscaled : cute::false_type {}; + struct IsSm120BlockScaled : cute::false_type {}; template - struct is_blockscaled> + struct IsSm120BlockScaled> : cute::true_type {}; - static constexpr bool IsBlockScaled = is_blockscaled::value; + static constexpr bool IsSm120Family = cute::is_same_v; static constexpr bool HeavyRegisterPressure = - IsBlockScaled ? (RegsPerThread >= 128) : (RegsPerThread >= 208); + IsSm120BlockScaled::value ? (RegsPerThread >= 128) : (RegsPerThread >= 208); static constexpr uint32_t LoadRegisterRequirement = !HeavyRegisterPressure ? 40 : 24; static constexpr uint32_t MmaRegisterRequirement = !HeavyRegisterPressure ? 232 : 240; @@ -804,15 +804,16 @@ public: // Update starting mainloop pipeline state for the next tile mainloop_pipe_consumer_state.advance(work_k_tile_count); } - #ifdef CUTLASS_ENABLE_GDC_FOR_SM90 - if (scheduler.is_last_tile(work_tile_info)) { - // Hint on an early release of global memory resources. - // The timing of calling this function only influences performance, - // not functional correctness. - cutlass::arch::launch_dependent_grids(); + if constexpr (!IsSm120Family) { + if (scheduler.is_last_tile(work_tile_info)) { + // Hint on an early release of global memory resources. + // The timing of calling this function only influences performance, + // not functional correctness. + cutlass::arch::launch_dependent_grids(); + + } } - #endif // Index of warp group within consumer warp groups int consumer_warp_group_idx = canonical_warp_group_idx() - NumLoadWarpGroups; diff --git a/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_pingpong.hpp b/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_pingpong.hpp index 073f3a50e..6e46217c3 100644 --- a/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_pingpong.hpp +++ b/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_pingpong.hpp @@ -147,6 +147,8 @@ public: static constexpr uint32_t LoadRegisterRequirement = !HeavyRegisterPressure ? 40 : 24; static constexpr uint32_t MmaRegisterRequirement = !HeavyRegisterPressure ? 232 : 240; + static constexpr bool IsSm120Family = cute::is_same_v; + // 1 stage ordered sequence between mainloop and epilogue producer load threads using LoadWarpOrderBarrier = cutlass::OrderedSequenceBarrier<1,2>; @@ -800,18 +802,18 @@ public: else if (warp_group_role == WarpGroupRole::Consumer0 || warp_group_role == WarpGroupRole::Consumer1) { cutlass::arch::warpgroup_reg_alloc(); - #ifdef CUTLASS_ENABLE_GDC_FOR_SM90 - // It is possible to have work tiles start off invalid, - // so we have to check that first. - if (not work_tile_info.is_valid()) { - // Hint on an early release of global memory resources. - // The timing of calling this function only influences performance, - // not functional correctness. - cutlass::arch::launch_dependent_grids(); + if constexpr (!IsSm120Family) { + // It is possible to have work tiles start off invalid, + // so we have to check that first. + if (not work_tile_info.is_valid()) { + // Hint on an early release of global memory resources. + // The timing of calling this function only influences performance, + // not functional correctness. + cutlass::arch::launch_dependent_grids(); - return; + return; + } } - #endif if constexpr (IsSchedDynamicPersistent) { // Consumer0's initial tile is static. It starts consuming the 2nd tile. @@ -868,15 +870,15 @@ public: // Update starting mainloop pipeline state for the next tile mainloop_pipe_consumer_state.advance(k_tile_count * NumMmaWarpGroups); - #ifdef CUTLASS_ENABLE_GDC_FOR_SM90 - if (scheduler.is_last_tile(work_tile_info, NumMmaWarpGroups)) { - // Hint on an early release of global memory resources. - // The timing of calling this function only influences performance, - // not functional correctness. - cutlass::arch::launch_dependent_grids(); + if constexpr (!IsSm120Family) { + if (scheduler.is_last_tile(work_tile_info, NumMmaWarpGroups)) { + // Hint on an early release of global memory resources. + // The timing of calling this function only influences performance, + // not functional correctness. + cutlass::arch::launch_dependent_grids(); + } } - #endif // Order two Math WG's Epilogue one after the other math_wg_order_barrier.wait();