Implement grouped GEMM (C_g = A_g x B_g for g groups) on Hopper using
CuTe DSL, extending the dense persistent GEMM with per-group TMA
descriptor management.
Kernel design (grouped_gemm.py):
- Warp-specialized pipeline: DMA warp group handles TMA loads and
per-group tensormap updates; MMA warp group runs WGMMA and stores C
- StaticPersistentGroupTileScheduler for cross-group tile scheduling
- Per-group TMA descriptor updates via GMEM or SMEM mode
- Supports fp16, fp8 (E4M3FN/E5M2), int8 with mixed A/B dtypes
- Configurable tile shapes (128x128, 128x256) and cluster shapes
- Fix base TensorMapManager: hoist uniform_smem_ptrs outside predicated
block to avoid illegal @P0 R2UR on sm_90a
Tests (test/examples/CuTeDSL/hopper/test_grouped_gemm.py):
- L0 compile and L1 correctness pytest suite covering tile shapes,
dtypes, major modes, cluster shapes, group counts, and mixed sizes
- Move to test/examples/CuTeDSL/hopper/ following sm_100a convention
- Fix deprecated startdir arg in test_sharding.py pytest hook
The notebook uses float16 tensors but the vectorized kernel documentation
incorrectly describes elements as 32-bit and uses 4-element vectorization.
Updated to correctly state 16-bit elements with 8-element vectorization
for proper 128-bit loads/stores.
Signed-off-by: Blake Ledden <bledden@users.noreply.github.com>
Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
Before this fix, combining two Boolean (i1) DSL values with Python `and`
triggered a verbose i1→i32→i1 round-trip in __dsl_and__:
arith.extui (×3), arith.select, arith.cmpi ne (×2) — 6 extra MLIR ops.
Add a fast path: when both operands are Boolean, delegate directly to
__and__, emitting a single arith.andi %a, %b : i1 — identical to `&`.
Both operators were already semantically equivalent; this fix makes the
generated MLIR identical as well.
Includes:
- repro_dsl_and_bool.py — minimal standalone reproducer / bug-report script
- test_dsl_and_fix.py — pytest tests verifying the fixed behaviour
Add subtraction operation for packed f32x2 values, following the same
pattern as the existing add_packed_f32x2 and mul_packed_f32x2 operations.
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>