mirror of
https://github.com/NVIDIA/cutlass.git
synced 2026-04-20 06:48:59 +00:00
v4.4.2 update. (#3104)
This commit is contained in:
18
CHANGELOG.md
18
CHANGELOG.md
@@ -2,6 +2,20 @@
|
||||
|
||||
# CUTLASS 4.x
|
||||
|
||||
## [4.4.2](https://github.com/NVIDIA/cutlass/releases/tag/v4.4.2) (2026-03-13)
|
||||
|
||||
### CuTe DSL
|
||||
* New features
|
||||
- CuTe DSL now supports Python 3.14 for both x86_64 and aarch64
|
||||
- Runtime Pointer/Tensor/FakeTensor now supports __cache_key__, providing a stable, hashable representation that simplifies and improves compiled function caching.
|
||||
* Bug fixing and improvements
|
||||
- Fixed Hopper FMHA causal attention performance regression on CUDA toolkit 13.1 by
|
||||
optimizing mbarrier synchronization to avoid unnecessary convergence barriers.
|
||||
- Fix kernel loading race condition when multiple GPU are present in the same process in JAX.
|
||||
|
||||
### CUTLASS C++
|
||||
* Enable Blackwell SM120f compilation of examples and exposes NVFP4/MX Grouped GEMM in the CUTLASS Profiler.
|
||||
|
||||
## [4.4.1](https://github.com/NVIDIA/cutlass/releases/tag/v4.4.1) (2026-02-27)
|
||||
|
||||
### CuTe DSL
|
||||
@@ -148,8 +162,6 @@
|
||||
* Work around a driver TMA descriptor related bug which will cause occasional errors on Blackwell when the tensor's backing memory allocation is less than 128KB and it is not a dense non-overlapping tensor.
|
||||
|
||||
## [4.3.3](https://github.com/NVIDIA/cutlass/releases/tag/v4.3.3) (2025-12-12)
|
||||
|
||||
### CuTe DSL
|
||||
* New features
|
||||
- Supported namedtuple and kwargs for JIT function arguments in tvm-ffi
|
||||
- Supported variadic tuples for JIT function argument in tvm-ffi
|
||||
@@ -159,8 +171,6 @@
|
||||
- Clearer error message for the case of runtime error cudaErrorInsufficientDriver
|
||||
|
||||
## [4.3.2](https://github.com/NVIDIA/cutlass/releases/tag/v4.3.2) (2025-12-05)
|
||||
|
||||
### CuTe DSL
|
||||
* New features
|
||||
- New env var `CUTE_DSL_CACHE_DIR` to specify the path for dumping caches
|
||||
|
||||
|
||||
10
README.md
10
README.md
@@ -1,9 +1,9 @@
|
||||

|
||||
# Overview
|
||||
|
||||
# CUTLASS 4.4.1
|
||||
# CUTLASS 4.4.2
|
||||
|
||||
_CUTLASS 4.4.1 - Feb 2026_
|
||||
_CUTLASS 4.4.2 - March 2026_
|
||||
|
||||
CUTLASS is a collection of abstractions for implementing high-performance matrix-matrix multiplication (GEMM)
|
||||
and related computations at all levels and scales within CUDA. It incorporates strategies for
|
||||
@@ -72,6 +72,8 @@ To get started quickly - please refer :
|
||||
- We allow grid carve-out without problem shapes being available on host.
|
||||
- Tma+LdMatrix features for loading+unpacking narrow-width types (refer to mixed_input_fmha_decode.py for example usage).
|
||||
- It is possible now to have customized epilogue fusion for persistent dense GEMM through a Python Epilogue Fusion Configuration (EFC) function, somewhat similar to CUTLASS C++ EVT. It also provides a PyTorch evaluator to compare the results.
|
||||
- CuTe DSL now supports Python 3.14 for both x86_64 and aarch64
|
||||
- Runtime Pointer/Tensor/FakeTensor now supports __cache_key__, providing a stable, hashable representation that simplifies and improves compiled function caching.
|
||||
|
||||
* More examples of authorizing peak-performance kernels
|
||||
- [SM103 batched 3xFP4 blockscaled GEMM kernel](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/blackwell/sm103_dense_blockscaled_gemm_persistent.py)
|
||||
@@ -85,6 +87,9 @@ To get started quickly - please refer :
|
||||
- Fixed an indexing issue of scalar tensor
|
||||
- Fixed small K reference check error for cta_tile_n = 256 case with overlapping accumulator optimization in [Blackwell SM100 persistent dense blockscaled GEMM with static scheduling](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/blackwell/dense_blockscaled_gemm_persistent.py).
|
||||
- Fixed a segfault issue with tvm-ffi on aarch64
|
||||
- Fixed Hopper FMHA causal attention performance regression on CUDA toolkit 13.1 by
|
||||
optimizing mbarrier synchronization to avoid unnecessary convergence barriers.
|
||||
- Fix kernel loading race condition when multiple GPU are present in the same process in JAX.
|
||||
|
||||
* API changes
|
||||
- Deprecate get_num_tmem_alloc_cols from blackwell_helpers.py. Use the one from tmem_allocator.py instead.
|
||||
@@ -127,6 +132,7 @@ To get started quickly - please refer :
|
||||
* Add support for arbitrary application-provided strides for block-scale tensors.
|
||||
- Users and applications now must pass valid block-scale strides in all cases, even when the tensor is packed.
|
||||
* Support 4x blockscaled public ptx for CUDA 13.1.
|
||||
* Enable Blackwell SM120f compilation of examples and exposes NVFP4/MX Grouped GEMM in the CUTLASS Profiler.
|
||||
* Allow non-static `TmaGbasis` in `AuxTmaParams`.
|
||||
- Some cases in attention kernel may require non-static `tma_gbasis`.
|
||||
- Relax the restriction on `TmaGbasis` parameter of `AuxTmaParams` and users are allowed to manually construct a dynamic gbasis.
|
||||
|
||||
@@ -1655,9 +1655,11 @@ def run(
|
||||
.to(dtype=torch_dtype(c_dtype))
|
||||
.to(dtype=torch.float32)
|
||||
)
|
||||
# Read back the result from CuTe tensor (c_storage was updated in-place)
|
||||
torch.testing.assert_close(
|
||||
c_storage.to(dtype=torch.float32), ref, atol=tolerance, rtol=1e-03
|
||||
c_storage.view(torch_dtype(c_dtype)).to(dtype=torch.float32),
|
||||
ref,
|
||||
atol=tolerance,
|
||||
rtol=1e-03,
|
||||
)
|
||||
|
||||
if not benchmark:
|
||||
|
||||
@@ -1725,9 +1725,11 @@ def run(
|
||||
.to(dtype=torch_dtype(c_dtype))
|
||||
.to(dtype=torch.float32)
|
||||
)
|
||||
# Read back the result from CuTe tensor (c_storage was updated in-place)
|
||||
torch.testing.assert_close(
|
||||
c_storage.to(dtype=torch.float32), ref, atol=tolerance, rtol=1e-03
|
||||
c_storage.view(torch_dtype(c_dtype)).to(dtype=torch.float32),
|
||||
ref,
|
||||
atol=tolerance,
|
||||
rtol=1e-03,
|
||||
)
|
||||
|
||||
if not benchmark:
|
||||
|
||||
@@ -2546,8 +2546,6 @@ def run(
|
||||
slices = tuple(slice(s, e) for s, e in zip(padding, shape_))
|
||||
torch_tensor = torch_tensor_full[slices].detach()
|
||||
f32_torch_tensor = f32_torch_tensor_full[slices].detach()
|
||||
torch_tensor._keep_alive = torch_tensor_full
|
||||
f32_torch_tensor._keep_alive = f32_torch_tensor_full
|
||||
|
||||
# Create dtype cute tensor with offset (gpu)
|
||||
cute_tensor = from_dlpack(torch_tensor, assumed_align=16)
|
||||
|
||||
@@ -104,6 +104,67 @@ if __name__ == "__main__":
|
||||
|
||||
from helpers import fmha_helpers as fmha_utils
|
||||
|
||||
from cutlass.cutlass_dsl import (
|
||||
Boolean, Int32, if_generate, while_generate, yield_out, not_, dsl_user_op,
|
||||
)
|
||||
from cutlass._mlir.dialects import nvvm
|
||||
from cutlass._mlir._mlir_libs._cutlass_ir._mlir.ir import IntegerType
|
||||
from contextlib import contextmanager
|
||||
|
||||
|
||||
import inspect as _inspect
|
||||
|
||||
_timelimit_has_res = "res" in _inspect.signature(
|
||||
nvvm.mbarrier_try_wait_parity_timelimit
|
||||
).parameters
|
||||
|
||||
|
||||
def _try_wait_timelimit(llvm_ptr, phase_val, timeout, *, loc=None, ip=None):
|
||||
if _timelimit_has_res:
|
||||
i1 = IntegerType.get_signless(1)
|
||||
return nvvm.mbarrier_try_wait_parity_timelimit(
|
||||
i1, llvm_ptr, phase_val, timeout, loc=loc, ip=ip,
|
||||
)
|
||||
return nvvm.mbarrier_try_wait_parity_timelimit(
|
||||
llvm_ptr, phase_val, timeout, loc=loc, ip=ip,
|
||||
)
|
||||
|
||||
|
||||
@dsl_user_op
|
||||
def _optimized_mbarrier_wait(mbar_ptr, phase, *, loc=None, ip=None):
|
||||
llvm_ptr = mbar_ptr.llvm_ptr
|
||||
phase_val = Int32(phase).ir_value(loc=loc, ip=ip)
|
||||
_true = lambda: Boolean(True).ir_value(loc=loc, ip=ip)
|
||||
timeout = Int32(10000000).ir_value(loc=loc, ip=ip)
|
||||
d = Boolean(_try_wait_timelimit(llvm_ptr, phase_val, timeout, loc=loc, ip=ip))
|
||||
d = if_generate(d, _true,
|
||||
lambda: _try_wait_timelimit(llvm_ptr, phase_val, timeout, loc=loc, ip=ip),
|
||||
None, [Boolean], loc=loc, ip=ip)
|
||||
d = if_generate(d, _true,
|
||||
lambda: _try_wait_timelimit(llvm_ptr, phase_val, timeout, loc=loc, ip=ip),
|
||||
None, [Boolean], loc=loc, ip=ip)
|
||||
def _fallback():
|
||||
inner = Boolean(False).ir_value(loc=loc, ip=ip)
|
||||
ctx = while_generate([inner], lambda x: not_(x, loc=loc, ip=ip), loc=loc, ip=ip)
|
||||
with ctx as (_,):
|
||||
r = Boolean(_try_wait_timelimit(
|
||||
llvm_ptr, phase_val, timeout, loc=loc, ip=ip,
|
||||
))
|
||||
yield_out([r], loc=loc, ip=ip)
|
||||
return Boolean(True).ir_value(loc=loc, ip=ip)
|
||||
if_generate(d, _true, _fallback, None, [Boolean], loc=loc, ip=ip)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def _use_optimized_mbarrier_wait():
|
||||
import cutlass.cute.arch as arch_mod
|
||||
orig_wait = arch_mod.mbarrier_wait
|
||||
arch_mod.mbarrier_wait = _optimized_mbarrier_wait
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
arch_mod.mbarrier_wait = orig_wait
|
||||
|
||||
|
||||
class HopperFusedMultiHeadAttentionForward:
|
||||
def __init__(
|
||||
@@ -439,36 +500,37 @@ class HopperFusedMultiHeadAttentionForward:
|
||||
self.shared_storage = SharedStorage
|
||||
|
||||
# Launch the kernel synchronously
|
||||
self.kernel(
|
||||
qk_tiled_mma,
|
||||
pv_tiled_mma,
|
||||
tma_atom_q,
|
||||
tma_tensor_q,
|
||||
tma_atom_k,
|
||||
tma_tensor_k,
|
||||
tma_atom_v,
|
||||
tma_tensor_v,
|
||||
tma_atom_o,
|
||||
tma_tensor_o,
|
||||
lse,
|
||||
scale_softmax_log2,
|
||||
scale_softmax,
|
||||
scale_output,
|
||||
window_size_left,
|
||||
window_size_right,
|
||||
q_smem_layout_staged,
|
||||
k_smem_layout_staged,
|
||||
v_smem_layout_staged,
|
||||
o_smem_layout_staged,
|
||||
self.tile_sched_params,
|
||||
).launch(
|
||||
grid=grid,
|
||||
block=[self.threads_per_cta, 1, 1],
|
||||
cluster=self.cluster_shape_mnk,
|
||||
smem=self.shared_storage.size_in_bytes(),
|
||||
stream=stream,
|
||||
min_blocks_per_mp=1,
|
||||
)
|
||||
with _use_optimized_mbarrier_wait():
|
||||
self.kernel(
|
||||
qk_tiled_mma,
|
||||
pv_tiled_mma,
|
||||
tma_atom_q,
|
||||
tma_tensor_q,
|
||||
tma_atom_k,
|
||||
tma_tensor_k,
|
||||
tma_atom_v,
|
||||
tma_tensor_v,
|
||||
tma_atom_o,
|
||||
tma_tensor_o,
|
||||
lse,
|
||||
scale_softmax_log2,
|
||||
scale_softmax,
|
||||
scale_output,
|
||||
window_size_left,
|
||||
window_size_right,
|
||||
q_smem_layout_staged,
|
||||
k_smem_layout_staged,
|
||||
v_smem_layout_staged,
|
||||
o_smem_layout_staged,
|
||||
self.tile_sched_params,
|
||||
).launch(
|
||||
grid=grid,
|
||||
block=[self.threads_per_cta, 1, 1],
|
||||
cluster=self.cluster_shape_mnk,
|
||||
smem=self.shared_storage.size_in_bytes(),
|
||||
stream=stream,
|
||||
min_blocks_per_mp=1,
|
||||
)
|
||||
|
||||
# GPU device kernel
|
||||
@cute.kernel
|
||||
|
||||
@@ -97,7 +97,7 @@
|
||||
" # Step 2: Map thread index to tensor coordinates\n",
|
||||
" # -------------------------------------------\n",
|
||||
" # Each thread will process one element of the input tensors\n",
|
||||
" m, n = gA.shape # Get tensor dimensions (M rows × N columns)\n",
|
||||
" m, n = gA.shape # Get tensor dimensions (M rows \u00d7 N columns)\n",
|
||||
"\n",
|
||||
" # Convert linear thread index to 2D coordinates:\n",
|
||||
" # - ni: column index (0 to n-1)\n",
|
||||
@@ -198,7 +198,7 @@
|
||||
" num_threads_per_block = 256\n",
|
||||
"\n",
|
||||
" # Get input dimensions\n",
|
||||
" m, n = mA.shape # Matrix dimensions (M rows × N columns)\n",
|
||||
" m, n = mA.shape # Matrix dimensions (M rows \u00d7 N columns)\n",
|
||||
"\n",
|
||||
" # Create kernel instance\n",
|
||||
" kernel = naive_elementwise_add_kernel(mA, mB, mC)\n",
|
||||
@@ -298,7 +298,7 @@
|
||||
" - For elementwise add:\n",
|
||||
" * Read: 2 elements (A and B)\n",
|
||||
" * Write: 1 element (C)\n",
|
||||
" * Total bytes = (2 reads + 1 write) × elements × sizeof(dtype)\n",
|
||||
" * Total bytes = (2 reads + 1 write) \u00d7 elements \u00d7 sizeof(dtype)\n",
|
||||
"\n",
|
||||
"Below is our benchmarking utility that measures these metrics:"
|
||||
]
|
||||
@@ -368,7 +368,7 @@
|
||||
"\n",
|
||||
"According to *Little's Law*, naive implementation has\n",
|
||||
" - 1 element (4 bytes load + 2 bytes store) per thread\n",
|
||||
" - 256 threads/block × N blocks\n",
|
||||
" - 256 threads/block \u00d7 N blocks\n",
|
||||
" - Limited in-flight operations\n",
|
||||
"\n",
|
||||
"In some GPUs, it's insufficient parallelism to saturate memory bandwidth.\n",
|
||||
@@ -385,7 +385,35 @@
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": "## Vectorized Load and Store\n\nTo improve performance according to Little's Law, we need to increase the number\nof in-flight requests. We can do this by increasing the number of bytes handled\nin each load & store operation per thread through vectorized memory access.\n\nSince Ampere GPUs support up to 128-bit per load/store and each element is 16-bit,\nwe can load 8 elements per vectorized operation on contiguous rows.\nCuTe tiling operations make this vectorization straightforward.\n\nUsing ``tiled_tensor = cute.zipped_divide(tensor, tiler)``, we can partition the input\n``tensor`` into groups of ``tiler`` blocks. For vectorization, we specify ``tiler``\nas the block of data each thread accesses (8 contiguous elements in the same row, or ``(1,8)``).\nDifferent threads can then access different blocks by indexing into the 2nd mode of ``tiled_tensor``.\n\n```python\nmA : cute.Tensor # (2048,2048):(2048,1)\ngA = cute.zipped_divide(a, tiler=(1, 8)) # tiled/vectorized => ((1,8),(2048,256)):((0,1),(2048,8))\n```\n\n$\n \\begin{array}{ccccc}\n & ((1,8) & , & (2048,256)) & : ((0,1),(2048,8)) \\\\\n & \\underbrace{\\phantom{(1,8)}}_{tiler} & & \\underbrace{\\phantom{(2048,256)}}_{threads} & \\\\\n & \\text{\\scriptsize per-thread} & & \\text{\\scriptsize num of tiles}\n \\end{array}\n$"
|
||||
"source": [
|
||||
"## Vectorized Load and Store\n",
|
||||
"\n",
|
||||
"To improve performance according to Little's Law, we need to increase the number\n",
|
||||
"of in-flight requests. We can do this by increasing the number of bytes handled\n",
|
||||
"in each load & store operation per thread through vectorized memory access.\n",
|
||||
"\n",
|
||||
"Since Ampere GPUs support up to 128-bit per load/store and each element is 16-bit,\n",
|
||||
"we can load 8 elements per vectorized operation on contiguous rows.\n",
|
||||
"CuTe tiling operations make this vectorization straightforward.\n",
|
||||
"\n",
|
||||
"Using ``tiled_tensor = cute.zipped_divide(tensor, tiler)``, we can partition the input\n",
|
||||
"``tensor`` into groups of ``tiler`` blocks. For vectorization, we specify ``tiler``\n",
|
||||
"as the block of data each thread accesses (8 contiguous elements in the same row, or ``(1,8)``).\n",
|
||||
"Different threads can then access different blocks by indexing into the 2nd mode of ``tiled_tensor``.\n",
|
||||
"\n",
|
||||
"```python\n",
|
||||
"mA : cute.Tensor # (2048,2048):(2048,1)\n",
|
||||
"gA = cute.zipped_divide(a, tiler=(1, 8)) # tiled/vectorized => ((1,8),(2048,256)):((0,1),(2048,8))\n",
|
||||
"```\n",
|
||||
"\n",
|
||||
"$\n",
|
||||
" \\begin{array}{ccccc}\n",
|
||||
" & ((1,8) & , & (2048,256)) & : ((0,1),(2048,8)) \\\\\n",
|
||||
" & \\underbrace{\\phantom{(1,8)}}_{tiler} & & \\underbrace{\\phantom{(2048,256)}}_{threads} & \\\\\n",
|
||||
" & \\text{\\scriptsize per-thread} & & \\text{\\scriptsize num of tiles}\n",
|
||||
" \\end{array}\n",
|
||||
"$"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
@@ -423,14 +451,91 @@
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": "This vectorized kernel follows a similar structure to its naive non-vectorized counterpart,\nwith one key difference: the tensor slicing pattern. By using `(None, (mi, ni))` as the slice indices,\nwe can extract a `(1,8)` sub-tensor from `gA`, `gB` and `gC` like \n\n$ gA[(None, (mi, ni))]: $\n\n$\n \\begin{array}{ccccc}\n Layout: & ( & (1,8) & , & (2048,256) & ) & : & ((0,1),(2048,8)) & \\xrightarrow{\\text{slice}} & ((1,8)):((0,1)) \\\\\n & & \\underbrace{\\phantom{(1,8)}} & & \\underbrace{\\phantom{(2048,256)}} & & \\\\\n Coord: & ( & None & , & (mi, ni) & ) & &\n \\end{array}\n$\n\nThen tensor data can be loaded into vector via the `gA[(None, (mi, ni))].load()` method. It is equivalent to\n\n```python\nv0 = gA[(0, (mi, ni))] # => mA[(mi, ni * 8 + 0)]\nv1 = gA[(1, (mi, ni))] # => mA[(mi, ni * 8 + 1)]\nv2 = gA[(2, (mi, ni))] # => mA[(mi, ni * 8 + 2)]\nv3 = gA[(3, (mi, ni))] # => mA[(mi, ni * 8 + 3)]\nv4 = gA[(4, (mi, ni))] # => mA[(mi, ni * 8 + 4)]\nv5 = gA[(5, (mi, ni))] # => mA[(mi, ni * 8 + 5)]\nv6 = gA[(6, (mi, ni))] # => mA[(mi, ni * 8 + 6)]\nv7 = gA[(7, (mi, ni))] # => mA[(mi, ni * 8 + 7)]\n```\n\n### Assumed Alignment\n\nIn order to guide compile to use vectorized load/store, we must tell compiler to assume alignment of incoming pointer. \nIt's on users side to guarantee actual pointer at runtime meet the alignment restriction.\n\n```python\na_ = from_dlpack(a, assumed_align=16)\nb_ = from_dlpack(b, assumed_align=16)\nc_ = from_dlpack(c, assumed_align=16)\n\n# Compile kernel with alignment assumption\ncompiled_func = cute.compile(vectorized_elementwise_add, a_, b_, c_)\n```\n\nIt's worth to note that partitioned or tiled tensor could have different alignment of its base pointer because of offset\nduring sub-slice."
|
||||
"source": [
|
||||
"This vectorized kernel follows a similar structure to its naive non-vectorized counterpart,\n",
|
||||
"with one key difference: the tensor slicing pattern. By using `(None, (mi, ni))` as the slice indices,\n",
|
||||
"we can extract a `(1,8)` sub-tensor from `gA`, `gB` and `gC` like \n",
|
||||
"\n",
|
||||
"$ gA[(None, (mi, ni))]: $\n",
|
||||
"\n",
|
||||
"$\n",
|
||||
" \\begin{array}{ccccc}\n",
|
||||
" Layout: & ( & (1,8) & , & (2048,256) & ) & : & ((0,1),(2048,8)) & \\xrightarrow{\\text{slice}} & ((1,8)):((0,1)) \\\\\n",
|
||||
" & & \\underbrace{\\phantom{(1,8)}} & & \\underbrace{\\phantom{(2048,256)}} & & \\\\\n",
|
||||
" Coord: & ( & None & , & (mi, ni) & ) & &\n",
|
||||
" \\end{array}\n",
|
||||
"$\n",
|
||||
"\n",
|
||||
"Then tensor data can be loaded into vector via the `gA[(None, (mi, ni))].load()` method. It is equivalent to\n",
|
||||
"\n",
|
||||
"```python\n",
|
||||
"v0 = gA[(0, (mi, ni))] # => mA[(mi, ni * 8 + 0)]\n",
|
||||
"v1 = gA[(1, (mi, ni))] # => mA[(mi, ni * 8 + 1)]\n",
|
||||
"v2 = gA[(2, (mi, ni))] # => mA[(mi, ni * 8 + 2)]\n",
|
||||
"v3 = gA[(3, (mi, ni))] # => mA[(mi, ni * 8 + 3)]\n",
|
||||
"v4 = gA[(4, (mi, ni))] # => mA[(mi, ni * 8 + 4)]\n",
|
||||
"v5 = gA[(5, (mi, ni))] # => mA[(mi, ni * 8 + 5)]\n",
|
||||
"v6 = gA[(6, (mi, ni))] # => mA[(mi, ni * 8 + 6)]\n",
|
||||
"v7 = gA[(7, (mi, ni))] # => mA[(mi, ni * 8 + 7)]\n",
|
||||
"```\n",
|
||||
"\n",
|
||||
"### Assumed Alignment\n",
|
||||
"\n",
|
||||
"In order to guide compile to use vectorized load/store, we must tell compiler to assume alignment of incoming pointer. \n",
|
||||
"It's on users side to guarantee actual pointer at runtime meet the alignment restriction.\n",
|
||||
"\n",
|
||||
"```python\n",
|
||||
"a_ = from_dlpack(a, assumed_align=16)\n",
|
||||
"b_ = from_dlpack(b, assumed_align=16)\n",
|
||||
"c_ = from_dlpack(c, assumed_align=16)\n",
|
||||
"\n",
|
||||
"# Compile kernel with alignment assumption\n",
|
||||
"compiled_func = cute.compile(vectorized_elementwise_add, a_, b_, c_)\n",
|
||||
"```\n",
|
||||
"\n",
|
||||
"It's worth to note that partitioned or tiled tensor could have different alignment of its base pointer because of offset\n",
|
||||
"during sub-slice."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": "@cute.jit\ndef vectorized_elementwise_add(mA: cute.Tensor, mB: cute.Tensor, mC: cute.Tensor):\n threads_per_block = 256\n\n gA = cute.zipped_divide(mA, (1, 8))\n gB = cute.zipped_divide(mB, (1, 8))\n gC = cute.zipped_divide(mC, (1, 8))\n\n print(\"[DSL INFO] Tiled Tensors:\")\n print(f\"[DSL INFO] gA = {gA}\")\n print(f\"[DSL INFO] gB = {gB}\")\n print(f\"[DSL INFO] gC = {gC}\")\n\n vectorized_elementwise_add_kernel(gA, gB, gC).launch(\n grid=(cute.size(gC, mode=[1]) // threads_per_block, 1, 1),\n block=(threads_per_block, 1, 1),\n )\n\n\na = torch.randn(M, N, device=\"cuda\", dtype=torch.float16)\nb = torch.randn(M, N, device=\"cuda\", dtype=torch.float16)\nc = torch.zeros(M, N, device=\"cuda\", dtype=torch.float16)\n\na_ = from_dlpack(a, assumed_align=16)\nb_ = from_dlpack(b, assumed_align=16)\nc_ = from_dlpack(c, assumed_align=16)\n\ncompiled_func = cute.compile(vectorized_elementwise_add, a_, b_, c_)\ncompiled_func(a_, b_, c_)\n\n# verify correctness\ntorch.testing.assert_close(c, a + b)"
|
||||
"source": [
|
||||
"@cute.jit\n",
|
||||
"def vectorized_elementwise_add(mA: cute.Tensor, mB: cute.Tensor, mC: cute.Tensor):\n",
|
||||
" threads_per_block = 256\n",
|
||||
"\n",
|
||||
" gA = cute.zipped_divide(mA, (1, 8))\n",
|
||||
" gB = cute.zipped_divide(mB, (1, 8))\n",
|
||||
" gC = cute.zipped_divide(mC, (1, 8))\n",
|
||||
"\n",
|
||||
" print(\"[DSL INFO] Tiled Tensors:\")\n",
|
||||
" print(f\"[DSL INFO] gA = {gA}\")\n",
|
||||
" print(f\"[DSL INFO] gB = {gB}\")\n",
|
||||
" print(f\"[DSL INFO] gC = {gC}\")\n",
|
||||
"\n",
|
||||
" vectorized_elementwise_add_kernel(gA, gB, gC).launch(\n",
|
||||
" grid=(cute.size(gC, mode=[1]) // threads_per_block, 1, 1),\n",
|
||||
" block=(threads_per_block, 1, 1),\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"a = torch.randn(M, N, device=\"cuda\", dtype=torch.float16)\n",
|
||||
"b = torch.randn(M, N, device=\"cuda\", dtype=torch.float16)\n",
|
||||
"c = torch.zeros(M, N, device=\"cuda\", dtype=torch.float16)\n",
|
||||
"\n",
|
||||
"a_ = from_dlpack(a, assumed_align=16)\n",
|
||||
"b_ = from_dlpack(b, assumed_align=16)\n",
|
||||
"c_ = from_dlpack(c, assumed_align=16)\n",
|
||||
"\n",
|
||||
"compiled_func = cute.compile(vectorized_elementwise_add, a_, b_, c_)\n",
|
||||
"compiled_func(a_, b_, c_)\n",
|
||||
"\n",
|
||||
"# verify correctness\n",
|
||||
"torch.testing.assert_close(c, a + b)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
@@ -444,7 +549,68 @@
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": "## TV Layout\n\nBoth the naive and vectorized kernels follow a common pattern to map thread indices\nto physical addresses in two steps:\n\nStep 1: Map thread index to logical coordinates in `(M, N)`\n\n* `mi = thread_idx // n`\n* `ni = thread_idx % n`\n\nIn native version, each thread process 1 element, in this case, `mi` and `ni` is logical\ncoordinate into data tensor `mA`, `mB` and `mC`.\n\nInt vectorized version, each thread process multiple values of input and output tensor.\nlogical coordinate should be computed with both thread and value index.\n\n* `thread_idx // n`\n* `(thread_idx % n) * 8 + value_idx`\n\n\nStep 2: Map logical coordinates in `(M, N)` to physical addresses using the tensor layout\n\n* Vectorized Load\n\n```python\n frgA = gA[(None, (mi, ni))].load()\n```\n\n* Elementwise Load (less efficient)\n\n```python\n frgA0 = mA[(mi, ni * 8 + 0)]\n frgA1 = mA[(mi, ni * 8 + 1)]\n frgA2 = mA[(mi, ni * 8 + 2)]\n frgA3 = mA[(mi, ni * 8 + 3)]\n frgA4 = mA[(mi, ni * 8 + 4)]\n frgA5 = mA[(mi, ni * 8 + 5)]\n frgA6 = mA[(mi, ni * 8 + 6)]\n frgA7 = mA[(mi, ni * 8 + 7)]\n\n # Or use divided layout\n\n frgA0 = gA[(0, (mi, ni))]\n frgA1 = gA[(1, (mi, ni))]\n frgA2 = gA[(2, (mi, ni))]\n frgA3 = gA[(3, (mi, ni))]\n frgA4 = gA[(4, (mi, ni))]\n frgA5 = gA[(5, (mi, ni))]\n frgA6 = gA[(6, (mi, ni))]\n frgA7 = gA[(7, (mi, ni))]\n```\n\nCuTe introduces TV layout to represent this mapping from thread index and value index\n(i.e., the 8 elements loaded per thread) to the logical coordinate space of a tensor.\nBy configuring different TV layouts, we can experiment with different memory access\npatterns with minimal code changes.\n\n**Definition:** *TV Layout* is rank-2 layout which maps `(thread_index, value_index)` \nto logical coordinate of tensor. \n\nWe always have *TV Layout* with canonical form as `(thread_domain, value_domain):(..., ...)`.\n\nWith *TV Layout*, each thread can find logical coordinates or indices of data partitioned\nto current thread."
|
||||
"source": [
|
||||
"## TV Layout\n",
|
||||
"\n",
|
||||
"Both the naive and vectorized kernels follow a common pattern to map thread indices\n",
|
||||
"to physical addresses in two steps:\n",
|
||||
"\n",
|
||||
"Step 1: Map thread index to logical coordinates in `(M, N)`\n",
|
||||
"\n",
|
||||
"* `mi = thread_idx // n`\n",
|
||||
"* `ni = thread_idx % n`\n",
|
||||
"\n",
|
||||
"In native version, each thread process 1 element, in this case, `mi` and `ni` is logical\n",
|
||||
"coordinate into data tensor `mA`, `mB` and `mC`.\n",
|
||||
"\n",
|
||||
"Int vectorized version, each thread process multiple values of input and output tensor.\n",
|
||||
"logical coordinate should be computed with both thread and value index.\n",
|
||||
"\n",
|
||||
"* `thread_idx // n`\n",
|
||||
"* `(thread_idx % n) * 8 + value_idx`\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"Step 2: Map logical coordinates in `(M, N)` to physical addresses using the tensor layout\n",
|
||||
"\n",
|
||||
"* Vectorized Load\n",
|
||||
"\n",
|
||||
"```python\n",
|
||||
" frgA = gA[(None, (mi, ni))].load()\n",
|
||||
"```\n",
|
||||
"\n",
|
||||
"* Elementwise Load (less efficient)\n",
|
||||
"\n",
|
||||
"```python\n",
|
||||
" frgA0 = mA[(mi, ni * 8 + 0)]\n",
|
||||
" frgA1 = mA[(mi, ni * 8 + 1)]\n",
|
||||
" frgA2 = mA[(mi, ni * 8 + 2)]\n",
|
||||
" frgA3 = mA[(mi, ni * 8 + 3)]\n",
|
||||
" frgA4 = mA[(mi, ni * 8 + 4)]\n",
|
||||
" frgA5 = mA[(mi, ni * 8 + 5)]\n",
|
||||
" frgA6 = mA[(mi, ni * 8 + 6)]\n",
|
||||
" frgA7 = mA[(mi, ni * 8 + 7)]\n",
|
||||
"\n",
|
||||
" # Or use divided layout\n",
|
||||
"\n",
|
||||
" frgA0 = gA[(0, (mi, ni))]\n",
|
||||
" frgA1 = gA[(1, (mi, ni))]\n",
|
||||
" frgA2 = gA[(2, (mi, ni))]\n",
|
||||
" frgA3 = gA[(3, (mi, ni))]\n",
|
||||
"```\n",
|
||||
"\n",
|
||||
"CuTe introduces TV layout to represent this mapping from thread index and value index\n",
|
||||
"(i.e., the 4 elements loaded per thread) to the logical coordinate space of a tensor.\n",
|
||||
"By configuring different TV layouts, we can experiment with different memory access\n",
|
||||
"patterns with minimal code changes.\n",
|
||||
"\n",
|
||||
"**Definition:** *TV Layout* is rank-2 layout which maps `(thread_index, value_index)` \n",
|
||||
"to logical coordinate of tensor. \n",
|
||||
"\n",
|
||||
"We always have *TV Layout* with canonical form as `(thread_domain, value_domain):(..., ...)`.\n",
|
||||
"\n",
|
||||
"With *TV Layout*, each thread can find logical coordinates or indices of data partitioned\n",
|
||||
"to current thread.\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
@@ -1057,4 +1223,4 @@
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 4
|
||||
}
|
||||
}
|
||||
|
||||
@@ -279,13 +279,13 @@ struct CollectiveBuilder<
|
||||
// Basic storage block for new Scaling Factor Layouts
|
||||
using mnBasicBlockShape = Shape<_32,_4>;
|
||||
using mnBasicBlockStride = Stride<_16,_4>;
|
||||
using kBasicBlockShape = Shape<Int<SFVectorSize>, Int<MMA_NSF>>;
|
||||
using kBasicBlockShape = Shape<Int<(int)SFVectorSize>, Int<MMA_NSF>>;
|
||||
using kBasicBlockStride = Stride<_0, _1>;
|
||||
|
||||
using sSFA_shapeM = decltype(prepend(size<0>(TileShape_MNK{}) / Blk_MN{}, mnBasicBlockShape{}));
|
||||
using sSF_strideMN = decltype(prepend( Blk_Elems{}, mnBasicBlockStride{}));
|
||||
using sSFA_strideM = sSF_strideMN;
|
||||
using sSF_shapeK = decltype(prepend(make_shape( Blk_SF{}/Int<MMA_NSF>{}, size<2>(TileShape_MNK{}) / Int<SFVectorSize>{} / Blk_SF{}), kBasicBlockShape{}));
|
||||
using sSF_shapeK = decltype(prepend(make_shape( Blk_SF{}/Int<MMA_NSF>{}, size<2>(TileShape_MNK{}) / Int<(int)SFVectorSize>{} / Blk_SF{}), kBasicBlockShape{}));
|
||||
|
||||
using sSFA_strideK = decltype(prepend(make_stride( Int<MMA_NSF>{}, size<0>(TileShape_MNK{}) / Blk_MN{} * Blk_Elems{}), kBasicBlockStride{}));
|
||||
using sSFA_shape = decltype(make_shape( sSFA_shapeM{}, sSF_shapeK{}));
|
||||
|
||||
@@ -36,7 +36,7 @@
|
||||
|
||||
#define CUTLASS_MAJOR 4
|
||||
#define CUTLASS_MINOR 4
|
||||
#define CUTLASS_PATCH 1
|
||||
#define CUTLASS_PATCH 2
|
||||
|
||||
#ifdef CUTLASS_VERSIONS_GENERATED
|
||||
#include "cutlass/version_extended.h"
|
||||
|
||||
@@ -175,22 +175,21 @@ stride is set to 1 unless inconsistent with the layout of the DLPack tensor. For
|
||||
The default value for ``leading_dim`` is ``None``. In such case, the system
|
||||
automatically deduces it from the tensor's layout using the following logic:
|
||||
|
||||
1. If a dimension's stride is 1, that dimension is marked as the leading dimension.
|
||||
2. If multiple dimensions satisfy condition 1, an error is thrown indicating deduction failure.
|
||||
1. If exactly one dimension has stride 1, that dimension is the leading dimension.
|
||||
2. If multiple dimensions have stride 1, deduction succeeds only when exactly one of them
|
||||
has size > 1 (that dimension is used). If none or more than one has size > 1, an error is raised.
|
||||
Note that after converting a **PyTorch** tensor to the DLPack format, the stride for dimensions
|
||||
with size 1 are canonicalized to 1. This canonicalization can increase the likelihood of
|
||||
deduction failures. This behavior is specific to PyTorch and does not occur with NumPy for
|
||||
example.
|
||||
3. If no dimension satisfies condition 1, all strides are marked as dynamic.
|
||||
with size 1 are canonicalized to 1, which can produce multiple stride-1 dimensions.
|
||||
3. If no dimension has stride 1, all strides remain dynamic.
|
||||
|
||||
For example:
|
||||
|
||||
- For a tensor with layout ``(2,2,3,4):(2,1,4,12)``, the leading dimension is 1.
|
||||
The layout will be marked as ``(?,?,?,?):(?,1,?,?)``.
|
||||
- For a tensor with layout ``(1,5,1):(1,1,1)``, if ``leading_dim`` is not specified,
|
||||
a deduction failure error is raised.
|
||||
- For a tensor with layout ``(2,2):(8,2)``, since no dimension has stride 1,
|
||||
all dimensions are marked as dynamic: ``(?,?):(?,?)``.
|
||||
- For a tensor with layout ``(1,5,1):(1,1,1)``, multiple dimensions have stride 1 but exactly one
|
||||
has size > 1 (dim 1). The leading dimension is deduced to be 1: ``(?,?,?):(?,1,?)``.
|
||||
- For a tensor with layout ``(2,2):(8,2)``, no dimension has stride 1, so all strides remain
|
||||
dynamic: ``(?,?):(?,?)``.
|
||||
|
||||
The leading dimension accepts negative index which means the dimension is counted from the last dimension. For example,
|
||||
|
||||
@@ -206,8 +205,9 @@ The following example demonstrates how to use ``mark_layout_dynamic`` to specify
|
||||
* ``t1`` & ``t2`` shows the usage of ``mark_layout_dynamic`` with specified ``leading_dim``.
|
||||
* ``t3`` shows the usage of ``mark_layout_dynamic`` with no leading dimension.
|
||||
* ``t4`` shows the usage of ``mark_layout_dynamic`` with broadcasted dimensions.
|
||||
* ``t5`` demonstrates the deduction failure when the there're more than one dimensions with stride equals to 1.
|
||||
* ``t6`` & ``t7`` demonstrates incorrect settings for ``leading_dim`` and expected errors.
|
||||
* ``t5`` shows automatic deduction for tensor ``b`` (multiple stride-1, exactly one has size > 1 → dim 1).
|
||||
* ``t5_fail`` demonstrates the deduction failure when multiple dimensions have stride 1 but none has size > 1.
|
||||
* ``t6`` & ``t7`` demonstrate incorrect settings for ``leading_dim`` and expected errors.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
@@ -245,8 +245,14 @@ The following example demonstrates how to use ``mark_layout_dynamic`` to specify
|
||||
print(t4)
|
||||
# (?,?,?,?):(?,0,0,1)
|
||||
|
||||
# b has layout (1,4,1,32,1):(1,1,1,4,1); dim 1 has size > 1, so deduction succeeds to dim 1.
|
||||
t5 = from_dlpack(b).mark_layout_dynamic()
|
||||
# Can't decude the leading dimension from layout, please specify the leading_dim explicitly.
|
||||
print(t5)
|
||||
# (?,?,?,?,?):(?{i64},1,?{i64},?{i64},?{i64})
|
||||
|
||||
# Rejected: multiple stride-1, none with size > 1 (e.g. torch.ones(1,1,1)).
|
||||
t5_fail = from_dlpack(torch.ones(1, 1, 1)).mark_layout_dynamic()
|
||||
# Can't deduce the leading dimension from layout (multiple dimensions have stride 1 but none has size > 1)...
|
||||
|
||||
t6 = from_dlpack(a).mark_layout_dynamic(leading_dim=1)
|
||||
# Expected strides[leading_dim] == 1, but got 16
|
||||
|
||||
@@ -68,7 +68,8 @@ class ExternalBinaryModule:
|
||||
|
||||
load_provider: LoadProvider = None
|
||||
|
||||
def __init__(self, file_path: str):
|
||||
def __init__(self, file_path: str, enable_tvm_ffi: bool = False):
|
||||
self.enable_tvm_ffi = enable_tvm_ffi
|
||||
assert self.load_provider is not None, (
|
||||
"Load provider is not set for ExternalBinaryModule."
|
||||
)
|
||||
@@ -82,13 +83,28 @@ class ExternalBinaryModule:
|
||||
object_file_content = f.read()
|
||||
except Exception as e:
|
||||
raise DSLRuntimeError(f"Failed to read object file {file_path}: {e}")
|
||||
|
||||
useJitLink = not enable_tvm_ffi
|
||||
# Lifetime of the engine is same as the ExternalBinaryModule.
|
||||
self.engine = self.load_provider.execution_engine_constructor(
|
||||
object_file_content, shared_libs
|
||||
object_file_content, shared_libs, useJitLink
|
||||
)
|
||||
|
||||
def __getattr__(self, function_prefix: str) -> "JitCompiledFunction":
|
||||
"""Get the jit_function from the `function_prefix`. The `function_prefix` is specified when users dump the object file. When there is no function_prefix found in the module, the function will raise an error."""
|
||||
if self.enable_tvm_ffi:
|
||||
try:
|
||||
import tvm_ffi
|
||||
|
||||
function_ptr = self.engine.lookup("__tvm_ffi_" + function_prefix)
|
||||
return tvm_ffi.Function.__from_extern_c__(
|
||||
function_ptr, keep_alive_object=self.engine
|
||||
)
|
||||
except Exception as e:
|
||||
raise DSLRuntimeError(
|
||||
f"Failed to load TVM FFI function {function_prefix}: {e}"
|
||||
)
|
||||
|
||||
try:
|
||||
args_spec, function_name, kernel_info, version_str = (
|
||||
decode_metadata_from_execution_engine(
|
||||
@@ -124,3 +140,7 @@ class ExternalBinaryModule:
|
||||
load_from_binary=True,
|
||||
)
|
||||
return jit_function
|
||||
|
||||
def __getitem__(self, function_prefix: str) -> "JitCompiledFunction":
|
||||
"""Get the jit_function from the `function_prefix`. The `function_prefix` is specified when users dump the object file. When there is no function_prefix found in the module, the function will raise an error."""
|
||||
return self.__getattr__(function_prefix)
|
||||
|
||||
@@ -202,6 +202,8 @@ from .math import *
|
||||
# Used as internal symbol
|
||||
from .. import cutlass_dsl as _dsl
|
||||
|
||||
from .ffi import ffi
|
||||
|
||||
# Aliases
|
||||
jit = _dsl.CuTeDSL.jit
|
||||
kernel = _dsl.CuTeDSL.kernel
|
||||
@@ -312,4 +314,5 @@ __all__ = [
|
||||
"kernel",
|
||||
"register_jit_arg_adapter",
|
||||
"compile",
|
||||
"ffi",
|
||||
]
|
||||
|
||||
@@ -96,7 +96,6 @@ __all__ = [
|
||||
"fma_packed_f32x2",
|
||||
"mul_packed_f32x2",
|
||||
"add_packed_f32x2",
|
||||
"sub_packed_f32x2",
|
||||
"fmax",
|
||||
"rcp_approx",
|
||||
"exp2",
|
||||
|
||||
@@ -23,7 +23,6 @@ from ..typing import Int32, Pointer, Int128
|
||||
def issue_clc_query(
|
||||
mbar_ptr: Pointer,
|
||||
clc_response_ptr: Pointer,
|
||||
multicast: bool = True,
|
||||
loc=None,
|
||||
ip=None,
|
||||
) -> None:
|
||||
@@ -40,20 +39,12 @@ def issue_clc_query(
|
||||
"""
|
||||
mbar_llvm_ptr = mbar_ptr.llvm_ptr
|
||||
clc_response_llvm_ptr = clc_response_ptr.llvm_ptr
|
||||
if multicast:
|
||||
nvvm.clusterlaunchcontrol_try_cancel_multicast(
|
||||
clc_response_llvm_ptr,
|
||||
mbar_llvm_ptr,
|
||||
loc=loc,
|
||||
ip=ip,
|
||||
)
|
||||
else:
|
||||
nvvm.clusterlaunchcontrol_try_cancel(
|
||||
clc_response_llvm_ptr,
|
||||
mbar_llvm_ptr,
|
||||
loc=loc,
|
||||
ip=ip,
|
||||
)
|
||||
nvvm.clusterlaunchcontrol_try_cancel_multicast(
|
||||
clc_response_llvm_ptr,
|
||||
mbar_llvm_ptr,
|
||||
loc=loc,
|
||||
ip=ip,
|
||||
)
|
||||
|
||||
|
||||
@dsl_user_op
|
||||
|
||||
@@ -604,7 +604,6 @@ def fence_proxy(
|
||||
],
|
||||
*,
|
||||
space: Optional[Literal["cta", "cluster"]] = None,
|
||||
use_intrinsic=None,
|
||||
loc=None,
|
||||
ip=None,
|
||||
) -> None:
|
||||
@@ -623,7 +622,6 @@ def fence_proxy(
|
||||
- "cta" : CTA (Cooperative Thread Array) scope
|
||||
- "cluster" : Cluster scope
|
||||
:type space: Optional[Literal["cta", "cluster"]]
|
||||
:param use_intrinsic: Whether to use intrinsic version
|
||||
"""
|
||||
from cutlass._mlir.dialects.nvvm import (
|
||||
SharedSpace,
|
||||
@@ -640,7 +638,6 @@ def fence_proxy(
|
||||
nvvm.fence_proxy(
|
||||
kind=kind,
|
||||
space=space,
|
||||
use_intrinsic=use_intrinsic,
|
||||
loc=loc,
|
||||
ip=ip,
|
||||
)
|
||||
@@ -940,9 +937,6 @@ mul_packed_f32x2 = partial(
|
||||
add_packed_f32x2 = partial(
|
||||
calc_packed_f32x2_op, src_c=None, calc_func=nvvm.add_packed_f32x2
|
||||
)
|
||||
sub_packed_f32x2 = partial(
|
||||
calc_packed_f32x2_op, src_c=None, calc_func=nvvm.sub_packed_f32x2
|
||||
)
|
||||
|
||||
|
||||
@dsl_user_op
|
||||
@@ -959,20 +953,6 @@ def fmax(
|
||||
)
|
||||
|
||||
|
||||
@dsl_user_op
|
||||
def fmin(
|
||||
a: Union[float, Float32], b: Union[float, Float32], *, loc=None, ip=None
|
||||
) -> Float32:
|
||||
return Float32(
|
||||
nvvm.fmin(
|
||||
Float32(a).ir_value(loc=loc, ip=ip),
|
||||
Float32(b).ir_value(loc=loc, ip=ip),
|
||||
loc=loc,
|
||||
ip=ip,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@dsl_user_op
|
||||
def rcp_approx(a: Union[float, Float32], *, loc=None, ip=None):
|
||||
return Float32(
|
||||
|
||||
@@ -1587,7 +1587,7 @@ def pretty_str(arg) -> str:
|
||||
|
||||
|
||||
@dsl_user_op
|
||||
def printf(*args, loc=None, ip=None, end="\n") -> None:
|
||||
def printf(*args, loc=None, ip=None) -> None:
|
||||
"""
|
||||
Print one or more values with optional formatting.
|
||||
|
||||
@@ -1607,8 +1607,6 @@ def printf(*args, loc=None, ip=None, end="\n") -> None:
|
||||
:type loc: Optional[Location]
|
||||
:param ip: Insertion point for code generation, defaults to None
|
||||
:type ip: Optional[InsertionPoint]
|
||||
:param end: Suffix for the printed value, defaults to newline
|
||||
:type end: Optional[str]
|
||||
:raises ValueError: If no arguments are provided
|
||||
:raises TypeError: If an unsupported argument type is passed
|
||||
|
||||
@@ -1638,10 +1636,10 @@ def printf(*args, loc=None, ip=None, end="\n") -> None:
|
||||
raise ValueError("expects at least one argument to print")
|
||||
|
||||
if isinstance(args[0], str):
|
||||
fmt = args[0] + end
|
||||
fmt = args[0] + "\n"
|
||||
args = args[1:]
|
||||
else:
|
||||
fmt = "{}" + ", {}" * (len(args) - 1) + end
|
||||
fmt = "{}" + ", {}" * (len(args) - 1) + "\n"
|
||||
|
||||
def process_arg(arg):
|
||||
arg0 = arg.value if isinstance(arg, Numeric) else arg
|
||||
@@ -3762,6 +3760,35 @@ def zipped_divide(target: Tensor, tiler: Tiler, *, loc=None, ip=None) -> Tensor:
|
||||
|
||||
@dsl_user_op
|
||||
def zipped_divide(target, tiler: Tiler, *, loc=None, ip=None):
|
||||
"""
|
||||
``zipped_divide`` is ``logical_divide`` with Tiler modes and Rest modes gathered together: ``(Tiler,Rest)``
|
||||
|
||||
- When Tiler is Layout, this has no effect as ``logical_divide`` results in the same.
|
||||
- When Tiler is ``Tile`` (nested tuple of ``Layout``) or ``Shape``, this zips modes into standard form
|
||||
``((BLK_A,BLK_B),(a,b,x,y))``
|
||||
|
||||
For example, if ``target`` has shape ``(s, t, r)`` and ``tiler`` has shape ``(BLK_A, BLK_B)``,
|
||||
then the result will have shape ``((BLK_A, BLK_B), (ceil_div(s, BLK_A), ceil_div(t, BLK_B), r))``.
|
||||
|
||||
:param target: The layout or tensor to partition.
|
||||
:type target: Layout or Tensor
|
||||
:param tiler: The tiling specification (can be a Layout, Shape, Tile).
|
||||
:type tiler: Tiler
|
||||
:param loc: Optional MLIR IR location information.
|
||||
:type loc: optional
|
||||
:param ip: Optional MLIR IR insertion point.
|
||||
:type ip: optional
|
||||
:return: A zipped (partitioned) version of the target.
|
||||
:rtype: Layout or Tensor
|
||||
|
||||
**Example:**
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
layout = cute.make_layout((128, 64), stride=(64, 1))
|
||||
tiler = (8, 8)
|
||||
result = cute.zipped_divide(layout, tiler) # result shape: ((8, 8), (16, 8))
|
||||
"""
|
||||
if isinstance(tiler, tuple):
|
||||
tiler = _pack_tile(tiler, loc=loc, ip=ip) # type: ignore
|
||||
return _op_wrapper(
|
||||
@@ -3904,6 +3931,73 @@ def local_tile(
|
||||
loc=None,
|
||||
ip=None,
|
||||
) -> Tensor:
|
||||
"""
|
||||
Partition a tensor into tiles using a tiler and extract a single tile at the provided coordinate.
|
||||
|
||||
The ``local_tile`` operation applies a ``zipped_divide`` to split the ``input`` tensor by the ``tiler``
|
||||
and then slices out a single tile using the provided `coord`. This is commonly used for extracting block-,
|
||||
thread-, or CTA-level tiles for parallel operations.
|
||||
|
||||
.. math::
|
||||
|
||||
\\text{local_tile}(input, tiler, coord) = \\text{zipped_divide}(input, tiler)[coord]
|
||||
|
||||
This function corresponds to the CUTE/C++ `local_tile` utility:
|
||||
https://docs.nvidia.com/cutlass/media/docs/cpp/cute/03_tensor.html#local-tile
|
||||
|
||||
:param input: The input tensor to partition into tiles.
|
||||
:type input: Tensor
|
||||
:param tiler: The tiling specification (can be a Layout, Shape, Tile).
|
||||
:type tiler: Tiler
|
||||
:param coord: The coordinate to select within the remainder ("rest") modes after tiling.
|
||||
This selects which tile to extract.
|
||||
:type coord: Coord
|
||||
:param proj: (Optional) Projection onto tiling modes; specify to project out unused tiler modes,
|
||||
e.g., when working with projections of tilers in multi-mode partitioning.
|
||||
Default is None for no projection.
|
||||
:type proj: XTuple, optional
|
||||
:param loc: (Optional) MLIR location, for diagnostic/debugging.
|
||||
:type loc: Any, optional
|
||||
:param ip: (Optional) MLIR insertion point, used in IR building context.
|
||||
:type ip: Any, optional
|
||||
|
||||
:return: A new tensor representing the local tile selected at the given coordinate.
|
||||
:rtype: Tensor
|
||||
|
||||
**Examples**
|
||||
|
||||
1. Tiling a 2D tensor and extracting a tile:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
# input: (16, 24)
|
||||
tensor : cute.Tensor
|
||||
tiler = (2, 4)
|
||||
coord = (1, 1)
|
||||
|
||||
# output: (8, 6)
|
||||
# - zipped_divide(tensor, tiler) -> ((2, 4), (8, 6))
|
||||
# - local_tile(tensor, tiler, coord) -> (8, 6)
|
||||
result = cute.local_tile(tensor, tiler=tiler, coord=coord)
|
||||
|
||||
2. Using a stride projection for specialized tiling:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
# input: (16, 24)
|
||||
tensor : cute.Tensor
|
||||
tiler = (2, 2, 4)
|
||||
coord = (0, 1, 1)
|
||||
proj = (1, None, 1)
|
||||
|
||||
# output: (8, 6)
|
||||
# projected_tiler: (2, 4)
|
||||
# projected_coord: (0, 1)
|
||||
# - zipped_divide(tensor, projected_tiler) -> ((2, 4), (8, 6))
|
||||
# - local_tile(tensor, projected_tiler, projected_coord) -> (8, 6)
|
||||
result = cute.local_tile(tensor, tiler=tiler, coord=coord, proj=proj)
|
||||
"""
|
||||
|
||||
tiler_val = _pack_tile(tiler, loc=loc, ip=ip)
|
||||
coord_val = _pack_coord(coord, loc=loc, ip=ip)
|
||||
if proj is not None:
|
||||
|
||||
206
python/CuTeDSL/cutlass/cute/ffi.py
Normal file
206
python/CuTeDSL/cutlass/cute/ffi.py
Normal file
@@ -0,0 +1,206 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
|
||||
#
|
||||
# Use of this software is governed by the terms and conditions of the
|
||||
# NVIDIA End User License Agreement (EULA), available at:
|
||||
# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html
|
||||
#
|
||||
# Any use, reproduction, disclosure, or distribution of this software
|
||||
# and related documentation outside the scope permitted by the EULA
|
||||
# is strictly prohibited.
|
||||
|
||||
from cutlass._mlir import ir
|
||||
from cutlass._mlir.dialects import func
|
||||
from cutlass.base_dsl.typing import get_mlir_types, NumericMeta, Numeric, as_numeric
|
||||
from cutlass.base_dsl.dsl import extract_mlir_values
|
||||
|
||||
from cutlass import DSLRuntimeError
|
||||
|
||||
|
||||
class ffi:
|
||||
"""
|
||||
Foreign Function Interface (FFI) wrapper for external function invocation in the CUTLASS Python DSL.
|
||||
|
||||
This class enables calling external MLIR function prototypes from Python code, handling type conversion,
|
||||
prototype registration, and dynamic insertion of function symbols into MLIR modules as needed.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
name : str
|
||||
Name of the external function. This will be used as the symbol name when calling or registering a prototype in the MLIR module.
|
||||
params_types : list, optional
|
||||
List of argument types for the external function. These can be CUTLASS numeric types, numeric meta types, or types convertible via `get_mlir_types`.
|
||||
return_type : optional
|
||||
The return type of the external function. If not specified, the function is assumed to have no return value.
|
||||
|
||||
Methods
|
||||
-------
|
||||
__call__(*args)
|
||||
Calls the external function with the given arguments, ensuring argument and result types match the prototype.
|
||||
"""
|
||||
|
||||
def __init__(self, *, name: str, params_types: list = [], return_type=None):
|
||||
self.name = name
|
||||
self.params_types = params_types
|
||||
self.return_type = [return_type] if return_type else []
|
||||
|
||||
def _get_prototype_region(self, current_op):
|
||||
"""
|
||||
Helper method to determine the appropriate MLIR module and region for inserting a function prototype.
|
||||
|
||||
This method recursively traverses the current operation's parent hierarchy to find the correct module
|
||||
and region where the function prototype should be inserted. It supports both builtin.module and gpu.module.
|
||||
:param current_op: The current operation to check.
|
||||
:type current_op: Operation
|
||||
|
||||
:returns:
|
||||
A tuple containing the module operation and the insertion region.
|
||||
:rtype: tuple
|
||||
"""
|
||||
if current_op is None:
|
||||
raise DSLRuntimeError("current operation is unknown")
|
||||
op_name = current_op.name
|
||||
if op_name in ["builtin.module", "gpu.module"]:
|
||||
return current_op, current_op.regions[0].blocks[0]
|
||||
else:
|
||||
return self._get_prototype_region(current_op.parent)
|
||||
|
||||
@staticmethod
|
||||
def _to_mlir_types(args):
|
||||
"""
|
||||
Helper method to convert a list of arguments to their corresponding MLIR types.
|
||||
|
||||
This method converts CUTLASS numeric types, numeric meta types, and types convertible via `get_mlir_types`
|
||||
to their corresponding MLIR types.
|
||||
:param args: The list of arguments to convert to MLIR types.
|
||||
:type args: list
|
||||
|
||||
:returns:
|
||||
A list of MLIR types.
|
||||
:rtype: list
|
||||
"""
|
||||
types = []
|
||||
for param in args:
|
||||
if isinstance(param, NumericMeta):
|
||||
types.append(param.mlir_type)
|
||||
elif isinstance(param, Numeric):
|
||||
types.append(param.mlir_type)
|
||||
else:
|
||||
types.extend(get_mlir_types(param))
|
||||
return types
|
||||
|
||||
@staticmethod
|
||||
def _type_check(callee, exec_types, returns_types):
|
||||
"""
|
||||
Helper method to check if the function prototype types match the expected types.
|
||||
|
||||
This method compares the input and output types of the function prototype with the provided expected types.
|
||||
:param callee: The function prototype operation to check.
|
||||
:type callee: func.FuncOp
|
||||
:param exec_types: The expected input types.
|
||||
:type exec_types: list
|
||||
:param returns_types: The expected output types.
|
||||
:type returns_types: list
|
||||
"""
|
||||
if callee.type.inputs != exec_types or callee.type.results != returns_types:
|
||||
raise DSLRuntimeError(
|
||||
f"External prototype types mismatch, trying to call with ({exec_types}) -> ({returns_types}), got {callee.type}"
|
||||
)
|
||||
|
||||
def _create_prototype_in_region(self, op, region, exec_args):
|
||||
"""
|
||||
Helper method to create or retrieve a function prototype in the current module.
|
||||
|
||||
This method checks if a function prototype with the given name already exists in the symbol table of the current module.
|
||||
If it does, it checks if the prototype's types match the expected types. If it does not, it raises an error.
|
||||
If it does not exist, it creates a new function prototype and inserts it into the current region.
|
||||
:param op: The module operation to check.
|
||||
:type op: Operation
|
||||
:param region: The region to insert the function prototype into.
|
||||
:type region: Region
|
||||
:param exec_args: The arguments to pass to the function prototype.
|
||||
:type exec_args: list
|
||||
"""
|
||||
symbol_table = ir.SymbolTable(op.operation)
|
||||
|
||||
if self.name in symbol_table:
|
||||
callee = symbol_table[self.name]
|
||||
else:
|
||||
with ir.InsertionPoint(region):
|
||||
callee = func.FuncOp(
|
||||
self.name,
|
||||
(
|
||||
ffi._to_mlir_types(self.params_types),
|
||||
ffi._to_mlir_types(self.return_type),
|
||||
),
|
||||
)
|
||||
callee.sym_visibility = ir.StringAttr.get("private")
|
||||
|
||||
# Sanity check the function prototype types match the expected types
|
||||
self._type_check(
|
||||
callee,
|
||||
ffi._to_mlir_types(exec_args),
|
||||
ffi._to_mlir_types(self.return_type),
|
||||
)
|
||||
|
||||
return callee
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
"""
|
||||
Calls the FFI function prototype with the provided arguments.
|
||||
|
||||
This method ensures that an IR-level function prototype (external declaration)
|
||||
with the given name and type signature exists in the current module. If it does not
|
||||
exist, it will be created and inserted into the module. A call operation to this
|
||||
function is then emitted using the arguments supplied by the caller.
|
||||
|
||||
:param args:
|
||||
The runtime arguments to pass to the FFI function. These will be converted to
|
||||
their corresponding numeric types and lowered to MLIR values before being used as arguments.
|
||||
:type args: tuple
|
||||
|
||||
:returns:
|
||||
The MLIR call operation created for this invocation.
|
||||
:rtype: func.CallOp
|
||||
|
||||
:raises DSLRuntimeError:
|
||||
If there is no active MLIR insertion point or if the current operation
|
||||
context cannot be determined.
|
||||
"""
|
||||
|
||||
if kwargs:
|
||||
raise DSLRuntimeError(
|
||||
"Keyword arguments are not supported for FFI calls",
|
||||
suggestion="Use positional arguments only",
|
||||
)
|
||||
|
||||
# Get the current insertion point and operation
|
||||
try:
|
||||
current_ip = ir.InsertionPoint.current
|
||||
except Exception:
|
||||
raise DSLRuntimeError(
|
||||
"Failed to determine current insertion point",
|
||||
suggestion="Make sure this is called under a jit context",
|
||||
)
|
||||
current_op = current_ip.block.owner
|
||||
module_op, insertion_region = self._get_prototype_region(current_op)
|
||||
|
||||
# Extract the arguments to MLIR values
|
||||
exec_args = []
|
||||
for arg in args:
|
||||
exec_arg = extract_mlir_values(arg)
|
||||
if not exec_arg:
|
||||
exec_arg = [as_numeric(arg).ir_value()]
|
||||
exec_args.extend(exec_arg)
|
||||
|
||||
# Create the function prototype in module, so if it's under kernel function, prototype will be inserted into gpu.module
|
||||
# If it's under gpu.module, prototype will be inserted into builtin.module
|
||||
callee = self._create_prototype_in_region(
|
||||
module_op, insertion_region, exec_args
|
||||
)
|
||||
|
||||
# Emit the call operation
|
||||
result = func.call(callee.type.results, self.name, exec_args)
|
||||
|
||||
if self.return_type:
|
||||
return result
|
||||
@@ -333,7 +333,7 @@ class MmaF16BF16Trait(MmaTraits):
|
||||
@dataclass(frozen=True)
|
||||
class MmaF8Op(MmaOp):
|
||||
"""
|
||||
FP8 warpgroup MMA Operation.
|
||||
F8 warpgroup MMA Operation.
|
||||
|
||||
See the `PTX documentation <https://docs.nvidia.com/cuda/parallel-thread-execution/#asynchronous-warpgroup-level-matrix-instructions-wgmma-mma>`__.
|
||||
This Operation covers the instructions using the ``.e4m3`` or ``.e5m2`` qualifiers for the input operands.
|
||||
|
||||
@@ -17,6 +17,7 @@ import itertools
|
||||
import operator
|
||||
from typing import Union, Optional, Type, List
|
||||
|
||||
|
||||
# MLIR modules imports
|
||||
from cutlass._mlir import ir
|
||||
from cutlass.base_dsl.env_manager import get_prefix_dsl_libs
|
||||
@@ -128,6 +129,10 @@ class _Pointer(Pointer):
|
||||
def __repr__(self):
|
||||
return self.__str__()
|
||||
|
||||
@property
|
||||
def __cache_key__(self) -> tuple:
|
||||
return (self.dtype, self._addr_space, self._assumed_align)
|
||||
|
||||
|
||||
class _Tensor(Tensor):
|
||||
def __init__(
|
||||
@@ -144,7 +149,7 @@ class _Tensor(Tensor):
|
||||
elif enable_tvm_ffi:
|
||||
import tvm_ffi
|
||||
|
||||
self._tvm_ffi_tensor = tvm_ffi.from_dlpack(tensor, stream=-1)
|
||||
self._tvm_ffi_tensor = tvm_ffi.from_dlpack(tensor)
|
||||
self._dlpack_data = self._tvm_ffi_tensor.__dlpack__()
|
||||
else:
|
||||
try:
|
||||
@@ -185,9 +190,17 @@ class _Tensor(Tensor):
|
||||
:param leading_dim: The leading dimension of the layout, defaults to None
|
||||
:type leading_dim: int, optional
|
||||
|
||||
When ``leading_dim`` is None, automatically deduces the leading dimension from the tensor layout.
|
||||
The layout can be deduced only when exactly one dimension has a stride of 1. Raises an error
|
||||
if the layout cannot be automatically deduced.
|
||||
When ``leading_dim`` is None, the leading dimension is deduced as follows.
|
||||
|
||||
(1) If exactly one dimension has stride 1, that dimension is used.
|
||||
|
||||
(2) If multiple dimensions have stride 1 but exactly one of them has size > 1,
|
||||
that dimension is used.
|
||||
|
||||
(3) If multiple dimensions have stride 1 but none or more than one has size > 1,
|
||||
an error is raised.
|
||||
|
||||
(4) If no dimension has stride 1, all strides remain dynamic.
|
||||
|
||||
When ``leading_dim`` is explicitly specified, marks the layout as dynamic while setting the
|
||||
stride at ``leading_dim`` to 1. Also validates that the specified ``leading_dim`` is consistent
|
||||
@@ -304,6 +317,13 @@ class _Tensor(Tensor):
|
||||
def __repr__(self):
|
||||
return self.__str__()
|
||||
|
||||
@property
|
||||
def __cache_key__(self) -> tuple:
|
||||
self.load_dltensor()
|
||||
if self._dtype is None:
|
||||
self._dtype = self._dltensor_wrapper.dtype
|
||||
return (self._dtype, self._assumed_align, self._dltensor_wrapper.cache_key())
|
||||
|
||||
def __setitem__(self, crd, value):
|
||||
raise TypeError("runtime._Tensor is not indexable")
|
||||
|
||||
@@ -417,42 +437,76 @@ def _get_cute_type_str(inp):
|
||||
return "(" + ",".join(elems) + ")"
|
||||
|
||||
|
||||
class _FakeCompactTensor(Tensor):
|
||||
class _FakeTensor(Tensor):
|
||||
"""Fake Tensor implementation as a placeholder.
|
||||
It mimics the interface of Tensor, but does not hold real data or allow indexing.
|
||||
Used for compilation or testing situations where only shape/type/layout information is needed.
|
||||
All attempts to access or mutate data will raise errors.
|
||||
"""
|
||||
|
||||
"""
|
||||
Create a fake tensor with the given shape, type, and layout.
|
||||
|
||||
:param dtype: Data type of the tensor elements
|
||||
:type dtype: Type[Numeric]
|
||||
:param shape: Shape of the tensor, consists of int (static) or SymInt (dynamic)
|
||||
:type shape: tuple[Union[int, SymInt], ...]
|
||||
:param stride: Stride of the tensor, defaults to None, consists of int (static) or SymInt (dynamic)
|
||||
:type stride: tuple[Union[int, SymInt], ...], optional
|
||||
:param memspace: Memory space where the fake tensor resides. Defaults to AddressSpace.gmem.
|
||||
:type memspace: AddressSpace, optional
|
||||
:param assumed_align: Assumed alignment of the tensor (bytes), defaults to None. If None, uses the element size bytes as the assumed alignment.
|
||||
:type assumed_align: int, optional
|
||||
:param use_32bit_stride: Whether to use 32-bit stride. Defaults to False. When True, the dynamic stride bitwidth
|
||||
will be set to 32 for small problem sizes (cosize(layout) <= Int32_max) for better performance. This is only applied
|
||||
when the dimension is dynamic.
|
||||
:type use_32bit_stride: bool, optional
|
||||
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dtype,
|
||||
shape,
|
||||
stride_order,
|
||||
memspace=None,
|
||||
assumed_align=None,
|
||||
use_32bit_stride=False,
|
||||
dtype: Type[Numeric],
|
||||
shape: tuple[Union[int, SymInt], ...],
|
||||
*,
|
||||
stride: tuple[Union[int, SymInt], ...],
|
||||
memspace: AddressSpace = AddressSpace.gmem,
|
||||
assumed_align: int | None = None,
|
||||
use_32bit_stride: bool = False,
|
||||
compact: bool = False,
|
||||
):
|
||||
self._dtype = dtype
|
||||
self._shape = shape
|
||||
self._stride_order = stride_order or tuple(range(len(shape)))
|
||||
# cannot use memspace or AddressSpace.gmem because AddressSpace.generic is 0
|
||||
self._memspace = memspace if memspace is not None else AddressSpace.gmem
|
||||
self._assumed_align = assumed_align or -(-dtype.width // 8)
|
||||
self._stride = stride
|
||||
self._use_32bit_stride = use_32bit_stride
|
||||
self._compact = compact
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"FakeTensorOrdered<{self._dtype}, {self._shape}, {self._stride_order}>"
|
||||
if not isinstance(shape, (tuple, list)):
|
||||
raise ValueError(f"Expected tuple or list but got {type(shape)}")
|
||||
|
||||
def __repr__(self):
|
||||
return self.__str__()
|
||||
if not all(isinstance(s, (int, SymInt)) for s in self._shape):
|
||||
raise ValueError("All shape elements must be int or SymInt")
|
||||
|
||||
if stride is not None and not all(
|
||||
isinstance(s, (int, SymInt)) for s in self._stride
|
||||
):
|
||||
raise ValueError("All stride elements must be int or SymInt")
|
||||
self._memspace = memspace
|
||||
self._assumed_align = assumed_align
|
||||
if assumed_align is None:
|
||||
# use the bytes width of the element dtype. The alignment is at least one byte align.
|
||||
self._assumed_align = (self._dtype.width + 7) // 8
|
||||
|
||||
@property
|
||||
def mlir_type(self) -> ir.Type:
|
||||
shape_ty = ir.Type.parse(
|
||||
'!cute.shape<"' + _get_cute_type_str(self._shape) + '">'
|
||||
)
|
||||
layout_ty = _cute_ir.LayoutType.get_ordered(
|
||||
shape_ty, self._stride_order, self._use_32bit_stride
|
||||
)
|
||||
self._stride = layout_ty.stride
|
||||
ptr_ty = _cute_ir.PtrType.get(
|
||||
self._dtype.mlir_type, self._memspace, self._assumed_align
|
||||
)
|
||||
shape_str = _get_cute_type_str(self._shape)
|
||||
stride_str = _get_cute_type_str(self._stride)
|
||||
layout_ty = ir.Type.parse(f'!cute.layout<"{shape_str}:{stride_str}">')
|
||||
|
||||
# Boolean types are stored as i8 in memory
|
||||
elem_type = T.i8() if self._dtype.width == 1 else self._dtype.mlir_type
|
||||
ptr_ty = _cute_ir.PtrType.get(elem_type, self._memspace, self._assumed_align)
|
||||
return _cute_ir.MemRefType.get(ptr_ty, layout_ty)
|
||||
|
||||
def __get_mlir_types__(self):
|
||||
@@ -463,11 +517,53 @@ class _FakeCompactTensor(Tensor):
|
||||
assert isinstance(values[0], CoreTensor)
|
||||
return CoreTensor(values[0].value, self._dtype)
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"FakeTensor<{self._dtype}, {self._shape}, {self._stride}>"
|
||||
|
||||
@property
|
||||
def __cache_key__(self) -> tuple:
|
||||
# Check if any shape or stride element is a SymInt without a symbol
|
||||
import warnings
|
||||
|
||||
has_unnamed_symint = False
|
||||
for dim in self._shape:
|
||||
if isinstance(dim, SymInt) and dim.symbol is None:
|
||||
has_unnamed_symint = True
|
||||
break
|
||||
if not self._compact:
|
||||
if not has_unnamed_symint:
|
||||
for stride in self._stride:
|
||||
if isinstance(stride, SymInt) and stride.symbol is None:
|
||||
has_unnamed_symint = True
|
||||
break
|
||||
|
||||
if has_unnamed_symint:
|
||||
warnings.warn(
|
||||
"FakeTensor cache_key contains unnamed symbolic dimensions. "
|
||||
"Different variables with the same shape/stride pattern will have "
|
||||
"identical cache keys, which may cause incorrect cache hits. "
|
||||
"Consider using 'symbol' parameter to distinguish variables: "
|
||||
"cute.sym_int32(symbol='M'), cute.sym_int32(symbol='N')",
|
||||
UserWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
return (
|
||||
self._dtype,
|
||||
self._memspace,
|
||||
self._assumed_align,
|
||||
self._shape,
|
||||
self._stride,
|
||||
)
|
||||
|
||||
def __repr__(self):
|
||||
return self.__str__()
|
||||
|
||||
def __setitem__(self, crd, value):
|
||||
raise DSLRuntimeError("runtime._FakeCompactTensor is not indexable")
|
||||
raise DSLRuntimeError("runtime._FakeTensor is not indexable")
|
||||
|
||||
def __getitem__(self, crd):
|
||||
raise DSLRuntimeError("runtime._FakeCompactTensor is not indexable")
|
||||
raise DSLRuntimeError("runtime._FakeTensor is not indexable")
|
||||
|
||||
@property
|
||||
def element_type(self) -> Type[Numeric]:
|
||||
@@ -491,118 +587,7 @@ class _FakeCompactTensor(Tensor):
|
||||
|
||||
@property
|
||||
def leading_dim(self):
|
||||
for dim, order in enumerate(self._stride_order):
|
||||
if order == 0:
|
||||
return dim
|
||||
|
||||
@property
|
||||
def dynamic_shapes_mask(self):
|
||||
return tuple(1 if isinstance(e, SymInt) else 0 for e in self._shape)
|
||||
|
||||
@property
|
||||
def dynamic_strides_mask(self):
|
||||
return tuple(1 if isinstance(e, SymInt) else 0 for e in self._stride)
|
||||
|
||||
def fill(self, value: Numeric):
|
||||
raise DSLRuntimeError("runtime._FakeCompactTensor is not writable")
|
||||
|
||||
|
||||
class _FakeTensor(Tensor):
|
||||
"""Fake Tensor implementation as a placeholder.
|
||||
It mimics the interface of Tensor, but does not hold real data or allow indexing.
|
||||
Used for compilation or testing situations where only shape/type/layout information is needed.
|
||||
All attempts to access or mutate data will raise errors.
|
||||
"""
|
||||
|
||||
"""
|
||||
Create a fake tensor with the given shape, type, and layout.
|
||||
|
||||
:param dtype: Data type of the tensor elements
|
||||
:type dtype: Type[Numeric]
|
||||
:param shape: Shape of the tensor, consists of int (static) or SymInt (dynamic)
|
||||
:type shape: tuple[int, ...]
|
||||
:param stride: Stride of the tensor, defaults to None, consists of int (static) or SymInt (dynamic)
|
||||
:type stride: tuple[int, ...], optional
|
||||
:param assumed_align: Assumed alignment of the tensor (bytes), defaults to None. If None, uses the element size bytes as the assumed alignment.
|
||||
:type assumed_align: int, optional
|
||||
:param use_32bit_stride: Whether to use 32-bit stride. Defaults to False. When True, the dynamic stride bitwidth
|
||||
will be set to 32 for small problem sizes (cosize(layout) <= Int32_max) for better performance. This is only applied
|
||||
when the dimension is dynamic.
|
||||
:type use_32bit_stride: bool, optional
|
||||
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, dtype, shape, *, stride, memspace=None, assumed_align=None):
|
||||
self._dtype = dtype
|
||||
self._shape = shape
|
||||
self._stride = stride
|
||||
# cannot use memspace or AddressSpace.generic because AddressSpace.generic is 0
|
||||
self._memspace = memspace if memspace is not None else AddressSpace.gmem
|
||||
self._assumed_align = assumed_align
|
||||
if assumed_align is None:
|
||||
# use the bytes width of the element dtype. The alignment is at least one byte align.
|
||||
self._assumed_align = (self._dtype.width + 7) // 8
|
||||
|
||||
if not isinstance(shape, (tuple, list)):
|
||||
raise ValueError(f"Expected tuple or list but got {type(shape)}")
|
||||
|
||||
if not all(isinstance(s, (int, SymInt)) for s in self._shape):
|
||||
raise ValueError("All shape elements must be int or SymInt")
|
||||
|
||||
if stride is not None and not all(
|
||||
isinstance(s, (int, SymInt)) for s in self._stride
|
||||
):
|
||||
raise ValueError("All stride elements must be int or SymInt")
|
||||
@property
|
||||
def mlir_type(self) -> ir.Type:
|
||||
shape_str = _get_cute_type_str(self._shape)
|
||||
stride_str = _get_cute_type_str(self._stride)
|
||||
layout_ty = ir.Type.parse(f'!cute.layout<"{shape_str}:{stride_str}">')
|
||||
ptr_ty = _cute_ir.PtrType.get(
|
||||
self._dtype.mlir_type, self._memspace, self._assumed_align
|
||||
)
|
||||
return _cute_ir.MemRefType.get(ptr_ty, layout_ty)
|
||||
|
||||
def __get_mlir_types__(self):
|
||||
return [self.mlir_type]
|
||||
|
||||
def __new_from_mlir_values__(self, values):
|
||||
assert len(values) == 1
|
||||
assert isinstance(values[0], CoreTensor)
|
||||
return CoreTensor(values[0].value, self._dtype)
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"FakeTensor<{self._dtype}, {self._shape}, {self._stride}>"
|
||||
|
||||
def __repr__(self):
|
||||
return self.__str__()
|
||||
|
||||
def __setitem__(self, crd, value):
|
||||
raise DSLRuntimeError("runtime._FakeTensor is not indexable")
|
||||
|
||||
def __getitem__(self, crd):
|
||||
raise DSLRuntimeError("runtime._FakeTensor is not indexable")
|
||||
|
||||
@property
|
||||
def element_type(self) -> Type[Numeric]:
|
||||
return self._dtype
|
||||
|
||||
@property
|
||||
def memspace(self):
|
||||
return self._memspace
|
||||
|
||||
@property
|
||||
def iterator(self):
|
||||
raise DSLRuntimeError("runtime._FakeTensor has dummy iterator")
|
||||
|
||||
@property
|
||||
def shape(self):
|
||||
return self._shape
|
||||
|
||||
@property
|
||||
def stride(self):
|
||||
return self._stride
|
||||
return core.leading_dim(self._shape, self._stride)
|
||||
|
||||
@property
|
||||
def dynamic_shapes_mask(self):
|
||||
@@ -617,35 +602,36 @@ class _FakeTensor(Tensor):
|
||||
|
||||
|
||||
def make_fake_compact_tensor(
|
||||
dtype,
|
||||
shape,
|
||||
dtype: Type[Numeric],
|
||||
shape: tuple[Union[int, SymInt], ...],
|
||||
*,
|
||||
stride_order=None,
|
||||
memspace=None,
|
||||
assumed_align=None,
|
||||
use_32bit_stride=False,
|
||||
stride_order: Optional[tuple[int, ...]] = None,
|
||||
memspace: AddressSpace = AddressSpace.gmem,
|
||||
assumed_align: Optional[int] = None,
|
||||
use_32bit_stride: bool = False,
|
||||
):
|
||||
"""
|
||||
Create a fake tensor with the specified shape, element type, and a compact memory layout.
|
||||
|
||||
:param dtype: Data type of the tensor elements.
|
||||
:type dtype: Type[Numeric]
|
||||
:param shape: Shape of the tensor.
|
||||
:type shape: tuple[int, ...]
|
||||
:param shape: Shape of the tensor, consisting of static (int) or dynamic (SymInt) dimensions.
|
||||
:type shape: tuple[Union[int, SymInt], ...]
|
||||
:param stride_order: Order in which strides (memory layout) are assigned to the tensor dimensions.
|
||||
If None, the default layout is left-to-right order (known as column-major order for flatten layout).
|
||||
Otherwise, it should be a permutation order of the dimension indices.
|
||||
The mode with stride_order 0 is the fastest changing (leading) dimension, and N-1 is the slowest changing.
|
||||
:type stride_order: tuple[int, ...], optional
|
||||
:param memspace: Memory space where the fake tensor resides. Optional.
|
||||
:type memspace: str, optional
|
||||
:param assumed_align: Assumed byte alignment for the tensor data. If None, the default alignment is used.
|
||||
:param memspace: Memory space where the fake tensor resides. Defaults to AddressSpace.gmem.
|
||||
:type memspace: AddressSpace, optional
|
||||
:param assumed_align: Assumed byte alignment for the tensor data. If None, the default alignment is the dtype width, & at least 1 byte.
|
||||
:type assumed_align: int, optional
|
||||
:param use_32bit_stride: Whether to use 32-bit stride for dynamic dimensions. If True and the total size of the
|
||||
layout (cosize(layout)) fits within int32, then dynamic strides will use 32-bit integers for improved performance.
|
||||
Only applies when dimensions are dynamic. Defaults to False.
|
||||
:type use_32bit_stride: bool, optional
|
||||
:return: An instance of a fake tensor with the given properties and compact layout.
|
||||
:rtype: _FakeCompactTensor
|
||||
:rtype: _FakeTensor
|
||||
|
||||
**Examples:**
|
||||
|
||||
@@ -663,31 +649,68 @@ def make_fake_compact_tensor(
|
||||
# tensor<ptr<f32, generic> o (100,?{div=8}):(?{i32 div=8},1)>
|
||||
compiled_foo = cute.compile(foo, x)
|
||||
|
||||
# Default stride order is left-to-right order: (1, 8)
|
||||
y = make_fake_compact_tensor(cutlass.Float32, (8, 3))
|
||||
# Default stride order is left-to-right order (0, 1, ..., n-1)
|
||||
y = make_fake_compact_tensor(cutlass.Float32, (8, 3, 2)) # y.stride == (1, 8, 24)
|
||||
"""
|
||||
|
||||
return _FakeCompactTensor(
|
||||
if stride_order is not None:
|
||||
if len(stride_order) != len(shape):
|
||||
raise ValueError(
|
||||
f"stride_order ({stride_order}) must be empty or have same length as shape ({shape})."
|
||||
)
|
||||
else:
|
||||
# Default stride order is left-to-right
|
||||
stride_order = stride_order or tuple(range(len(shape)))
|
||||
|
||||
# Make compact strides (possibly symbolic) from shape & stride_order
|
||||
stride = [None] * len(stride_order)
|
||||
stride_product = 1
|
||||
for order in range(len(stride_order)):
|
||||
idx = stride_order.index(order)
|
||||
stride[idx] = stride_product
|
||||
stride_product *= shape[idx]
|
||||
|
||||
stride_width = 32 if use_32bit_stride else 64
|
||||
stride = tuple(
|
||||
(
|
||||
SymInt(width=stride_width, divisibility=s.divisibility)
|
||||
if isinstance(s, SymInt)
|
||||
else s
|
||||
)
|
||||
for s in stride
|
||||
)
|
||||
|
||||
return _FakeTensor(
|
||||
dtype,
|
||||
shape,
|
||||
stride_order=stride_order,
|
||||
stride=stride,
|
||||
memspace=memspace,
|
||||
assumed_align=assumed_align,
|
||||
use_32bit_stride=use_32bit_stride,
|
||||
compact=True,
|
||||
)
|
||||
|
||||
|
||||
def make_fake_tensor(dtype, shape, stride, *, memspace=None, assumed_align=None):
|
||||
def make_fake_tensor(
|
||||
dtype: Type[Numeric],
|
||||
shape: tuple[Union[int, SymInt], ...],
|
||||
stride: tuple[Union[int, SymInt], ...],
|
||||
*,
|
||||
memspace: AddressSpace = AddressSpace.gmem,
|
||||
assumed_align: Optional[int] = None,
|
||||
):
|
||||
"""
|
||||
Create a fake tensor with the specified element type, shape, and stride.
|
||||
|
||||
:param dtype: Data type of the tensor elements.
|
||||
:type dtype: Type[Numeric]
|
||||
:param shape: Shape of the tensor.
|
||||
:type shape: tuple[int, ...]
|
||||
:param stride: Stride of the tensor.
|
||||
:type stride: tuple[int, ...]
|
||||
:param assumed_align: Assumed byte alignment for the tensor data. If None, the default alignment is used. Defaults to None.
|
||||
:param shape: Shape of the tensor, consisting of static (int) or dynamic (SymInt) dimensions.
|
||||
:type shape: tuple[Union[int, SymInt], ...]
|
||||
:param stride: Stride of the tensor, consisting of static (int) or dynamic (SymInt) values.
|
||||
:type stride: tuple[Union[int, SymInt], ...]
|
||||
:param memspace: Memory space where the fake tensor resides. Defaults to AddressSpace.gmem.
|
||||
:type memspace: AddressSpace, optional
|
||||
:param assumed_align: Assumed byte alignment for the tensor data. If None, the default alignment is the dtype width, & at least 1 byte.
|
||||
:type assumed_align: int, optional
|
||||
:return: An instance of a fake tensor with the given properties.
|
||||
:rtype: _FakeTensor
|
||||
@@ -953,22 +976,7 @@ def load_module(file_path: str, *, enable_tvm_ffi: bool = False):
|
||||
if Path(path).exists():
|
||||
_LOAD_MODULE_LIBS_CACHE.append(ctypes.CDLL(path))
|
||||
|
||||
if enable_tvm_ffi:
|
||||
import tvm_ffi
|
||||
|
||||
try:
|
||||
# keep_module_alive=False means the module will be unloaded
|
||||
# after the returned module goes out of scope, this is useful
|
||||
# for frequent loading and unloading of modules. The only requirement
|
||||
# is that the module do not return object that have deleter in the module
|
||||
# and the returned object lives longer than the module.
|
||||
# DSL functions to not have such issue so it is desirable to set this to False.
|
||||
return tvm_ffi.load_module(file_path, keep_module_alive=False)
|
||||
except TypeError:
|
||||
# compatible with tvm-ffi < 0.1.6
|
||||
return tvm_ffi.load_module(file_path)
|
||||
else:
|
||||
return ExternalBinaryModule(file_path)
|
||||
return ExternalBinaryModule(file_path, enable_tvm_ffi=enable_tvm_ffi)
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Try to register_jit_arg_adapter for TensorAdapter
|
||||
|
||||
@@ -139,6 +139,19 @@ class _Tensor(Tensor):
|
||||
def __init__(
|
||||
self, value, dtype: Optional[Type[Numeric]] = None, *, loc=None, ip=None
|
||||
):
|
||||
"""Initialize a Tensor from an MLIR value.
|
||||
|
||||
:param value: The MLIR operation result value or another Tensor to initialize from
|
||||
:type value: Union[ir.Value, _Tensor]
|
||||
:param dtype: The user specified data type of the tensor elements, defaults to None
|
||||
:type dtype: Optional[Type[Numeric]]
|
||||
:param loc: The source location for the operation, defaults to None
|
||||
:type loc: Optional[Location]
|
||||
:param ip: The insertion point for the operation, defaults to None
|
||||
:type ip: Optional[InsertionPoint]
|
||||
:raises TypeError: If value is not ir.Value or _Tensor
|
||||
:raises TypeError: If iterator type is not supported
|
||||
"""
|
||||
self._dtype = dtype
|
||||
if isinstance(value, ir.Value):
|
||||
self.value = value
|
||||
@@ -952,6 +965,37 @@ def make_fragment_like(src, dtype=None, *, loc=None, ip=None):
|
||||
def recast_tensor(
|
||||
src: Tensor, dtype: Type[Numeric], swizzle_=None, *, loc=None, ip=None
|
||||
):
|
||||
"""Recast a tensor to a different data type by changing the element interpretation.
|
||||
|
||||
This function reinterprets the memory of a tensor with a different element type,
|
||||
adjusting both the iterator pointer type and the layout to maintain consistency.
|
||||
|
||||
:param src: The source tensor to recast
|
||||
:type src: Tensor
|
||||
:param dtype: The target data type for tensor elements
|
||||
:type dtype: Type[Numeric]
|
||||
:param swizzle_: Optional swizzle parameter (reserved for future use), defaults to None
|
||||
:type swizzle_: Optional, unused
|
||||
:param loc: Source location for MLIR operation tracking, defaults to None
|
||||
:type loc: Optional[Location]
|
||||
:param ip: Insertion point for MLIR operation, defaults to None
|
||||
:type ip: Optional[InsertionPoint]
|
||||
:return: A new tensor with the same memory but reinterpreted as dtype
|
||||
:rtype: Tensor
|
||||
:raises TypeError: If dtype is not a subclass of Numeric
|
||||
|
||||
**Examples:**
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
# Create a Float32 tensor
|
||||
tensor_f32 = make_rmem_tensor((4, 8), Float32)
|
||||
|
||||
# Recast to Int32 to manipulate bits
|
||||
tensor_i32 = recast_tensor(tensor_f32, Int32)
|
||||
|
||||
# Both tensors share the same memory, but interpret it differently
|
||||
"""
|
||||
if not isclass(dtype) or not issubclass(dtype, Numeric):
|
||||
raise TypeError(f"dtype must be a type of Numeric, but got {dtype}")
|
||||
|
||||
@@ -972,6 +1016,36 @@ def recast_tensor(
|
||||
|
||||
@dsl_user_op
|
||||
def domain_offset(coord: Coord, tensor: Tensor, *, loc=None, ip=None) -> Tensor:
|
||||
"""Offset the tensor domain by the given coordinate.
|
||||
|
||||
This function creates a new tensor by offsetting the iterator/pointer of the input tensor
|
||||
by the amount corresponding to the given coordinate in its layout.
|
||||
|
||||
:param coord: The coordinate offset to apply
|
||||
:type coord: Coord
|
||||
:param tensor: The source tensor to offset
|
||||
:type tensor: Tensor
|
||||
:param loc: Source location for MLIR operation tracking, defaults to None
|
||||
:type loc: Optional[Location]
|
||||
:param ip: Insertion point for MLIR operation, defaults to None
|
||||
:type ip: Optional[InsertionPoint]
|
||||
:return: A new tensor with the offset iterator
|
||||
:rtype: Tensor
|
||||
:raises ValueError: If the tensor type doesn't support domain offsetting
|
||||
|
||||
**Examples:**
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
# Create a tensor with a row-major layout
|
||||
ptr = make_ptr(Float32, base_ptr, AddressSpace.gmem)
|
||||
layout = make_layout((64, 128), stride=(128, 1))
|
||||
tensor = make_tensor(ptr, layout)
|
||||
|
||||
# Offset by coordinate (3, 5)
|
||||
offset_tensor = domain_offset((3, 5), tensor)
|
||||
# offset_tensor now points to element at (3, 5)
|
||||
"""
|
||||
offset = crd2idx(coord, tensor.layout, loc=loc, ip=ip)
|
||||
if isinstance(tensor.iterator, Pointer):
|
||||
return make_tensor(
|
||||
|
||||
@@ -9,6 +9,8 @@
|
||||
# and related documentation outside the scope permitted by the EULA
|
||||
# is strictly prohibited.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
import ctypes
|
||||
from typing import ForwardRef, Tuple, Union, Any, Type, List, Optional, Literal
|
||||
@@ -24,15 +26,18 @@ Int = Union[int, Integer]
|
||||
|
||||
|
||||
class SymInt:
|
||||
def __init__(self, width: Literal[32, 64] = 32, *, divisibility=1):
|
||||
def __init__(
|
||||
self, width: Literal[32, 64] = 32, *, divisibility=1, symbol: str | None = None
|
||||
):
|
||||
if width not in [32, 64]:
|
||||
raise ValueError(f"Unsupported width: {width}")
|
||||
|
||||
self._width = width
|
||||
self._divisibility = divisibility
|
||||
self._symbol = symbol
|
||||
|
||||
def __hash__(self):
|
||||
return hash((self._width, self._divisibility))
|
||||
return hash((self._width, self._divisibility, self._symbol))
|
||||
|
||||
@property
|
||||
def width(self):
|
||||
@@ -42,8 +47,16 @@ class SymInt:
|
||||
def divisibility(self):
|
||||
return self._divisibility
|
||||
|
||||
@property
|
||||
def symbol(self):
|
||||
return self._symbol
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"?{{i{self._width} div={self._divisibility}}}"
|
||||
prefix = "" if self._symbol is None else self._symbol + " "
|
||||
if self._width == 32:
|
||||
return f"{prefix}?{{div={self._divisibility}}}"
|
||||
else:
|
||||
return f"{prefix}?{{i{self._width} div={self._divisibility}}}"
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return self.__str__()
|
||||
@@ -51,18 +64,52 @@ class SymInt:
|
||||
def __eq__(self, other) -> bool:
|
||||
if not isinstance(other, SymInt):
|
||||
return False
|
||||
|
||||
return all(
|
||||
[self._width == other._width, self._divisibility == other._divisibility]
|
||||
[
|
||||
self._width == other._width,
|
||||
self._divisibility == other._divisibility,
|
||||
self._symbol == other._symbol,
|
||||
]
|
||||
)
|
||||
|
||||
def __mod__(self, other: int) -> Union["SymInt", int]:
|
||||
if self._divisibility % other != 0:
|
||||
def __mod__(self, other: int | SymInt) -> SymInt | int:
|
||||
if isinstance(other, int):
|
||||
other_div, result_width = other, self._width
|
||||
elif isinstance(other, SymInt):
|
||||
other_div, result_width = (
|
||||
other._divisibility,
|
||||
max(self._width, other._width),
|
||||
)
|
||||
else:
|
||||
return NotImplemented
|
||||
|
||||
if self._divisibility % other_div == 0:
|
||||
return 0
|
||||
else:
|
||||
from math import gcd
|
||||
|
||||
div = gcd(self._divisibility, other)
|
||||
return SymInt(self._width, divisibility=div)
|
||||
return SymInt(result_width, divisibility=gcd(self._divisibility, other_div))
|
||||
|
||||
def __rmod__(self, other: int) -> int:
|
||||
"""int % SymInt: check if the int conforms to this SymInt's divisibility"""
|
||||
if isinstance(other, int):
|
||||
return other % self._divisibility
|
||||
return NotImplemented
|
||||
|
||||
def __mul__(self, other: int | SymInt) -> SymInt:
|
||||
if isinstance(other, int):
|
||||
return SymInt(self._width, divisibility=self._divisibility * other)
|
||||
elif isinstance(other, SymInt):
|
||||
return SymInt(
|
||||
width=max(self._width, other._width),
|
||||
divisibility=self._divisibility * other._divisibility,
|
||||
)
|
||||
else:
|
||||
return 0
|
||||
return NotImplemented
|
||||
|
||||
def __rmul__(self, other: int | SymInt) -> SymInt:
|
||||
return self.__mul__(other)
|
||||
|
||||
def __c_pointers__(self):
|
||||
return [ctypes.c_void_p(0).value]
|
||||
@@ -73,7 +120,7 @@ class SymInt:
|
||||
)
|
||||
return [res_ty]
|
||||
|
||||
def __new_from_mlir_values__(self, values) -> "SymInt":
|
||||
def __new_from_mlir_values__(self, values) -> SymInt:
|
||||
from .core import IntValue
|
||||
|
||||
if self.width == 32:
|
||||
@@ -84,16 +131,18 @@ class SymInt:
|
||||
assert False, f"Unsupported width: {self.width}"
|
||||
return self
|
||||
|
||||
def sym_int(width: Literal[32, 64] = 32, *, divisibility=1) -> SymInt:
|
||||
return SymInt(width, divisibility=divisibility)
|
||||
def sym_int(
|
||||
width: Literal[32, 64] = 32, *, divisibility=1, symbol: str | None = None
|
||||
) -> SymInt:
|
||||
return SymInt(width, divisibility=divisibility, symbol=symbol)
|
||||
|
||||
|
||||
def sym_int32(divisibility=1) -> SymInt:
|
||||
return sym_int(32, divisibility=divisibility)
|
||||
def sym_int32(divisibility=1, symbol: str | None = None) -> SymInt:
|
||||
return sym_int(32, divisibility=divisibility, symbol=symbol)
|
||||
|
||||
|
||||
def sym_int64(divisibility=1) -> SymInt:
|
||||
return sym_int(64, divisibility=divisibility)
|
||||
def sym_int64(divisibility=1, symbol: str | None = None) -> SymInt:
|
||||
return sym_int(64, divisibility=divisibility, symbol=symbol)
|
||||
|
||||
|
||||
ScaledBasis = ForwardRef("ScaledBasis")
|
||||
|
||||
@@ -1199,7 +1199,6 @@ class KernelLauncher:
|
||||
return self.dsl._get_smem_usage()
|
||||
|
||||
def launch(self, *args, **kwargs):
|
||||
self.dsl.frame = inspect.currentframe().f_back
|
||||
self.dsl._preprocess_launch_config_args(args, kwargs)
|
||||
config = self.dsl.LaunchConfig(*args, **kwargs)
|
||||
kernel_attrs = _build_kernel_attrs(config)
|
||||
@@ -1216,7 +1215,6 @@ class KernelLauncher:
|
||||
|
||||
ret, name = kernel_generator(*self.func_args, **self.func_kwargs, config=config)
|
||||
self.dsl.kernel_info[name] = kernel_attrs
|
||||
self.dsl.frame = None
|
||||
return ret.launch_op_ret
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
|
||||
@@ -35,7 +35,6 @@ if is_available():
|
||||
)
|
||||
from .compile import (
|
||||
release_compile_cache,
|
||||
initialize_cutlass_dsl,
|
||||
)
|
||||
from .ffi import (
|
||||
get_export_disabled_safety_checks,
|
||||
@@ -48,10 +47,6 @@ if is_available():
|
||||
# This is a legacy name for TensorSpec. It will be removed eventually.
|
||||
TensorMode = TensorSpec
|
||||
|
||||
# This explicit init method ensures that we avoid initialization at
|
||||
# unexpected times in jax tracing.
|
||||
initialize_cutlass_dsl()
|
||||
|
||||
__all__ = [
|
||||
"cutlass_call",
|
||||
"jax_to_cutlass_dtype",
|
||||
|
||||
@@ -267,36 +267,4 @@ def release_compile_cache():
|
||||
_CUTLASS_COMPILE_CACHE.clear()
|
||||
dsl = CuTeDSL._get_dsl()
|
||||
dsl.jit_cache.clear()
|
||||
# TODO: This is needed to release frames being held in the DSL
|
||||
# We should avoid holding such references as they unexpectedly
|
||||
# extend object lifetime.
|
||||
dsl.frame = None
|
||||
gc.collect()
|
||||
|
||||
|
||||
class _DummyInitKernel:
|
||||
@cute.kernel
|
||||
def kernel(self):
|
||||
pass
|
||||
|
||||
@cute.jit
|
||||
def init(self):
|
||||
pass
|
||||
|
||||
|
||||
_CUTLASS_DSL_INITIALIZED = False
|
||||
|
||||
|
||||
def initialize_cutlass_dsl():
|
||||
"""Initializes cutlass DSL."""
|
||||
global _CUTLASS_DSL_INITIALIZED
|
||||
if _CUTLASS_DSL_INITIALIZED:
|
||||
return
|
||||
|
||||
# Call compiler to ensure we've pre-processed any kernels inside cutedsl.
|
||||
kernel = _DummyInitKernel()
|
||||
with _compile_lock:
|
||||
logger.debug("Initializing cutlass dsl...")
|
||||
_ = cutlass.cute.compile(kernel.init)
|
||||
|
||||
_CUTLASS_DSL_INITIALIZED = True
|
||||
|
||||
@@ -28,9 +28,13 @@ logger = logging.getLogger(__name__)
|
||||
_CUTE_DSL_RUNTIME_LIBRARY_NAME = "cute_dsl_runtime"
|
||||
|
||||
_CUTLASS_CALL_TARGETS = {
|
||||
"CuteDSLRT_NvJaxCutlassCall": {"execute": "CuteDSLRT_NvJaxCutlassCallExecute"},
|
||||
"CuteDSLRT_NvJaxCutlassCall": {
|
||||
"execute": "CuteDSLRT_NvJaxCutlassCallExecute",
|
||||
"prepare": "CuteDSLRT_NvJaxCutlassCallPrepare",
|
||||
},
|
||||
"CuteDSLRT_NvJaxCutlassCallNoCudaGraph": {
|
||||
"execute": "CuteDSLRT_NvJaxCutlassCallExecuteNoCudaGraph"
|
||||
"execute": "CuteDSLRT_NvJaxCutlassCallExecuteNoCudaGraph",
|
||||
"prepare": "CuteDSLRT_NvJaxCutlassCallPrepare",
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
# Use `pip install -r requirements-cu13.txt` with the present file to install a
|
||||
# wheel consistent with the present state of the github repository
|
||||
nvidia-cutlass-dsl[cu13]==4.4.1
|
||||
nvidia-cutlass-dsl[cu13]==4.4.2
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
# Use `pip install -r requirements.txt` with the present file to install a
|
||||
# wheel consistent with the present state of the github repository
|
||||
nvidia-cutlass-dsl==4.4.1
|
||||
nvidia-cutlass-dsl==4.4.2
|
||||
|
||||
@@ -133,7 +133,7 @@ def get_option_registry():
|
||||
this._option_registry = OptionRegistry(device_cc())
|
||||
return this._option_registry
|
||||
|
||||
this.__version__ = '4.4.1'
|
||||
this.__version__ = '4.4.2'
|
||||
|
||||
from cutlass_cppgen.backend import create_memory_pool
|
||||
from cutlass_cppgen.emit.pytorch import pytorch
|
||||
|
||||
@@ -51,7 +51,7 @@ setup_pycute.perform_setup()
|
||||
|
||||
setup(
|
||||
name='cutlass_cppgen',
|
||||
version='4.4.1',
|
||||
version='4.4.2',
|
||||
description='CUTLASS Pythonic Interface',
|
||||
package_dir={'': '.'},
|
||||
packages=[
|
||||
|
||||
@@ -36,7 +36,7 @@ from setuptools import setup
|
||||
def perform_setup():
|
||||
setup(
|
||||
name='cutlass_library',
|
||||
version='4.4.1',
|
||||
version='4.4.2',
|
||||
description='CUTLASS library generation scripts',
|
||||
packages=['cutlass_library']
|
||||
)
|
||||
|
||||
@@ -36,7 +36,7 @@ from setuptools import setup
|
||||
def perform_setup():
|
||||
setup(
|
||||
name='pycute',
|
||||
version='4.4.1',
|
||||
version='4.4.2',
|
||||
description='Python implementation of CuTe',
|
||||
packages=['pycute'],
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user