mirror of
https://github.com/NVIDIA/cutlass.git
synced 2026-05-11 17:00:05 +00:00
Add dense_gemm_fp8_2xacc.py — a CuTeDSL port of CUTLASS Example 54 (54_hopper_fp8_warp_specialized_gemm.cu) for NVIDIA Hopper (SM90). Implements D = scale_a * scale_b * (A @ B) where A/B are FP8 E4M3FN using the 2xAcc (double accumulation) technique: a temporary accumulator is periodically promoted into the main accumulator every mma_promotion_interval MMA instructions to prevent FP8 precision loss. Features: - FP8 E4M3FN inputs with Float32 accumulation - 2xAcc for improved numerical accuracy - TMA with multicast for A/B/D transfers - WGMMA warp-specialized persistent tile scheduling - Configurable output dtype: Float16, Float32, Float8E4M3FN - Scalar scale_a / scale_b epilogue factors - Cluster shapes up to 2x2 Add pytest test suite covering: - L0 compile tests: all tile shapes, cluster shapes, output dtypes, mma_promotion_interval values - L1 correctness tests: numerical validation vs torch.einsum reference for all configs, non-trivial scale factors, and batched GEMM (L>1) - Benchmark tests (pytest -m bench -s): representative problem sizes with warmup, cold-L2, and TFLOPS reporting Also fix conftest.py to import cutlass before adding examples/python/CuTeDSL to sys.path, preventing the jax/ examples subdirectory from being detected as a namespace package and breaking cutlass's JAX availability check.