Commit Graph

11 Commits

Author SHA1 Message Date
TungtungQia
1d9e1f6d7a [CuTeDSL] Fix loop carried target scope (#3200)
* [CuTeDSL] Bug fix for scf.for's write_args analysis

* [CuTeDSL] Add for loop test
2026-05-11 16:02:26 +08:00
Johnsonms
f74fea9ce3 [Hopper CuTeDSL] Add FP8 GEMM with 2xAcc (#3149)
Add dense_gemm_fp8_2xacc.py — a CuTeDSL port of CUTLASS Example 54
(54_hopper_fp8_warp_specialized_gemm.cu) for NVIDIA Hopper (SM90).

Implements D = scale_a * scale_b * (A @ B) where A/B are FP8 E4M3FN using
the 2xAcc (double accumulation) technique: a temporary accumulator is
periodically promoted into the main accumulator every mma_promotion_interval
MMA instructions to prevent FP8 precision loss.

Features:
- FP8 E4M3FN inputs with Float32 accumulation
- 2xAcc for improved numerical accuracy
- TMA with multicast for A/B/D transfers
- WGMMA warp-specialized persistent tile scheduling
- Configurable output dtype: Float16, Float32, Float8E4M3FN
- Scalar scale_a / scale_b epilogue factors
- Cluster shapes up to 2x2

Add pytest test suite covering:
- L0 compile tests: all tile shapes, cluster shapes, output dtypes,
  mma_promotion_interval values
- L1 correctness tests: numerical validation vs torch.einsum reference
  for all configs, non-trivial scale factors, and batched GEMM (L>1)
- Benchmark tests (pytest -m bench -s): representative problem sizes
  with warmup, cold-L2, and TFLOPS reporting

Also fix conftest.py to import cutlass before adding examples/python/CuTeDSL
to sys.path, preventing the jax/ examples subdirectory from being detected
as a namespace package and breaking cutlass's JAX availability check.
2026-04-25 16:10:33 -04:00
Nandor Licker
ea46e277d2 Add absf and floor to cute.math (#3156)
The ops are already exposed by the underlying dialect.
2026-04-17 08:54:24 +08:00
Nandor Licker
3f3db08a0a Add support for empty dataclass arguments (#3152)
A dataclass with no fields exposed a bug in `extract_dataclass_members`:

```
@dataclass
class Dummy:
  pass
```

The type/return path was inconsistent. This PR fixes the function to
support empty dataclasses, which are useful in unions.
2026-04-17 08:47:47 +08:00
Johnsonms
982748aa73 [Hopper CuTeDSL] Add grouped GEMM persistent kernel and tests (#3091)
Implement grouped GEMM (C_g = A_g x B_g for g groups) on Hopper using
CuTe DSL, extending the dense persistent GEMM with per-group TMA
descriptor management.

Kernel design (grouped_gemm.py):
- Warp-specialized pipeline: DMA warp group handles TMA loads and
  per-group tensormap updates; MMA warp group runs WGMMA and stores C
- StaticPersistentGroupTileScheduler for cross-group tile scheduling
- Per-group TMA descriptor updates via GMEM or SMEM mode
- Supports fp16, fp8 (E4M3FN/E5M2), int8 with mixed A/B dtypes
- Configurable tile shapes (128x128, 128x256) and cluster shapes
- Fix base TensorMapManager: hoist uniform_smem_ptrs outside predicated
  block to avoid illegal @P0 R2UR on sm_90a

Tests (test/examples/CuTeDSL/hopper/test_grouped_gemm.py):
- L0 compile and L1 correctness pytest suite covering tile shapes,
  dtypes, major modes, cluster shapes, group counts, and mixed sizes
- Move to test/examples/CuTeDSL/hopper/ following sm_100a convention
- Fix deprecated startdir arg in test_sharding.py pytest hook
2026-03-18 00:40:15 -04:00
Linfeng Zheng
772fbb264e [CLI] add cutedsl fp16 gemm tutorial from 2 to 6 (#3106)
* [CLI] add fp16 gemm tutorial from 2 to 6

* [CLI] refine comments
2026-03-17 10:11:55 +08:00
Brian K. Ryu
147f5673d0 New RMS Norm example with unit tests (#2917)
* Add rmsnorm example

* Address reviewer comments. (1) use the cute.runtime definition directly. (2) use the nvvm_wrapper's warp reduce directly

* Separate out reduce.py

* Change copyright notice years
2026-01-13 09:05:31 +08:00
Junkai-Wu
0d2b201e8c v4.3.5 update. (#2934)
* v4.3.5 update.

* Update copyright to 2026
2026-01-08 15:02:56 -05:00
questa-quan-wang
2aee73922c Minor fix for testing of blockscaled dense GEMM with TMA prefetch (#2930)
* new example with TMA prefetch feature targeting for DRAM latency bound cases

* minor fix to resitrct as 100a arch

* typo

* apply arch for whole pytest

---------

Co-authored-by: Questa Wang <questaw@computelab-frontend-7.nvidia.com>
Co-authored-by: Questa Wang <questaw@umbriel-b200-145.ipp4a1.colossus.nvidia.com>
2026-01-05 16:36:03 +08:00
questa-quan-wang
3f4c086d09 new example with TMA prefetch feature targeting for DRAM latency bound cases (#2881)
Co-authored-by: Questa Wang <questaw@computelab-frontend-7.nvidia.com>
2025-12-23 15:29:48 +08:00
Linfeng Zheng
f6402fcd5e add pytest support for tutorial gemm (#2826)
* add pytest support for tutorial gemm

* add license
2025-12-05 08:45:01 -05:00