* v4.6 dev update. * Remove CUTLASS_HOST_DEVICE from CudaHostAdapater::memsetDevice (#3286) * [SM120] Add ptr-array TMA collective for tensor/token-scaled FP8 grouped GEMM (#3280) * gemm: add SM120 array TMA collective for tensor/token-scaled FP8 grouped GEMM Adds CollectiveMma and CollectiveBuilder specializations for MainloopSm120ArrayTmaWarpSpecialized, enabling ptr-array grouped GEMM (MoE expert dispatch) with tensor- and token-level FP8 scaling on SM_120/SM_121 consumer Blackwell (RTX 5090/5080/5070, DGX Spark GB10). New files: - include/cutlass/gemm/collective/sm120_mma_array_tma.hpp CollectiveMma specialization for MainloopSm120ArrayTmaWarpSpecialized. Handles both Cooperative (4x2 atom layout) and Pingpong (2x2) schedules. Grouped GEMM via pointer-array indirection through params.ptr_A / ptr_B. Supports F8F6F4 MMA with TMA loads for both A and B operands. - include/cutlass/gemm/collective/builders/sm120_array_mma_builder.inl CollectiveBuilder specialization for KernelPtrArrayTmaWarpSpecialized Cooperative/PingpongSm120<N> schedule tags. Computes tile/stage counts from smem capacity, routes to MainloopSm120ArrayTmaWarpSpecialized dispatch policy, produces correctly-typed CollectiveOp. Modified files: - collective_mma.hpp: include sm120_mma_array_tma.hpp - collective_builder.hpp: include sm120_array_mma_builder.inl - sm120_mma_builder.inl: remove ptr-array schedules from enable_if (they now route to sm120_array_mma_builder.inl) and drop the IsPtrArrayKernel static_assert that enforced the restriction Validated on real SM_121 hardware (DGX Spark, 128 GB LPDDR5X) running vLLM with RedHatAI/gemma-4-26B-A4B-it-FP8-Dynamic (Gemma 4 MoE, 26B total / 4B active). Previously fell back to a non-CUTLASS Triton path; with this patch, the SM120 CUTLASS grouped GEMM collective activates and produces correct outputs. Short-sequence throughput improved ~7% vs the fallback baseline (76.3 → 81.9 tok/s). Closes #3263 Co-authored-by: Claude <noreply@anthropic.com> Signed-off-by: Tyler Merritt <tgmerritt@gmail.com> * test: add SM120 ptr-array grouped GEMM unit tests Adds 6 device-level tests for the CollectiveMma/CollectiveBuilder specializations introduced for MainloopSm120ArrayTmaWarpSpecialized, covering both KernelPtrArrayTmaWarpSpecializedPingpongSm120<2> and KernelPtrArrayTmaWarpSpecializedCooperativeSm120<2> schedule tags across e4m3×e4m3 (symmetric), e4m3×e5m2 (mixed), float and bfloat16 outputs, and two tile shapes. Tests land in test/unit/gemm/device/sm120_tensorop_gemm/ under the new cutlass_test_unit_sm120_grouped_gemm_device_tensorop CMake target, per reviewer request in PR #3280. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> --------- Signed-off-by: Tyler Merritt <tgmerritt@gmail.com> Co-authored-by: Claude <noreply@anthropic.com> --------- Signed-off-by: Tyler Merritt <tgmerritt@gmail.com> Co-authored-by: Alex Georgiev <89279829+alexngUNC@users.noreply.github.com> Co-authored-by: Tyler <tgmerritt@gmail.com> Co-authored-by: Claude <noreply@anthropic.com>
5.2 KiB
DSL Feature Examples
This directory demonstrates CuTe DSL capabilities beyond kernel authoring itself: exporting compiled kernels for deployment, integrating with ML frameworks, using foreign function interfaces, and accessing low-level DSL features like inline PTX and shared memory allocation.
Directory Structure
dsl/
export/ Exporting kernels to C shared libraries
export_to_c.py Compile a kernel and export as .so/.dylib
load_in_python.py Load and call the exported library from Python
run_with_dynamic_loading.cpp C++ driver using dlopen
run_with_dynamic_loading.sh Build/run script for dynamic loading
run_with_static_linking.cpp C++ driver using static linking
run_with_static_linking.sh Build/run script for static linking
ffi/ Foreign function interface
jit_argument.py JIT compilation with argument passing
tensor.cpp C++ tensor interop implementation
CMakeLists.txt CMake build for FFI examples
jax/ JAX integration
cutlass_call_basic.py Basic CUTLASS kernel call from JAX
cutlass_call_export.py Export a CUTLASS kernel for JAX
cutlass_call_sharding.py Multi-device sharding with CUTLASS kernels
elementwise_apply_example.py Elementwise apply via JAX
tvm_ffi/ TVM FFI integration
jit_and_use_in_torch.py JIT compile and call from PyTorch
jit_and_use_in_jax.py JIT compile and call from JAX
aot_export.py Ahead-of-time export
aot_use_in_torch.py Use AOT-exported kernel in PyTorch
aot_use_in_jax.py Use AOT-exported kernel in JAX
aot_use_in_cpp_bundle.cpp Use AOT-exported kernel in C++
aot_use_in_cpp_bundle.sh Build/run script for C++ AOT usage
compile_with_fake_tensor.py Compile using fake tensors
compile_with_symint_arg.py Compile with symbolic integer arguments
ampere_gemm_with_fake_tensor.py Ampere GEMM with fake tensor compilation
error_reporting.py Error reporting and diagnostics
call_bypass_dlpack.py Calling kernels bypassing DLPack
call_from_jit.py Calling conventions from JIT-compiled code
cooperative_launch.py Cooperative kernel launch (multi-CTA)
dynamic_smem_size.py Dynamic shared memory allocation
inline_ptx.py Embedding inline PTX assembly
launch_completion_and_programmatic_events.py
Launch completion / programmatic events with cudaEvent_t and CUevent
pointer.py Pointer manipulation in DSL
print_latex.py LaTeX rendering of CuTe layouts
programmatic_dependent_launch.py Programmatic dependent launch (PDL)
smem_allocator.py Shared memory allocator usage
torch_fake_tensor.py PyTorch fake tensor integration
torch_fp4.py PyTorch FP4 tensor support
Subdirectory Guides
export/ -- Kernel Export
Shows how to compile a CuTe DSL kernel into a standalone C shared library (.so)
that can be loaded and called from C++ or Python without any CuTe DSL dependency
at runtime. Includes complete examples for both dynamic loading (dlopen) and
static linking workflows.
ffi/ -- Foreign Function Interface
Demonstrates how to pass arguments between Python/CuTe DSL and C++ code using the FFI layer. Useful for integrating CuTe DSL kernels into existing C++ applications.
jax/ -- JAX Integration
Shows how to call CuTe DSL kernels from JAX using cutlass_call, including
basic invocation, kernel export for JAX, multi-device sharding, and elementwise
application patterns.
tvm_ffi/ -- TVM FFI Integration
Comprehensive examples for using CuTe DSL kernels through TVM's foreign function interface. Covers both JIT and AOT (ahead-of-time) compilation workflows, with usage examples for PyTorch, JAX, and C++. Also demonstrates fake-tensor compilation (no GPU required at compile time) and symbolic integer arguments.
Top-Level Files
The top-level Python files demonstrate individual DSL features:
call_bypass_dlpack.py/call_from_jit.py-- Kernel calling conventionsinline_ptx.py-- Embedding inline PTX assembly in CuTe DSL kernelslaunch_completion_and_programmatic_events.py-- Examples oflaunch_completion_eventandprogrammatic_eventlaunch attributes, using events created viatorch.cuda.Event(enable_timing=False)and presented as eithercudaEvent_t(cuda.bindings.runtime) orCUevent(cuda.bindings.driver). The stream is passed as acudaStream_t(cuda.bindings.runtime)programmatic_dependent_launch.py-- Programmatic dependent launch for chaining kernels with data dependenciescooperative_launch.py-- Cooperative launch for multi-CTA kernelsdynamic_smem_size.py/smem_allocator.py-- Shared memory allocationtorch_fake_tensor.py/torch_fp4.py-- PyTorch integration featurespointer.py-- Pointer manipulation within DSL kernelsprint_latex.py-- Render CuTe layouts as LaTeX for visualization