mirror of
https://github.com/NVIDIA/cutlass.git
synced 2026-04-19 22:38:56 +00:00
v4.4.2 update. (#3104)
This commit is contained in:
@@ -68,7 +68,8 @@ class ExternalBinaryModule:
|
||||
|
||||
load_provider: LoadProvider = None
|
||||
|
||||
def __init__(self, file_path: str):
|
||||
def __init__(self, file_path: str, enable_tvm_ffi: bool = False):
|
||||
self.enable_tvm_ffi = enable_tvm_ffi
|
||||
assert self.load_provider is not None, (
|
||||
"Load provider is not set for ExternalBinaryModule."
|
||||
)
|
||||
@@ -82,13 +83,28 @@ class ExternalBinaryModule:
|
||||
object_file_content = f.read()
|
||||
except Exception as e:
|
||||
raise DSLRuntimeError(f"Failed to read object file {file_path}: {e}")
|
||||
|
||||
useJitLink = not enable_tvm_ffi
|
||||
# Lifetime of the engine is same as the ExternalBinaryModule.
|
||||
self.engine = self.load_provider.execution_engine_constructor(
|
||||
object_file_content, shared_libs
|
||||
object_file_content, shared_libs, useJitLink
|
||||
)
|
||||
|
||||
def __getattr__(self, function_prefix: str) -> "JitCompiledFunction":
|
||||
"""Get the jit_function from the `function_prefix`. The `function_prefix` is specified when users dump the object file. When there is no function_prefix found in the module, the function will raise an error."""
|
||||
if self.enable_tvm_ffi:
|
||||
try:
|
||||
import tvm_ffi
|
||||
|
||||
function_ptr = self.engine.lookup("__tvm_ffi_" + function_prefix)
|
||||
return tvm_ffi.Function.__from_extern_c__(
|
||||
function_ptr, keep_alive_object=self.engine
|
||||
)
|
||||
except Exception as e:
|
||||
raise DSLRuntimeError(
|
||||
f"Failed to load TVM FFI function {function_prefix}: {e}"
|
||||
)
|
||||
|
||||
try:
|
||||
args_spec, function_name, kernel_info, version_str = (
|
||||
decode_metadata_from_execution_engine(
|
||||
@@ -124,3 +140,7 @@ class ExternalBinaryModule:
|
||||
load_from_binary=True,
|
||||
)
|
||||
return jit_function
|
||||
|
||||
def __getitem__(self, function_prefix: str) -> "JitCompiledFunction":
|
||||
"""Get the jit_function from the `function_prefix`. The `function_prefix` is specified when users dump the object file. When there is no function_prefix found in the module, the function will raise an error."""
|
||||
return self.__getattr__(function_prefix)
|
||||
|
||||
@@ -202,6 +202,8 @@ from .math import *
|
||||
# Used as internal symbol
|
||||
from .. import cutlass_dsl as _dsl
|
||||
|
||||
from .ffi import ffi
|
||||
|
||||
# Aliases
|
||||
jit = _dsl.CuTeDSL.jit
|
||||
kernel = _dsl.CuTeDSL.kernel
|
||||
@@ -312,4 +314,5 @@ __all__ = [
|
||||
"kernel",
|
||||
"register_jit_arg_adapter",
|
||||
"compile",
|
||||
"ffi",
|
||||
]
|
||||
|
||||
@@ -96,7 +96,6 @@ __all__ = [
|
||||
"fma_packed_f32x2",
|
||||
"mul_packed_f32x2",
|
||||
"add_packed_f32x2",
|
||||
"sub_packed_f32x2",
|
||||
"fmax",
|
||||
"rcp_approx",
|
||||
"exp2",
|
||||
|
||||
@@ -23,7 +23,6 @@ from ..typing import Int32, Pointer, Int128
|
||||
def issue_clc_query(
|
||||
mbar_ptr: Pointer,
|
||||
clc_response_ptr: Pointer,
|
||||
multicast: bool = True,
|
||||
loc=None,
|
||||
ip=None,
|
||||
) -> None:
|
||||
@@ -40,20 +39,12 @@ def issue_clc_query(
|
||||
"""
|
||||
mbar_llvm_ptr = mbar_ptr.llvm_ptr
|
||||
clc_response_llvm_ptr = clc_response_ptr.llvm_ptr
|
||||
if multicast:
|
||||
nvvm.clusterlaunchcontrol_try_cancel_multicast(
|
||||
clc_response_llvm_ptr,
|
||||
mbar_llvm_ptr,
|
||||
loc=loc,
|
||||
ip=ip,
|
||||
)
|
||||
else:
|
||||
nvvm.clusterlaunchcontrol_try_cancel(
|
||||
clc_response_llvm_ptr,
|
||||
mbar_llvm_ptr,
|
||||
loc=loc,
|
||||
ip=ip,
|
||||
)
|
||||
nvvm.clusterlaunchcontrol_try_cancel_multicast(
|
||||
clc_response_llvm_ptr,
|
||||
mbar_llvm_ptr,
|
||||
loc=loc,
|
||||
ip=ip,
|
||||
)
|
||||
|
||||
|
||||
@dsl_user_op
|
||||
|
||||
@@ -604,7 +604,6 @@ def fence_proxy(
|
||||
],
|
||||
*,
|
||||
space: Optional[Literal["cta", "cluster"]] = None,
|
||||
use_intrinsic=None,
|
||||
loc=None,
|
||||
ip=None,
|
||||
) -> None:
|
||||
@@ -623,7 +622,6 @@ def fence_proxy(
|
||||
- "cta" : CTA (Cooperative Thread Array) scope
|
||||
- "cluster" : Cluster scope
|
||||
:type space: Optional[Literal["cta", "cluster"]]
|
||||
:param use_intrinsic: Whether to use intrinsic version
|
||||
"""
|
||||
from cutlass._mlir.dialects.nvvm import (
|
||||
SharedSpace,
|
||||
@@ -640,7 +638,6 @@ def fence_proxy(
|
||||
nvvm.fence_proxy(
|
||||
kind=kind,
|
||||
space=space,
|
||||
use_intrinsic=use_intrinsic,
|
||||
loc=loc,
|
||||
ip=ip,
|
||||
)
|
||||
@@ -940,9 +937,6 @@ mul_packed_f32x2 = partial(
|
||||
add_packed_f32x2 = partial(
|
||||
calc_packed_f32x2_op, src_c=None, calc_func=nvvm.add_packed_f32x2
|
||||
)
|
||||
sub_packed_f32x2 = partial(
|
||||
calc_packed_f32x2_op, src_c=None, calc_func=nvvm.sub_packed_f32x2
|
||||
)
|
||||
|
||||
|
||||
@dsl_user_op
|
||||
@@ -959,20 +953,6 @@ def fmax(
|
||||
)
|
||||
|
||||
|
||||
@dsl_user_op
|
||||
def fmin(
|
||||
a: Union[float, Float32], b: Union[float, Float32], *, loc=None, ip=None
|
||||
) -> Float32:
|
||||
return Float32(
|
||||
nvvm.fmin(
|
||||
Float32(a).ir_value(loc=loc, ip=ip),
|
||||
Float32(b).ir_value(loc=loc, ip=ip),
|
||||
loc=loc,
|
||||
ip=ip,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@dsl_user_op
|
||||
def rcp_approx(a: Union[float, Float32], *, loc=None, ip=None):
|
||||
return Float32(
|
||||
|
||||
@@ -1587,7 +1587,7 @@ def pretty_str(arg) -> str:
|
||||
|
||||
|
||||
@dsl_user_op
|
||||
def printf(*args, loc=None, ip=None, end="\n") -> None:
|
||||
def printf(*args, loc=None, ip=None) -> None:
|
||||
"""
|
||||
Print one or more values with optional formatting.
|
||||
|
||||
@@ -1607,8 +1607,6 @@ def printf(*args, loc=None, ip=None, end="\n") -> None:
|
||||
:type loc: Optional[Location]
|
||||
:param ip: Insertion point for code generation, defaults to None
|
||||
:type ip: Optional[InsertionPoint]
|
||||
:param end: Suffix for the printed value, defaults to newline
|
||||
:type end: Optional[str]
|
||||
:raises ValueError: If no arguments are provided
|
||||
:raises TypeError: If an unsupported argument type is passed
|
||||
|
||||
@@ -1638,10 +1636,10 @@ def printf(*args, loc=None, ip=None, end="\n") -> None:
|
||||
raise ValueError("expects at least one argument to print")
|
||||
|
||||
if isinstance(args[0], str):
|
||||
fmt = args[0] + end
|
||||
fmt = args[0] + "\n"
|
||||
args = args[1:]
|
||||
else:
|
||||
fmt = "{}" + ", {}" * (len(args) - 1) + end
|
||||
fmt = "{}" + ", {}" * (len(args) - 1) + "\n"
|
||||
|
||||
def process_arg(arg):
|
||||
arg0 = arg.value if isinstance(arg, Numeric) else arg
|
||||
@@ -3762,6 +3760,35 @@ def zipped_divide(target: Tensor, tiler: Tiler, *, loc=None, ip=None) -> Tensor:
|
||||
|
||||
@dsl_user_op
|
||||
def zipped_divide(target, tiler: Tiler, *, loc=None, ip=None):
|
||||
"""
|
||||
``zipped_divide`` is ``logical_divide`` with Tiler modes and Rest modes gathered together: ``(Tiler,Rest)``
|
||||
|
||||
- When Tiler is Layout, this has no effect as ``logical_divide`` results in the same.
|
||||
- When Tiler is ``Tile`` (nested tuple of ``Layout``) or ``Shape``, this zips modes into standard form
|
||||
``((BLK_A,BLK_B),(a,b,x,y))``
|
||||
|
||||
For example, if ``target`` has shape ``(s, t, r)`` and ``tiler`` has shape ``(BLK_A, BLK_B)``,
|
||||
then the result will have shape ``((BLK_A, BLK_B), (ceil_div(s, BLK_A), ceil_div(t, BLK_B), r))``.
|
||||
|
||||
:param target: The layout or tensor to partition.
|
||||
:type target: Layout or Tensor
|
||||
:param tiler: The tiling specification (can be a Layout, Shape, Tile).
|
||||
:type tiler: Tiler
|
||||
:param loc: Optional MLIR IR location information.
|
||||
:type loc: optional
|
||||
:param ip: Optional MLIR IR insertion point.
|
||||
:type ip: optional
|
||||
:return: A zipped (partitioned) version of the target.
|
||||
:rtype: Layout or Tensor
|
||||
|
||||
**Example:**
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
layout = cute.make_layout((128, 64), stride=(64, 1))
|
||||
tiler = (8, 8)
|
||||
result = cute.zipped_divide(layout, tiler) # result shape: ((8, 8), (16, 8))
|
||||
"""
|
||||
if isinstance(tiler, tuple):
|
||||
tiler = _pack_tile(tiler, loc=loc, ip=ip) # type: ignore
|
||||
return _op_wrapper(
|
||||
@@ -3904,6 +3931,73 @@ def local_tile(
|
||||
loc=None,
|
||||
ip=None,
|
||||
) -> Tensor:
|
||||
"""
|
||||
Partition a tensor into tiles using a tiler and extract a single tile at the provided coordinate.
|
||||
|
||||
The ``local_tile`` operation applies a ``zipped_divide`` to split the ``input`` tensor by the ``tiler``
|
||||
and then slices out a single tile using the provided `coord`. This is commonly used for extracting block-,
|
||||
thread-, or CTA-level tiles for parallel operations.
|
||||
|
||||
.. math::
|
||||
|
||||
\\text{local_tile}(input, tiler, coord) = \\text{zipped_divide}(input, tiler)[coord]
|
||||
|
||||
This function corresponds to the CUTE/C++ `local_tile` utility:
|
||||
https://docs.nvidia.com/cutlass/media/docs/cpp/cute/03_tensor.html#local-tile
|
||||
|
||||
:param input: The input tensor to partition into tiles.
|
||||
:type input: Tensor
|
||||
:param tiler: The tiling specification (can be a Layout, Shape, Tile).
|
||||
:type tiler: Tiler
|
||||
:param coord: The coordinate to select within the remainder ("rest") modes after tiling.
|
||||
This selects which tile to extract.
|
||||
:type coord: Coord
|
||||
:param proj: (Optional) Projection onto tiling modes; specify to project out unused tiler modes,
|
||||
e.g., when working with projections of tilers in multi-mode partitioning.
|
||||
Default is None for no projection.
|
||||
:type proj: XTuple, optional
|
||||
:param loc: (Optional) MLIR location, for diagnostic/debugging.
|
||||
:type loc: Any, optional
|
||||
:param ip: (Optional) MLIR insertion point, used in IR building context.
|
||||
:type ip: Any, optional
|
||||
|
||||
:return: A new tensor representing the local tile selected at the given coordinate.
|
||||
:rtype: Tensor
|
||||
|
||||
**Examples**
|
||||
|
||||
1. Tiling a 2D tensor and extracting a tile:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
# input: (16, 24)
|
||||
tensor : cute.Tensor
|
||||
tiler = (2, 4)
|
||||
coord = (1, 1)
|
||||
|
||||
# output: (8, 6)
|
||||
# - zipped_divide(tensor, tiler) -> ((2, 4), (8, 6))
|
||||
# - local_tile(tensor, tiler, coord) -> (8, 6)
|
||||
result = cute.local_tile(tensor, tiler=tiler, coord=coord)
|
||||
|
||||
2. Using a stride projection for specialized tiling:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
# input: (16, 24)
|
||||
tensor : cute.Tensor
|
||||
tiler = (2, 2, 4)
|
||||
coord = (0, 1, 1)
|
||||
proj = (1, None, 1)
|
||||
|
||||
# output: (8, 6)
|
||||
# projected_tiler: (2, 4)
|
||||
# projected_coord: (0, 1)
|
||||
# - zipped_divide(tensor, projected_tiler) -> ((2, 4), (8, 6))
|
||||
# - local_tile(tensor, projected_tiler, projected_coord) -> (8, 6)
|
||||
result = cute.local_tile(tensor, tiler=tiler, coord=coord, proj=proj)
|
||||
"""
|
||||
|
||||
tiler_val = _pack_tile(tiler, loc=loc, ip=ip)
|
||||
coord_val = _pack_coord(coord, loc=loc, ip=ip)
|
||||
if proj is not None:
|
||||
|
||||
206
python/CuTeDSL/cutlass/cute/ffi.py
Normal file
206
python/CuTeDSL/cutlass/cute/ffi.py
Normal file
@@ -0,0 +1,206 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
|
||||
#
|
||||
# Use of this software is governed by the terms and conditions of the
|
||||
# NVIDIA End User License Agreement (EULA), available at:
|
||||
# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html
|
||||
#
|
||||
# Any use, reproduction, disclosure, or distribution of this software
|
||||
# and related documentation outside the scope permitted by the EULA
|
||||
# is strictly prohibited.
|
||||
|
||||
from cutlass._mlir import ir
|
||||
from cutlass._mlir.dialects import func
|
||||
from cutlass.base_dsl.typing import get_mlir_types, NumericMeta, Numeric, as_numeric
|
||||
from cutlass.base_dsl.dsl import extract_mlir_values
|
||||
|
||||
from cutlass import DSLRuntimeError
|
||||
|
||||
|
||||
class ffi:
|
||||
"""
|
||||
Foreign Function Interface (FFI) wrapper for external function invocation in the CUTLASS Python DSL.
|
||||
|
||||
This class enables calling external MLIR function prototypes from Python code, handling type conversion,
|
||||
prototype registration, and dynamic insertion of function symbols into MLIR modules as needed.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
name : str
|
||||
Name of the external function. This will be used as the symbol name when calling or registering a prototype in the MLIR module.
|
||||
params_types : list, optional
|
||||
List of argument types for the external function. These can be CUTLASS numeric types, numeric meta types, or types convertible via `get_mlir_types`.
|
||||
return_type : optional
|
||||
The return type of the external function. If not specified, the function is assumed to have no return value.
|
||||
|
||||
Methods
|
||||
-------
|
||||
__call__(*args)
|
||||
Calls the external function with the given arguments, ensuring argument and result types match the prototype.
|
||||
"""
|
||||
|
||||
def __init__(self, *, name: str, params_types: list = [], return_type=None):
|
||||
self.name = name
|
||||
self.params_types = params_types
|
||||
self.return_type = [return_type] if return_type else []
|
||||
|
||||
def _get_prototype_region(self, current_op):
|
||||
"""
|
||||
Helper method to determine the appropriate MLIR module and region for inserting a function prototype.
|
||||
|
||||
This method recursively traverses the current operation's parent hierarchy to find the correct module
|
||||
and region where the function prototype should be inserted. It supports both builtin.module and gpu.module.
|
||||
:param current_op: The current operation to check.
|
||||
:type current_op: Operation
|
||||
|
||||
:returns:
|
||||
A tuple containing the module operation and the insertion region.
|
||||
:rtype: tuple
|
||||
"""
|
||||
if current_op is None:
|
||||
raise DSLRuntimeError("current operation is unknown")
|
||||
op_name = current_op.name
|
||||
if op_name in ["builtin.module", "gpu.module"]:
|
||||
return current_op, current_op.regions[0].blocks[0]
|
||||
else:
|
||||
return self._get_prototype_region(current_op.parent)
|
||||
|
||||
@staticmethod
|
||||
def _to_mlir_types(args):
|
||||
"""
|
||||
Helper method to convert a list of arguments to their corresponding MLIR types.
|
||||
|
||||
This method converts CUTLASS numeric types, numeric meta types, and types convertible via `get_mlir_types`
|
||||
to their corresponding MLIR types.
|
||||
:param args: The list of arguments to convert to MLIR types.
|
||||
:type args: list
|
||||
|
||||
:returns:
|
||||
A list of MLIR types.
|
||||
:rtype: list
|
||||
"""
|
||||
types = []
|
||||
for param in args:
|
||||
if isinstance(param, NumericMeta):
|
||||
types.append(param.mlir_type)
|
||||
elif isinstance(param, Numeric):
|
||||
types.append(param.mlir_type)
|
||||
else:
|
||||
types.extend(get_mlir_types(param))
|
||||
return types
|
||||
|
||||
@staticmethod
|
||||
def _type_check(callee, exec_types, returns_types):
|
||||
"""
|
||||
Helper method to check if the function prototype types match the expected types.
|
||||
|
||||
This method compares the input and output types of the function prototype with the provided expected types.
|
||||
:param callee: The function prototype operation to check.
|
||||
:type callee: func.FuncOp
|
||||
:param exec_types: The expected input types.
|
||||
:type exec_types: list
|
||||
:param returns_types: The expected output types.
|
||||
:type returns_types: list
|
||||
"""
|
||||
if callee.type.inputs != exec_types or callee.type.results != returns_types:
|
||||
raise DSLRuntimeError(
|
||||
f"External prototype types mismatch, trying to call with ({exec_types}) -> ({returns_types}), got {callee.type}"
|
||||
)
|
||||
|
||||
def _create_prototype_in_region(self, op, region, exec_args):
|
||||
"""
|
||||
Helper method to create or retrieve a function prototype in the current module.
|
||||
|
||||
This method checks if a function prototype with the given name already exists in the symbol table of the current module.
|
||||
If it does, it checks if the prototype's types match the expected types. If it does not, it raises an error.
|
||||
If it does not exist, it creates a new function prototype and inserts it into the current region.
|
||||
:param op: The module operation to check.
|
||||
:type op: Operation
|
||||
:param region: The region to insert the function prototype into.
|
||||
:type region: Region
|
||||
:param exec_args: The arguments to pass to the function prototype.
|
||||
:type exec_args: list
|
||||
"""
|
||||
symbol_table = ir.SymbolTable(op.operation)
|
||||
|
||||
if self.name in symbol_table:
|
||||
callee = symbol_table[self.name]
|
||||
else:
|
||||
with ir.InsertionPoint(region):
|
||||
callee = func.FuncOp(
|
||||
self.name,
|
||||
(
|
||||
ffi._to_mlir_types(self.params_types),
|
||||
ffi._to_mlir_types(self.return_type),
|
||||
),
|
||||
)
|
||||
callee.sym_visibility = ir.StringAttr.get("private")
|
||||
|
||||
# Sanity check the function prototype types match the expected types
|
||||
self._type_check(
|
||||
callee,
|
||||
ffi._to_mlir_types(exec_args),
|
||||
ffi._to_mlir_types(self.return_type),
|
||||
)
|
||||
|
||||
return callee
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
"""
|
||||
Calls the FFI function prototype with the provided arguments.
|
||||
|
||||
This method ensures that an IR-level function prototype (external declaration)
|
||||
with the given name and type signature exists in the current module. If it does not
|
||||
exist, it will be created and inserted into the module. A call operation to this
|
||||
function is then emitted using the arguments supplied by the caller.
|
||||
|
||||
:param args:
|
||||
The runtime arguments to pass to the FFI function. These will be converted to
|
||||
their corresponding numeric types and lowered to MLIR values before being used as arguments.
|
||||
:type args: tuple
|
||||
|
||||
:returns:
|
||||
The MLIR call operation created for this invocation.
|
||||
:rtype: func.CallOp
|
||||
|
||||
:raises DSLRuntimeError:
|
||||
If there is no active MLIR insertion point or if the current operation
|
||||
context cannot be determined.
|
||||
"""
|
||||
|
||||
if kwargs:
|
||||
raise DSLRuntimeError(
|
||||
"Keyword arguments are not supported for FFI calls",
|
||||
suggestion="Use positional arguments only",
|
||||
)
|
||||
|
||||
# Get the current insertion point and operation
|
||||
try:
|
||||
current_ip = ir.InsertionPoint.current
|
||||
except Exception:
|
||||
raise DSLRuntimeError(
|
||||
"Failed to determine current insertion point",
|
||||
suggestion="Make sure this is called under a jit context",
|
||||
)
|
||||
current_op = current_ip.block.owner
|
||||
module_op, insertion_region = self._get_prototype_region(current_op)
|
||||
|
||||
# Extract the arguments to MLIR values
|
||||
exec_args = []
|
||||
for arg in args:
|
||||
exec_arg = extract_mlir_values(arg)
|
||||
if not exec_arg:
|
||||
exec_arg = [as_numeric(arg).ir_value()]
|
||||
exec_args.extend(exec_arg)
|
||||
|
||||
# Create the function prototype in module, so if it's under kernel function, prototype will be inserted into gpu.module
|
||||
# If it's under gpu.module, prototype will be inserted into builtin.module
|
||||
callee = self._create_prototype_in_region(
|
||||
module_op, insertion_region, exec_args
|
||||
)
|
||||
|
||||
# Emit the call operation
|
||||
result = func.call(callee.type.results, self.name, exec_args)
|
||||
|
||||
if self.return_type:
|
||||
return result
|
||||
@@ -333,7 +333,7 @@ class MmaF16BF16Trait(MmaTraits):
|
||||
@dataclass(frozen=True)
|
||||
class MmaF8Op(MmaOp):
|
||||
"""
|
||||
FP8 warpgroup MMA Operation.
|
||||
F8 warpgroup MMA Operation.
|
||||
|
||||
See the `PTX documentation <https://docs.nvidia.com/cuda/parallel-thread-execution/#asynchronous-warpgroup-level-matrix-instructions-wgmma-mma>`__.
|
||||
This Operation covers the instructions using the ``.e4m3`` or ``.e5m2`` qualifiers for the input operands.
|
||||
|
||||
@@ -17,6 +17,7 @@ import itertools
|
||||
import operator
|
||||
from typing import Union, Optional, Type, List
|
||||
|
||||
|
||||
# MLIR modules imports
|
||||
from cutlass._mlir import ir
|
||||
from cutlass.base_dsl.env_manager import get_prefix_dsl_libs
|
||||
@@ -128,6 +129,10 @@ class _Pointer(Pointer):
|
||||
def __repr__(self):
|
||||
return self.__str__()
|
||||
|
||||
@property
|
||||
def __cache_key__(self) -> tuple:
|
||||
return (self.dtype, self._addr_space, self._assumed_align)
|
||||
|
||||
|
||||
class _Tensor(Tensor):
|
||||
def __init__(
|
||||
@@ -144,7 +149,7 @@ class _Tensor(Tensor):
|
||||
elif enable_tvm_ffi:
|
||||
import tvm_ffi
|
||||
|
||||
self._tvm_ffi_tensor = tvm_ffi.from_dlpack(tensor, stream=-1)
|
||||
self._tvm_ffi_tensor = tvm_ffi.from_dlpack(tensor)
|
||||
self._dlpack_data = self._tvm_ffi_tensor.__dlpack__()
|
||||
else:
|
||||
try:
|
||||
@@ -185,9 +190,17 @@ class _Tensor(Tensor):
|
||||
:param leading_dim: The leading dimension of the layout, defaults to None
|
||||
:type leading_dim: int, optional
|
||||
|
||||
When ``leading_dim`` is None, automatically deduces the leading dimension from the tensor layout.
|
||||
The layout can be deduced only when exactly one dimension has a stride of 1. Raises an error
|
||||
if the layout cannot be automatically deduced.
|
||||
When ``leading_dim`` is None, the leading dimension is deduced as follows.
|
||||
|
||||
(1) If exactly one dimension has stride 1, that dimension is used.
|
||||
|
||||
(2) If multiple dimensions have stride 1 but exactly one of them has size > 1,
|
||||
that dimension is used.
|
||||
|
||||
(3) If multiple dimensions have stride 1 but none or more than one has size > 1,
|
||||
an error is raised.
|
||||
|
||||
(4) If no dimension has stride 1, all strides remain dynamic.
|
||||
|
||||
When ``leading_dim`` is explicitly specified, marks the layout as dynamic while setting the
|
||||
stride at ``leading_dim`` to 1. Also validates that the specified ``leading_dim`` is consistent
|
||||
@@ -304,6 +317,13 @@ class _Tensor(Tensor):
|
||||
def __repr__(self):
|
||||
return self.__str__()
|
||||
|
||||
@property
|
||||
def __cache_key__(self) -> tuple:
|
||||
self.load_dltensor()
|
||||
if self._dtype is None:
|
||||
self._dtype = self._dltensor_wrapper.dtype
|
||||
return (self._dtype, self._assumed_align, self._dltensor_wrapper.cache_key())
|
||||
|
||||
def __setitem__(self, crd, value):
|
||||
raise TypeError("runtime._Tensor is not indexable")
|
||||
|
||||
@@ -417,42 +437,76 @@ def _get_cute_type_str(inp):
|
||||
return "(" + ",".join(elems) + ")"
|
||||
|
||||
|
||||
class _FakeCompactTensor(Tensor):
|
||||
class _FakeTensor(Tensor):
|
||||
"""Fake Tensor implementation as a placeholder.
|
||||
It mimics the interface of Tensor, but does not hold real data or allow indexing.
|
||||
Used for compilation or testing situations where only shape/type/layout information is needed.
|
||||
All attempts to access or mutate data will raise errors.
|
||||
"""
|
||||
|
||||
"""
|
||||
Create a fake tensor with the given shape, type, and layout.
|
||||
|
||||
:param dtype: Data type of the tensor elements
|
||||
:type dtype: Type[Numeric]
|
||||
:param shape: Shape of the tensor, consists of int (static) or SymInt (dynamic)
|
||||
:type shape: tuple[Union[int, SymInt], ...]
|
||||
:param stride: Stride of the tensor, defaults to None, consists of int (static) or SymInt (dynamic)
|
||||
:type stride: tuple[Union[int, SymInt], ...], optional
|
||||
:param memspace: Memory space where the fake tensor resides. Defaults to AddressSpace.gmem.
|
||||
:type memspace: AddressSpace, optional
|
||||
:param assumed_align: Assumed alignment of the tensor (bytes), defaults to None. If None, uses the element size bytes as the assumed alignment.
|
||||
:type assumed_align: int, optional
|
||||
:param use_32bit_stride: Whether to use 32-bit stride. Defaults to False. When True, the dynamic stride bitwidth
|
||||
will be set to 32 for small problem sizes (cosize(layout) <= Int32_max) for better performance. This is only applied
|
||||
when the dimension is dynamic.
|
||||
:type use_32bit_stride: bool, optional
|
||||
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dtype,
|
||||
shape,
|
||||
stride_order,
|
||||
memspace=None,
|
||||
assumed_align=None,
|
||||
use_32bit_stride=False,
|
||||
dtype: Type[Numeric],
|
||||
shape: tuple[Union[int, SymInt], ...],
|
||||
*,
|
||||
stride: tuple[Union[int, SymInt], ...],
|
||||
memspace: AddressSpace = AddressSpace.gmem,
|
||||
assumed_align: int | None = None,
|
||||
use_32bit_stride: bool = False,
|
||||
compact: bool = False,
|
||||
):
|
||||
self._dtype = dtype
|
||||
self._shape = shape
|
||||
self._stride_order = stride_order or tuple(range(len(shape)))
|
||||
# cannot use memspace or AddressSpace.gmem because AddressSpace.generic is 0
|
||||
self._memspace = memspace if memspace is not None else AddressSpace.gmem
|
||||
self._assumed_align = assumed_align or -(-dtype.width // 8)
|
||||
self._stride = stride
|
||||
self._use_32bit_stride = use_32bit_stride
|
||||
self._compact = compact
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"FakeTensorOrdered<{self._dtype}, {self._shape}, {self._stride_order}>"
|
||||
if not isinstance(shape, (tuple, list)):
|
||||
raise ValueError(f"Expected tuple or list but got {type(shape)}")
|
||||
|
||||
def __repr__(self):
|
||||
return self.__str__()
|
||||
if not all(isinstance(s, (int, SymInt)) for s in self._shape):
|
||||
raise ValueError("All shape elements must be int or SymInt")
|
||||
|
||||
if stride is not None and not all(
|
||||
isinstance(s, (int, SymInt)) for s in self._stride
|
||||
):
|
||||
raise ValueError("All stride elements must be int or SymInt")
|
||||
self._memspace = memspace
|
||||
self._assumed_align = assumed_align
|
||||
if assumed_align is None:
|
||||
# use the bytes width of the element dtype. The alignment is at least one byte align.
|
||||
self._assumed_align = (self._dtype.width + 7) // 8
|
||||
|
||||
@property
|
||||
def mlir_type(self) -> ir.Type:
|
||||
shape_ty = ir.Type.parse(
|
||||
'!cute.shape<"' + _get_cute_type_str(self._shape) + '">'
|
||||
)
|
||||
layout_ty = _cute_ir.LayoutType.get_ordered(
|
||||
shape_ty, self._stride_order, self._use_32bit_stride
|
||||
)
|
||||
self._stride = layout_ty.stride
|
||||
ptr_ty = _cute_ir.PtrType.get(
|
||||
self._dtype.mlir_type, self._memspace, self._assumed_align
|
||||
)
|
||||
shape_str = _get_cute_type_str(self._shape)
|
||||
stride_str = _get_cute_type_str(self._stride)
|
||||
layout_ty = ir.Type.parse(f'!cute.layout<"{shape_str}:{stride_str}">')
|
||||
|
||||
# Boolean types are stored as i8 in memory
|
||||
elem_type = T.i8() if self._dtype.width == 1 else self._dtype.mlir_type
|
||||
ptr_ty = _cute_ir.PtrType.get(elem_type, self._memspace, self._assumed_align)
|
||||
return _cute_ir.MemRefType.get(ptr_ty, layout_ty)
|
||||
|
||||
def __get_mlir_types__(self):
|
||||
@@ -463,11 +517,53 @@ class _FakeCompactTensor(Tensor):
|
||||
assert isinstance(values[0], CoreTensor)
|
||||
return CoreTensor(values[0].value, self._dtype)
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"FakeTensor<{self._dtype}, {self._shape}, {self._stride}>"
|
||||
|
||||
@property
|
||||
def __cache_key__(self) -> tuple:
|
||||
# Check if any shape or stride element is a SymInt without a symbol
|
||||
import warnings
|
||||
|
||||
has_unnamed_symint = False
|
||||
for dim in self._shape:
|
||||
if isinstance(dim, SymInt) and dim.symbol is None:
|
||||
has_unnamed_symint = True
|
||||
break
|
||||
if not self._compact:
|
||||
if not has_unnamed_symint:
|
||||
for stride in self._stride:
|
||||
if isinstance(stride, SymInt) and stride.symbol is None:
|
||||
has_unnamed_symint = True
|
||||
break
|
||||
|
||||
if has_unnamed_symint:
|
||||
warnings.warn(
|
||||
"FakeTensor cache_key contains unnamed symbolic dimensions. "
|
||||
"Different variables with the same shape/stride pattern will have "
|
||||
"identical cache keys, which may cause incorrect cache hits. "
|
||||
"Consider using 'symbol' parameter to distinguish variables: "
|
||||
"cute.sym_int32(symbol='M'), cute.sym_int32(symbol='N')",
|
||||
UserWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
return (
|
||||
self._dtype,
|
||||
self._memspace,
|
||||
self._assumed_align,
|
||||
self._shape,
|
||||
self._stride,
|
||||
)
|
||||
|
||||
def __repr__(self):
|
||||
return self.__str__()
|
||||
|
||||
def __setitem__(self, crd, value):
|
||||
raise DSLRuntimeError("runtime._FakeCompactTensor is not indexable")
|
||||
raise DSLRuntimeError("runtime._FakeTensor is not indexable")
|
||||
|
||||
def __getitem__(self, crd):
|
||||
raise DSLRuntimeError("runtime._FakeCompactTensor is not indexable")
|
||||
raise DSLRuntimeError("runtime._FakeTensor is not indexable")
|
||||
|
||||
@property
|
||||
def element_type(self) -> Type[Numeric]:
|
||||
@@ -491,118 +587,7 @@ class _FakeCompactTensor(Tensor):
|
||||
|
||||
@property
|
||||
def leading_dim(self):
|
||||
for dim, order in enumerate(self._stride_order):
|
||||
if order == 0:
|
||||
return dim
|
||||
|
||||
@property
|
||||
def dynamic_shapes_mask(self):
|
||||
return tuple(1 if isinstance(e, SymInt) else 0 for e in self._shape)
|
||||
|
||||
@property
|
||||
def dynamic_strides_mask(self):
|
||||
return tuple(1 if isinstance(e, SymInt) else 0 for e in self._stride)
|
||||
|
||||
def fill(self, value: Numeric):
|
||||
raise DSLRuntimeError("runtime._FakeCompactTensor is not writable")
|
||||
|
||||
|
||||
class _FakeTensor(Tensor):
|
||||
"""Fake Tensor implementation as a placeholder.
|
||||
It mimics the interface of Tensor, but does not hold real data or allow indexing.
|
||||
Used for compilation or testing situations where only shape/type/layout information is needed.
|
||||
All attempts to access or mutate data will raise errors.
|
||||
"""
|
||||
|
||||
"""
|
||||
Create a fake tensor with the given shape, type, and layout.
|
||||
|
||||
:param dtype: Data type of the tensor elements
|
||||
:type dtype: Type[Numeric]
|
||||
:param shape: Shape of the tensor, consists of int (static) or SymInt (dynamic)
|
||||
:type shape: tuple[int, ...]
|
||||
:param stride: Stride of the tensor, defaults to None, consists of int (static) or SymInt (dynamic)
|
||||
:type stride: tuple[int, ...], optional
|
||||
:param assumed_align: Assumed alignment of the tensor (bytes), defaults to None. If None, uses the element size bytes as the assumed alignment.
|
||||
:type assumed_align: int, optional
|
||||
:param use_32bit_stride: Whether to use 32-bit stride. Defaults to False. When True, the dynamic stride bitwidth
|
||||
will be set to 32 for small problem sizes (cosize(layout) <= Int32_max) for better performance. This is only applied
|
||||
when the dimension is dynamic.
|
||||
:type use_32bit_stride: bool, optional
|
||||
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, dtype, shape, *, stride, memspace=None, assumed_align=None):
|
||||
self._dtype = dtype
|
||||
self._shape = shape
|
||||
self._stride = stride
|
||||
# cannot use memspace or AddressSpace.generic because AddressSpace.generic is 0
|
||||
self._memspace = memspace if memspace is not None else AddressSpace.gmem
|
||||
self._assumed_align = assumed_align
|
||||
if assumed_align is None:
|
||||
# use the bytes width of the element dtype. The alignment is at least one byte align.
|
||||
self._assumed_align = (self._dtype.width + 7) // 8
|
||||
|
||||
if not isinstance(shape, (tuple, list)):
|
||||
raise ValueError(f"Expected tuple or list but got {type(shape)}")
|
||||
|
||||
if not all(isinstance(s, (int, SymInt)) for s in self._shape):
|
||||
raise ValueError("All shape elements must be int or SymInt")
|
||||
|
||||
if stride is not None and not all(
|
||||
isinstance(s, (int, SymInt)) for s in self._stride
|
||||
):
|
||||
raise ValueError("All stride elements must be int or SymInt")
|
||||
@property
|
||||
def mlir_type(self) -> ir.Type:
|
||||
shape_str = _get_cute_type_str(self._shape)
|
||||
stride_str = _get_cute_type_str(self._stride)
|
||||
layout_ty = ir.Type.parse(f'!cute.layout<"{shape_str}:{stride_str}">')
|
||||
ptr_ty = _cute_ir.PtrType.get(
|
||||
self._dtype.mlir_type, self._memspace, self._assumed_align
|
||||
)
|
||||
return _cute_ir.MemRefType.get(ptr_ty, layout_ty)
|
||||
|
||||
def __get_mlir_types__(self):
|
||||
return [self.mlir_type]
|
||||
|
||||
def __new_from_mlir_values__(self, values):
|
||||
assert len(values) == 1
|
||||
assert isinstance(values[0], CoreTensor)
|
||||
return CoreTensor(values[0].value, self._dtype)
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"FakeTensor<{self._dtype}, {self._shape}, {self._stride}>"
|
||||
|
||||
def __repr__(self):
|
||||
return self.__str__()
|
||||
|
||||
def __setitem__(self, crd, value):
|
||||
raise DSLRuntimeError("runtime._FakeTensor is not indexable")
|
||||
|
||||
def __getitem__(self, crd):
|
||||
raise DSLRuntimeError("runtime._FakeTensor is not indexable")
|
||||
|
||||
@property
|
||||
def element_type(self) -> Type[Numeric]:
|
||||
return self._dtype
|
||||
|
||||
@property
|
||||
def memspace(self):
|
||||
return self._memspace
|
||||
|
||||
@property
|
||||
def iterator(self):
|
||||
raise DSLRuntimeError("runtime._FakeTensor has dummy iterator")
|
||||
|
||||
@property
|
||||
def shape(self):
|
||||
return self._shape
|
||||
|
||||
@property
|
||||
def stride(self):
|
||||
return self._stride
|
||||
return core.leading_dim(self._shape, self._stride)
|
||||
|
||||
@property
|
||||
def dynamic_shapes_mask(self):
|
||||
@@ -617,35 +602,36 @@ class _FakeTensor(Tensor):
|
||||
|
||||
|
||||
def make_fake_compact_tensor(
|
||||
dtype,
|
||||
shape,
|
||||
dtype: Type[Numeric],
|
||||
shape: tuple[Union[int, SymInt], ...],
|
||||
*,
|
||||
stride_order=None,
|
||||
memspace=None,
|
||||
assumed_align=None,
|
||||
use_32bit_stride=False,
|
||||
stride_order: Optional[tuple[int, ...]] = None,
|
||||
memspace: AddressSpace = AddressSpace.gmem,
|
||||
assumed_align: Optional[int] = None,
|
||||
use_32bit_stride: bool = False,
|
||||
):
|
||||
"""
|
||||
Create a fake tensor with the specified shape, element type, and a compact memory layout.
|
||||
|
||||
:param dtype: Data type of the tensor elements.
|
||||
:type dtype: Type[Numeric]
|
||||
:param shape: Shape of the tensor.
|
||||
:type shape: tuple[int, ...]
|
||||
:param shape: Shape of the tensor, consisting of static (int) or dynamic (SymInt) dimensions.
|
||||
:type shape: tuple[Union[int, SymInt], ...]
|
||||
:param stride_order: Order in which strides (memory layout) are assigned to the tensor dimensions.
|
||||
If None, the default layout is left-to-right order (known as column-major order for flatten layout).
|
||||
Otherwise, it should be a permutation order of the dimension indices.
|
||||
The mode with stride_order 0 is the fastest changing (leading) dimension, and N-1 is the slowest changing.
|
||||
:type stride_order: tuple[int, ...], optional
|
||||
:param memspace: Memory space where the fake tensor resides. Optional.
|
||||
:type memspace: str, optional
|
||||
:param assumed_align: Assumed byte alignment for the tensor data. If None, the default alignment is used.
|
||||
:param memspace: Memory space where the fake tensor resides. Defaults to AddressSpace.gmem.
|
||||
:type memspace: AddressSpace, optional
|
||||
:param assumed_align: Assumed byte alignment for the tensor data. If None, the default alignment is the dtype width, & at least 1 byte.
|
||||
:type assumed_align: int, optional
|
||||
:param use_32bit_stride: Whether to use 32-bit stride for dynamic dimensions. If True and the total size of the
|
||||
layout (cosize(layout)) fits within int32, then dynamic strides will use 32-bit integers for improved performance.
|
||||
Only applies when dimensions are dynamic. Defaults to False.
|
||||
:type use_32bit_stride: bool, optional
|
||||
:return: An instance of a fake tensor with the given properties and compact layout.
|
||||
:rtype: _FakeCompactTensor
|
||||
:rtype: _FakeTensor
|
||||
|
||||
**Examples:**
|
||||
|
||||
@@ -663,31 +649,68 @@ def make_fake_compact_tensor(
|
||||
# tensor<ptr<f32, generic> o (100,?{div=8}):(?{i32 div=8},1)>
|
||||
compiled_foo = cute.compile(foo, x)
|
||||
|
||||
# Default stride order is left-to-right order: (1, 8)
|
||||
y = make_fake_compact_tensor(cutlass.Float32, (8, 3))
|
||||
# Default stride order is left-to-right order (0, 1, ..., n-1)
|
||||
y = make_fake_compact_tensor(cutlass.Float32, (8, 3, 2)) # y.stride == (1, 8, 24)
|
||||
"""
|
||||
|
||||
return _FakeCompactTensor(
|
||||
if stride_order is not None:
|
||||
if len(stride_order) != len(shape):
|
||||
raise ValueError(
|
||||
f"stride_order ({stride_order}) must be empty or have same length as shape ({shape})."
|
||||
)
|
||||
else:
|
||||
# Default stride order is left-to-right
|
||||
stride_order = stride_order or tuple(range(len(shape)))
|
||||
|
||||
# Make compact strides (possibly symbolic) from shape & stride_order
|
||||
stride = [None] * len(stride_order)
|
||||
stride_product = 1
|
||||
for order in range(len(stride_order)):
|
||||
idx = stride_order.index(order)
|
||||
stride[idx] = stride_product
|
||||
stride_product *= shape[idx]
|
||||
|
||||
stride_width = 32 if use_32bit_stride else 64
|
||||
stride = tuple(
|
||||
(
|
||||
SymInt(width=stride_width, divisibility=s.divisibility)
|
||||
if isinstance(s, SymInt)
|
||||
else s
|
||||
)
|
||||
for s in stride
|
||||
)
|
||||
|
||||
return _FakeTensor(
|
||||
dtype,
|
||||
shape,
|
||||
stride_order=stride_order,
|
||||
stride=stride,
|
||||
memspace=memspace,
|
||||
assumed_align=assumed_align,
|
||||
use_32bit_stride=use_32bit_stride,
|
||||
compact=True,
|
||||
)
|
||||
|
||||
|
||||
def make_fake_tensor(dtype, shape, stride, *, memspace=None, assumed_align=None):
|
||||
def make_fake_tensor(
|
||||
dtype: Type[Numeric],
|
||||
shape: tuple[Union[int, SymInt], ...],
|
||||
stride: tuple[Union[int, SymInt], ...],
|
||||
*,
|
||||
memspace: AddressSpace = AddressSpace.gmem,
|
||||
assumed_align: Optional[int] = None,
|
||||
):
|
||||
"""
|
||||
Create a fake tensor with the specified element type, shape, and stride.
|
||||
|
||||
:param dtype: Data type of the tensor elements.
|
||||
:type dtype: Type[Numeric]
|
||||
:param shape: Shape of the tensor.
|
||||
:type shape: tuple[int, ...]
|
||||
:param stride: Stride of the tensor.
|
||||
:type stride: tuple[int, ...]
|
||||
:param assumed_align: Assumed byte alignment for the tensor data. If None, the default alignment is used. Defaults to None.
|
||||
:param shape: Shape of the tensor, consisting of static (int) or dynamic (SymInt) dimensions.
|
||||
:type shape: tuple[Union[int, SymInt], ...]
|
||||
:param stride: Stride of the tensor, consisting of static (int) or dynamic (SymInt) values.
|
||||
:type stride: tuple[Union[int, SymInt], ...]
|
||||
:param memspace: Memory space where the fake tensor resides. Defaults to AddressSpace.gmem.
|
||||
:type memspace: AddressSpace, optional
|
||||
:param assumed_align: Assumed byte alignment for the tensor data. If None, the default alignment is the dtype width, & at least 1 byte.
|
||||
:type assumed_align: int, optional
|
||||
:return: An instance of a fake tensor with the given properties.
|
||||
:rtype: _FakeTensor
|
||||
@@ -953,22 +976,7 @@ def load_module(file_path: str, *, enable_tvm_ffi: bool = False):
|
||||
if Path(path).exists():
|
||||
_LOAD_MODULE_LIBS_CACHE.append(ctypes.CDLL(path))
|
||||
|
||||
if enable_tvm_ffi:
|
||||
import tvm_ffi
|
||||
|
||||
try:
|
||||
# keep_module_alive=False means the module will be unloaded
|
||||
# after the returned module goes out of scope, this is useful
|
||||
# for frequent loading and unloading of modules. The only requirement
|
||||
# is that the module do not return object that have deleter in the module
|
||||
# and the returned object lives longer than the module.
|
||||
# DSL functions to not have such issue so it is desirable to set this to False.
|
||||
return tvm_ffi.load_module(file_path, keep_module_alive=False)
|
||||
except TypeError:
|
||||
# compatible with tvm-ffi < 0.1.6
|
||||
return tvm_ffi.load_module(file_path)
|
||||
else:
|
||||
return ExternalBinaryModule(file_path)
|
||||
return ExternalBinaryModule(file_path, enable_tvm_ffi=enable_tvm_ffi)
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Try to register_jit_arg_adapter for TensorAdapter
|
||||
|
||||
@@ -139,6 +139,19 @@ class _Tensor(Tensor):
|
||||
def __init__(
|
||||
self, value, dtype: Optional[Type[Numeric]] = None, *, loc=None, ip=None
|
||||
):
|
||||
"""Initialize a Tensor from an MLIR value.
|
||||
|
||||
:param value: The MLIR operation result value or another Tensor to initialize from
|
||||
:type value: Union[ir.Value, _Tensor]
|
||||
:param dtype: The user specified data type of the tensor elements, defaults to None
|
||||
:type dtype: Optional[Type[Numeric]]
|
||||
:param loc: The source location for the operation, defaults to None
|
||||
:type loc: Optional[Location]
|
||||
:param ip: The insertion point for the operation, defaults to None
|
||||
:type ip: Optional[InsertionPoint]
|
||||
:raises TypeError: If value is not ir.Value or _Tensor
|
||||
:raises TypeError: If iterator type is not supported
|
||||
"""
|
||||
self._dtype = dtype
|
||||
if isinstance(value, ir.Value):
|
||||
self.value = value
|
||||
@@ -952,6 +965,37 @@ def make_fragment_like(src, dtype=None, *, loc=None, ip=None):
|
||||
def recast_tensor(
|
||||
src: Tensor, dtype: Type[Numeric], swizzle_=None, *, loc=None, ip=None
|
||||
):
|
||||
"""Recast a tensor to a different data type by changing the element interpretation.
|
||||
|
||||
This function reinterprets the memory of a tensor with a different element type,
|
||||
adjusting both the iterator pointer type and the layout to maintain consistency.
|
||||
|
||||
:param src: The source tensor to recast
|
||||
:type src: Tensor
|
||||
:param dtype: The target data type for tensor elements
|
||||
:type dtype: Type[Numeric]
|
||||
:param swizzle_: Optional swizzle parameter (reserved for future use), defaults to None
|
||||
:type swizzle_: Optional, unused
|
||||
:param loc: Source location for MLIR operation tracking, defaults to None
|
||||
:type loc: Optional[Location]
|
||||
:param ip: Insertion point for MLIR operation, defaults to None
|
||||
:type ip: Optional[InsertionPoint]
|
||||
:return: A new tensor with the same memory but reinterpreted as dtype
|
||||
:rtype: Tensor
|
||||
:raises TypeError: If dtype is not a subclass of Numeric
|
||||
|
||||
**Examples:**
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
# Create a Float32 tensor
|
||||
tensor_f32 = make_rmem_tensor((4, 8), Float32)
|
||||
|
||||
# Recast to Int32 to manipulate bits
|
||||
tensor_i32 = recast_tensor(tensor_f32, Int32)
|
||||
|
||||
# Both tensors share the same memory, but interpret it differently
|
||||
"""
|
||||
if not isclass(dtype) or not issubclass(dtype, Numeric):
|
||||
raise TypeError(f"dtype must be a type of Numeric, but got {dtype}")
|
||||
|
||||
@@ -972,6 +1016,36 @@ def recast_tensor(
|
||||
|
||||
@dsl_user_op
|
||||
def domain_offset(coord: Coord, tensor: Tensor, *, loc=None, ip=None) -> Tensor:
|
||||
"""Offset the tensor domain by the given coordinate.
|
||||
|
||||
This function creates a new tensor by offsetting the iterator/pointer of the input tensor
|
||||
by the amount corresponding to the given coordinate in its layout.
|
||||
|
||||
:param coord: The coordinate offset to apply
|
||||
:type coord: Coord
|
||||
:param tensor: The source tensor to offset
|
||||
:type tensor: Tensor
|
||||
:param loc: Source location for MLIR operation tracking, defaults to None
|
||||
:type loc: Optional[Location]
|
||||
:param ip: Insertion point for MLIR operation, defaults to None
|
||||
:type ip: Optional[InsertionPoint]
|
||||
:return: A new tensor with the offset iterator
|
||||
:rtype: Tensor
|
||||
:raises ValueError: If the tensor type doesn't support domain offsetting
|
||||
|
||||
**Examples:**
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
# Create a tensor with a row-major layout
|
||||
ptr = make_ptr(Float32, base_ptr, AddressSpace.gmem)
|
||||
layout = make_layout((64, 128), stride=(128, 1))
|
||||
tensor = make_tensor(ptr, layout)
|
||||
|
||||
# Offset by coordinate (3, 5)
|
||||
offset_tensor = domain_offset((3, 5), tensor)
|
||||
# offset_tensor now points to element at (3, 5)
|
||||
"""
|
||||
offset = crd2idx(coord, tensor.layout, loc=loc, ip=ip)
|
||||
if isinstance(tensor.iterator, Pointer):
|
||||
return make_tensor(
|
||||
|
||||
@@ -9,6 +9,8 @@
|
||||
# and related documentation outside the scope permitted by the EULA
|
||||
# is strictly prohibited.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
import ctypes
|
||||
from typing import ForwardRef, Tuple, Union, Any, Type, List, Optional, Literal
|
||||
@@ -24,15 +26,18 @@ Int = Union[int, Integer]
|
||||
|
||||
|
||||
class SymInt:
|
||||
def __init__(self, width: Literal[32, 64] = 32, *, divisibility=1):
|
||||
def __init__(
|
||||
self, width: Literal[32, 64] = 32, *, divisibility=1, symbol: str | None = None
|
||||
):
|
||||
if width not in [32, 64]:
|
||||
raise ValueError(f"Unsupported width: {width}")
|
||||
|
||||
self._width = width
|
||||
self._divisibility = divisibility
|
||||
self._symbol = symbol
|
||||
|
||||
def __hash__(self):
|
||||
return hash((self._width, self._divisibility))
|
||||
return hash((self._width, self._divisibility, self._symbol))
|
||||
|
||||
@property
|
||||
def width(self):
|
||||
@@ -42,8 +47,16 @@ class SymInt:
|
||||
def divisibility(self):
|
||||
return self._divisibility
|
||||
|
||||
@property
|
||||
def symbol(self):
|
||||
return self._symbol
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"?{{i{self._width} div={self._divisibility}}}"
|
||||
prefix = "" if self._symbol is None else self._symbol + " "
|
||||
if self._width == 32:
|
||||
return f"{prefix}?{{div={self._divisibility}}}"
|
||||
else:
|
||||
return f"{prefix}?{{i{self._width} div={self._divisibility}}}"
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return self.__str__()
|
||||
@@ -51,18 +64,52 @@ class SymInt:
|
||||
def __eq__(self, other) -> bool:
|
||||
if not isinstance(other, SymInt):
|
||||
return False
|
||||
|
||||
return all(
|
||||
[self._width == other._width, self._divisibility == other._divisibility]
|
||||
[
|
||||
self._width == other._width,
|
||||
self._divisibility == other._divisibility,
|
||||
self._symbol == other._symbol,
|
||||
]
|
||||
)
|
||||
|
||||
def __mod__(self, other: int) -> Union["SymInt", int]:
|
||||
if self._divisibility % other != 0:
|
||||
def __mod__(self, other: int | SymInt) -> SymInt | int:
|
||||
if isinstance(other, int):
|
||||
other_div, result_width = other, self._width
|
||||
elif isinstance(other, SymInt):
|
||||
other_div, result_width = (
|
||||
other._divisibility,
|
||||
max(self._width, other._width),
|
||||
)
|
||||
else:
|
||||
return NotImplemented
|
||||
|
||||
if self._divisibility % other_div == 0:
|
||||
return 0
|
||||
else:
|
||||
from math import gcd
|
||||
|
||||
div = gcd(self._divisibility, other)
|
||||
return SymInt(self._width, divisibility=div)
|
||||
return SymInt(result_width, divisibility=gcd(self._divisibility, other_div))
|
||||
|
||||
def __rmod__(self, other: int) -> int:
|
||||
"""int % SymInt: check if the int conforms to this SymInt's divisibility"""
|
||||
if isinstance(other, int):
|
||||
return other % self._divisibility
|
||||
return NotImplemented
|
||||
|
||||
def __mul__(self, other: int | SymInt) -> SymInt:
|
||||
if isinstance(other, int):
|
||||
return SymInt(self._width, divisibility=self._divisibility * other)
|
||||
elif isinstance(other, SymInt):
|
||||
return SymInt(
|
||||
width=max(self._width, other._width),
|
||||
divisibility=self._divisibility * other._divisibility,
|
||||
)
|
||||
else:
|
||||
return 0
|
||||
return NotImplemented
|
||||
|
||||
def __rmul__(self, other: int | SymInt) -> SymInt:
|
||||
return self.__mul__(other)
|
||||
|
||||
def __c_pointers__(self):
|
||||
return [ctypes.c_void_p(0).value]
|
||||
@@ -73,7 +120,7 @@ class SymInt:
|
||||
)
|
||||
return [res_ty]
|
||||
|
||||
def __new_from_mlir_values__(self, values) -> "SymInt":
|
||||
def __new_from_mlir_values__(self, values) -> SymInt:
|
||||
from .core import IntValue
|
||||
|
||||
if self.width == 32:
|
||||
@@ -84,16 +131,18 @@ class SymInt:
|
||||
assert False, f"Unsupported width: {self.width}"
|
||||
return self
|
||||
|
||||
def sym_int(width: Literal[32, 64] = 32, *, divisibility=1) -> SymInt:
|
||||
return SymInt(width, divisibility=divisibility)
|
||||
def sym_int(
|
||||
width: Literal[32, 64] = 32, *, divisibility=1, symbol: str | None = None
|
||||
) -> SymInt:
|
||||
return SymInt(width, divisibility=divisibility, symbol=symbol)
|
||||
|
||||
|
||||
def sym_int32(divisibility=1) -> SymInt:
|
||||
return sym_int(32, divisibility=divisibility)
|
||||
def sym_int32(divisibility=1, symbol: str | None = None) -> SymInt:
|
||||
return sym_int(32, divisibility=divisibility, symbol=symbol)
|
||||
|
||||
|
||||
def sym_int64(divisibility=1) -> SymInt:
|
||||
return sym_int(64, divisibility=divisibility)
|
||||
def sym_int64(divisibility=1, symbol: str | None = None) -> SymInt:
|
||||
return sym_int(64, divisibility=divisibility, symbol=symbol)
|
||||
|
||||
|
||||
ScaledBasis = ForwardRef("ScaledBasis")
|
||||
|
||||
@@ -1199,7 +1199,6 @@ class KernelLauncher:
|
||||
return self.dsl._get_smem_usage()
|
||||
|
||||
def launch(self, *args, **kwargs):
|
||||
self.dsl.frame = inspect.currentframe().f_back
|
||||
self.dsl._preprocess_launch_config_args(args, kwargs)
|
||||
config = self.dsl.LaunchConfig(*args, **kwargs)
|
||||
kernel_attrs = _build_kernel_attrs(config)
|
||||
@@ -1216,7 +1215,6 @@ class KernelLauncher:
|
||||
|
||||
ret, name = kernel_generator(*self.func_args, **self.func_kwargs, config=config)
|
||||
self.dsl.kernel_info[name] = kernel_attrs
|
||||
self.dsl.frame = None
|
||||
return ret.launch_op_ret
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
|
||||
@@ -35,7 +35,6 @@ if is_available():
|
||||
)
|
||||
from .compile import (
|
||||
release_compile_cache,
|
||||
initialize_cutlass_dsl,
|
||||
)
|
||||
from .ffi import (
|
||||
get_export_disabled_safety_checks,
|
||||
@@ -48,10 +47,6 @@ if is_available():
|
||||
# This is a legacy name for TensorSpec. It will be removed eventually.
|
||||
TensorMode = TensorSpec
|
||||
|
||||
# This explicit init method ensures that we avoid initialization at
|
||||
# unexpected times in jax tracing.
|
||||
initialize_cutlass_dsl()
|
||||
|
||||
__all__ = [
|
||||
"cutlass_call",
|
||||
"jax_to_cutlass_dtype",
|
||||
|
||||
@@ -267,36 +267,4 @@ def release_compile_cache():
|
||||
_CUTLASS_COMPILE_CACHE.clear()
|
||||
dsl = CuTeDSL._get_dsl()
|
||||
dsl.jit_cache.clear()
|
||||
# TODO: This is needed to release frames being held in the DSL
|
||||
# We should avoid holding such references as they unexpectedly
|
||||
# extend object lifetime.
|
||||
dsl.frame = None
|
||||
gc.collect()
|
||||
|
||||
|
||||
class _DummyInitKernel:
|
||||
@cute.kernel
|
||||
def kernel(self):
|
||||
pass
|
||||
|
||||
@cute.jit
|
||||
def init(self):
|
||||
pass
|
||||
|
||||
|
||||
_CUTLASS_DSL_INITIALIZED = False
|
||||
|
||||
|
||||
def initialize_cutlass_dsl():
|
||||
"""Initializes cutlass DSL."""
|
||||
global _CUTLASS_DSL_INITIALIZED
|
||||
if _CUTLASS_DSL_INITIALIZED:
|
||||
return
|
||||
|
||||
# Call compiler to ensure we've pre-processed any kernels inside cutedsl.
|
||||
kernel = _DummyInitKernel()
|
||||
with _compile_lock:
|
||||
logger.debug("Initializing cutlass dsl...")
|
||||
_ = cutlass.cute.compile(kernel.init)
|
||||
|
||||
_CUTLASS_DSL_INITIALIZED = True
|
||||
|
||||
@@ -28,9 +28,13 @@ logger = logging.getLogger(__name__)
|
||||
_CUTE_DSL_RUNTIME_LIBRARY_NAME = "cute_dsl_runtime"
|
||||
|
||||
_CUTLASS_CALL_TARGETS = {
|
||||
"CuteDSLRT_NvJaxCutlassCall": {"execute": "CuteDSLRT_NvJaxCutlassCallExecute"},
|
||||
"CuteDSLRT_NvJaxCutlassCall": {
|
||||
"execute": "CuteDSLRT_NvJaxCutlassCallExecute",
|
||||
"prepare": "CuteDSLRT_NvJaxCutlassCallPrepare",
|
||||
},
|
||||
"CuteDSLRT_NvJaxCutlassCallNoCudaGraph": {
|
||||
"execute": "CuteDSLRT_NvJaxCutlassCallExecuteNoCudaGraph"
|
||||
"execute": "CuteDSLRT_NvJaxCutlassCallExecuteNoCudaGraph",
|
||||
"prepare": "CuteDSLRT_NvJaxCutlassCallPrepare",
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
# Use `pip install -r requirements-cu13.txt` with the present file to install a
|
||||
# wheel consistent with the present state of the github repository
|
||||
nvidia-cutlass-dsl[cu13]==4.4.1
|
||||
nvidia-cutlass-dsl[cu13]==4.4.2
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
# Use `pip install -r requirements.txt` with the present file to install a
|
||||
# wheel consistent with the present state of the github repository
|
||||
nvidia-cutlass-dsl==4.4.1
|
||||
nvidia-cutlass-dsl==4.4.2
|
||||
|
||||
@@ -133,7 +133,7 @@ def get_option_registry():
|
||||
this._option_registry = OptionRegistry(device_cc())
|
||||
return this._option_registry
|
||||
|
||||
this.__version__ = '4.4.1'
|
||||
this.__version__ = '4.4.2'
|
||||
|
||||
from cutlass_cppgen.backend import create_memory_pool
|
||||
from cutlass_cppgen.emit.pytorch import pytorch
|
||||
|
||||
@@ -51,7 +51,7 @@ setup_pycute.perform_setup()
|
||||
|
||||
setup(
|
||||
name='cutlass_cppgen',
|
||||
version='4.4.1',
|
||||
version='4.4.2',
|
||||
description='CUTLASS Pythonic Interface',
|
||||
package_dir={'': '.'},
|
||||
packages=[
|
||||
|
||||
@@ -36,7 +36,7 @@ from setuptools import setup
|
||||
def perform_setup():
|
||||
setup(
|
||||
name='cutlass_library',
|
||||
version='4.4.1',
|
||||
version='4.4.2',
|
||||
description='CUTLASS library generation scripts',
|
||||
packages=['cutlass_library']
|
||||
)
|
||||
|
||||
@@ -36,7 +36,7 @@ from setuptools import setup
|
||||
def perform_setup():
|
||||
setup(
|
||||
name='pycute',
|
||||
version='4.4.1',
|
||||
version='4.4.2',
|
||||
description='Python implementation of CuTe',
|
||||
packages=['pycute'],
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user