mirror of
https://github.com/NVIDIA/cutlass.git
synced 2026-06-29 02:47:05 +00:00
* v4.6 dev update. * Remove CUTLASS_HOST_DEVICE from CudaHostAdapater::memsetDevice (#3286) * [SM120] Add ptr-array TMA collective for tensor/token-scaled FP8 grouped GEMM (#3280) * gemm: add SM120 array TMA collective for tensor/token-scaled FP8 grouped GEMM Adds CollectiveMma and CollectiveBuilder specializations for MainloopSm120ArrayTmaWarpSpecialized, enabling ptr-array grouped GEMM (MoE expert dispatch) with tensor- and token-level FP8 scaling on SM_120/SM_121 consumer Blackwell (RTX 5090/5080/5070, DGX Spark GB10). New files: - include/cutlass/gemm/collective/sm120_mma_array_tma.hpp CollectiveMma specialization for MainloopSm120ArrayTmaWarpSpecialized. Handles both Cooperative (4x2 atom layout) and Pingpong (2x2) schedules. Grouped GEMM via pointer-array indirection through params.ptr_A / ptr_B. Supports F8F6F4 MMA with TMA loads for both A and B operands. - include/cutlass/gemm/collective/builders/sm120_array_mma_builder.inl CollectiveBuilder specialization for KernelPtrArrayTmaWarpSpecialized Cooperative/PingpongSm120<N> schedule tags. Computes tile/stage counts from smem capacity, routes to MainloopSm120ArrayTmaWarpSpecialized dispatch policy, produces correctly-typed CollectiveOp. Modified files: - collective_mma.hpp: include sm120_mma_array_tma.hpp - collective_builder.hpp: include sm120_array_mma_builder.inl - sm120_mma_builder.inl: remove ptr-array schedules from enable_if (they now route to sm120_array_mma_builder.inl) and drop the IsPtrArrayKernel static_assert that enforced the restriction Validated on real SM_121 hardware (DGX Spark, 128 GB LPDDR5X) running vLLM with RedHatAI/gemma-4-26B-A4B-it-FP8-Dynamic (Gemma 4 MoE, 26B total / 4B active). Previously fell back to a non-CUTLASS Triton path; with this patch, the SM120 CUTLASS grouped GEMM collective activates and produces correct outputs. Short-sequence throughput improved ~7% vs the fallback baseline (76.3 → 81.9 tok/s). Closes #3263 Co-authored-by: Claude <noreply@anthropic.com> Signed-off-by: Tyler Merritt <tgmerritt@gmail.com> * test: add SM120 ptr-array grouped GEMM unit tests Adds 6 device-level tests for the CollectiveMma/CollectiveBuilder specializations introduced for MainloopSm120ArrayTmaWarpSpecialized, covering both KernelPtrArrayTmaWarpSpecializedPingpongSm120<2> and KernelPtrArrayTmaWarpSpecializedCooperativeSm120<2> schedule tags across e4m3×e4m3 (symmetric), e4m3×e5m2 (mixed), float and bfloat16 outputs, and two tile shapes. Tests land in test/unit/gemm/device/sm120_tensorop_gemm/ under the new cutlass_test_unit_sm120_grouped_gemm_device_tensorop CMake target, per reviewer request in PR #3280. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> --------- Signed-off-by: Tyler Merritt <tgmerritt@gmail.com> Co-authored-by: Claude <noreply@anthropic.com> --------- Signed-off-by: Tyler Merritt <tgmerritt@gmail.com> Co-authored-by: Alex Georgiev <89279829+alexngUNC@users.noreply.github.com> Co-authored-by: Tyler <tgmerritt@gmail.com> Co-authored-by: Claude <noreply@anthropic.com>
107 lines
5.2 KiB
Markdown
107 lines
5.2 KiB
Markdown
# DSL Feature Examples
|
|
|
|
This directory demonstrates **CuTe DSL capabilities** beyond kernel authoring itself:
|
|
exporting compiled kernels for deployment, integrating with ML frameworks, using
|
|
foreign function interfaces, and accessing low-level DSL features like inline PTX
|
|
and shared memory allocation.
|
|
|
|
---
|
|
|
|
## Directory Structure
|
|
|
|
```
|
|
dsl/
|
|
export/ Exporting kernels to C shared libraries
|
|
export_to_c.py Compile a kernel and export as .so/.dylib
|
|
load_in_python.py Load and call the exported library from Python
|
|
run_with_dynamic_loading.cpp C++ driver using dlopen
|
|
run_with_dynamic_loading.sh Build/run script for dynamic loading
|
|
run_with_static_linking.cpp C++ driver using static linking
|
|
run_with_static_linking.sh Build/run script for static linking
|
|
ffi/ Foreign function interface
|
|
jit_argument.py JIT compilation with argument passing
|
|
tensor.cpp C++ tensor interop implementation
|
|
CMakeLists.txt CMake build for FFI examples
|
|
jax/ JAX integration
|
|
cutlass_call_basic.py Basic CUTLASS kernel call from JAX
|
|
cutlass_call_export.py Export a CUTLASS kernel for JAX
|
|
cutlass_call_sharding.py Multi-device sharding with CUTLASS kernels
|
|
elementwise_apply_example.py Elementwise apply via JAX
|
|
tvm_ffi/ TVM FFI integration
|
|
jit_and_use_in_torch.py JIT compile and call from PyTorch
|
|
jit_and_use_in_jax.py JIT compile and call from JAX
|
|
aot_export.py Ahead-of-time export
|
|
aot_use_in_torch.py Use AOT-exported kernel in PyTorch
|
|
aot_use_in_jax.py Use AOT-exported kernel in JAX
|
|
aot_use_in_cpp_bundle.cpp Use AOT-exported kernel in C++
|
|
aot_use_in_cpp_bundle.sh Build/run script for C++ AOT usage
|
|
compile_with_fake_tensor.py Compile using fake tensors
|
|
compile_with_symint_arg.py Compile with symbolic integer arguments
|
|
ampere_gemm_with_fake_tensor.py Ampere GEMM with fake tensor compilation
|
|
error_reporting.py Error reporting and diagnostics
|
|
call_bypass_dlpack.py Calling kernels bypassing DLPack
|
|
call_from_jit.py Calling conventions from JIT-compiled code
|
|
cooperative_launch.py Cooperative kernel launch (multi-CTA)
|
|
dynamic_smem_size.py Dynamic shared memory allocation
|
|
inline_ptx.py Embedding inline PTX assembly
|
|
launch_completion_and_programmatic_events.py
|
|
Launch completion / programmatic events with cudaEvent_t and CUevent
|
|
pointer.py Pointer manipulation in DSL
|
|
print_latex.py LaTeX rendering of CuTe layouts
|
|
programmatic_dependent_launch.py Programmatic dependent launch (PDL)
|
|
smem_allocator.py Shared memory allocator usage
|
|
torch_fake_tensor.py PyTorch fake tensor integration
|
|
torch_fp4.py PyTorch FP4 tensor support
|
|
```
|
|
|
|
---
|
|
|
|
## Subdirectory Guides
|
|
|
|
### `export/` -- Kernel Export
|
|
|
|
Shows how to compile a CuTe DSL kernel into a standalone C shared library (`.so`)
|
|
that can be loaded and called from C++ or Python without any CuTe DSL dependency
|
|
at runtime. Includes complete examples for both dynamic loading (`dlopen`) and
|
|
static linking workflows.
|
|
|
|
### `ffi/` -- Foreign Function Interface
|
|
|
|
Demonstrates how to pass arguments between Python/CuTe DSL and C++ code using
|
|
the FFI layer. Useful for integrating CuTe DSL kernels into existing C++
|
|
applications.
|
|
|
|
### `jax/` -- JAX Integration
|
|
|
|
Shows how to call CuTe DSL kernels from JAX using `cutlass_call`, including
|
|
basic invocation, kernel export for JAX, multi-device sharding, and elementwise
|
|
application patterns.
|
|
|
|
### `tvm_ffi/` -- TVM FFI Integration
|
|
|
|
Comprehensive examples for using CuTe DSL kernels through TVM's foreign function
|
|
interface. Covers both JIT and AOT (ahead-of-time) compilation workflows, with
|
|
usage examples for PyTorch, JAX, and C++. Also demonstrates fake-tensor
|
|
compilation (no GPU required at compile time) and symbolic integer arguments.
|
|
|
|
---
|
|
|
|
## Top-Level Files
|
|
|
|
The top-level Python files demonstrate individual DSL features:
|
|
|
|
- **`call_bypass_dlpack.py`** / **`call_from_jit.py`** -- Kernel calling conventions
|
|
- **`inline_ptx.py`** -- Embedding inline PTX assembly in CuTe DSL kernels
|
|
- **`launch_completion_and_programmatic_events.py`** -- Examples of
|
|
``launch_completion_event`` and ``programmatic_event`` launch attributes,
|
|
using events created via ``torch.cuda.Event(enable_timing=False)`` and
|
|
presented as either ``cudaEvent_t`` (`cuda.bindings.runtime`) or ``CUevent`` (`cuda.bindings.driver`). The
|
|
stream is passed as a ``cudaStream_t`` (`cuda.bindings.runtime`)
|
|
- **`programmatic_dependent_launch.py`** -- Programmatic dependent launch for
|
|
chaining kernels with data dependencies
|
|
- **`cooperative_launch.py`** -- Cooperative launch for multi-CTA kernels
|
|
- **`dynamic_smem_size.py`** / **`smem_allocator.py`** -- Shared memory allocation
|
|
- **`torch_fake_tensor.py`** / **`torch_fp4.py`** -- PyTorch integration features
|
|
- **`pointer.py`** -- Pointer manipulation within DSL kernels
|
|
- **`print_latex.py`** -- Render CuTe layouts as LaTeX for visualization
|