Files
cutlass/examples/python/CuTeDSL
Johnsonms f74fea9ce3 [Hopper CuTeDSL] Add FP8 GEMM with 2xAcc (#3149)
Add dense_gemm_fp8_2xacc.py — a CuTeDSL port of CUTLASS Example 54
(54_hopper_fp8_warp_specialized_gemm.cu) for NVIDIA Hopper (SM90).

Implements D = scale_a * scale_b * (A @ B) where A/B are FP8 E4M3FN using
the 2xAcc (double accumulation) technique: a temporary accumulator is
periodically promoted into the main accumulator every mma_promotion_interval
MMA instructions to prevent FP8 precision loss.

Features:
- FP8 E4M3FN inputs with Float32 accumulation
- 2xAcc for improved numerical accuracy
- TMA with multicast for A/B/D transfers
- WGMMA warp-specialized persistent tile scheduling
- Configurable output dtype: Float16, Float32, Float8E4M3FN
- Scalar scale_a / scale_b epilogue factors
- Cluster shapes up to 2x2

Add pytest test suite covering:
- L0 compile tests: all tile shapes, cluster shapes, output dtypes,
  mma_promotion_interval values
- L1 correctness tests: numerical validation vs torch.einsum reference
  for all configs, non-trivial scale factors, and batched GEMM (L>1)
- Benchmark tests (pytest -m bench -s): representative problem sizes
  with warmup, cold-L2, and TFLOPS reporting

Also fix conftest.py to import cutlass before adding examples/python/CuTeDSL
to sys.path, preventing the jax/ examples subdirectory from being detected
as a namespace package and breaking cutlass's JAX availability check.
2026-04-25 16:10:33 -04:00
..
2026-04-07 12:16:05 -04:00
2026-04-07 12:16:05 -04:00
2026-04-07 12:16:05 -04:00
2026-04-07 12:16:05 -04:00
2026-01-08 15:02:56 -05:00