mirror of
https://github.com/NVIDIA/cutlass.git
synced 2026-05-11 17:00:05 +00:00
[Bug Fix]Bypass launch grids for SM120 Kernel with SM90 Mainloop & SM100 TileScheduler (#2865)
* Delete unused #ifdef/#endif. Bypass sm120 case.
* Add todo.
* Fix pingpong.
* Revert "Add todo."
This reverts commit 246cb42091.
* Refine name.
Refine name again.
* Apply suggestions from code review
Skip `is_last_tile` for all sm120 kernels.
Co-authored-by: Junkai-Wu <junkaiw@nvidia.com>
* Skip early stop for sm120 kernel.
* Fix typo.
---------
Co-authored-by: Junkai-Wu <junkaiw@nvidia.com>
This commit is contained in:
@@ -136,16 +136,16 @@ public:
|
||||
// Detect if this is SM120 blockscaled kernel which hits high register pressure
|
||||
// on smaller tiles (e.g. 256x128 registers per thread)
|
||||
template <typename T>
|
||||
struct is_blockscaled : cute::false_type {};
|
||||
struct IsSm120BlockScaled : cute::false_type {};
|
||||
|
||||
template <int Stages, int SchedStages, class ClusterShape, class KernelSchedule>
|
||||
struct is_blockscaled<MainloopSm120TmaWarpSpecializedBlockScaled<Stages, SchedStages, ClusterShape, KernelSchedule>>
|
||||
struct IsSm120BlockScaled<MainloopSm120TmaWarpSpecializedBlockScaled<Stages, SchedStages, ClusterShape, KernelSchedule>>
|
||||
: cute::true_type {};
|
||||
|
||||
static constexpr bool IsBlockScaled = is_blockscaled<DispatchPolicy>::value;
|
||||
static constexpr bool IsSm120Family = cute::is_same_v<typename DispatchPolicy::ArchTag, arch::Sm120>;
|
||||
|
||||
static constexpr bool HeavyRegisterPressure =
|
||||
IsBlockScaled ? (RegsPerThread >= 128) : (RegsPerThread >= 208);
|
||||
IsSm120BlockScaled<DispatchPolicy>::value ? (RegsPerThread >= 128) : (RegsPerThread >= 208);
|
||||
|
||||
static constexpr uint32_t LoadRegisterRequirement = !HeavyRegisterPressure ? 40 : 24;
|
||||
static constexpr uint32_t MmaRegisterRequirement = !HeavyRegisterPressure ? 232 : 240;
|
||||
@@ -804,15 +804,16 @@ public:
|
||||
// Update starting mainloop pipeline state for the next tile
|
||||
mainloop_pipe_consumer_state.advance(work_k_tile_count);
|
||||
}
|
||||
#ifdef CUTLASS_ENABLE_GDC_FOR_SM90
|
||||
if (scheduler.is_last_tile(work_tile_info)) {
|
||||
// Hint on an early release of global memory resources.
|
||||
// The timing of calling this function only influences performance,
|
||||
// not functional correctness.
|
||||
cutlass::arch::launch_dependent_grids();
|
||||
|
||||
if constexpr (!IsSm120Family) {
|
||||
if (scheduler.is_last_tile(work_tile_info)) {
|
||||
// Hint on an early release of global memory resources.
|
||||
// The timing of calling this function only influences performance,
|
||||
// not functional correctness.
|
||||
cutlass::arch::launch_dependent_grids();
|
||||
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
// Index of warp group within consumer warp groups
|
||||
int consumer_warp_group_idx = canonical_warp_group_idx() - NumLoadWarpGroups;
|
||||
|
||||
@@ -147,6 +147,8 @@ public:
|
||||
static constexpr uint32_t LoadRegisterRequirement = !HeavyRegisterPressure ? 40 : 24;
|
||||
static constexpr uint32_t MmaRegisterRequirement = !HeavyRegisterPressure ? 232 : 240;
|
||||
|
||||
static constexpr bool IsSm120Family = cute::is_same_v<typename DispatchPolicy::ArchTag, arch::Sm120>;
|
||||
|
||||
// 1 stage ordered sequence between mainloop and epilogue producer load threads
|
||||
using LoadWarpOrderBarrier = cutlass::OrderedSequenceBarrier<1,2>;
|
||||
|
||||
@@ -800,18 +802,18 @@ public:
|
||||
else if (warp_group_role == WarpGroupRole::Consumer0 || warp_group_role == WarpGroupRole::Consumer1) {
|
||||
cutlass::arch::warpgroup_reg_alloc<MmaRegisterRequirement>();
|
||||
|
||||
#ifdef CUTLASS_ENABLE_GDC_FOR_SM90
|
||||
// It is possible to have work tiles start off invalid,
|
||||
// so we have to check that first.
|
||||
if (not work_tile_info.is_valid()) {
|
||||
// Hint on an early release of global memory resources.
|
||||
// The timing of calling this function only influences performance,
|
||||
// not functional correctness.
|
||||
cutlass::arch::launch_dependent_grids();
|
||||
if constexpr (!IsSm120Family) {
|
||||
// It is possible to have work tiles start off invalid,
|
||||
// so we have to check that first.
|
||||
if (not work_tile_info.is_valid()) {
|
||||
// Hint on an early release of global memory resources.
|
||||
// The timing of calling this function only influences performance,
|
||||
// not functional correctness.
|
||||
cutlass::arch::launch_dependent_grids();
|
||||
|
||||
return;
|
||||
return;
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
if constexpr (IsSchedDynamicPersistent) {
|
||||
// Consumer0's initial tile is static. It starts consuming the 2nd tile.
|
||||
@@ -868,15 +870,15 @@ public:
|
||||
// Update starting mainloop pipeline state for the next tile
|
||||
mainloop_pipe_consumer_state.advance(k_tile_count * NumMmaWarpGroups);
|
||||
|
||||
#ifdef CUTLASS_ENABLE_GDC_FOR_SM90
|
||||
if (scheduler.is_last_tile(work_tile_info, NumMmaWarpGroups)) {
|
||||
// Hint on an early release of global memory resources.
|
||||
// The timing of calling this function only influences performance,
|
||||
// not functional correctness.
|
||||
cutlass::arch::launch_dependent_grids();
|
||||
if constexpr (!IsSm120Family) {
|
||||
if (scheduler.is_last_tile(work_tile_info, NumMmaWarpGroups)) {
|
||||
// Hint on an early release of global memory resources.
|
||||
// The timing of calling this function only influences performance,
|
||||
// not functional correctness.
|
||||
cutlass::arch::launch_dependent_grids();
|
||||
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
// Order two Math WG's Epilogue one after the other
|
||||
math_wg_order_barrier.wait();
|
||||
|
||||
Reference in New Issue
Block a user