From ade6376fa075df782e012457751169fe2a5a53b4 Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Mon, 21 Apr 2025 00:02:30 -0400 Subject: [PATCH] [SM90] Change register allocation for TileN=208 to avoid spills (#2219) With the usual register allocation (producer 40, consumer 232) compiling Gemm with tile shape 256 x 208 (cooperative) or 128 x 208 (pingpong) show lots of register spilling (e.g. ~3000 bytes spill). For this case we can change the register allocation to producer 24, consumer 240, which avoids spills. --- .../kernel/sm90_gemm_tma_warpspecialized_cooperative.hpp | 8 ++++++-- .../kernel/sm90_gemm_tma_warpspecialized_pingpong.hpp | 8 ++++++-- 2 files changed, 12 insertions(+), 4 deletions(-) diff --git a/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_cooperative.hpp b/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_cooperative.hpp index c2ba8f890..c781d2fee 100644 --- a/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_cooperative.hpp +++ b/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_cooperative.hpp @@ -128,8 +128,12 @@ public: static constexpr uint32_t NumProducerThreads = CollectiveMainloop::NumProducerThreadEvents; /// Register requirement for Load and Math WGs - static constexpr uint32_t LoadRegisterRequirement = 40; - static constexpr uint32_t MmaRegisterRequirement = 232; + static constexpr int RegsPerThread = + size<0>(TileShape{}) * size<1>(TileShape{}) / NumMMAThreads * + sizeof(ElementAccumulator) / sizeof(uint32_t); + static constexpr bool HeavyRegisterPressure = RegsPerThread >= 208; + static constexpr uint32_t LoadRegisterRequirement = !HeavyRegisterPressure ? 40 : 24; + static constexpr uint32_t MmaRegisterRequirement = !HeavyRegisterPressure ? 232 : 240; // 1 stage ordered sequence between mainloop and epilogue producer load threads using LoadWarpOrderBarrier = cutlass::OrderedSequenceBarrier<1,2>; diff --git a/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_pingpong.hpp b/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_pingpong.hpp index aead199bb..587456d28 100644 --- a/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_pingpong.hpp +++ b/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_pingpong.hpp @@ -138,8 +138,12 @@ public: static_assert(MaxThreadsPerBlock == 384, "Pingpong kernel must have 384 threads in total."); /// Register requirement for Load and Math WGs - static constexpr uint32_t LoadRegisterRequirement = 40; - static constexpr uint32_t MmaRegisterRequirement = 232; + static constexpr int RegsPerThread = + size<0>(TileShape{}) * size<1>(TileShape{}) / NumMMAThreads * + sizeof(ElementAccumulator) / sizeof(uint32_t); + static constexpr bool HeavyRegisterPressure = RegsPerThread >= 208; + static constexpr uint32_t LoadRegisterRequirement = !HeavyRegisterPressure ? 40 : 24; + static constexpr uint32_t MmaRegisterRequirement = !HeavyRegisterPressure ? 232 : 240; // 1 stage ordered sequence between mainloop and epilogue producer load threads using LoadWarpOrderBarrier = cutlass::OrderedSequenceBarrier<1,2>;