v4.4.2 update. (#3104)

This commit is contained in:
Junkai-Wu
2026-03-17 12:58:19 +08:00
committed by GitHub
parent 772fbb264e
commit 1b741cabaa
31 changed files with 996 additions and 355 deletions

View File

@@ -68,7 +68,8 @@ class ExternalBinaryModule:
load_provider: LoadProvider = None
def __init__(self, file_path: str):
def __init__(self, file_path: str, enable_tvm_ffi: bool = False):
self.enable_tvm_ffi = enable_tvm_ffi
assert self.load_provider is not None, (
"Load provider is not set for ExternalBinaryModule."
)
@@ -82,13 +83,28 @@ class ExternalBinaryModule:
object_file_content = f.read()
except Exception as e:
raise DSLRuntimeError(f"Failed to read object file {file_path}: {e}")
useJitLink = not enable_tvm_ffi
# Lifetime of the engine is same as the ExternalBinaryModule.
self.engine = self.load_provider.execution_engine_constructor(
object_file_content, shared_libs
object_file_content, shared_libs, useJitLink
)
def __getattr__(self, function_prefix: str) -> "JitCompiledFunction":
"""Get the jit_function from the `function_prefix`. The `function_prefix` is specified when users dump the object file. When there is no function_prefix found in the module, the function will raise an error."""
if self.enable_tvm_ffi:
try:
import tvm_ffi
function_ptr = self.engine.lookup("__tvm_ffi_" + function_prefix)
return tvm_ffi.Function.__from_extern_c__(
function_ptr, keep_alive_object=self.engine
)
except Exception as e:
raise DSLRuntimeError(
f"Failed to load TVM FFI function {function_prefix}: {e}"
)
try:
args_spec, function_name, kernel_info, version_str = (
decode_metadata_from_execution_engine(
@@ -124,3 +140,7 @@ class ExternalBinaryModule:
load_from_binary=True,
)
return jit_function
def __getitem__(self, function_prefix: str) -> "JitCompiledFunction":
"""Get the jit_function from the `function_prefix`. The `function_prefix` is specified when users dump the object file. When there is no function_prefix found in the module, the function will raise an error."""
return self.__getattr__(function_prefix)

View File

@@ -202,6 +202,8 @@ from .math import *
# Used as internal symbol
from .. import cutlass_dsl as _dsl
from .ffi import ffi
# Aliases
jit = _dsl.CuTeDSL.jit
kernel = _dsl.CuTeDSL.kernel
@@ -312,4 +314,5 @@ __all__ = [
"kernel",
"register_jit_arg_adapter",
"compile",
"ffi",
]

View File

@@ -96,7 +96,6 @@ __all__ = [
"fma_packed_f32x2",
"mul_packed_f32x2",
"add_packed_f32x2",
"sub_packed_f32x2",
"fmax",
"rcp_approx",
"exp2",

View File

@@ -23,7 +23,6 @@ from ..typing import Int32, Pointer, Int128
def issue_clc_query(
mbar_ptr: Pointer,
clc_response_ptr: Pointer,
multicast: bool = True,
loc=None,
ip=None,
) -> None:
@@ -40,20 +39,12 @@ def issue_clc_query(
"""
mbar_llvm_ptr = mbar_ptr.llvm_ptr
clc_response_llvm_ptr = clc_response_ptr.llvm_ptr
if multicast:
nvvm.clusterlaunchcontrol_try_cancel_multicast(
clc_response_llvm_ptr,
mbar_llvm_ptr,
loc=loc,
ip=ip,
)
else:
nvvm.clusterlaunchcontrol_try_cancel(
clc_response_llvm_ptr,
mbar_llvm_ptr,
loc=loc,
ip=ip,
)
nvvm.clusterlaunchcontrol_try_cancel_multicast(
clc_response_llvm_ptr,
mbar_llvm_ptr,
loc=loc,
ip=ip,
)
@dsl_user_op

View File

@@ -604,7 +604,6 @@ def fence_proxy(
],
*,
space: Optional[Literal["cta", "cluster"]] = None,
use_intrinsic=None,
loc=None,
ip=None,
) -> None:
@@ -623,7 +622,6 @@ def fence_proxy(
- "cta" : CTA (Cooperative Thread Array) scope
- "cluster" : Cluster scope
:type space: Optional[Literal["cta", "cluster"]]
:param use_intrinsic: Whether to use intrinsic version
"""
from cutlass._mlir.dialects.nvvm import (
SharedSpace,
@@ -640,7 +638,6 @@ def fence_proxy(
nvvm.fence_proxy(
kind=kind,
space=space,
use_intrinsic=use_intrinsic,
loc=loc,
ip=ip,
)
@@ -940,9 +937,6 @@ mul_packed_f32x2 = partial(
add_packed_f32x2 = partial(
calc_packed_f32x2_op, src_c=None, calc_func=nvvm.add_packed_f32x2
)
sub_packed_f32x2 = partial(
calc_packed_f32x2_op, src_c=None, calc_func=nvvm.sub_packed_f32x2
)
@dsl_user_op
@@ -959,20 +953,6 @@ def fmax(
)
@dsl_user_op
def fmin(
a: Union[float, Float32], b: Union[float, Float32], *, loc=None, ip=None
) -> Float32:
return Float32(
nvvm.fmin(
Float32(a).ir_value(loc=loc, ip=ip),
Float32(b).ir_value(loc=loc, ip=ip),
loc=loc,
ip=ip,
)
)
@dsl_user_op
def rcp_approx(a: Union[float, Float32], *, loc=None, ip=None):
return Float32(

View File

@@ -1587,7 +1587,7 @@ def pretty_str(arg) -> str:
@dsl_user_op
def printf(*args, loc=None, ip=None, end="\n") -> None:
def printf(*args, loc=None, ip=None) -> None:
"""
Print one or more values with optional formatting.
@@ -1607,8 +1607,6 @@ def printf(*args, loc=None, ip=None, end="\n") -> None:
:type loc: Optional[Location]
:param ip: Insertion point for code generation, defaults to None
:type ip: Optional[InsertionPoint]
:param end: Suffix for the printed value, defaults to newline
:type end: Optional[str]
:raises ValueError: If no arguments are provided
:raises TypeError: If an unsupported argument type is passed
@@ -1638,10 +1636,10 @@ def printf(*args, loc=None, ip=None, end="\n") -> None:
raise ValueError("expects at least one argument to print")
if isinstance(args[0], str):
fmt = args[0] + end
fmt = args[0] + "\n"
args = args[1:]
else:
fmt = "{}" + ", {}" * (len(args) - 1) + end
fmt = "{}" + ", {}" * (len(args) - 1) + "\n"
def process_arg(arg):
arg0 = arg.value if isinstance(arg, Numeric) else arg
@@ -3762,6 +3760,35 @@ def zipped_divide(target: Tensor, tiler: Tiler, *, loc=None, ip=None) -> Tensor:
@dsl_user_op
def zipped_divide(target, tiler: Tiler, *, loc=None, ip=None):
"""
``zipped_divide`` is ``logical_divide`` with Tiler modes and Rest modes gathered together: ``(Tiler,Rest)``
- When Tiler is Layout, this has no effect as ``logical_divide`` results in the same.
- When Tiler is ``Tile`` (nested tuple of ``Layout``) or ``Shape``, this zips modes into standard form
``((BLK_A,BLK_B),(a,b,x,y))``
For example, if ``target`` has shape ``(s, t, r)`` and ``tiler`` has shape ``(BLK_A, BLK_B)``,
then the result will have shape ``((BLK_A, BLK_B), (ceil_div(s, BLK_A), ceil_div(t, BLK_B), r))``.
:param target: The layout or tensor to partition.
:type target: Layout or Tensor
:param tiler: The tiling specification (can be a Layout, Shape, Tile).
:type tiler: Tiler
:param loc: Optional MLIR IR location information.
:type loc: optional
:param ip: Optional MLIR IR insertion point.
:type ip: optional
:return: A zipped (partitioned) version of the target.
:rtype: Layout or Tensor
**Example:**
.. code-block:: python
layout = cute.make_layout((128, 64), stride=(64, 1))
tiler = (8, 8)
result = cute.zipped_divide(layout, tiler) # result shape: ((8, 8), (16, 8))
"""
if isinstance(tiler, tuple):
tiler = _pack_tile(tiler, loc=loc, ip=ip) # type: ignore
return _op_wrapper(
@@ -3904,6 +3931,73 @@ def local_tile(
loc=None,
ip=None,
) -> Tensor:
"""
Partition a tensor into tiles using a tiler and extract a single tile at the provided coordinate.
The ``local_tile`` operation applies a ``zipped_divide`` to split the ``input`` tensor by the ``tiler``
and then slices out a single tile using the provided `coord`. This is commonly used for extracting block-,
thread-, or CTA-level tiles for parallel operations.
.. math::
\\text{local_tile}(input, tiler, coord) = \\text{zipped_divide}(input, tiler)[coord]
This function corresponds to the CUTE/C++ `local_tile` utility:
https://docs.nvidia.com/cutlass/media/docs/cpp/cute/03_tensor.html#local-tile
:param input: The input tensor to partition into tiles.
:type input: Tensor
:param tiler: The tiling specification (can be a Layout, Shape, Tile).
:type tiler: Tiler
:param coord: The coordinate to select within the remainder ("rest") modes after tiling.
This selects which tile to extract.
:type coord: Coord
:param proj: (Optional) Projection onto tiling modes; specify to project out unused tiler modes,
e.g., when working with projections of tilers in multi-mode partitioning.
Default is None for no projection.
:type proj: XTuple, optional
:param loc: (Optional) MLIR location, for diagnostic/debugging.
:type loc: Any, optional
:param ip: (Optional) MLIR insertion point, used in IR building context.
:type ip: Any, optional
:return: A new tensor representing the local tile selected at the given coordinate.
:rtype: Tensor
**Examples**
1. Tiling a 2D tensor and extracting a tile:
.. code-block:: python
# input: (16, 24)
tensor : cute.Tensor
tiler = (2, 4)
coord = (1, 1)
# output: (8, 6)
# - zipped_divide(tensor, tiler) -> ((2, 4), (8, 6))
# - local_tile(tensor, tiler, coord) -> (8, 6)
result = cute.local_tile(tensor, tiler=tiler, coord=coord)
2. Using a stride projection for specialized tiling:
.. code-block:: python
# input: (16, 24)
tensor : cute.Tensor
tiler = (2, 2, 4)
coord = (0, 1, 1)
proj = (1, None, 1)
# output: (8, 6)
# projected_tiler: (2, 4)
# projected_coord: (0, 1)
# - zipped_divide(tensor, projected_tiler) -> ((2, 4), (8, 6))
# - local_tile(tensor, projected_tiler, projected_coord) -> (8, 6)
result = cute.local_tile(tensor, tiler=tiler, coord=coord, proj=proj)
"""
tiler_val = _pack_tile(tiler, loc=loc, ip=ip)
coord_val = _pack_coord(coord, loc=loc, ip=ip)
if proj is not None:

View File

@@ -0,0 +1,206 @@
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
#
# Use of this software is governed by the terms and conditions of the
# NVIDIA End User License Agreement (EULA), available at:
# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html
#
# Any use, reproduction, disclosure, or distribution of this software
# and related documentation outside the scope permitted by the EULA
# is strictly prohibited.
from cutlass._mlir import ir
from cutlass._mlir.dialects import func
from cutlass.base_dsl.typing import get_mlir_types, NumericMeta, Numeric, as_numeric
from cutlass.base_dsl.dsl import extract_mlir_values
from cutlass import DSLRuntimeError
class ffi:
"""
Foreign Function Interface (FFI) wrapper for external function invocation in the CUTLASS Python DSL.
This class enables calling external MLIR function prototypes from Python code, handling type conversion,
prototype registration, and dynamic insertion of function symbols into MLIR modules as needed.
Parameters
----------
name : str
Name of the external function. This will be used as the symbol name when calling or registering a prototype in the MLIR module.
params_types : list, optional
List of argument types for the external function. These can be CUTLASS numeric types, numeric meta types, or types convertible via `get_mlir_types`.
return_type : optional
The return type of the external function. If not specified, the function is assumed to have no return value.
Methods
-------
__call__(*args)
Calls the external function with the given arguments, ensuring argument and result types match the prototype.
"""
def __init__(self, *, name: str, params_types: list = [], return_type=None):
self.name = name
self.params_types = params_types
self.return_type = [return_type] if return_type else []
def _get_prototype_region(self, current_op):
"""
Helper method to determine the appropriate MLIR module and region for inserting a function prototype.
This method recursively traverses the current operation's parent hierarchy to find the correct module
and region where the function prototype should be inserted. It supports both builtin.module and gpu.module.
:param current_op: The current operation to check.
:type current_op: Operation
:returns:
A tuple containing the module operation and the insertion region.
:rtype: tuple
"""
if current_op is None:
raise DSLRuntimeError("current operation is unknown")
op_name = current_op.name
if op_name in ["builtin.module", "gpu.module"]:
return current_op, current_op.regions[0].blocks[0]
else:
return self._get_prototype_region(current_op.parent)
@staticmethod
def _to_mlir_types(args):
"""
Helper method to convert a list of arguments to their corresponding MLIR types.
This method converts CUTLASS numeric types, numeric meta types, and types convertible via `get_mlir_types`
to their corresponding MLIR types.
:param args: The list of arguments to convert to MLIR types.
:type args: list
:returns:
A list of MLIR types.
:rtype: list
"""
types = []
for param in args:
if isinstance(param, NumericMeta):
types.append(param.mlir_type)
elif isinstance(param, Numeric):
types.append(param.mlir_type)
else:
types.extend(get_mlir_types(param))
return types
@staticmethod
def _type_check(callee, exec_types, returns_types):
"""
Helper method to check if the function prototype types match the expected types.
This method compares the input and output types of the function prototype with the provided expected types.
:param callee: The function prototype operation to check.
:type callee: func.FuncOp
:param exec_types: The expected input types.
:type exec_types: list
:param returns_types: The expected output types.
:type returns_types: list
"""
if callee.type.inputs != exec_types or callee.type.results != returns_types:
raise DSLRuntimeError(
f"External prototype types mismatch, trying to call with ({exec_types}) -> ({returns_types}), got {callee.type}"
)
def _create_prototype_in_region(self, op, region, exec_args):
"""
Helper method to create or retrieve a function prototype in the current module.
This method checks if a function prototype with the given name already exists in the symbol table of the current module.
If it does, it checks if the prototype's types match the expected types. If it does not, it raises an error.
If it does not exist, it creates a new function prototype and inserts it into the current region.
:param op: The module operation to check.
:type op: Operation
:param region: The region to insert the function prototype into.
:type region: Region
:param exec_args: The arguments to pass to the function prototype.
:type exec_args: list
"""
symbol_table = ir.SymbolTable(op.operation)
if self.name in symbol_table:
callee = symbol_table[self.name]
else:
with ir.InsertionPoint(region):
callee = func.FuncOp(
self.name,
(
ffi._to_mlir_types(self.params_types),
ffi._to_mlir_types(self.return_type),
),
)
callee.sym_visibility = ir.StringAttr.get("private")
# Sanity check the function prototype types match the expected types
self._type_check(
callee,
ffi._to_mlir_types(exec_args),
ffi._to_mlir_types(self.return_type),
)
return callee
def __call__(self, *args, **kwargs):
"""
Calls the FFI function prototype with the provided arguments.
This method ensures that an IR-level function prototype (external declaration)
with the given name and type signature exists in the current module. If it does not
exist, it will be created and inserted into the module. A call operation to this
function is then emitted using the arguments supplied by the caller.
:param args:
The runtime arguments to pass to the FFI function. These will be converted to
their corresponding numeric types and lowered to MLIR values before being used as arguments.
:type args: tuple
:returns:
The MLIR call operation created for this invocation.
:rtype: func.CallOp
:raises DSLRuntimeError:
If there is no active MLIR insertion point or if the current operation
context cannot be determined.
"""
if kwargs:
raise DSLRuntimeError(
"Keyword arguments are not supported for FFI calls",
suggestion="Use positional arguments only",
)
# Get the current insertion point and operation
try:
current_ip = ir.InsertionPoint.current
except Exception:
raise DSLRuntimeError(
"Failed to determine current insertion point",
suggestion="Make sure this is called under a jit context",
)
current_op = current_ip.block.owner
module_op, insertion_region = self._get_prototype_region(current_op)
# Extract the arguments to MLIR values
exec_args = []
for arg in args:
exec_arg = extract_mlir_values(arg)
if not exec_arg:
exec_arg = [as_numeric(arg).ir_value()]
exec_args.extend(exec_arg)
# Create the function prototype in module, so if it's under kernel function, prototype will be inserted into gpu.module
# If it's under gpu.module, prototype will be inserted into builtin.module
callee = self._create_prototype_in_region(
module_op, insertion_region, exec_args
)
# Emit the call operation
result = func.call(callee.type.results, self.name, exec_args)
if self.return_type:
return result

View File

@@ -333,7 +333,7 @@ class MmaF16BF16Trait(MmaTraits):
@dataclass(frozen=True)
class MmaF8Op(MmaOp):
"""
FP8 warpgroup MMA Operation.
F8 warpgroup MMA Operation.
See the `PTX documentation <https://docs.nvidia.com/cuda/parallel-thread-execution/#asynchronous-warpgroup-level-matrix-instructions-wgmma-mma>`__.
This Operation covers the instructions using the ``.e4m3`` or ``.e5m2`` qualifiers for the input operands.

View File

@@ -17,6 +17,7 @@ import itertools
import operator
from typing import Union, Optional, Type, List
# MLIR modules imports
from cutlass._mlir import ir
from cutlass.base_dsl.env_manager import get_prefix_dsl_libs
@@ -128,6 +129,10 @@ class _Pointer(Pointer):
def __repr__(self):
return self.__str__()
@property
def __cache_key__(self) -> tuple:
return (self.dtype, self._addr_space, self._assumed_align)
class _Tensor(Tensor):
def __init__(
@@ -144,7 +149,7 @@ class _Tensor(Tensor):
elif enable_tvm_ffi:
import tvm_ffi
self._tvm_ffi_tensor = tvm_ffi.from_dlpack(tensor, stream=-1)
self._tvm_ffi_tensor = tvm_ffi.from_dlpack(tensor)
self._dlpack_data = self._tvm_ffi_tensor.__dlpack__()
else:
try:
@@ -185,9 +190,17 @@ class _Tensor(Tensor):
:param leading_dim: The leading dimension of the layout, defaults to None
:type leading_dim: int, optional
When ``leading_dim`` is None, automatically deduces the leading dimension from the tensor layout.
The layout can be deduced only when exactly one dimension has a stride of 1. Raises an error
if the layout cannot be automatically deduced.
When ``leading_dim`` is None, the leading dimension is deduced as follows.
(1) If exactly one dimension has stride 1, that dimension is used.
(2) If multiple dimensions have stride 1 but exactly one of them has size > 1,
that dimension is used.
(3) If multiple dimensions have stride 1 but none or more than one has size > 1,
an error is raised.
(4) If no dimension has stride 1, all strides remain dynamic.
When ``leading_dim`` is explicitly specified, marks the layout as dynamic while setting the
stride at ``leading_dim`` to 1. Also validates that the specified ``leading_dim`` is consistent
@@ -304,6 +317,13 @@ class _Tensor(Tensor):
def __repr__(self):
return self.__str__()
@property
def __cache_key__(self) -> tuple:
self.load_dltensor()
if self._dtype is None:
self._dtype = self._dltensor_wrapper.dtype
return (self._dtype, self._assumed_align, self._dltensor_wrapper.cache_key())
def __setitem__(self, crd, value):
raise TypeError("runtime._Tensor is not indexable")
@@ -417,42 +437,76 @@ def _get_cute_type_str(inp):
return "(" + ",".join(elems) + ")"
class _FakeCompactTensor(Tensor):
class _FakeTensor(Tensor):
"""Fake Tensor implementation as a placeholder.
It mimics the interface of Tensor, but does not hold real data or allow indexing.
Used for compilation or testing situations where only shape/type/layout information is needed.
All attempts to access or mutate data will raise errors.
"""
"""
Create a fake tensor with the given shape, type, and layout.
:param dtype: Data type of the tensor elements
:type dtype: Type[Numeric]
:param shape: Shape of the tensor, consists of int (static) or SymInt (dynamic)
:type shape: tuple[Union[int, SymInt], ...]
:param stride: Stride of the tensor, defaults to None, consists of int (static) or SymInt (dynamic)
:type stride: tuple[Union[int, SymInt], ...], optional
:param memspace: Memory space where the fake tensor resides. Defaults to AddressSpace.gmem.
:type memspace: AddressSpace, optional
:param assumed_align: Assumed alignment of the tensor (bytes), defaults to None. If None, uses the element size bytes as the assumed alignment.
:type assumed_align: int, optional
:param use_32bit_stride: Whether to use 32-bit stride. Defaults to False. When True, the dynamic stride bitwidth
will be set to 32 for small problem sizes (cosize(layout) <= Int32_max) for better performance. This is only applied
when the dimension is dynamic.
:type use_32bit_stride: bool, optional
"""
def __init__(
self,
dtype,
shape,
stride_order,
memspace=None,
assumed_align=None,
use_32bit_stride=False,
dtype: Type[Numeric],
shape: tuple[Union[int, SymInt], ...],
*,
stride: tuple[Union[int, SymInt], ...],
memspace: AddressSpace = AddressSpace.gmem,
assumed_align: int | None = None,
use_32bit_stride: bool = False,
compact: bool = False,
):
self._dtype = dtype
self._shape = shape
self._stride_order = stride_order or tuple(range(len(shape)))
# cannot use memspace or AddressSpace.gmem because AddressSpace.generic is 0
self._memspace = memspace if memspace is not None else AddressSpace.gmem
self._assumed_align = assumed_align or -(-dtype.width // 8)
self._stride = stride
self._use_32bit_stride = use_32bit_stride
self._compact = compact
def __str__(self) -> str:
return f"FakeTensorOrdered<{self._dtype}, {self._shape}, {self._stride_order}>"
if not isinstance(shape, (tuple, list)):
raise ValueError(f"Expected tuple or list but got {type(shape)}")
def __repr__(self):
return self.__str__()
if not all(isinstance(s, (int, SymInt)) for s in self._shape):
raise ValueError("All shape elements must be int or SymInt")
if stride is not None and not all(
isinstance(s, (int, SymInt)) for s in self._stride
):
raise ValueError("All stride elements must be int or SymInt")
self._memspace = memspace
self._assumed_align = assumed_align
if assumed_align is None:
# use the bytes width of the element dtype. The alignment is at least one byte align.
self._assumed_align = (self._dtype.width + 7) // 8
@property
def mlir_type(self) -> ir.Type:
shape_ty = ir.Type.parse(
'!cute.shape<"' + _get_cute_type_str(self._shape) + '">'
)
layout_ty = _cute_ir.LayoutType.get_ordered(
shape_ty, self._stride_order, self._use_32bit_stride
)
self._stride = layout_ty.stride
ptr_ty = _cute_ir.PtrType.get(
self._dtype.mlir_type, self._memspace, self._assumed_align
)
shape_str = _get_cute_type_str(self._shape)
stride_str = _get_cute_type_str(self._stride)
layout_ty = ir.Type.parse(f'!cute.layout<"{shape_str}:{stride_str}">')
# Boolean types are stored as i8 in memory
elem_type = T.i8() if self._dtype.width == 1 else self._dtype.mlir_type
ptr_ty = _cute_ir.PtrType.get(elem_type, self._memspace, self._assumed_align)
return _cute_ir.MemRefType.get(ptr_ty, layout_ty)
def __get_mlir_types__(self):
@@ -463,11 +517,53 @@ class _FakeCompactTensor(Tensor):
assert isinstance(values[0], CoreTensor)
return CoreTensor(values[0].value, self._dtype)
def __str__(self) -> str:
return f"FakeTensor<{self._dtype}, {self._shape}, {self._stride}>"
@property
def __cache_key__(self) -> tuple:
# Check if any shape or stride element is a SymInt without a symbol
import warnings
has_unnamed_symint = False
for dim in self._shape:
if isinstance(dim, SymInt) and dim.symbol is None:
has_unnamed_symint = True
break
if not self._compact:
if not has_unnamed_symint:
for stride in self._stride:
if isinstance(stride, SymInt) and stride.symbol is None:
has_unnamed_symint = True
break
if has_unnamed_symint:
warnings.warn(
"FakeTensor cache_key contains unnamed symbolic dimensions. "
"Different variables with the same shape/stride pattern will have "
"identical cache keys, which may cause incorrect cache hits. "
"Consider using 'symbol' parameter to distinguish variables: "
"cute.sym_int32(symbol='M'), cute.sym_int32(symbol='N')",
UserWarning,
stacklevel=2,
)
return (
self._dtype,
self._memspace,
self._assumed_align,
self._shape,
self._stride,
)
def __repr__(self):
return self.__str__()
def __setitem__(self, crd, value):
raise DSLRuntimeError("runtime._FakeCompactTensor is not indexable")
raise DSLRuntimeError("runtime._FakeTensor is not indexable")
def __getitem__(self, crd):
raise DSLRuntimeError("runtime._FakeCompactTensor is not indexable")
raise DSLRuntimeError("runtime._FakeTensor is not indexable")
@property
def element_type(self) -> Type[Numeric]:
@@ -491,118 +587,7 @@ class _FakeCompactTensor(Tensor):
@property
def leading_dim(self):
for dim, order in enumerate(self._stride_order):
if order == 0:
return dim
@property
def dynamic_shapes_mask(self):
return tuple(1 if isinstance(e, SymInt) else 0 for e in self._shape)
@property
def dynamic_strides_mask(self):
return tuple(1 if isinstance(e, SymInt) else 0 for e in self._stride)
def fill(self, value: Numeric):
raise DSLRuntimeError("runtime._FakeCompactTensor is not writable")
class _FakeTensor(Tensor):
"""Fake Tensor implementation as a placeholder.
It mimics the interface of Tensor, but does not hold real data or allow indexing.
Used for compilation or testing situations where only shape/type/layout information is needed.
All attempts to access or mutate data will raise errors.
"""
"""
Create a fake tensor with the given shape, type, and layout.
:param dtype: Data type of the tensor elements
:type dtype: Type[Numeric]
:param shape: Shape of the tensor, consists of int (static) or SymInt (dynamic)
:type shape: tuple[int, ...]
:param stride: Stride of the tensor, defaults to None, consists of int (static) or SymInt (dynamic)
:type stride: tuple[int, ...], optional
:param assumed_align: Assumed alignment of the tensor (bytes), defaults to None. If None, uses the element size bytes as the assumed alignment.
:type assumed_align: int, optional
:param use_32bit_stride: Whether to use 32-bit stride. Defaults to False. When True, the dynamic stride bitwidth
will be set to 32 for small problem sizes (cosize(layout) <= Int32_max) for better performance. This is only applied
when the dimension is dynamic.
:type use_32bit_stride: bool, optional
"""
def __init__(self, dtype, shape, *, stride, memspace=None, assumed_align=None):
self._dtype = dtype
self._shape = shape
self._stride = stride
# cannot use memspace or AddressSpace.generic because AddressSpace.generic is 0
self._memspace = memspace if memspace is not None else AddressSpace.gmem
self._assumed_align = assumed_align
if assumed_align is None:
# use the bytes width of the element dtype. The alignment is at least one byte align.
self._assumed_align = (self._dtype.width + 7) // 8
if not isinstance(shape, (tuple, list)):
raise ValueError(f"Expected tuple or list but got {type(shape)}")
if not all(isinstance(s, (int, SymInt)) for s in self._shape):
raise ValueError("All shape elements must be int or SymInt")
if stride is not None and not all(
isinstance(s, (int, SymInt)) for s in self._stride
):
raise ValueError("All stride elements must be int or SymInt")
@property
def mlir_type(self) -> ir.Type:
shape_str = _get_cute_type_str(self._shape)
stride_str = _get_cute_type_str(self._stride)
layout_ty = ir.Type.parse(f'!cute.layout<"{shape_str}:{stride_str}">')
ptr_ty = _cute_ir.PtrType.get(
self._dtype.mlir_type, self._memspace, self._assumed_align
)
return _cute_ir.MemRefType.get(ptr_ty, layout_ty)
def __get_mlir_types__(self):
return [self.mlir_type]
def __new_from_mlir_values__(self, values):
assert len(values) == 1
assert isinstance(values[0], CoreTensor)
return CoreTensor(values[0].value, self._dtype)
def __str__(self) -> str:
return f"FakeTensor<{self._dtype}, {self._shape}, {self._stride}>"
def __repr__(self):
return self.__str__()
def __setitem__(self, crd, value):
raise DSLRuntimeError("runtime._FakeTensor is not indexable")
def __getitem__(self, crd):
raise DSLRuntimeError("runtime._FakeTensor is not indexable")
@property
def element_type(self) -> Type[Numeric]:
return self._dtype
@property
def memspace(self):
return self._memspace
@property
def iterator(self):
raise DSLRuntimeError("runtime._FakeTensor has dummy iterator")
@property
def shape(self):
return self._shape
@property
def stride(self):
return self._stride
return core.leading_dim(self._shape, self._stride)
@property
def dynamic_shapes_mask(self):
@@ -617,35 +602,36 @@ class _FakeTensor(Tensor):
def make_fake_compact_tensor(
dtype,
shape,
dtype: Type[Numeric],
shape: tuple[Union[int, SymInt], ...],
*,
stride_order=None,
memspace=None,
assumed_align=None,
use_32bit_stride=False,
stride_order: Optional[tuple[int, ...]] = None,
memspace: AddressSpace = AddressSpace.gmem,
assumed_align: Optional[int] = None,
use_32bit_stride: bool = False,
):
"""
Create a fake tensor with the specified shape, element type, and a compact memory layout.
:param dtype: Data type of the tensor elements.
:type dtype: Type[Numeric]
:param shape: Shape of the tensor.
:type shape: tuple[int, ...]
:param shape: Shape of the tensor, consisting of static (int) or dynamic (SymInt) dimensions.
:type shape: tuple[Union[int, SymInt], ...]
:param stride_order: Order in which strides (memory layout) are assigned to the tensor dimensions.
If None, the default layout is left-to-right order (known as column-major order for flatten layout).
Otherwise, it should be a permutation order of the dimension indices.
The mode with stride_order 0 is the fastest changing (leading) dimension, and N-1 is the slowest changing.
:type stride_order: tuple[int, ...], optional
:param memspace: Memory space where the fake tensor resides. Optional.
:type memspace: str, optional
:param assumed_align: Assumed byte alignment for the tensor data. If None, the default alignment is used.
:param memspace: Memory space where the fake tensor resides. Defaults to AddressSpace.gmem.
:type memspace: AddressSpace, optional
:param assumed_align: Assumed byte alignment for the tensor data. If None, the default alignment is the dtype width, & at least 1 byte.
:type assumed_align: int, optional
:param use_32bit_stride: Whether to use 32-bit stride for dynamic dimensions. If True and the total size of the
layout (cosize(layout)) fits within int32, then dynamic strides will use 32-bit integers for improved performance.
Only applies when dimensions are dynamic. Defaults to False.
:type use_32bit_stride: bool, optional
:return: An instance of a fake tensor with the given properties and compact layout.
:rtype: _FakeCompactTensor
:rtype: _FakeTensor
**Examples:**
@@ -663,31 +649,68 @@ def make_fake_compact_tensor(
# tensor<ptr<f32, generic> o (100,?{div=8}):(?{i32 div=8},1)>
compiled_foo = cute.compile(foo, x)
# Default stride order is left-to-right order: (1, 8)
y = make_fake_compact_tensor(cutlass.Float32, (8, 3))
# Default stride order is left-to-right order (0, 1, ..., n-1)
y = make_fake_compact_tensor(cutlass.Float32, (8, 3, 2)) # y.stride == (1, 8, 24)
"""
return _FakeCompactTensor(
if stride_order is not None:
if len(stride_order) != len(shape):
raise ValueError(
f"stride_order ({stride_order}) must be empty or have same length as shape ({shape})."
)
else:
# Default stride order is left-to-right
stride_order = stride_order or tuple(range(len(shape)))
# Make compact strides (possibly symbolic) from shape & stride_order
stride = [None] * len(stride_order)
stride_product = 1
for order in range(len(stride_order)):
idx = stride_order.index(order)
stride[idx] = stride_product
stride_product *= shape[idx]
stride_width = 32 if use_32bit_stride else 64
stride = tuple(
(
SymInt(width=stride_width, divisibility=s.divisibility)
if isinstance(s, SymInt)
else s
)
for s in stride
)
return _FakeTensor(
dtype,
shape,
stride_order=stride_order,
stride=stride,
memspace=memspace,
assumed_align=assumed_align,
use_32bit_stride=use_32bit_stride,
compact=True,
)
def make_fake_tensor(dtype, shape, stride, *, memspace=None, assumed_align=None):
def make_fake_tensor(
dtype: Type[Numeric],
shape: tuple[Union[int, SymInt], ...],
stride: tuple[Union[int, SymInt], ...],
*,
memspace: AddressSpace = AddressSpace.gmem,
assumed_align: Optional[int] = None,
):
"""
Create a fake tensor with the specified element type, shape, and stride.
:param dtype: Data type of the tensor elements.
:type dtype: Type[Numeric]
:param shape: Shape of the tensor.
:type shape: tuple[int, ...]
:param stride: Stride of the tensor.
:type stride: tuple[int, ...]
:param assumed_align: Assumed byte alignment for the tensor data. If None, the default alignment is used. Defaults to None.
:param shape: Shape of the tensor, consisting of static (int) or dynamic (SymInt) dimensions.
:type shape: tuple[Union[int, SymInt], ...]
:param stride: Stride of the tensor, consisting of static (int) or dynamic (SymInt) values.
:type stride: tuple[Union[int, SymInt], ...]
:param memspace: Memory space where the fake tensor resides. Defaults to AddressSpace.gmem.
:type memspace: AddressSpace, optional
:param assumed_align: Assumed byte alignment for the tensor data. If None, the default alignment is the dtype width, & at least 1 byte.
:type assumed_align: int, optional
:return: An instance of a fake tensor with the given properties.
:rtype: _FakeTensor
@@ -953,22 +976,7 @@ def load_module(file_path: str, *, enable_tvm_ffi: bool = False):
if Path(path).exists():
_LOAD_MODULE_LIBS_CACHE.append(ctypes.CDLL(path))
if enable_tvm_ffi:
import tvm_ffi
try:
# keep_module_alive=False means the module will be unloaded
# after the returned module goes out of scope, this is useful
# for frequent loading and unloading of modules. The only requirement
# is that the module do not return object that have deleter in the module
# and the returned object lives longer than the module.
# DSL functions to not have such issue so it is desirable to set this to False.
return tvm_ffi.load_module(file_path, keep_module_alive=False)
except TypeError:
# compatible with tvm-ffi < 0.1.6
return tvm_ffi.load_module(file_path)
else:
return ExternalBinaryModule(file_path)
return ExternalBinaryModule(file_path, enable_tvm_ffi=enable_tvm_ffi)
# -------------------------------------------------------------------------
# Try to register_jit_arg_adapter for TensorAdapter

View File

@@ -139,6 +139,19 @@ class _Tensor(Tensor):
def __init__(
self, value, dtype: Optional[Type[Numeric]] = None, *, loc=None, ip=None
):
"""Initialize a Tensor from an MLIR value.
:param value: The MLIR operation result value or another Tensor to initialize from
:type value: Union[ir.Value, _Tensor]
:param dtype: The user specified data type of the tensor elements, defaults to None
:type dtype: Optional[Type[Numeric]]
:param loc: The source location for the operation, defaults to None
:type loc: Optional[Location]
:param ip: The insertion point for the operation, defaults to None
:type ip: Optional[InsertionPoint]
:raises TypeError: If value is not ir.Value or _Tensor
:raises TypeError: If iterator type is not supported
"""
self._dtype = dtype
if isinstance(value, ir.Value):
self.value = value
@@ -952,6 +965,37 @@ def make_fragment_like(src, dtype=None, *, loc=None, ip=None):
def recast_tensor(
src: Tensor, dtype: Type[Numeric], swizzle_=None, *, loc=None, ip=None
):
"""Recast a tensor to a different data type by changing the element interpretation.
This function reinterprets the memory of a tensor with a different element type,
adjusting both the iterator pointer type and the layout to maintain consistency.
:param src: The source tensor to recast
:type src: Tensor
:param dtype: The target data type for tensor elements
:type dtype: Type[Numeric]
:param swizzle_: Optional swizzle parameter (reserved for future use), defaults to None
:type swizzle_: Optional, unused
:param loc: Source location for MLIR operation tracking, defaults to None
:type loc: Optional[Location]
:param ip: Insertion point for MLIR operation, defaults to None
:type ip: Optional[InsertionPoint]
:return: A new tensor with the same memory but reinterpreted as dtype
:rtype: Tensor
:raises TypeError: If dtype is not a subclass of Numeric
**Examples:**
.. code-block:: python
# Create a Float32 tensor
tensor_f32 = make_rmem_tensor((4, 8), Float32)
# Recast to Int32 to manipulate bits
tensor_i32 = recast_tensor(tensor_f32, Int32)
# Both tensors share the same memory, but interpret it differently
"""
if not isclass(dtype) or not issubclass(dtype, Numeric):
raise TypeError(f"dtype must be a type of Numeric, but got {dtype}")
@@ -972,6 +1016,36 @@ def recast_tensor(
@dsl_user_op
def domain_offset(coord: Coord, tensor: Tensor, *, loc=None, ip=None) -> Tensor:
"""Offset the tensor domain by the given coordinate.
This function creates a new tensor by offsetting the iterator/pointer of the input tensor
by the amount corresponding to the given coordinate in its layout.
:param coord: The coordinate offset to apply
:type coord: Coord
:param tensor: The source tensor to offset
:type tensor: Tensor
:param loc: Source location for MLIR operation tracking, defaults to None
:type loc: Optional[Location]
:param ip: Insertion point for MLIR operation, defaults to None
:type ip: Optional[InsertionPoint]
:return: A new tensor with the offset iterator
:rtype: Tensor
:raises ValueError: If the tensor type doesn't support domain offsetting
**Examples:**
.. code-block:: python
# Create a tensor with a row-major layout
ptr = make_ptr(Float32, base_ptr, AddressSpace.gmem)
layout = make_layout((64, 128), stride=(128, 1))
tensor = make_tensor(ptr, layout)
# Offset by coordinate (3, 5)
offset_tensor = domain_offset((3, 5), tensor)
# offset_tensor now points to element at (3, 5)
"""
offset = crd2idx(coord, tensor.layout, loc=loc, ip=ip)
if isinstance(tensor.iterator, Pointer):
return make_tensor(

View File

@@ -9,6 +9,8 @@
# and related documentation outside the scope permitted by the EULA
# is strictly prohibited.
from __future__ import annotations
from abc import ABC, abstractmethod
import ctypes
from typing import ForwardRef, Tuple, Union, Any, Type, List, Optional, Literal
@@ -24,15 +26,18 @@ Int = Union[int, Integer]
class SymInt:
def __init__(self, width: Literal[32, 64] = 32, *, divisibility=1):
def __init__(
self, width: Literal[32, 64] = 32, *, divisibility=1, symbol: str | None = None
):
if width not in [32, 64]:
raise ValueError(f"Unsupported width: {width}")
self._width = width
self._divisibility = divisibility
self._symbol = symbol
def __hash__(self):
return hash((self._width, self._divisibility))
return hash((self._width, self._divisibility, self._symbol))
@property
def width(self):
@@ -42,8 +47,16 @@ class SymInt:
def divisibility(self):
return self._divisibility
@property
def symbol(self):
return self._symbol
def __str__(self) -> str:
return f"?{{i{self._width} div={self._divisibility}}}"
prefix = "" if self._symbol is None else self._symbol + " "
if self._width == 32:
return f"{prefix}?{{div={self._divisibility}}}"
else:
return f"{prefix}?{{i{self._width} div={self._divisibility}}}"
def __repr__(self) -> str:
return self.__str__()
@@ -51,18 +64,52 @@ class SymInt:
def __eq__(self, other) -> bool:
if not isinstance(other, SymInt):
return False
return all(
[self._width == other._width, self._divisibility == other._divisibility]
[
self._width == other._width,
self._divisibility == other._divisibility,
self._symbol == other._symbol,
]
)
def __mod__(self, other: int) -> Union["SymInt", int]:
if self._divisibility % other != 0:
def __mod__(self, other: int | SymInt) -> SymInt | int:
if isinstance(other, int):
other_div, result_width = other, self._width
elif isinstance(other, SymInt):
other_div, result_width = (
other._divisibility,
max(self._width, other._width),
)
else:
return NotImplemented
if self._divisibility % other_div == 0:
return 0
else:
from math import gcd
div = gcd(self._divisibility, other)
return SymInt(self._width, divisibility=div)
return SymInt(result_width, divisibility=gcd(self._divisibility, other_div))
def __rmod__(self, other: int) -> int:
"""int % SymInt: check if the int conforms to this SymInt's divisibility"""
if isinstance(other, int):
return other % self._divisibility
return NotImplemented
def __mul__(self, other: int | SymInt) -> SymInt:
if isinstance(other, int):
return SymInt(self._width, divisibility=self._divisibility * other)
elif isinstance(other, SymInt):
return SymInt(
width=max(self._width, other._width),
divisibility=self._divisibility * other._divisibility,
)
else:
return 0
return NotImplemented
def __rmul__(self, other: int | SymInt) -> SymInt:
return self.__mul__(other)
def __c_pointers__(self):
return [ctypes.c_void_p(0).value]
@@ -73,7 +120,7 @@ class SymInt:
)
return [res_ty]
def __new_from_mlir_values__(self, values) -> "SymInt":
def __new_from_mlir_values__(self, values) -> SymInt:
from .core import IntValue
if self.width == 32:
@@ -84,16 +131,18 @@ class SymInt:
assert False, f"Unsupported width: {self.width}"
return self
def sym_int(width: Literal[32, 64] = 32, *, divisibility=1) -> SymInt:
return SymInt(width, divisibility=divisibility)
def sym_int(
width: Literal[32, 64] = 32, *, divisibility=1, symbol: str | None = None
) -> SymInt:
return SymInt(width, divisibility=divisibility, symbol=symbol)
def sym_int32(divisibility=1) -> SymInt:
return sym_int(32, divisibility=divisibility)
def sym_int32(divisibility=1, symbol: str | None = None) -> SymInt:
return sym_int(32, divisibility=divisibility, symbol=symbol)
def sym_int64(divisibility=1) -> SymInt:
return sym_int(64, divisibility=divisibility)
def sym_int64(divisibility=1, symbol: str | None = None) -> SymInt:
return sym_int(64, divisibility=divisibility, symbol=symbol)
ScaledBasis = ForwardRef("ScaledBasis")

View File

@@ -1199,7 +1199,6 @@ class KernelLauncher:
return self.dsl._get_smem_usage()
def launch(self, *args, **kwargs):
self.dsl.frame = inspect.currentframe().f_back
self.dsl._preprocess_launch_config_args(args, kwargs)
config = self.dsl.LaunchConfig(*args, **kwargs)
kernel_attrs = _build_kernel_attrs(config)
@@ -1216,7 +1215,6 @@ class KernelLauncher:
ret, name = kernel_generator(*self.func_args, **self.func_kwargs, config=config)
self.dsl.kernel_info[name] = kernel_attrs
self.dsl.frame = None
return ret.launch_op_ret
def __call__(self, *args, **kwargs):

View File

@@ -35,7 +35,6 @@ if is_available():
)
from .compile import (
release_compile_cache,
initialize_cutlass_dsl,
)
from .ffi import (
get_export_disabled_safety_checks,
@@ -48,10 +47,6 @@ if is_available():
# This is a legacy name for TensorSpec. It will be removed eventually.
TensorMode = TensorSpec
# This explicit init method ensures that we avoid initialization at
# unexpected times in jax tracing.
initialize_cutlass_dsl()
__all__ = [
"cutlass_call",
"jax_to_cutlass_dtype",

View File

@@ -267,36 +267,4 @@ def release_compile_cache():
_CUTLASS_COMPILE_CACHE.clear()
dsl = CuTeDSL._get_dsl()
dsl.jit_cache.clear()
# TODO: This is needed to release frames being held in the DSL
# We should avoid holding such references as they unexpectedly
# extend object lifetime.
dsl.frame = None
gc.collect()
class _DummyInitKernel:
@cute.kernel
def kernel(self):
pass
@cute.jit
def init(self):
pass
_CUTLASS_DSL_INITIALIZED = False
def initialize_cutlass_dsl():
"""Initializes cutlass DSL."""
global _CUTLASS_DSL_INITIALIZED
if _CUTLASS_DSL_INITIALIZED:
return
# Call compiler to ensure we've pre-processed any kernels inside cutedsl.
kernel = _DummyInitKernel()
with _compile_lock:
logger.debug("Initializing cutlass dsl...")
_ = cutlass.cute.compile(kernel.init)
_CUTLASS_DSL_INITIALIZED = True

View File

@@ -28,9 +28,13 @@ logger = logging.getLogger(__name__)
_CUTE_DSL_RUNTIME_LIBRARY_NAME = "cute_dsl_runtime"
_CUTLASS_CALL_TARGETS = {
"CuteDSLRT_NvJaxCutlassCall": {"execute": "CuteDSLRT_NvJaxCutlassCallExecute"},
"CuteDSLRT_NvJaxCutlassCall": {
"execute": "CuteDSLRT_NvJaxCutlassCallExecute",
"prepare": "CuteDSLRT_NvJaxCutlassCallPrepare",
},
"CuteDSLRT_NvJaxCutlassCallNoCudaGraph": {
"execute": "CuteDSLRT_NvJaxCutlassCallExecuteNoCudaGraph"
"execute": "CuteDSLRT_NvJaxCutlassCallExecuteNoCudaGraph",
"prepare": "CuteDSLRT_NvJaxCutlassCallPrepare",
},
}

View File

@@ -1,3 +1,3 @@
# Use `pip install -r requirements-cu13.txt` with the present file to install a
# wheel consistent with the present state of the github repository
nvidia-cutlass-dsl[cu13]==4.4.1
nvidia-cutlass-dsl[cu13]==4.4.2

View File

@@ -1,3 +1,3 @@
# Use `pip install -r requirements.txt` with the present file to install a
# wheel consistent with the present state of the github repository
nvidia-cutlass-dsl==4.4.1
nvidia-cutlass-dsl==4.4.2

View File

@@ -133,7 +133,7 @@ def get_option_registry():
this._option_registry = OptionRegistry(device_cc())
return this._option_registry
this.__version__ = '4.4.1'
this.__version__ = '4.4.2'
from cutlass_cppgen.backend import create_memory_pool
from cutlass_cppgen.emit.pytorch import pytorch

View File

@@ -51,7 +51,7 @@ setup_pycute.perform_setup()
setup(
name='cutlass_cppgen',
version='4.4.1',
version='4.4.2',
description='CUTLASS Pythonic Interface',
package_dir={'': '.'},
packages=[

View File

@@ -36,7 +36,7 @@ from setuptools import setup
def perform_setup():
setup(
name='cutlass_library',
version='4.4.1',
version='4.4.2',
description='CUTLASS library generation scripts',
packages=['cutlass_library']
)

View File

@@ -36,7 +36,7 @@ from setuptools import setup
def perform_setup():
setup(
name='pycute',
version='4.4.1',
version='4.4.2',
description='Python implementation of CuTe',
packages=['pycute'],
)