Files
Junkai-Wu 39b352fa93 v4.6 dev update. (#3315)
* v4.6 dev update.

* Remove CUTLASS_HOST_DEVICE from CudaHostAdapater::memsetDevice (#3286)

* [SM120] Add ptr-array TMA collective for tensor/token-scaled FP8 grouped GEMM (#3280)

* gemm: add SM120 array TMA collective for tensor/token-scaled FP8 grouped GEMM

Adds CollectiveMma and CollectiveBuilder specializations for
MainloopSm120ArrayTmaWarpSpecialized, enabling ptr-array grouped GEMM
(MoE expert dispatch) with tensor- and token-level FP8 scaling on
SM_120/SM_121 consumer Blackwell (RTX 5090/5080/5070, DGX Spark GB10).

New files:
- include/cutlass/gemm/collective/sm120_mma_array_tma.hpp
  CollectiveMma specialization for MainloopSm120ArrayTmaWarpSpecialized.
  Handles both Cooperative (4x2 atom layout) and Pingpong (2x2) schedules.
  Grouped GEMM via pointer-array indirection through params.ptr_A / ptr_B.
  Supports F8F6F4 MMA with TMA loads for both A and B operands.

- include/cutlass/gemm/collective/builders/sm120_array_mma_builder.inl
  CollectiveBuilder specialization for KernelPtrArrayTmaWarpSpecialized
  Cooperative/PingpongSm120<N> schedule tags. Computes tile/stage counts
  from smem capacity, routes to MainloopSm120ArrayTmaWarpSpecialized
  dispatch policy, produces correctly-typed CollectiveOp.

Modified files:
- collective_mma.hpp: include sm120_mma_array_tma.hpp
- collective_builder.hpp: include sm120_array_mma_builder.inl
- sm120_mma_builder.inl: remove ptr-array schedules from enable_if
  (they now route to sm120_array_mma_builder.inl) and drop the
  IsPtrArrayKernel static_assert that enforced the restriction

Validated on real SM_121 hardware (DGX Spark, 128 GB LPDDR5X) running
vLLM with RedHatAI/gemma-4-26B-A4B-it-FP8-Dynamic (Gemma 4 MoE, 26B
total / 4B active). Previously fell back to a non-CUTLASS Triton path;
with this patch, the SM120 CUTLASS grouped GEMM collective activates and
produces correct outputs. Short-sequence throughput improved ~7% vs the
fallback baseline (76.3 → 81.9 tok/s).

Closes #3263

Co-authored-by: Claude <noreply@anthropic.com>
Signed-off-by: Tyler Merritt <tgmerritt@gmail.com>

* test: add SM120 ptr-array grouped GEMM unit tests

Adds 6 device-level tests for the CollectiveMma/CollectiveBuilder
specializations introduced for MainloopSm120ArrayTmaWarpSpecialized,
covering both KernelPtrArrayTmaWarpSpecializedPingpongSm120<2> and
KernelPtrArrayTmaWarpSpecializedCooperativeSm120<2> schedule tags across
e4m3×e4m3 (symmetric), e4m3×e5m2 (mixed), float and bfloat16 outputs,
and two tile shapes.

Tests land in test/unit/gemm/device/sm120_tensorop_gemm/ under the new
cutlass_test_unit_sm120_grouped_gemm_device_tensorop CMake target, per
reviewer request in PR #3280.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

---------

Signed-off-by: Tyler Merritt <tgmerritt@gmail.com>
Co-authored-by: Claude <noreply@anthropic.com>

---------

Signed-off-by: Tyler Merritt <tgmerritt@gmail.com>
Co-authored-by: Alex Georgiev <89279829+alexngUNC@users.noreply.github.com>
Co-authored-by: Tyler <tgmerritt@gmail.com>
Co-authored-by: Claude <noreply@anthropic.com>
2026-06-15 23:23:20 -04:00

5.2 KiB

DSL Feature Examples

This directory demonstrates CuTe DSL capabilities beyond kernel authoring itself: exporting compiled kernels for deployment, integrating with ML frameworks, using foreign function interfaces, and accessing low-level DSL features like inline PTX and shared memory allocation.


Directory Structure

dsl/
  export/                        Exporting kernels to C shared libraries
    export_to_c.py                 Compile a kernel and export as .so/.dylib
    load_in_python.py              Load and call the exported library from Python
    run_with_dynamic_loading.cpp   C++ driver using dlopen
    run_with_dynamic_loading.sh    Build/run script for dynamic loading
    run_with_static_linking.cpp    C++ driver using static linking
    run_with_static_linking.sh     Build/run script for static linking
  ffi/                           Foreign function interface
    jit_argument.py                JIT compilation with argument passing
    tensor.cpp                     C++ tensor interop implementation
    CMakeLists.txt                 CMake build for FFI examples
  jax/                           JAX integration
    cutlass_call_basic.py          Basic CUTLASS kernel call from JAX
    cutlass_call_export.py         Export a CUTLASS kernel for JAX
    cutlass_call_sharding.py       Multi-device sharding with CUTLASS kernels
    elementwise_apply_example.py   Elementwise apply via JAX
  tvm_ffi/                       TVM FFI integration
    jit_and_use_in_torch.py        JIT compile and call from PyTorch
    jit_and_use_in_jax.py          JIT compile and call from JAX
    aot_export.py                  Ahead-of-time export
    aot_use_in_torch.py            Use AOT-exported kernel in PyTorch
    aot_use_in_jax.py              Use AOT-exported kernel in JAX
    aot_use_in_cpp_bundle.cpp      Use AOT-exported kernel in C++
    aot_use_in_cpp_bundle.sh       Build/run script for C++ AOT usage
    compile_with_fake_tensor.py    Compile using fake tensors
    compile_with_symint_arg.py     Compile with symbolic integer arguments
    ampere_gemm_with_fake_tensor.py  Ampere GEMM with fake tensor compilation
    error_reporting.py             Error reporting and diagnostics
  call_bypass_dlpack.py          Calling kernels bypassing DLPack
  call_from_jit.py               Calling conventions from JIT-compiled code
  cooperative_launch.py          Cooperative kernel launch (multi-CTA)
  dynamic_smem_size.py           Dynamic shared memory allocation
  inline_ptx.py                  Embedding inline PTX assembly
  launch_completion_and_programmatic_events.py
                                 Launch completion / programmatic events with cudaEvent_t and CUevent
  pointer.py                     Pointer manipulation in DSL
  print_latex.py                 LaTeX rendering of CuTe layouts
  programmatic_dependent_launch.py  Programmatic dependent launch (PDL)
  smem_allocator.py              Shared memory allocator usage
  torch_fake_tensor.py           PyTorch fake tensor integration
  torch_fp4.py                   PyTorch FP4 tensor support

Subdirectory Guides

export/ -- Kernel Export

Shows how to compile a CuTe DSL kernel into a standalone C shared library (.so) that can be loaded and called from C++ or Python without any CuTe DSL dependency at runtime. Includes complete examples for both dynamic loading (dlopen) and static linking workflows.

ffi/ -- Foreign Function Interface

Demonstrates how to pass arguments between Python/CuTe DSL and C++ code using the FFI layer. Useful for integrating CuTe DSL kernels into existing C++ applications.

jax/ -- JAX Integration

Shows how to call CuTe DSL kernels from JAX using cutlass_call, including basic invocation, kernel export for JAX, multi-device sharding, and elementwise application patterns.

tvm_ffi/ -- TVM FFI Integration

Comprehensive examples for using CuTe DSL kernels through TVM's foreign function interface. Covers both JIT and AOT (ahead-of-time) compilation workflows, with usage examples for PyTorch, JAX, and C++. Also demonstrates fake-tensor compilation (no GPU required at compile time) and symbolic integer arguments.


Top-Level Files

The top-level Python files demonstrate individual DSL features:

  • call_bypass_dlpack.py / call_from_jit.py -- Kernel calling conventions
  • inline_ptx.py -- Embedding inline PTX assembly in CuTe DSL kernels
  • launch_completion_and_programmatic_events.py -- Examples of launch_completion_event and programmatic_event launch attributes, using events created via torch.cuda.Event(enable_timing=False) and presented as either cudaEvent_t (cuda.bindings.runtime) or CUevent (cuda.bindings.driver). The stream is passed as a cudaStream_t (cuda.bindings.runtime)
  • programmatic_dependent_launch.py -- Programmatic dependent launch for chaining kernels with data dependencies
  • cooperative_launch.py -- Cooperative launch for multi-CTA kernels
  • dynamic_smem_size.py / smem_allocator.py -- Shared memory allocation
  • torch_fake_tensor.py / torch_fp4.py -- PyTorch integration features
  • pointer.py -- Pointer manipulation within DSL kernels
  • print_latex.py -- Render CuTe layouts as LaTeX for visualization