mirror of
https://github.com/NVIDIA/cutlass.git
synced 2026-03-23 16:47:32 +00:00
Implement grouped GEMM (C_g = A_g x B_g for g groups) on Hopper using CuTe DSL, extending the dense persistent GEMM with per-group TMA descriptor management. Kernel design (grouped_gemm.py): - Warp-specialized pipeline: DMA warp group handles TMA loads and per-group tensormap updates; MMA warp group runs WGMMA and stores C - StaticPersistentGroupTileScheduler for cross-group tile scheduling - Per-group TMA descriptor updates via GMEM or SMEM mode - Supports fp16, fp8 (E4M3FN/E5M2), int8 with mixed A/B dtypes - Configurable tile shapes (128x128, 128x256) and cluster shapes - Fix base TensorMapManager: hoist uniform_smem_ptrs outside predicated block to avoid illegal @P0 R2UR on sm_90a Tests (test/examples/CuTeDSL/hopper/test_grouped_gemm.py): - L0 compile and L1 correctness pytest suite covering tile shapes, dtypes, major modes, cluster shapes, group counts, and mixed sizes - Move to test/examples/CuTeDSL/hopper/ following sm_100a convention - Fix deprecated startdir arg in test_sharding.py pytest hook