v4.4.2 update. (#3104)

This commit is contained in:
Junkai-Wu
2026-03-17 12:58:19 +08:00
committed by GitHub
parent 772fbb264e
commit 1b741cabaa
31 changed files with 996 additions and 355 deletions

View File

@@ -2,6 +2,20 @@
# CUTLASS 4.x
## [4.4.2](https://github.com/NVIDIA/cutlass/releases/tag/v4.4.2) (2026-03-13)
### CuTe DSL
* New features
- CuTe DSL now supports Python 3.14 for both x86_64 and aarch64
- Runtime Pointer/Tensor/FakeTensor now supports __cache_key__, providing a stable, hashable representation that simplifies and improves compiled function caching.
* Bug fixing and improvements
- Fixed Hopper FMHA causal attention performance regression on CUDA toolkit 13.1 by
optimizing mbarrier synchronization to avoid unnecessary convergence barriers.
- Fix kernel loading race condition when multiple GPU are present in the same process in JAX.
### CUTLASS C++
* Enable Blackwell SM120f compilation of examples and exposes NVFP4/MX Grouped GEMM in the CUTLASS Profiler.
## [4.4.1](https://github.com/NVIDIA/cutlass/releases/tag/v4.4.1) (2026-02-27)
### CuTe DSL
@@ -148,8 +162,6 @@
* Work around a driver TMA descriptor related bug which will cause occasional errors on Blackwell when the tensor's backing memory allocation is less than 128KB and it is not a dense non-overlapping tensor.
## [4.3.3](https://github.com/NVIDIA/cutlass/releases/tag/v4.3.3) (2025-12-12)
### CuTe DSL
* New features
- Supported namedtuple and kwargs for JIT function arguments in tvm-ffi
- Supported variadic tuples for JIT function argument in tvm-ffi
@@ -159,8 +171,6 @@
- Clearer error message for the case of runtime error cudaErrorInsufficientDriver
## [4.3.2](https://github.com/NVIDIA/cutlass/releases/tag/v4.3.2) (2025-12-05)
### CuTe DSL
* New features
- New env var `CUTE_DSL_CACHE_DIR` to specify the path for dumping caches

View File

@@ -1,9 +1,9 @@
![ALT](./media/images/gemm-hierarchy-with-epilogue-no-labels.png "Complete CUDA GEMM decomposition")
# Overview
# CUTLASS 4.4.1
# CUTLASS 4.4.2
_CUTLASS 4.4.1 - Feb 2026_
_CUTLASS 4.4.2 - March 2026_
CUTLASS is a collection of abstractions for implementing high-performance matrix-matrix multiplication (GEMM)
and related computations at all levels and scales within CUDA. It incorporates strategies for
@@ -72,6 +72,8 @@ To get started quickly - please refer :
- We allow grid carve-out without problem shapes being available on host.
- Tma+LdMatrix features for loading+unpacking narrow-width types (refer to mixed_input_fmha_decode.py for example usage).
- It is possible now to have customized epilogue fusion for persistent dense GEMM through a Python Epilogue Fusion Configuration (EFC) function, somewhat similar to CUTLASS C++ EVT. It also provides a PyTorch evaluator to compare the results.
- CuTe DSL now supports Python 3.14 for both x86_64 and aarch64
- Runtime Pointer/Tensor/FakeTensor now supports __cache_key__, providing a stable, hashable representation that simplifies and improves compiled function caching.
* More examples of authorizing peak-performance kernels
- [SM103 batched 3xFP4 blockscaled GEMM kernel](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/blackwell/sm103_dense_blockscaled_gemm_persistent.py)
@@ -85,6 +87,9 @@ To get started quickly - please refer :
- Fixed an indexing issue of scalar tensor
- Fixed small K reference check error for cta_tile_n = 256 case with overlapping accumulator optimization in [Blackwell SM100 persistent dense blockscaled GEMM with static scheduling](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/blackwell/dense_blockscaled_gemm_persistent.py).
- Fixed a segfault issue with tvm-ffi on aarch64
- Fixed Hopper FMHA causal attention performance regression on CUDA toolkit 13.1 by
optimizing mbarrier synchronization to avoid unnecessary convergence barriers.
- Fix kernel loading race condition when multiple GPU are present in the same process in JAX.
* API changes
- Deprecate get_num_tmem_alloc_cols from blackwell_helpers.py. Use the one from tmem_allocator.py instead.
@@ -127,6 +132,7 @@ To get started quickly - please refer :
* Add support for arbitrary application-provided strides for block-scale tensors.
- Users and applications now must pass valid block-scale strides in all cases, even when the tensor is packed.
* Support 4x blockscaled public ptx for CUDA 13.1.
* Enable Blackwell SM120f compilation of examples and exposes NVFP4/MX Grouped GEMM in the CUTLASS Profiler.
* Allow non-static `TmaGbasis` in `AuxTmaParams`.
- Some cases in attention kernel may require non-static `tma_gbasis`.
- Relax the restriction on `TmaGbasis` parameter of `AuxTmaParams` and users are allowed to manually construct a dynamic gbasis.

View File

@@ -1655,9 +1655,11 @@ def run(
.to(dtype=torch_dtype(c_dtype))
.to(dtype=torch.float32)
)
# Read back the result from CuTe tensor (c_storage was updated in-place)
torch.testing.assert_close(
c_storage.to(dtype=torch.float32), ref, atol=tolerance, rtol=1e-03
c_storage.view(torch_dtype(c_dtype)).to(dtype=torch.float32),
ref,
atol=tolerance,
rtol=1e-03,
)
if not benchmark:

View File

@@ -1725,9 +1725,11 @@ def run(
.to(dtype=torch_dtype(c_dtype))
.to(dtype=torch.float32)
)
# Read back the result from CuTe tensor (c_storage was updated in-place)
torch.testing.assert_close(
c_storage.to(dtype=torch.float32), ref, atol=tolerance, rtol=1e-03
c_storage.view(torch_dtype(c_dtype)).to(dtype=torch.float32),
ref,
atol=tolerance,
rtol=1e-03,
)
if not benchmark:

View File

@@ -2546,8 +2546,6 @@ def run(
slices = tuple(slice(s, e) for s, e in zip(padding, shape_))
torch_tensor = torch_tensor_full[slices].detach()
f32_torch_tensor = f32_torch_tensor_full[slices].detach()
torch_tensor._keep_alive = torch_tensor_full
f32_torch_tensor._keep_alive = f32_torch_tensor_full
# Create dtype cute tensor with offset (gpu)
cute_tensor = from_dlpack(torch_tensor, assumed_align=16)

View File

@@ -104,6 +104,67 @@ if __name__ == "__main__":
from helpers import fmha_helpers as fmha_utils
from cutlass.cutlass_dsl import (
Boolean, Int32, if_generate, while_generate, yield_out, not_, dsl_user_op,
)
from cutlass._mlir.dialects import nvvm
from cutlass._mlir._mlir_libs._cutlass_ir._mlir.ir import IntegerType
from contextlib import contextmanager
import inspect as _inspect
_timelimit_has_res = "res" in _inspect.signature(
nvvm.mbarrier_try_wait_parity_timelimit
).parameters
def _try_wait_timelimit(llvm_ptr, phase_val, timeout, *, loc=None, ip=None):
if _timelimit_has_res:
i1 = IntegerType.get_signless(1)
return nvvm.mbarrier_try_wait_parity_timelimit(
i1, llvm_ptr, phase_val, timeout, loc=loc, ip=ip,
)
return nvvm.mbarrier_try_wait_parity_timelimit(
llvm_ptr, phase_val, timeout, loc=loc, ip=ip,
)
@dsl_user_op
def _optimized_mbarrier_wait(mbar_ptr, phase, *, loc=None, ip=None):
llvm_ptr = mbar_ptr.llvm_ptr
phase_val = Int32(phase).ir_value(loc=loc, ip=ip)
_true = lambda: Boolean(True).ir_value(loc=loc, ip=ip)
timeout = Int32(10000000).ir_value(loc=loc, ip=ip)
d = Boolean(_try_wait_timelimit(llvm_ptr, phase_val, timeout, loc=loc, ip=ip))
d = if_generate(d, _true,
lambda: _try_wait_timelimit(llvm_ptr, phase_val, timeout, loc=loc, ip=ip),
None, [Boolean], loc=loc, ip=ip)
d = if_generate(d, _true,
lambda: _try_wait_timelimit(llvm_ptr, phase_val, timeout, loc=loc, ip=ip),
None, [Boolean], loc=loc, ip=ip)
def _fallback():
inner = Boolean(False).ir_value(loc=loc, ip=ip)
ctx = while_generate([inner], lambda x: not_(x, loc=loc, ip=ip), loc=loc, ip=ip)
with ctx as (_,):
r = Boolean(_try_wait_timelimit(
llvm_ptr, phase_val, timeout, loc=loc, ip=ip,
))
yield_out([r], loc=loc, ip=ip)
return Boolean(True).ir_value(loc=loc, ip=ip)
if_generate(d, _true, _fallback, None, [Boolean], loc=loc, ip=ip)
@contextmanager
def _use_optimized_mbarrier_wait():
import cutlass.cute.arch as arch_mod
orig_wait = arch_mod.mbarrier_wait
arch_mod.mbarrier_wait = _optimized_mbarrier_wait
try:
yield
finally:
arch_mod.mbarrier_wait = orig_wait
class HopperFusedMultiHeadAttentionForward:
def __init__(
@@ -439,36 +500,37 @@ class HopperFusedMultiHeadAttentionForward:
self.shared_storage = SharedStorage
# Launch the kernel synchronously
self.kernel(
qk_tiled_mma,
pv_tiled_mma,
tma_atom_q,
tma_tensor_q,
tma_atom_k,
tma_tensor_k,
tma_atom_v,
tma_tensor_v,
tma_atom_o,
tma_tensor_o,
lse,
scale_softmax_log2,
scale_softmax,
scale_output,
window_size_left,
window_size_right,
q_smem_layout_staged,
k_smem_layout_staged,
v_smem_layout_staged,
o_smem_layout_staged,
self.tile_sched_params,
).launch(
grid=grid,
block=[self.threads_per_cta, 1, 1],
cluster=self.cluster_shape_mnk,
smem=self.shared_storage.size_in_bytes(),
stream=stream,
min_blocks_per_mp=1,
)
with _use_optimized_mbarrier_wait():
self.kernel(
qk_tiled_mma,
pv_tiled_mma,
tma_atom_q,
tma_tensor_q,
tma_atom_k,
tma_tensor_k,
tma_atom_v,
tma_tensor_v,
tma_atom_o,
tma_tensor_o,
lse,
scale_softmax_log2,
scale_softmax,
scale_output,
window_size_left,
window_size_right,
q_smem_layout_staged,
k_smem_layout_staged,
v_smem_layout_staged,
o_smem_layout_staged,
self.tile_sched_params,
).launch(
grid=grid,
block=[self.threads_per_cta, 1, 1],
cluster=self.cluster_shape_mnk,
smem=self.shared_storage.size_in_bytes(),
stream=stream,
min_blocks_per_mp=1,
)
# GPU device kernel
@cute.kernel

View File

@@ -97,7 +97,7 @@
" # Step 2: Map thread index to tensor coordinates\n",
" # -------------------------------------------\n",
" # Each thread will process one element of the input tensors\n",
" m, n = gA.shape # Get tensor dimensions (M rows × N columns)\n",
" m, n = gA.shape # Get tensor dimensions (M rows \u00d7 N columns)\n",
"\n",
" # Convert linear thread index to 2D coordinates:\n",
" # - ni: column index (0 to n-1)\n",
@@ -198,7 +198,7 @@
" num_threads_per_block = 256\n",
"\n",
" # Get input dimensions\n",
" m, n = mA.shape # Matrix dimensions (M rows × N columns)\n",
" m, n = mA.shape # Matrix dimensions (M rows \u00d7 N columns)\n",
"\n",
" # Create kernel instance\n",
" kernel = naive_elementwise_add_kernel(mA, mB, mC)\n",
@@ -298,7 +298,7 @@
" - For elementwise add:\n",
" * Read: 2 elements (A and B)\n",
" * Write: 1 element (C)\n",
" * Total bytes = (2 reads + 1 write) × elements × sizeof(dtype)\n",
" * Total bytes = (2 reads + 1 write) \u00d7 elements \u00d7 sizeof(dtype)\n",
"\n",
"Below is our benchmarking utility that measures these metrics:"
]
@@ -368,7 +368,7 @@
"\n",
"According to *Little's Law*, naive implementation has\n",
" - 1 element (4 bytes load + 2 bytes store) per thread\n",
" - 256 threads/block × N blocks\n",
" - 256 threads/block \u00d7 N blocks\n",
" - Limited in-flight operations\n",
"\n",
"In some GPUs, it's insufficient parallelism to saturate memory bandwidth.\n",
@@ -385,7 +385,35 @@
{
"cell_type": "markdown",
"metadata": {},
"source": "## Vectorized Load and Store\n\nTo improve performance according to Little's Law, we need to increase the number\nof in-flight requests. We can do this by increasing the number of bytes handled\nin each load & store operation per thread through vectorized memory access.\n\nSince Ampere GPUs support up to 128-bit per load/store and each element is 16-bit,\nwe can load 8 elements per vectorized operation on contiguous rows.\nCuTe tiling operations make this vectorization straightforward.\n\nUsing ``tiled_tensor = cute.zipped_divide(tensor, tiler)``, we can partition the input\n``tensor`` into groups of ``tiler`` blocks. For vectorization, we specify ``tiler``\nas the block of data each thread accesses (8 contiguous elements in the same row, or ``(1,8)``).\nDifferent threads can then access different blocks by indexing into the 2nd mode of ``tiled_tensor``.\n\n```python\nmA : cute.Tensor # (2048,2048):(2048,1)\ngA = cute.zipped_divide(a, tiler=(1, 8)) # tiled/vectorized => ((1,8),(2048,256)):((0,1),(2048,8))\n```\n\n$\n \\begin{array}{ccccc}\n & ((1,8) & , & (2048,256)) & : ((0,1),(2048,8)) \\\\\n & \\underbrace{\\phantom{(1,8)}}_{tiler} & & \\underbrace{\\phantom{(2048,256)}}_{threads} & \\\\\n & \\text{\\scriptsize per-thread} & & \\text{\\scriptsize num of tiles}\n \\end{array}\n$"
"source": [
"## Vectorized Load and Store\n",
"\n",
"To improve performance according to Little's Law, we need to increase the number\n",
"of in-flight requests. We can do this by increasing the number of bytes handled\n",
"in each load & store operation per thread through vectorized memory access.\n",
"\n",
"Since Ampere GPUs support up to 128-bit per load/store and each element is 16-bit,\n",
"we can load 8 elements per vectorized operation on contiguous rows.\n",
"CuTe tiling operations make this vectorization straightforward.\n",
"\n",
"Using ``tiled_tensor = cute.zipped_divide(tensor, tiler)``, we can partition the input\n",
"``tensor`` into groups of ``tiler`` blocks. For vectorization, we specify ``tiler``\n",
"as the block of data each thread accesses (8 contiguous elements in the same row, or ``(1,8)``).\n",
"Different threads can then access different blocks by indexing into the 2nd mode of ``tiled_tensor``.\n",
"\n",
"```python\n",
"mA : cute.Tensor # (2048,2048):(2048,1)\n",
"gA = cute.zipped_divide(a, tiler=(1, 8)) # tiled/vectorized => ((1,8),(2048,256)):((0,1),(2048,8))\n",
"```\n",
"\n",
"$\n",
" \\begin{array}{ccccc}\n",
" & ((1,8) & , & (2048,256)) & : ((0,1),(2048,8)) \\\\\n",
" & \\underbrace{\\phantom{(1,8)}}_{tiler} & & \\underbrace{\\phantom{(2048,256)}}_{threads} & \\\\\n",
" & \\text{\\scriptsize per-thread} & & \\text{\\scriptsize num of tiles}\n",
" \\end{array}\n",
"$"
]
},
{
"cell_type": "code",
@@ -423,14 +451,91 @@
{
"cell_type": "markdown",
"metadata": {},
"source": "This vectorized kernel follows a similar structure to its naive non-vectorized counterpart,\nwith one key difference: the tensor slicing pattern. By using `(None, (mi, ni))` as the slice indices,\nwe can extract a `(1,8)` sub-tensor from `gA`, `gB` and `gC` like \n\n$ gA[(None, (mi, ni))]: $\n\n$\n \\begin{array}{ccccc}\n Layout: & ( & (1,8) & , & (2048,256) & ) & : & ((0,1),(2048,8)) & \\xrightarrow{\\text{slice}} & ((1,8)):((0,1)) \\\\\n & & \\underbrace{\\phantom{(1,8)}} & & \\underbrace{\\phantom{(2048,256)}} & & \\\\\n Coord: & ( & None & , & (mi, ni) & ) & &\n \\end{array}\n$\n\nThen tensor data can be loaded into vector via the `gA[(None, (mi, ni))].load()` method. It is equivalent to\n\n```python\nv0 = gA[(0, (mi, ni))] # => mA[(mi, ni * 8 + 0)]\nv1 = gA[(1, (mi, ni))] # => mA[(mi, ni * 8 + 1)]\nv2 = gA[(2, (mi, ni))] # => mA[(mi, ni * 8 + 2)]\nv3 = gA[(3, (mi, ni))] # => mA[(mi, ni * 8 + 3)]\nv4 = gA[(4, (mi, ni))] # => mA[(mi, ni * 8 + 4)]\nv5 = gA[(5, (mi, ni))] # => mA[(mi, ni * 8 + 5)]\nv6 = gA[(6, (mi, ni))] # => mA[(mi, ni * 8 + 6)]\nv7 = gA[(7, (mi, ni))] # => mA[(mi, ni * 8 + 7)]\n```\n\n### Assumed Alignment\n\nIn order to guide compile to use vectorized load/store, we must tell compiler to assume alignment of incoming pointer. \nIt's on users side to guarantee actual pointer at runtime meet the alignment restriction.\n\n```python\na_ = from_dlpack(a, assumed_align=16)\nb_ = from_dlpack(b, assumed_align=16)\nc_ = from_dlpack(c, assumed_align=16)\n\n# Compile kernel with alignment assumption\ncompiled_func = cute.compile(vectorized_elementwise_add, a_, b_, c_)\n```\n\nIt's worth to note that partitioned or tiled tensor could have different alignment of its base pointer because of offset\nduring sub-slice."
"source": [
"This vectorized kernel follows a similar structure to its naive non-vectorized counterpart,\n",
"with one key difference: the tensor slicing pattern. By using `(None, (mi, ni))` as the slice indices,\n",
"we can extract a `(1,8)` sub-tensor from `gA`, `gB` and `gC` like \n",
"\n",
"$ gA[(None, (mi, ni))]: $\n",
"\n",
"$\n",
" \\begin{array}{ccccc}\n",
" Layout: & ( & (1,8) & , & (2048,256) & ) & : & ((0,1),(2048,8)) & \\xrightarrow{\\text{slice}} & ((1,8)):((0,1)) \\\\\n",
" & & \\underbrace{\\phantom{(1,8)}} & & \\underbrace{\\phantom{(2048,256)}} & & \\\\\n",
" Coord: & ( & None & , & (mi, ni) & ) & &\n",
" \\end{array}\n",
"$\n",
"\n",
"Then tensor data can be loaded into vector via the `gA[(None, (mi, ni))].load()` method. It is equivalent to\n",
"\n",
"```python\n",
"v0 = gA[(0, (mi, ni))] # => mA[(mi, ni * 8 + 0)]\n",
"v1 = gA[(1, (mi, ni))] # => mA[(mi, ni * 8 + 1)]\n",
"v2 = gA[(2, (mi, ni))] # => mA[(mi, ni * 8 + 2)]\n",
"v3 = gA[(3, (mi, ni))] # => mA[(mi, ni * 8 + 3)]\n",
"v4 = gA[(4, (mi, ni))] # => mA[(mi, ni * 8 + 4)]\n",
"v5 = gA[(5, (mi, ni))] # => mA[(mi, ni * 8 + 5)]\n",
"v6 = gA[(6, (mi, ni))] # => mA[(mi, ni * 8 + 6)]\n",
"v7 = gA[(7, (mi, ni))] # => mA[(mi, ni * 8 + 7)]\n",
"```\n",
"\n",
"### Assumed Alignment\n",
"\n",
"In order to guide compile to use vectorized load/store, we must tell compiler to assume alignment of incoming pointer. \n",
"It's on users side to guarantee actual pointer at runtime meet the alignment restriction.\n",
"\n",
"```python\n",
"a_ = from_dlpack(a, assumed_align=16)\n",
"b_ = from_dlpack(b, assumed_align=16)\n",
"c_ = from_dlpack(c, assumed_align=16)\n",
"\n",
"# Compile kernel with alignment assumption\n",
"compiled_func = cute.compile(vectorized_elementwise_add, a_, b_, c_)\n",
"```\n",
"\n",
"It's worth to note that partitioned or tiled tensor could have different alignment of its base pointer because of offset\n",
"during sub-slice."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": "@cute.jit\ndef vectorized_elementwise_add(mA: cute.Tensor, mB: cute.Tensor, mC: cute.Tensor):\n threads_per_block = 256\n\n gA = cute.zipped_divide(mA, (1, 8))\n gB = cute.zipped_divide(mB, (1, 8))\n gC = cute.zipped_divide(mC, (1, 8))\n\n print(\"[DSL INFO] Tiled Tensors:\")\n print(f\"[DSL INFO] gA = {gA}\")\n print(f\"[DSL INFO] gB = {gB}\")\n print(f\"[DSL INFO] gC = {gC}\")\n\n vectorized_elementwise_add_kernel(gA, gB, gC).launch(\n grid=(cute.size(gC, mode=[1]) // threads_per_block, 1, 1),\n block=(threads_per_block, 1, 1),\n )\n\n\na = torch.randn(M, N, device=\"cuda\", dtype=torch.float16)\nb = torch.randn(M, N, device=\"cuda\", dtype=torch.float16)\nc = torch.zeros(M, N, device=\"cuda\", dtype=torch.float16)\n\na_ = from_dlpack(a, assumed_align=16)\nb_ = from_dlpack(b, assumed_align=16)\nc_ = from_dlpack(c, assumed_align=16)\n\ncompiled_func = cute.compile(vectorized_elementwise_add, a_, b_, c_)\ncompiled_func(a_, b_, c_)\n\n# verify correctness\ntorch.testing.assert_close(c, a + b)"
"source": [
"@cute.jit\n",
"def vectorized_elementwise_add(mA: cute.Tensor, mB: cute.Tensor, mC: cute.Tensor):\n",
" threads_per_block = 256\n",
"\n",
" gA = cute.zipped_divide(mA, (1, 8))\n",
" gB = cute.zipped_divide(mB, (1, 8))\n",
" gC = cute.zipped_divide(mC, (1, 8))\n",
"\n",
" print(\"[DSL INFO] Tiled Tensors:\")\n",
" print(f\"[DSL INFO] gA = {gA}\")\n",
" print(f\"[DSL INFO] gB = {gB}\")\n",
" print(f\"[DSL INFO] gC = {gC}\")\n",
"\n",
" vectorized_elementwise_add_kernel(gA, gB, gC).launch(\n",
" grid=(cute.size(gC, mode=[1]) // threads_per_block, 1, 1),\n",
" block=(threads_per_block, 1, 1),\n",
" )\n",
"\n",
"\n",
"a = torch.randn(M, N, device=\"cuda\", dtype=torch.float16)\n",
"b = torch.randn(M, N, device=\"cuda\", dtype=torch.float16)\n",
"c = torch.zeros(M, N, device=\"cuda\", dtype=torch.float16)\n",
"\n",
"a_ = from_dlpack(a, assumed_align=16)\n",
"b_ = from_dlpack(b, assumed_align=16)\n",
"c_ = from_dlpack(c, assumed_align=16)\n",
"\n",
"compiled_func = cute.compile(vectorized_elementwise_add, a_, b_, c_)\n",
"compiled_func(a_, b_, c_)\n",
"\n",
"# verify correctness\n",
"torch.testing.assert_close(c, a + b)"
]
},
{
"cell_type": "code",
@@ -444,7 +549,68 @@
{
"cell_type": "markdown",
"metadata": {},
"source": "## TV Layout\n\nBoth the naive and vectorized kernels follow a common pattern to map thread indices\nto physical addresses in two steps:\n\nStep 1: Map thread index to logical coordinates in `(M, N)`\n\n* `mi = thread_idx // n`\n* `ni = thread_idx % n`\n\nIn native version, each thread process 1 element, in this case, `mi` and `ni` is logical\ncoordinate into data tensor `mA`, `mB` and `mC`.\n\nInt vectorized version, each thread process multiple values of input and output tensor.\nlogical coordinate should be computed with both thread and value index.\n\n* `thread_idx // n`\n* `(thread_idx % n) * 8 + value_idx`\n\n\nStep 2: Map logical coordinates in `(M, N)` to physical addresses using the tensor layout\n\n* Vectorized Load\n\n```python\n frgA = gA[(None, (mi, ni))].load()\n```\n\n* Elementwise Load (less efficient)\n\n```python\n frgA0 = mA[(mi, ni * 8 + 0)]\n frgA1 = mA[(mi, ni * 8 + 1)]\n frgA2 = mA[(mi, ni * 8 + 2)]\n frgA3 = mA[(mi, ni * 8 + 3)]\n frgA4 = mA[(mi, ni * 8 + 4)]\n frgA5 = mA[(mi, ni * 8 + 5)]\n frgA6 = mA[(mi, ni * 8 + 6)]\n frgA7 = mA[(mi, ni * 8 + 7)]\n\n # Or use divided layout\n\n frgA0 = gA[(0, (mi, ni))]\n frgA1 = gA[(1, (mi, ni))]\n frgA2 = gA[(2, (mi, ni))]\n frgA3 = gA[(3, (mi, ni))]\n frgA4 = gA[(4, (mi, ni))]\n frgA5 = gA[(5, (mi, ni))]\n frgA6 = gA[(6, (mi, ni))]\n frgA7 = gA[(7, (mi, ni))]\n```\n\nCuTe introduces TV layout to represent this mapping from thread index and value index\n(i.e., the 8 elements loaded per thread) to the logical coordinate space of a tensor.\nBy configuring different TV layouts, we can experiment with different memory access\npatterns with minimal code changes.\n\n**Definition:** *TV Layout* is rank-2 layout which maps `(thread_index, value_index)` \nto logical coordinate of tensor. \n\nWe always have *TV Layout* with canonical form as `(thread_domain, value_domain):(..., ...)`.\n\nWith *TV Layout*, each thread can find logical coordinates or indices of data partitioned\nto current thread."
"source": [
"## TV Layout\n",
"\n",
"Both the naive and vectorized kernels follow a common pattern to map thread indices\n",
"to physical addresses in two steps:\n",
"\n",
"Step 1: Map thread index to logical coordinates in `(M, N)`\n",
"\n",
"* `mi = thread_idx // n`\n",
"* `ni = thread_idx % n`\n",
"\n",
"In native version, each thread process 1 element, in this case, `mi` and `ni` is logical\n",
"coordinate into data tensor `mA`, `mB` and `mC`.\n",
"\n",
"Int vectorized version, each thread process multiple values of input and output tensor.\n",
"logical coordinate should be computed with both thread and value index.\n",
"\n",
"* `thread_idx // n`\n",
"* `(thread_idx % n) * 8 + value_idx`\n",
"\n",
"\n",
"Step 2: Map logical coordinates in `(M, N)` to physical addresses using the tensor layout\n",
"\n",
"* Vectorized Load\n",
"\n",
"```python\n",
" frgA = gA[(None, (mi, ni))].load()\n",
"```\n",
"\n",
"* Elementwise Load (less efficient)\n",
"\n",
"```python\n",
" frgA0 = mA[(mi, ni * 8 + 0)]\n",
" frgA1 = mA[(mi, ni * 8 + 1)]\n",
" frgA2 = mA[(mi, ni * 8 + 2)]\n",
" frgA3 = mA[(mi, ni * 8 + 3)]\n",
" frgA4 = mA[(mi, ni * 8 + 4)]\n",
" frgA5 = mA[(mi, ni * 8 + 5)]\n",
" frgA6 = mA[(mi, ni * 8 + 6)]\n",
" frgA7 = mA[(mi, ni * 8 + 7)]\n",
"\n",
" # Or use divided layout\n",
"\n",
" frgA0 = gA[(0, (mi, ni))]\n",
" frgA1 = gA[(1, (mi, ni))]\n",
" frgA2 = gA[(2, (mi, ni))]\n",
" frgA3 = gA[(3, (mi, ni))]\n",
"```\n",
"\n",
"CuTe introduces TV layout to represent this mapping from thread index and value index\n",
"(i.e., the 4 elements loaded per thread) to the logical coordinate space of a tensor.\n",
"By configuring different TV layouts, we can experiment with different memory access\n",
"patterns with minimal code changes.\n",
"\n",
"**Definition:** *TV Layout* is rank-2 layout which maps `(thread_index, value_index)` \n",
"to logical coordinate of tensor. \n",
"\n",
"We always have *TV Layout* with canonical form as `(thread_domain, value_domain):(..., ...)`.\n",
"\n",
"With *TV Layout*, each thread can find logical coordinates or indices of data partitioned\n",
"to current thread.\n"
]
},
{
"cell_type": "markdown",
@@ -1057,4 +1223,4 @@
},
"nbformat": 4,
"nbformat_minor": 4
}
}

View File

@@ -279,13 +279,13 @@ struct CollectiveBuilder<
// Basic storage block for new Scaling Factor Layouts
using mnBasicBlockShape = Shape<_32,_4>;
using mnBasicBlockStride = Stride<_16,_4>;
using kBasicBlockShape = Shape<Int<SFVectorSize>, Int<MMA_NSF>>;
using kBasicBlockShape = Shape<Int<(int)SFVectorSize>, Int<MMA_NSF>>;
using kBasicBlockStride = Stride<_0, _1>;
using sSFA_shapeM = decltype(prepend(size<0>(TileShape_MNK{}) / Blk_MN{}, mnBasicBlockShape{}));
using sSF_strideMN = decltype(prepend( Blk_Elems{}, mnBasicBlockStride{}));
using sSFA_strideM = sSF_strideMN;
using sSF_shapeK = decltype(prepend(make_shape( Blk_SF{}/Int<MMA_NSF>{}, size<2>(TileShape_MNK{}) / Int<SFVectorSize>{} / Blk_SF{}), kBasicBlockShape{}));
using sSF_shapeK = decltype(prepend(make_shape( Blk_SF{}/Int<MMA_NSF>{}, size<2>(TileShape_MNK{}) / Int<(int)SFVectorSize>{} / Blk_SF{}), kBasicBlockShape{}));
using sSFA_strideK = decltype(prepend(make_stride( Int<MMA_NSF>{}, size<0>(TileShape_MNK{}) / Blk_MN{} * Blk_Elems{}), kBasicBlockStride{}));
using sSFA_shape = decltype(make_shape( sSFA_shapeM{}, sSF_shapeK{}));

View File

@@ -36,7 +36,7 @@
#define CUTLASS_MAJOR 4
#define CUTLASS_MINOR 4
#define CUTLASS_PATCH 1
#define CUTLASS_PATCH 2
#ifdef CUTLASS_VERSIONS_GENERATED
#include "cutlass/version_extended.h"

View File

@@ -175,22 +175,21 @@ stride is set to 1 unless inconsistent with the layout of the DLPack tensor. For
The default value for ``leading_dim`` is ``None``. In such case, the system
automatically deduces it from the tensor's layout using the following logic:
1. If a dimension's stride is 1, that dimension is marked as the leading dimension.
2. If multiple dimensions satisfy condition 1, an error is thrown indicating deduction failure.
1. If exactly one dimension has stride 1, that dimension is the leading dimension.
2. If multiple dimensions have stride 1, deduction succeeds only when exactly one of them
has size > 1 (that dimension is used). If none or more than one has size > 1, an error is raised.
Note that after converting a **PyTorch** tensor to the DLPack format, the stride for dimensions
with size 1 are canonicalized to 1. This canonicalization can increase the likelihood of
deduction failures. This behavior is specific to PyTorch and does not occur with NumPy for
example.
3. If no dimension satisfies condition 1, all strides are marked as dynamic.
with size 1 are canonicalized to 1, which can produce multiple stride-1 dimensions.
3. If no dimension has stride 1, all strides remain dynamic.
For example:
- For a tensor with layout ``(2,2,3,4):(2,1,4,12)``, the leading dimension is 1.
The layout will be marked as ``(?,?,?,?):(?,1,?,?)``.
- For a tensor with layout ``(1,5,1):(1,1,1)``, if ``leading_dim`` is not specified,
a deduction failure error is raised.
- For a tensor with layout ``(2,2):(8,2)``, since no dimension has stride 1,
all dimensions are marked as dynamic: ``(?,?):(?,?)``.
- For a tensor with layout ``(1,5,1):(1,1,1)``, multiple dimensions have stride 1 but exactly one
has size > 1 (dim 1). The leading dimension is deduced to be 1: ``(?,?,?):(?,1,?)``.
- For a tensor with layout ``(2,2):(8,2)``, no dimension has stride 1, so all strides remain
dynamic: ``(?,?):(?,?)``.
The leading dimension accepts negative index which means the dimension is counted from the last dimension. For example,
@@ -206,8 +205,9 @@ The following example demonstrates how to use ``mark_layout_dynamic`` to specify
* ``t1`` & ``t2`` shows the usage of ``mark_layout_dynamic`` with specified ``leading_dim``.
* ``t3`` shows the usage of ``mark_layout_dynamic`` with no leading dimension.
* ``t4`` shows the usage of ``mark_layout_dynamic`` with broadcasted dimensions.
* ``t5`` demonstrates the deduction failure when the there're more than one dimensions with stride equals to 1.
* ``t6`` & ``t7`` demonstrates incorrect settings for ``leading_dim`` and expected errors.
* ``t5`` shows automatic deduction for tensor ``b`` (multiple stride-1, exactly one has size > 1 → dim 1).
* ``t5_fail`` demonstrates the deduction failure when multiple dimensions have stride 1 but none has size > 1.
* ``t6`` & ``t7`` demonstrate incorrect settings for ``leading_dim`` and expected errors.
.. code-block:: python
@@ -245,8 +245,14 @@ The following example demonstrates how to use ``mark_layout_dynamic`` to specify
print(t4)
# (?,?,?,?):(?,0,0,1)
# b has layout (1,4,1,32,1):(1,1,1,4,1); dim 1 has size > 1, so deduction succeeds to dim 1.
t5 = from_dlpack(b).mark_layout_dynamic()
# Can't decude the leading dimension from layout, please specify the leading_dim explicitly.
print(t5)
# (?,?,?,?,?):(?{i64},1,?{i64},?{i64},?{i64})
# Rejected: multiple stride-1, none with size > 1 (e.g. torch.ones(1,1,1)).
t5_fail = from_dlpack(torch.ones(1, 1, 1)).mark_layout_dynamic()
# Can't deduce the leading dimension from layout (multiple dimensions have stride 1 but none has size > 1)...
t6 = from_dlpack(a).mark_layout_dynamic(leading_dim=1)
# Expected strides[leading_dim] == 1, but got 16

View File

@@ -68,7 +68,8 @@ class ExternalBinaryModule:
load_provider: LoadProvider = None
def __init__(self, file_path: str):
def __init__(self, file_path: str, enable_tvm_ffi: bool = False):
self.enable_tvm_ffi = enable_tvm_ffi
assert self.load_provider is not None, (
"Load provider is not set for ExternalBinaryModule."
)
@@ -82,13 +83,28 @@ class ExternalBinaryModule:
object_file_content = f.read()
except Exception as e:
raise DSLRuntimeError(f"Failed to read object file {file_path}: {e}")
useJitLink = not enable_tvm_ffi
# Lifetime of the engine is same as the ExternalBinaryModule.
self.engine = self.load_provider.execution_engine_constructor(
object_file_content, shared_libs
object_file_content, shared_libs, useJitLink
)
def __getattr__(self, function_prefix: str) -> "JitCompiledFunction":
"""Get the jit_function from the `function_prefix`. The `function_prefix` is specified when users dump the object file. When there is no function_prefix found in the module, the function will raise an error."""
if self.enable_tvm_ffi:
try:
import tvm_ffi
function_ptr = self.engine.lookup("__tvm_ffi_" + function_prefix)
return tvm_ffi.Function.__from_extern_c__(
function_ptr, keep_alive_object=self.engine
)
except Exception as e:
raise DSLRuntimeError(
f"Failed to load TVM FFI function {function_prefix}: {e}"
)
try:
args_spec, function_name, kernel_info, version_str = (
decode_metadata_from_execution_engine(
@@ -124,3 +140,7 @@ class ExternalBinaryModule:
load_from_binary=True,
)
return jit_function
def __getitem__(self, function_prefix: str) -> "JitCompiledFunction":
"""Get the jit_function from the `function_prefix`. The `function_prefix` is specified when users dump the object file. When there is no function_prefix found in the module, the function will raise an error."""
return self.__getattr__(function_prefix)

View File

@@ -202,6 +202,8 @@ from .math import *
# Used as internal symbol
from .. import cutlass_dsl as _dsl
from .ffi import ffi
# Aliases
jit = _dsl.CuTeDSL.jit
kernel = _dsl.CuTeDSL.kernel
@@ -312,4 +314,5 @@ __all__ = [
"kernel",
"register_jit_arg_adapter",
"compile",
"ffi",
]

View File

@@ -96,7 +96,6 @@ __all__ = [
"fma_packed_f32x2",
"mul_packed_f32x2",
"add_packed_f32x2",
"sub_packed_f32x2",
"fmax",
"rcp_approx",
"exp2",

View File

@@ -23,7 +23,6 @@ from ..typing import Int32, Pointer, Int128
def issue_clc_query(
mbar_ptr: Pointer,
clc_response_ptr: Pointer,
multicast: bool = True,
loc=None,
ip=None,
) -> None:
@@ -40,20 +39,12 @@ def issue_clc_query(
"""
mbar_llvm_ptr = mbar_ptr.llvm_ptr
clc_response_llvm_ptr = clc_response_ptr.llvm_ptr
if multicast:
nvvm.clusterlaunchcontrol_try_cancel_multicast(
clc_response_llvm_ptr,
mbar_llvm_ptr,
loc=loc,
ip=ip,
)
else:
nvvm.clusterlaunchcontrol_try_cancel(
clc_response_llvm_ptr,
mbar_llvm_ptr,
loc=loc,
ip=ip,
)
nvvm.clusterlaunchcontrol_try_cancel_multicast(
clc_response_llvm_ptr,
mbar_llvm_ptr,
loc=loc,
ip=ip,
)
@dsl_user_op

View File

@@ -604,7 +604,6 @@ def fence_proxy(
],
*,
space: Optional[Literal["cta", "cluster"]] = None,
use_intrinsic=None,
loc=None,
ip=None,
) -> None:
@@ -623,7 +622,6 @@ def fence_proxy(
- "cta" : CTA (Cooperative Thread Array) scope
- "cluster" : Cluster scope
:type space: Optional[Literal["cta", "cluster"]]
:param use_intrinsic: Whether to use intrinsic version
"""
from cutlass._mlir.dialects.nvvm import (
SharedSpace,
@@ -640,7 +638,6 @@ def fence_proxy(
nvvm.fence_proxy(
kind=kind,
space=space,
use_intrinsic=use_intrinsic,
loc=loc,
ip=ip,
)
@@ -940,9 +937,6 @@ mul_packed_f32x2 = partial(
add_packed_f32x2 = partial(
calc_packed_f32x2_op, src_c=None, calc_func=nvvm.add_packed_f32x2
)
sub_packed_f32x2 = partial(
calc_packed_f32x2_op, src_c=None, calc_func=nvvm.sub_packed_f32x2
)
@dsl_user_op
@@ -959,20 +953,6 @@ def fmax(
)
@dsl_user_op
def fmin(
a: Union[float, Float32], b: Union[float, Float32], *, loc=None, ip=None
) -> Float32:
return Float32(
nvvm.fmin(
Float32(a).ir_value(loc=loc, ip=ip),
Float32(b).ir_value(loc=loc, ip=ip),
loc=loc,
ip=ip,
)
)
@dsl_user_op
def rcp_approx(a: Union[float, Float32], *, loc=None, ip=None):
return Float32(

View File

@@ -1587,7 +1587,7 @@ def pretty_str(arg) -> str:
@dsl_user_op
def printf(*args, loc=None, ip=None, end="\n") -> None:
def printf(*args, loc=None, ip=None) -> None:
"""
Print one or more values with optional formatting.
@@ -1607,8 +1607,6 @@ def printf(*args, loc=None, ip=None, end="\n") -> None:
:type loc: Optional[Location]
:param ip: Insertion point for code generation, defaults to None
:type ip: Optional[InsertionPoint]
:param end: Suffix for the printed value, defaults to newline
:type end: Optional[str]
:raises ValueError: If no arguments are provided
:raises TypeError: If an unsupported argument type is passed
@@ -1638,10 +1636,10 @@ def printf(*args, loc=None, ip=None, end="\n") -> None:
raise ValueError("expects at least one argument to print")
if isinstance(args[0], str):
fmt = args[0] + end
fmt = args[0] + "\n"
args = args[1:]
else:
fmt = "{}" + ", {}" * (len(args) - 1) + end
fmt = "{}" + ", {}" * (len(args) - 1) + "\n"
def process_arg(arg):
arg0 = arg.value if isinstance(arg, Numeric) else arg
@@ -3762,6 +3760,35 @@ def zipped_divide(target: Tensor, tiler: Tiler, *, loc=None, ip=None) -> Tensor:
@dsl_user_op
def zipped_divide(target, tiler: Tiler, *, loc=None, ip=None):
"""
``zipped_divide`` is ``logical_divide`` with Tiler modes and Rest modes gathered together: ``(Tiler,Rest)``
- When Tiler is Layout, this has no effect as ``logical_divide`` results in the same.
- When Tiler is ``Tile`` (nested tuple of ``Layout``) or ``Shape``, this zips modes into standard form
``((BLK_A,BLK_B),(a,b,x,y))``
For example, if ``target`` has shape ``(s, t, r)`` and ``tiler`` has shape ``(BLK_A, BLK_B)``,
then the result will have shape ``((BLK_A, BLK_B), (ceil_div(s, BLK_A), ceil_div(t, BLK_B), r))``.
:param target: The layout or tensor to partition.
:type target: Layout or Tensor
:param tiler: The tiling specification (can be a Layout, Shape, Tile).
:type tiler: Tiler
:param loc: Optional MLIR IR location information.
:type loc: optional
:param ip: Optional MLIR IR insertion point.
:type ip: optional
:return: A zipped (partitioned) version of the target.
:rtype: Layout or Tensor
**Example:**
.. code-block:: python
layout = cute.make_layout((128, 64), stride=(64, 1))
tiler = (8, 8)
result = cute.zipped_divide(layout, tiler) # result shape: ((8, 8), (16, 8))
"""
if isinstance(tiler, tuple):
tiler = _pack_tile(tiler, loc=loc, ip=ip) # type: ignore
return _op_wrapper(
@@ -3904,6 +3931,73 @@ def local_tile(
loc=None,
ip=None,
) -> Tensor:
"""
Partition a tensor into tiles using a tiler and extract a single tile at the provided coordinate.
The ``local_tile`` operation applies a ``zipped_divide`` to split the ``input`` tensor by the ``tiler``
and then slices out a single tile using the provided `coord`. This is commonly used for extracting block-,
thread-, or CTA-level tiles for parallel operations.
.. math::
\\text{local_tile}(input, tiler, coord) = \\text{zipped_divide}(input, tiler)[coord]
This function corresponds to the CUTE/C++ `local_tile` utility:
https://docs.nvidia.com/cutlass/media/docs/cpp/cute/03_tensor.html#local-tile
:param input: The input tensor to partition into tiles.
:type input: Tensor
:param tiler: The tiling specification (can be a Layout, Shape, Tile).
:type tiler: Tiler
:param coord: The coordinate to select within the remainder ("rest") modes after tiling.
This selects which tile to extract.
:type coord: Coord
:param proj: (Optional) Projection onto tiling modes; specify to project out unused tiler modes,
e.g., when working with projections of tilers in multi-mode partitioning.
Default is None for no projection.
:type proj: XTuple, optional
:param loc: (Optional) MLIR location, for diagnostic/debugging.
:type loc: Any, optional
:param ip: (Optional) MLIR insertion point, used in IR building context.
:type ip: Any, optional
:return: A new tensor representing the local tile selected at the given coordinate.
:rtype: Tensor
**Examples**
1. Tiling a 2D tensor and extracting a tile:
.. code-block:: python
# input: (16, 24)
tensor : cute.Tensor
tiler = (2, 4)
coord = (1, 1)
# output: (8, 6)
# - zipped_divide(tensor, tiler) -> ((2, 4), (8, 6))
# - local_tile(tensor, tiler, coord) -> (8, 6)
result = cute.local_tile(tensor, tiler=tiler, coord=coord)
2. Using a stride projection for specialized tiling:
.. code-block:: python
# input: (16, 24)
tensor : cute.Tensor
tiler = (2, 2, 4)
coord = (0, 1, 1)
proj = (1, None, 1)
# output: (8, 6)
# projected_tiler: (2, 4)
# projected_coord: (0, 1)
# - zipped_divide(tensor, projected_tiler) -> ((2, 4), (8, 6))
# - local_tile(tensor, projected_tiler, projected_coord) -> (8, 6)
result = cute.local_tile(tensor, tiler=tiler, coord=coord, proj=proj)
"""
tiler_val = _pack_tile(tiler, loc=loc, ip=ip)
coord_val = _pack_coord(coord, loc=loc, ip=ip)
if proj is not None:

View File

@@ -0,0 +1,206 @@
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
#
# Use of this software is governed by the terms and conditions of the
# NVIDIA End User License Agreement (EULA), available at:
# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html
#
# Any use, reproduction, disclosure, or distribution of this software
# and related documentation outside the scope permitted by the EULA
# is strictly prohibited.
from cutlass._mlir import ir
from cutlass._mlir.dialects import func
from cutlass.base_dsl.typing import get_mlir_types, NumericMeta, Numeric, as_numeric
from cutlass.base_dsl.dsl import extract_mlir_values
from cutlass import DSLRuntimeError
class ffi:
"""
Foreign Function Interface (FFI) wrapper for external function invocation in the CUTLASS Python DSL.
This class enables calling external MLIR function prototypes from Python code, handling type conversion,
prototype registration, and dynamic insertion of function symbols into MLIR modules as needed.
Parameters
----------
name : str
Name of the external function. This will be used as the symbol name when calling or registering a prototype in the MLIR module.
params_types : list, optional
List of argument types for the external function. These can be CUTLASS numeric types, numeric meta types, or types convertible via `get_mlir_types`.
return_type : optional
The return type of the external function. If not specified, the function is assumed to have no return value.
Methods
-------
__call__(*args)
Calls the external function with the given arguments, ensuring argument and result types match the prototype.
"""
def __init__(self, *, name: str, params_types: list = [], return_type=None):
self.name = name
self.params_types = params_types
self.return_type = [return_type] if return_type else []
def _get_prototype_region(self, current_op):
"""
Helper method to determine the appropriate MLIR module and region for inserting a function prototype.
This method recursively traverses the current operation's parent hierarchy to find the correct module
and region where the function prototype should be inserted. It supports both builtin.module and gpu.module.
:param current_op: The current operation to check.
:type current_op: Operation
:returns:
A tuple containing the module operation and the insertion region.
:rtype: tuple
"""
if current_op is None:
raise DSLRuntimeError("current operation is unknown")
op_name = current_op.name
if op_name in ["builtin.module", "gpu.module"]:
return current_op, current_op.regions[0].blocks[0]
else:
return self._get_prototype_region(current_op.parent)
@staticmethod
def _to_mlir_types(args):
"""
Helper method to convert a list of arguments to their corresponding MLIR types.
This method converts CUTLASS numeric types, numeric meta types, and types convertible via `get_mlir_types`
to their corresponding MLIR types.
:param args: The list of arguments to convert to MLIR types.
:type args: list
:returns:
A list of MLIR types.
:rtype: list
"""
types = []
for param in args:
if isinstance(param, NumericMeta):
types.append(param.mlir_type)
elif isinstance(param, Numeric):
types.append(param.mlir_type)
else:
types.extend(get_mlir_types(param))
return types
@staticmethod
def _type_check(callee, exec_types, returns_types):
"""
Helper method to check if the function prototype types match the expected types.
This method compares the input and output types of the function prototype with the provided expected types.
:param callee: The function prototype operation to check.
:type callee: func.FuncOp
:param exec_types: The expected input types.
:type exec_types: list
:param returns_types: The expected output types.
:type returns_types: list
"""
if callee.type.inputs != exec_types or callee.type.results != returns_types:
raise DSLRuntimeError(
f"External prototype types mismatch, trying to call with ({exec_types}) -> ({returns_types}), got {callee.type}"
)
def _create_prototype_in_region(self, op, region, exec_args):
"""
Helper method to create or retrieve a function prototype in the current module.
This method checks if a function prototype with the given name already exists in the symbol table of the current module.
If it does, it checks if the prototype's types match the expected types. If it does not, it raises an error.
If it does not exist, it creates a new function prototype and inserts it into the current region.
:param op: The module operation to check.
:type op: Operation
:param region: The region to insert the function prototype into.
:type region: Region
:param exec_args: The arguments to pass to the function prototype.
:type exec_args: list
"""
symbol_table = ir.SymbolTable(op.operation)
if self.name in symbol_table:
callee = symbol_table[self.name]
else:
with ir.InsertionPoint(region):
callee = func.FuncOp(
self.name,
(
ffi._to_mlir_types(self.params_types),
ffi._to_mlir_types(self.return_type),
),
)
callee.sym_visibility = ir.StringAttr.get("private")
# Sanity check the function prototype types match the expected types
self._type_check(
callee,
ffi._to_mlir_types(exec_args),
ffi._to_mlir_types(self.return_type),
)
return callee
def __call__(self, *args, **kwargs):
"""
Calls the FFI function prototype with the provided arguments.
This method ensures that an IR-level function prototype (external declaration)
with the given name and type signature exists in the current module. If it does not
exist, it will be created and inserted into the module. A call operation to this
function is then emitted using the arguments supplied by the caller.
:param args:
The runtime arguments to pass to the FFI function. These will be converted to
their corresponding numeric types and lowered to MLIR values before being used as arguments.
:type args: tuple
:returns:
The MLIR call operation created for this invocation.
:rtype: func.CallOp
:raises DSLRuntimeError:
If there is no active MLIR insertion point or if the current operation
context cannot be determined.
"""
if kwargs:
raise DSLRuntimeError(
"Keyword arguments are not supported for FFI calls",
suggestion="Use positional arguments only",
)
# Get the current insertion point and operation
try:
current_ip = ir.InsertionPoint.current
except Exception:
raise DSLRuntimeError(
"Failed to determine current insertion point",
suggestion="Make sure this is called under a jit context",
)
current_op = current_ip.block.owner
module_op, insertion_region = self._get_prototype_region(current_op)
# Extract the arguments to MLIR values
exec_args = []
for arg in args:
exec_arg = extract_mlir_values(arg)
if not exec_arg:
exec_arg = [as_numeric(arg).ir_value()]
exec_args.extend(exec_arg)
# Create the function prototype in module, so if it's under kernel function, prototype will be inserted into gpu.module
# If it's under gpu.module, prototype will be inserted into builtin.module
callee = self._create_prototype_in_region(
module_op, insertion_region, exec_args
)
# Emit the call operation
result = func.call(callee.type.results, self.name, exec_args)
if self.return_type:
return result

View File

@@ -333,7 +333,7 @@ class MmaF16BF16Trait(MmaTraits):
@dataclass(frozen=True)
class MmaF8Op(MmaOp):
"""
FP8 warpgroup MMA Operation.
F8 warpgroup MMA Operation.
See the `PTX documentation <https://docs.nvidia.com/cuda/parallel-thread-execution/#asynchronous-warpgroup-level-matrix-instructions-wgmma-mma>`__.
This Operation covers the instructions using the ``.e4m3`` or ``.e5m2`` qualifiers for the input operands.

View File

@@ -17,6 +17,7 @@ import itertools
import operator
from typing import Union, Optional, Type, List
# MLIR modules imports
from cutlass._mlir import ir
from cutlass.base_dsl.env_manager import get_prefix_dsl_libs
@@ -128,6 +129,10 @@ class _Pointer(Pointer):
def __repr__(self):
return self.__str__()
@property
def __cache_key__(self) -> tuple:
return (self.dtype, self._addr_space, self._assumed_align)
class _Tensor(Tensor):
def __init__(
@@ -144,7 +149,7 @@ class _Tensor(Tensor):
elif enable_tvm_ffi:
import tvm_ffi
self._tvm_ffi_tensor = tvm_ffi.from_dlpack(tensor, stream=-1)
self._tvm_ffi_tensor = tvm_ffi.from_dlpack(tensor)
self._dlpack_data = self._tvm_ffi_tensor.__dlpack__()
else:
try:
@@ -185,9 +190,17 @@ class _Tensor(Tensor):
:param leading_dim: The leading dimension of the layout, defaults to None
:type leading_dim: int, optional
When ``leading_dim`` is None, automatically deduces the leading dimension from the tensor layout.
The layout can be deduced only when exactly one dimension has a stride of 1. Raises an error
if the layout cannot be automatically deduced.
When ``leading_dim`` is None, the leading dimension is deduced as follows.
(1) If exactly one dimension has stride 1, that dimension is used.
(2) If multiple dimensions have stride 1 but exactly one of them has size > 1,
that dimension is used.
(3) If multiple dimensions have stride 1 but none or more than one has size > 1,
an error is raised.
(4) If no dimension has stride 1, all strides remain dynamic.
When ``leading_dim`` is explicitly specified, marks the layout as dynamic while setting the
stride at ``leading_dim`` to 1. Also validates that the specified ``leading_dim`` is consistent
@@ -304,6 +317,13 @@ class _Tensor(Tensor):
def __repr__(self):
return self.__str__()
@property
def __cache_key__(self) -> tuple:
self.load_dltensor()
if self._dtype is None:
self._dtype = self._dltensor_wrapper.dtype
return (self._dtype, self._assumed_align, self._dltensor_wrapper.cache_key())
def __setitem__(self, crd, value):
raise TypeError("runtime._Tensor is not indexable")
@@ -417,42 +437,76 @@ def _get_cute_type_str(inp):
return "(" + ",".join(elems) + ")"
class _FakeCompactTensor(Tensor):
class _FakeTensor(Tensor):
"""Fake Tensor implementation as a placeholder.
It mimics the interface of Tensor, but does not hold real data or allow indexing.
Used for compilation or testing situations where only shape/type/layout information is needed.
All attempts to access or mutate data will raise errors.
"""
"""
Create a fake tensor with the given shape, type, and layout.
:param dtype: Data type of the tensor elements
:type dtype: Type[Numeric]
:param shape: Shape of the tensor, consists of int (static) or SymInt (dynamic)
:type shape: tuple[Union[int, SymInt], ...]
:param stride: Stride of the tensor, defaults to None, consists of int (static) or SymInt (dynamic)
:type stride: tuple[Union[int, SymInt], ...], optional
:param memspace: Memory space where the fake tensor resides. Defaults to AddressSpace.gmem.
:type memspace: AddressSpace, optional
:param assumed_align: Assumed alignment of the tensor (bytes), defaults to None. If None, uses the element size bytes as the assumed alignment.
:type assumed_align: int, optional
:param use_32bit_stride: Whether to use 32-bit stride. Defaults to False. When True, the dynamic stride bitwidth
will be set to 32 for small problem sizes (cosize(layout) <= Int32_max) for better performance. This is only applied
when the dimension is dynamic.
:type use_32bit_stride: bool, optional
"""
def __init__(
self,
dtype,
shape,
stride_order,
memspace=None,
assumed_align=None,
use_32bit_stride=False,
dtype: Type[Numeric],
shape: tuple[Union[int, SymInt], ...],
*,
stride: tuple[Union[int, SymInt], ...],
memspace: AddressSpace = AddressSpace.gmem,
assumed_align: int | None = None,
use_32bit_stride: bool = False,
compact: bool = False,
):
self._dtype = dtype
self._shape = shape
self._stride_order = stride_order or tuple(range(len(shape)))
# cannot use memspace or AddressSpace.gmem because AddressSpace.generic is 0
self._memspace = memspace if memspace is not None else AddressSpace.gmem
self._assumed_align = assumed_align or -(-dtype.width // 8)
self._stride = stride
self._use_32bit_stride = use_32bit_stride
self._compact = compact
def __str__(self) -> str:
return f"FakeTensorOrdered<{self._dtype}, {self._shape}, {self._stride_order}>"
if not isinstance(shape, (tuple, list)):
raise ValueError(f"Expected tuple or list but got {type(shape)}")
def __repr__(self):
return self.__str__()
if not all(isinstance(s, (int, SymInt)) for s in self._shape):
raise ValueError("All shape elements must be int or SymInt")
if stride is not None and not all(
isinstance(s, (int, SymInt)) for s in self._stride
):
raise ValueError("All stride elements must be int or SymInt")
self._memspace = memspace
self._assumed_align = assumed_align
if assumed_align is None:
# use the bytes width of the element dtype. The alignment is at least one byte align.
self._assumed_align = (self._dtype.width + 7) // 8
@property
def mlir_type(self) -> ir.Type:
shape_ty = ir.Type.parse(
'!cute.shape<"' + _get_cute_type_str(self._shape) + '">'
)
layout_ty = _cute_ir.LayoutType.get_ordered(
shape_ty, self._stride_order, self._use_32bit_stride
)
self._stride = layout_ty.stride
ptr_ty = _cute_ir.PtrType.get(
self._dtype.mlir_type, self._memspace, self._assumed_align
)
shape_str = _get_cute_type_str(self._shape)
stride_str = _get_cute_type_str(self._stride)
layout_ty = ir.Type.parse(f'!cute.layout<"{shape_str}:{stride_str}">')
# Boolean types are stored as i8 in memory
elem_type = T.i8() if self._dtype.width == 1 else self._dtype.mlir_type
ptr_ty = _cute_ir.PtrType.get(elem_type, self._memspace, self._assumed_align)
return _cute_ir.MemRefType.get(ptr_ty, layout_ty)
def __get_mlir_types__(self):
@@ -463,11 +517,53 @@ class _FakeCompactTensor(Tensor):
assert isinstance(values[0], CoreTensor)
return CoreTensor(values[0].value, self._dtype)
def __str__(self) -> str:
return f"FakeTensor<{self._dtype}, {self._shape}, {self._stride}>"
@property
def __cache_key__(self) -> tuple:
# Check if any shape or stride element is a SymInt without a symbol
import warnings
has_unnamed_symint = False
for dim in self._shape:
if isinstance(dim, SymInt) and dim.symbol is None:
has_unnamed_symint = True
break
if not self._compact:
if not has_unnamed_symint:
for stride in self._stride:
if isinstance(stride, SymInt) and stride.symbol is None:
has_unnamed_symint = True
break
if has_unnamed_symint:
warnings.warn(
"FakeTensor cache_key contains unnamed symbolic dimensions. "
"Different variables with the same shape/stride pattern will have "
"identical cache keys, which may cause incorrect cache hits. "
"Consider using 'symbol' parameter to distinguish variables: "
"cute.sym_int32(symbol='M'), cute.sym_int32(symbol='N')",
UserWarning,
stacklevel=2,
)
return (
self._dtype,
self._memspace,
self._assumed_align,
self._shape,
self._stride,
)
def __repr__(self):
return self.__str__()
def __setitem__(self, crd, value):
raise DSLRuntimeError("runtime._FakeCompactTensor is not indexable")
raise DSLRuntimeError("runtime._FakeTensor is not indexable")
def __getitem__(self, crd):
raise DSLRuntimeError("runtime._FakeCompactTensor is not indexable")
raise DSLRuntimeError("runtime._FakeTensor is not indexable")
@property
def element_type(self) -> Type[Numeric]:
@@ -491,118 +587,7 @@ class _FakeCompactTensor(Tensor):
@property
def leading_dim(self):
for dim, order in enumerate(self._stride_order):
if order == 0:
return dim
@property
def dynamic_shapes_mask(self):
return tuple(1 if isinstance(e, SymInt) else 0 for e in self._shape)
@property
def dynamic_strides_mask(self):
return tuple(1 if isinstance(e, SymInt) else 0 for e in self._stride)
def fill(self, value: Numeric):
raise DSLRuntimeError("runtime._FakeCompactTensor is not writable")
class _FakeTensor(Tensor):
"""Fake Tensor implementation as a placeholder.
It mimics the interface of Tensor, but does not hold real data or allow indexing.
Used for compilation or testing situations where only shape/type/layout information is needed.
All attempts to access or mutate data will raise errors.
"""
"""
Create a fake tensor with the given shape, type, and layout.
:param dtype: Data type of the tensor elements
:type dtype: Type[Numeric]
:param shape: Shape of the tensor, consists of int (static) or SymInt (dynamic)
:type shape: tuple[int, ...]
:param stride: Stride of the tensor, defaults to None, consists of int (static) or SymInt (dynamic)
:type stride: tuple[int, ...], optional
:param assumed_align: Assumed alignment of the tensor (bytes), defaults to None. If None, uses the element size bytes as the assumed alignment.
:type assumed_align: int, optional
:param use_32bit_stride: Whether to use 32-bit stride. Defaults to False. When True, the dynamic stride bitwidth
will be set to 32 for small problem sizes (cosize(layout) <= Int32_max) for better performance. This is only applied
when the dimension is dynamic.
:type use_32bit_stride: bool, optional
"""
def __init__(self, dtype, shape, *, stride, memspace=None, assumed_align=None):
self._dtype = dtype
self._shape = shape
self._stride = stride
# cannot use memspace or AddressSpace.generic because AddressSpace.generic is 0
self._memspace = memspace if memspace is not None else AddressSpace.gmem
self._assumed_align = assumed_align
if assumed_align is None:
# use the bytes width of the element dtype. The alignment is at least one byte align.
self._assumed_align = (self._dtype.width + 7) // 8
if not isinstance(shape, (tuple, list)):
raise ValueError(f"Expected tuple or list but got {type(shape)}")
if not all(isinstance(s, (int, SymInt)) for s in self._shape):
raise ValueError("All shape elements must be int or SymInt")
if stride is not None and not all(
isinstance(s, (int, SymInt)) for s in self._stride
):
raise ValueError("All stride elements must be int or SymInt")
@property
def mlir_type(self) -> ir.Type:
shape_str = _get_cute_type_str(self._shape)
stride_str = _get_cute_type_str(self._stride)
layout_ty = ir.Type.parse(f'!cute.layout<"{shape_str}:{stride_str}">')
ptr_ty = _cute_ir.PtrType.get(
self._dtype.mlir_type, self._memspace, self._assumed_align
)
return _cute_ir.MemRefType.get(ptr_ty, layout_ty)
def __get_mlir_types__(self):
return [self.mlir_type]
def __new_from_mlir_values__(self, values):
assert len(values) == 1
assert isinstance(values[0], CoreTensor)
return CoreTensor(values[0].value, self._dtype)
def __str__(self) -> str:
return f"FakeTensor<{self._dtype}, {self._shape}, {self._stride}>"
def __repr__(self):
return self.__str__()
def __setitem__(self, crd, value):
raise DSLRuntimeError("runtime._FakeTensor is not indexable")
def __getitem__(self, crd):
raise DSLRuntimeError("runtime._FakeTensor is not indexable")
@property
def element_type(self) -> Type[Numeric]:
return self._dtype
@property
def memspace(self):
return self._memspace
@property
def iterator(self):
raise DSLRuntimeError("runtime._FakeTensor has dummy iterator")
@property
def shape(self):
return self._shape
@property
def stride(self):
return self._stride
return core.leading_dim(self._shape, self._stride)
@property
def dynamic_shapes_mask(self):
@@ -617,35 +602,36 @@ class _FakeTensor(Tensor):
def make_fake_compact_tensor(
dtype,
shape,
dtype: Type[Numeric],
shape: tuple[Union[int, SymInt], ...],
*,
stride_order=None,
memspace=None,
assumed_align=None,
use_32bit_stride=False,
stride_order: Optional[tuple[int, ...]] = None,
memspace: AddressSpace = AddressSpace.gmem,
assumed_align: Optional[int] = None,
use_32bit_stride: bool = False,
):
"""
Create a fake tensor with the specified shape, element type, and a compact memory layout.
:param dtype: Data type of the tensor elements.
:type dtype: Type[Numeric]
:param shape: Shape of the tensor.
:type shape: tuple[int, ...]
:param shape: Shape of the tensor, consisting of static (int) or dynamic (SymInt) dimensions.
:type shape: tuple[Union[int, SymInt], ...]
:param stride_order: Order in which strides (memory layout) are assigned to the tensor dimensions.
If None, the default layout is left-to-right order (known as column-major order for flatten layout).
Otherwise, it should be a permutation order of the dimension indices.
The mode with stride_order 0 is the fastest changing (leading) dimension, and N-1 is the slowest changing.
:type stride_order: tuple[int, ...], optional
:param memspace: Memory space where the fake tensor resides. Optional.
:type memspace: str, optional
:param assumed_align: Assumed byte alignment for the tensor data. If None, the default alignment is used.
:param memspace: Memory space where the fake tensor resides. Defaults to AddressSpace.gmem.
:type memspace: AddressSpace, optional
:param assumed_align: Assumed byte alignment for the tensor data. If None, the default alignment is the dtype width, & at least 1 byte.
:type assumed_align: int, optional
:param use_32bit_stride: Whether to use 32-bit stride for dynamic dimensions. If True and the total size of the
layout (cosize(layout)) fits within int32, then dynamic strides will use 32-bit integers for improved performance.
Only applies when dimensions are dynamic. Defaults to False.
:type use_32bit_stride: bool, optional
:return: An instance of a fake tensor with the given properties and compact layout.
:rtype: _FakeCompactTensor
:rtype: _FakeTensor
**Examples:**
@@ -663,31 +649,68 @@ def make_fake_compact_tensor(
# tensor<ptr<f32, generic> o (100,?{div=8}):(?{i32 div=8},1)>
compiled_foo = cute.compile(foo, x)
# Default stride order is left-to-right order: (1, 8)
y = make_fake_compact_tensor(cutlass.Float32, (8, 3))
# Default stride order is left-to-right order (0, 1, ..., n-1)
y = make_fake_compact_tensor(cutlass.Float32, (8, 3, 2)) # y.stride == (1, 8, 24)
"""
return _FakeCompactTensor(
if stride_order is not None:
if len(stride_order) != len(shape):
raise ValueError(
f"stride_order ({stride_order}) must be empty or have same length as shape ({shape})."
)
else:
# Default stride order is left-to-right
stride_order = stride_order or tuple(range(len(shape)))
# Make compact strides (possibly symbolic) from shape & stride_order
stride = [None] * len(stride_order)
stride_product = 1
for order in range(len(stride_order)):
idx = stride_order.index(order)
stride[idx] = stride_product
stride_product *= shape[idx]
stride_width = 32 if use_32bit_stride else 64
stride = tuple(
(
SymInt(width=stride_width, divisibility=s.divisibility)
if isinstance(s, SymInt)
else s
)
for s in stride
)
return _FakeTensor(
dtype,
shape,
stride_order=stride_order,
stride=stride,
memspace=memspace,
assumed_align=assumed_align,
use_32bit_stride=use_32bit_stride,
compact=True,
)
def make_fake_tensor(dtype, shape, stride, *, memspace=None, assumed_align=None):
def make_fake_tensor(
dtype: Type[Numeric],
shape: tuple[Union[int, SymInt], ...],
stride: tuple[Union[int, SymInt], ...],
*,
memspace: AddressSpace = AddressSpace.gmem,
assumed_align: Optional[int] = None,
):
"""
Create a fake tensor with the specified element type, shape, and stride.
:param dtype: Data type of the tensor elements.
:type dtype: Type[Numeric]
:param shape: Shape of the tensor.
:type shape: tuple[int, ...]
:param stride: Stride of the tensor.
:type stride: tuple[int, ...]
:param assumed_align: Assumed byte alignment for the tensor data. If None, the default alignment is used. Defaults to None.
:param shape: Shape of the tensor, consisting of static (int) or dynamic (SymInt) dimensions.
:type shape: tuple[Union[int, SymInt], ...]
:param stride: Stride of the tensor, consisting of static (int) or dynamic (SymInt) values.
:type stride: tuple[Union[int, SymInt], ...]
:param memspace: Memory space where the fake tensor resides. Defaults to AddressSpace.gmem.
:type memspace: AddressSpace, optional
:param assumed_align: Assumed byte alignment for the tensor data. If None, the default alignment is the dtype width, & at least 1 byte.
:type assumed_align: int, optional
:return: An instance of a fake tensor with the given properties.
:rtype: _FakeTensor
@@ -953,22 +976,7 @@ def load_module(file_path: str, *, enable_tvm_ffi: bool = False):
if Path(path).exists():
_LOAD_MODULE_LIBS_CACHE.append(ctypes.CDLL(path))
if enable_tvm_ffi:
import tvm_ffi
try:
# keep_module_alive=False means the module will be unloaded
# after the returned module goes out of scope, this is useful
# for frequent loading and unloading of modules. The only requirement
# is that the module do not return object that have deleter in the module
# and the returned object lives longer than the module.
# DSL functions to not have such issue so it is desirable to set this to False.
return tvm_ffi.load_module(file_path, keep_module_alive=False)
except TypeError:
# compatible with tvm-ffi < 0.1.6
return tvm_ffi.load_module(file_path)
else:
return ExternalBinaryModule(file_path)
return ExternalBinaryModule(file_path, enable_tvm_ffi=enable_tvm_ffi)
# -------------------------------------------------------------------------
# Try to register_jit_arg_adapter for TensorAdapter

View File

@@ -139,6 +139,19 @@ class _Tensor(Tensor):
def __init__(
self, value, dtype: Optional[Type[Numeric]] = None, *, loc=None, ip=None
):
"""Initialize a Tensor from an MLIR value.
:param value: The MLIR operation result value or another Tensor to initialize from
:type value: Union[ir.Value, _Tensor]
:param dtype: The user specified data type of the tensor elements, defaults to None
:type dtype: Optional[Type[Numeric]]
:param loc: The source location for the operation, defaults to None
:type loc: Optional[Location]
:param ip: The insertion point for the operation, defaults to None
:type ip: Optional[InsertionPoint]
:raises TypeError: If value is not ir.Value or _Tensor
:raises TypeError: If iterator type is not supported
"""
self._dtype = dtype
if isinstance(value, ir.Value):
self.value = value
@@ -952,6 +965,37 @@ def make_fragment_like(src, dtype=None, *, loc=None, ip=None):
def recast_tensor(
src: Tensor, dtype: Type[Numeric], swizzle_=None, *, loc=None, ip=None
):
"""Recast a tensor to a different data type by changing the element interpretation.
This function reinterprets the memory of a tensor with a different element type,
adjusting both the iterator pointer type and the layout to maintain consistency.
:param src: The source tensor to recast
:type src: Tensor
:param dtype: The target data type for tensor elements
:type dtype: Type[Numeric]
:param swizzle_: Optional swizzle parameter (reserved for future use), defaults to None
:type swizzle_: Optional, unused
:param loc: Source location for MLIR operation tracking, defaults to None
:type loc: Optional[Location]
:param ip: Insertion point for MLIR operation, defaults to None
:type ip: Optional[InsertionPoint]
:return: A new tensor with the same memory but reinterpreted as dtype
:rtype: Tensor
:raises TypeError: If dtype is not a subclass of Numeric
**Examples:**
.. code-block:: python
# Create a Float32 tensor
tensor_f32 = make_rmem_tensor((4, 8), Float32)
# Recast to Int32 to manipulate bits
tensor_i32 = recast_tensor(tensor_f32, Int32)
# Both tensors share the same memory, but interpret it differently
"""
if not isclass(dtype) or not issubclass(dtype, Numeric):
raise TypeError(f"dtype must be a type of Numeric, but got {dtype}")
@@ -972,6 +1016,36 @@ def recast_tensor(
@dsl_user_op
def domain_offset(coord: Coord, tensor: Tensor, *, loc=None, ip=None) -> Tensor:
"""Offset the tensor domain by the given coordinate.
This function creates a new tensor by offsetting the iterator/pointer of the input tensor
by the amount corresponding to the given coordinate in its layout.
:param coord: The coordinate offset to apply
:type coord: Coord
:param tensor: The source tensor to offset
:type tensor: Tensor
:param loc: Source location for MLIR operation tracking, defaults to None
:type loc: Optional[Location]
:param ip: Insertion point for MLIR operation, defaults to None
:type ip: Optional[InsertionPoint]
:return: A new tensor with the offset iterator
:rtype: Tensor
:raises ValueError: If the tensor type doesn't support domain offsetting
**Examples:**
.. code-block:: python
# Create a tensor with a row-major layout
ptr = make_ptr(Float32, base_ptr, AddressSpace.gmem)
layout = make_layout((64, 128), stride=(128, 1))
tensor = make_tensor(ptr, layout)
# Offset by coordinate (3, 5)
offset_tensor = domain_offset((3, 5), tensor)
# offset_tensor now points to element at (3, 5)
"""
offset = crd2idx(coord, tensor.layout, loc=loc, ip=ip)
if isinstance(tensor.iterator, Pointer):
return make_tensor(

View File

@@ -9,6 +9,8 @@
# and related documentation outside the scope permitted by the EULA
# is strictly prohibited.
from __future__ import annotations
from abc import ABC, abstractmethod
import ctypes
from typing import ForwardRef, Tuple, Union, Any, Type, List, Optional, Literal
@@ -24,15 +26,18 @@ Int = Union[int, Integer]
class SymInt:
def __init__(self, width: Literal[32, 64] = 32, *, divisibility=1):
def __init__(
self, width: Literal[32, 64] = 32, *, divisibility=1, symbol: str | None = None
):
if width not in [32, 64]:
raise ValueError(f"Unsupported width: {width}")
self._width = width
self._divisibility = divisibility
self._symbol = symbol
def __hash__(self):
return hash((self._width, self._divisibility))
return hash((self._width, self._divisibility, self._symbol))
@property
def width(self):
@@ -42,8 +47,16 @@ class SymInt:
def divisibility(self):
return self._divisibility
@property
def symbol(self):
return self._symbol
def __str__(self) -> str:
return f"?{{i{self._width} div={self._divisibility}}}"
prefix = "" if self._symbol is None else self._symbol + " "
if self._width == 32:
return f"{prefix}?{{div={self._divisibility}}}"
else:
return f"{prefix}?{{i{self._width} div={self._divisibility}}}"
def __repr__(self) -> str:
return self.__str__()
@@ -51,18 +64,52 @@ class SymInt:
def __eq__(self, other) -> bool:
if not isinstance(other, SymInt):
return False
return all(
[self._width == other._width, self._divisibility == other._divisibility]
[
self._width == other._width,
self._divisibility == other._divisibility,
self._symbol == other._symbol,
]
)
def __mod__(self, other: int) -> Union["SymInt", int]:
if self._divisibility % other != 0:
def __mod__(self, other: int | SymInt) -> SymInt | int:
if isinstance(other, int):
other_div, result_width = other, self._width
elif isinstance(other, SymInt):
other_div, result_width = (
other._divisibility,
max(self._width, other._width),
)
else:
return NotImplemented
if self._divisibility % other_div == 0:
return 0
else:
from math import gcd
div = gcd(self._divisibility, other)
return SymInt(self._width, divisibility=div)
return SymInt(result_width, divisibility=gcd(self._divisibility, other_div))
def __rmod__(self, other: int) -> int:
"""int % SymInt: check if the int conforms to this SymInt's divisibility"""
if isinstance(other, int):
return other % self._divisibility
return NotImplemented
def __mul__(self, other: int | SymInt) -> SymInt:
if isinstance(other, int):
return SymInt(self._width, divisibility=self._divisibility * other)
elif isinstance(other, SymInt):
return SymInt(
width=max(self._width, other._width),
divisibility=self._divisibility * other._divisibility,
)
else:
return 0
return NotImplemented
def __rmul__(self, other: int | SymInt) -> SymInt:
return self.__mul__(other)
def __c_pointers__(self):
return [ctypes.c_void_p(0).value]
@@ -73,7 +120,7 @@ class SymInt:
)
return [res_ty]
def __new_from_mlir_values__(self, values) -> "SymInt":
def __new_from_mlir_values__(self, values) -> SymInt:
from .core import IntValue
if self.width == 32:
@@ -84,16 +131,18 @@ class SymInt:
assert False, f"Unsupported width: {self.width}"
return self
def sym_int(width: Literal[32, 64] = 32, *, divisibility=1) -> SymInt:
return SymInt(width, divisibility=divisibility)
def sym_int(
width: Literal[32, 64] = 32, *, divisibility=1, symbol: str | None = None
) -> SymInt:
return SymInt(width, divisibility=divisibility, symbol=symbol)
def sym_int32(divisibility=1) -> SymInt:
return sym_int(32, divisibility=divisibility)
def sym_int32(divisibility=1, symbol: str | None = None) -> SymInt:
return sym_int(32, divisibility=divisibility, symbol=symbol)
def sym_int64(divisibility=1) -> SymInt:
return sym_int(64, divisibility=divisibility)
def sym_int64(divisibility=1, symbol: str | None = None) -> SymInt:
return sym_int(64, divisibility=divisibility, symbol=symbol)
ScaledBasis = ForwardRef("ScaledBasis")

View File

@@ -1199,7 +1199,6 @@ class KernelLauncher:
return self.dsl._get_smem_usage()
def launch(self, *args, **kwargs):
self.dsl.frame = inspect.currentframe().f_back
self.dsl._preprocess_launch_config_args(args, kwargs)
config = self.dsl.LaunchConfig(*args, **kwargs)
kernel_attrs = _build_kernel_attrs(config)
@@ -1216,7 +1215,6 @@ class KernelLauncher:
ret, name = kernel_generator(*self.func_args, **self.func_kwargs, config=config)
self.dsl.kernel_info[name] = kernel_attrs
self.dsl.frame = None
return ret.launch_op_ret
def __call__(self, *args, **kwargs):

View File

@@ -35,7 +35,6 @@ if is_available():
)
from .compile import (
release_compile_cache,
initialize_cutlass_dsl,
)
from .ffi import (
get_export_disabled_safety_checks,
@@ -48,10 +47,6 @@ if is_available():
# This is a legacy name for TensorSpec. It will be removed eventually.
TensorMode = TensorSpec
# This explicit init method ensures that we avoid initialization at
# unexpected times in jax tracing.
initialize_cutlass_dsl()
__all__ = [
"cutlass_call",
"jax_to_cutlass_dtype",

View File

@@ -267,36 +267,4 @@ def release_compile_cache():
_CUTLASS_COMPILE_CACHE.clear()
dsl = CuTeDSL._get_dsl()
dsl.jit_cache.clear()
# TODO: This is needed to release frames being held in the DSL
# We should avoid holding such references as they unexpectedly
# extend object lifetime.
dsl.frame = None
gc.collect()
class _DummyInitKernel:
@cute.kernel
def kernel(self):
pass
@cute.jit
def init(self):
pass
_CUTLASS_DSL_INITIALIZED = False
def initialize_cutlass_dsl():
"""Initializes cutlass DSL."""
global _CUTLASS_DSL_INITIALIZED
if _CUTLASS_DSL_INITIALIZED:
return
# Call compiler to ensure we've pre-processed any kernels inside cutedsl.
kernel = _DummyInitKernel()
with _compile_lock:
logger.debug("Initializing cutlass dsl...")
_ = cutlass.cute.compile(kernel.init)
_CUTLASS_DSL_INITIALIZED = True

View File

@@ -28,9 +28,13 @@ logger = logging.getLogger(__name__)
_CUTE_DSL_RUNTIME_LIBRARY_NAME = "cute_dsl_runtime"
_CUTLASS_CALL_TARGETS = {
"CuteDSLRT_NvJaxCutlassCall": {"execute": "CuteDSLRT_NvJaxCutlassCallExecute"},
"CuteDSLRT_NvJaxCutlassCall": {
"execute": "CuteDSLRT_NvJaxCutlassCallExecute",
"prepare": "CuteDSLRT_NvJaxCutlassCallPrepare",
},
"CuteDSLRT_NvJaxCutlassCallNoCudaGraph": {
"execute": "CuteDSLRT_NvJaxCutlassCallExecuteNoCudaGraph"
"execute": "CuteDSLRT_NvJaxCutlassCallExecuteNoCudaGraph",
"prepare": "CuteDSLRT_NvJaxCutlassCallPrepare",
},
}

View File

@@ -1,3 +1,3 @@
# Use `pip install -r requirements-cu13.txt` with the present file to install a
# wheel consistent with the present state of the github repository
nvidia-cutlass-dsl[cu13]==4.4.1
nvidia-cutlass-dsl[cu13]==4.4.2

View File

@@ -1,3 +1,3 @@
# Use `pip install -r requirements.txt` with the present file to install a
# wheel consistent with the present state of the github repository
nvidia-cutlass-dsl==4.4.1
nvidia-cutlass-dsl==4.4.2

View File

@@ -133,7 +133,7 @@ def get_option_registry():
this._option_registry = OptionRegistry(device_cc())
return this._option_registry
this.__version__ = '4.4.1'
this.__version__ = '4.4.2'
from cutlass_cppgen.backend import create_memory_pool
from cutlass_cppgen.emit.pytorch import pytorch

View File

@@ -51,7 +51,7 @@ setup_pycute.perform_setup()
setup(
name='cutlass_cppgen',
version='4.4.1',
version='4.4.2',
description='CUTLASS Pythonic Interface',
package_dir={'': '.'},
packages=[

View File

@@ -36,7 +36,7 @@ from setuptools import setup
def perform_setup():
setup(
name='cutlass_library',
version='4.4.1',
version='4.4.2',
description='CUTLASS library generation scripts',
packages=['cutlass_library']
)

View File

@@ -36,7 +36,7 @@ from setuptools import setup
def perform_setup():
setup(
name='pycute',
version='4.4.1',
version='4.4.2',
description='Python implementation of CuTe',
packages=['pycute'],
)