v4.3.1 update. (#2817)

This commit is contained in:
Junkai-Wu
2025-11-27 22:49:30 +08:00
committed by GitHub
parent 2052fd3885
commit 1de3a576cc
44 changed files with 3316 additions and 510 deletions

View File

@@ -9,6 +9,10 @@
# and related documentation outside the scope permitted by the EULA
# is strictly prohibited.
from ._mlir._mlir_libs import _cutlass_ir
_cutlass_ir.populate(_cutlass_ir)
from .cutlass_dsl import (
Constexpr,
dsl_user_op,

View File

@@ -32,6 +32,9 @@ class Arch(Enum):
sm_101 = (10, 1, "")
sm_101a = (10, 1, "a")
sm_101f = (10, 1, "f")
sm_103 = (10, 3, "")
sm_103a = (10, 3, "a")
sm_103f = (10, 3, "f")
sm_110 = (11, 0, "")
sm_110a = (11, 0, "a")
sm_110f = (11, 0, "f")
@@ -80,6 +83,9 @@ class Arch(Enum):
Arch.sm_101,
Arch.sm_101a,
Arch.sm_101f,
Arch.sm_103,
Arch.sm_103a,
Arch.sm_103f,
Arch.sm_110,
Arch.sm_110a,
Arch.sm_110f,

View File

@@ -529,6 +529,7 @@ def _get_self_module():
return inspect.getmodule(_get_self_module)
@lru_cache(maxsize=16)
def cf_symbol_check(symbol):
"""
Check if the symbol is control flow symbol from current module.

View File

@@ -1344,6 +1344,25 @@ class DSLPreprocessor(ast.NodeTransformer):
),
node,
)
elif func.id == "super" and node.args == [] and node.keywords == []:
# If it's a Python3 argument free super(), rewrite to old style super with args
# So if this call is under dynamic control flow, it still works.
return ast.copy_location(
ast.Call(
func=func,
args=node.args
+ [
ast.Attribute(
value=ast.Name(id="self", ctx=ast.Load()),
attr="__class__",
ctx=ast.Load(),
),
ast.Name(id="self", ctx=ast.Load()),
],
keywords=node.keywords,
),
node,
)
elif isinstance(func, ast.Attribute) and isinstance(func.value, ast.Name):
def create_downcast_call(arg):
@@ -1853,21 +1872,41 @@ class DSLPreprocessor(ast.NodeTransformer):
# Handle elif case
elif_node = node.orelse[0]
nested_if_name = elif_region_name
# Recursion for nested elif
nested_if = self.create_if_function(
nested_if_name, elif_node, write_args, full_write_args_count
)
else_block = ast.FunctionDef(
name=else_block_name,
args=func_then_else_arguments,
body=[
nested_if,
ast.Return(
value=ast.Name(id=nested_if_name, ctx=ast.Load())
),
],
decorator_list=[],
)
# AST cannot distinguish between the following two cases:
# elif pred:
# and
# else:
# if pred:
# And under both cases, the `pred` can be a const_expr, so we need to handle it here.
if self.is_node_constexpr(elif_node):
self.generic_visit(elif_node)
check = self._insert_cf_symbol_check(elif_node.test.func)
else_block = ast.FunctionDef(
name=else_block_name,
args=func_then_else_arguments,
body=[
check,
elif_node,
ast.Return(value=return_list),
],
decorator_list=[],
)
else:
# Recursion for nested elif
nested_if = self.create_if_function(
nested_if_name, elif_node, write_args, full_write_args_count
)
else_block = ast.FunctionDef(
name=else_block_name,
args=func_then_else_arguments,
body=[
nested_if,
ast.Return(
value=ast.Name(id=nested_if_name, ctx=ast.Load())
),
],
decorator_list=[],
)
else:
else_body = []
for stmt in node.orelse:

View File

@@ -356,7 +356,11 @@ class GPUArch(StringCompileOption):
class EnableTVMFFI(EmptyCompileOption):
option_name = "enable-tvm-ffi"
pass
class DumpDir(EmptyCompileOption):
option_name = "dump-dir"
class CompileOptions:
@@ -380,6 +384,7 @@ class CompileOptions:
GPUArch: GPUArch(""),
LinkLibraries: LinkLibraries(""),
EnableTVMFFI: EnableTVMFFI(False),
DumpDir: DumpDir(""),
}
if options is not None:
@@ -416,19 +421,24 @@ class CompileOptions:
if self.options[GPUArch].value == ""
else self.options[GPUArch].value
)
dump_dir = (
envar.dump_dir
if self.options[DumpDir].value == ""
else self.options[DumpDir].value
)
if self.options[KeepPTX].value:
self.options[KeepPTX].dump_path = os.path.join(
envar.dump_dir, f"{function_name}"
dump_dir, f"{function_name}"
)
self.options[KeepPTX].full_ptx_path = os.path.join(
envar.dump_dir, f"{function_name}.{arch}.ptx"
dump_dir, f"{function_name}.{arch}.ptx"
)
if self.options[KeepCUBIN].value:
self.options[KeepCUBIN].dump_path = os.path.join(
envar.dump_dir, f"{function_name}"
dump_dir, f"{function_name}"
)
self.options[KeepCUBIN].full_cubin_path = os.path.join(
envar.dump_dir, f"{function_name}.{arch}.cubin"
dump_dir, f"{function_name}.{arch}.cubin"
)
@property
@@ -504,6 +514,7 @@ def _parse_compile_options_from_str(options: str) -> CompileOptions:
"keep_ptx": KeepPTX,
"gpu_arch": GPUArch,
"enable_tvm_ffi": EnableTVMFFI,
"dump_dir": DumpDir,
}
return mapping[option_str]
@@ -520,6 +531,7 @@ def _parse_compile_options_from_str(options: str) -> CompileOptions:
parser.add_argument("--ptxas-options", type=str, default="")
parser.add_argument("--gpu-arch", type=str, default="")
parser.add_argument("--enable-tvm-ffi", action="store_true", default=False)
parser.add_argument("--dump-dir", type=str, default="")
compile_options = CompileOptions()
try:
# Use shlex to properly handle options with spaces
@@ -545,7 +557,7 @@ class CompileCallable:
def __init__(self, options=None):
def preprocess_options(option):
if type(option) is type and issubclass(
option, (BooleanCompileOption, BooleanBasedFileDumpOption)
option, (BooleanCompileOption, BooleanBasedFileDumpOption, EnableTVMFFI)
):
# Automatically creates a True instance of the option
return option(True)

View File

@@ -53,7 +53,7 @@ from .runtime.jit_arg_adapters import is_argument_constexpr, JitArgAdapterRegist
from .ast_preprocessor import DSLPreprocessor
from .common import *
from .typing import get_c_pointers, get_mlir_types, Integer, arg_compatible_with_tvm_ffi
from .typing import get_c_pointers, get_mlir_types, Integer
from .arch import Arch
# =============================================================================
@@ -810,10 +810,6 @@ class BaseDSL:
if is_host:
if self.envar.enable_tvm_ffi:
if not arg_compatible_with_tvm_ffi(arg):
raise DSLRuntimeError(
f"Argument #{i + 1} ({arg_name}) is not a TVM FFI argument."
)
jit_exec_arg.extend([arg])
else:
jit_exec_arg.extend(get_c_pointers(arg))
@@ -1218,6 +1214,8 @@ class BaseDSL:
no_cache,
func_type=JitCompiledFunction,
*,
full_args=None,
full_kwargs=None,
dynamic_args=None,
dynamic_kwargs=None,
original_function_name=None,
@@ -1388,6 +1386,8 @@ class BaseDSL:
pipeline,
args_spec,
no_cache,
full_args=args,
full_kwargs=kwargs,
dynamic_args=dynamic_args,
dynamic_kwargs=dynamic_kwargs,
original_function_name=original_function_name,
@@ -1400,7 +1400,7 @@ class BaseDSL:
module_hash,
)
jit_function = self.jit_cache[module_hash]
finally:
self.post_compilation_cleanup()

View File

@@ -235,27 +235,20 @@ def get_prefix_dsl_libs(prefix: str):
return prefix_libs_existing
def get_libs_cand(start):
target_libs = {
"mlir_c_runner_utils",
"mlir_runner_utils",
"mlir_cuda_runtime",
target_cuda_dialect_libs = {
"cuda_dialect_runtime",
}
lib_folder_guesses = [
"lib",
]
libs_cand = find_libs_in_ancestors(start, target_libs, lib_folder_guesses)
optional_libs_cand = find_libs_in_ancestors(
start, {"cuda_dialect_runtime"}, lib_folder_guesses
)
if libs_cand:
dsl_libs = ":".join(libs_cand)
if optional_libs_cand:
dsl_libs += ":" + ":".join(optional_libs_cand)
return dsl_libs
for target_libs in [
target_cuda_dialect_libs,
]:
libs_cand = find_libs_in_ancestors(start, target_libs, lib_folder_guesses)
if libs_cand:
dsl_libs = ":".join(libs_cand)
return dsl_libs
return None
# find from install folder

View File

@@ -216,6 +216,69 @@ class ExecutionArgs:
return exe_args, adapted_args
def get_rectified_args_from_original_args(self, full_args, full_kwargs):
"""
This function is used to rectify the original arguments to the runtime
arguments that matched the original args_spec.
:param full_args: The original full arguments to filter.
:param full_kwargs: The original full keyword arguments to filter.
:return: The filtered arguments and keyword arguments.
"""
arg_spec = self.original_args_spec
if arg_spec.defaults:
defaults_start_idx = len(arg_spec.args) - len(arg_spec.defaults)
else:
defaults_start_idx = len(arg_spec.args)
runtime_args = []
# Filter arguments and maintain their properties
for i, arg_name in enumerate(arg_spec.args):
arg_type = arg_spec.annotations.get(arg_name, None)
# Skip compile-time arguments
if is_arg_spec_constexpr(arg_type, arg_name, i, self.function_name):
continue
# Keep corresponding default if it exists
if i >= defaults_start_idx:
default_idx = i - defaults_start_idx
runtime_args.append(arg_spec.defaults[default_idx])
else:
runtime_args.append(full_args[i])
# Filter keyword-only arguments
runtime_kwargs = {}
if arg_spec.kwonlyargs:
for i, kwarg in enumerate(arg_spec.kwonlyargs):
arg_type = arg_spec.annotations.get(kwarg, None)
# Skip compile-time arguments
if is_arg_spec_constexpr(arg_type, kwarg, i, self.function_name):
continue
# Keep runtime keyword-only arguments
if kwarg in full_kwargs:
runtime_kwargs[kwarg] = full_kwargs[kwarg]
elif arg_spec.kwonlydefaults and kwarg in arg_spec.kwonlydefaults:
runtime_kwargs[kwarg] = arg_spec.kwonlydefaults[kwarg]
if (len(runtime_args) != len(self.args_spec.args) or
len(runtime_kwargs) != len(self.args_spec.kwonlyargs)):
raise DSLRuntimeError(
"input args/kwargs length does not match runtime function signature!",
context={
"input args length": len(runtime_args),
"input kwargs length": len(runtime_kwargs),
"function signature args length": len(self.args_spec.args),
"function signature kwonlyargs length": len(self.args_spec.kwonlyargs),
},
)
return runtime_args + list(runtime_kwargs.values())
def filter_runtime_arg_spec(self, arg_spec: inspect.FullArgSpec):
runtime_args = []
runtime_annotations = {}
@@ -562,6 +625,13 @@ class JitCompiledFunction:
kernel_modules[sym] = CudaModuleAndKernel(sym, cubin_module, kernel, attrs)
return list(kernel_modules.values())
def _validate_engine(self):
if self.engine is None:
raise DSLRuntimeError(
"The compiled function does not have a valid execution engine.",
suggestion="For cross-compilation, please use `cute.export.export_to_c` to serialize the compiled function and load/execute it on target device.",
)
def to(self, device=None) -> JitExecutor:
"""Returns an executable function bound to the given device.
@@ -573,6 +643,7 @@ class JitCompiledFunction:
:return: A callable executor function.
:rtype: JitExecutor
"""
self._validate_engine()
with self._executor_lock:
# We need to ensure that the modules are loaded if not already
if self.jit_module is None:
@@ -621,4 +692,4 @@ class JitCompiledFunction:
# object alive as it hold a reference to self.
proxy_self = weakref.proxy(self)
self._default_executor = proxy_self.to(None)
self._default_executor.run_compiled_program(exe_args)
return self._default_executor.run_compiled_program(exe_args)

View File

@@ -20,6 +20,18 @@ from ..._mlir.dialects import llvm
from .tvm_ffi_builder import CallContext, CallProvider, TVMFFIBuilder
def _flatten_tuple_params(params: list[spec.Param]) -> list[spec.Param]:
"""Recursively flatten TupleParam into list of params."""
flattened = []
for param in params:
if isinstance(param, spec.TupleParam):
# Recursively flatten nested tuples
flattened.extend(_flatten_tuple_params(param.params))
else:
flattened.append(param)
return flattened
class NopCallProvider(CallProvider):
"""No-op call provider for testing purposes."""
@@ -53,7 +65,11 @@ class DynamicParamPackCallProvider(CallProvider, TVMFFIBuilder):
"""
def __init__(
self, target_func: str, include_num_args: bool = False, struct_call: bool = False,
self,
target_func: str,
include_num_args: bool = False,
struct_call: bool = False,
flatten_tuple_params: bool = True,
) -> None:
import tvm_ffi
@@ -61,8 +77,12 @@ class DynamicParamPackCallProvider(CallProvider, TVMFFIBuilder):
self.target_func = target_func
self.include_num_args = include_num_args
self.struct_call = struct_call
self.flatten_tuple_params = flatten_tuple_params
self.float4x2_dtype = tvm_ffi.dtype("float4_e2m1fnx2")
if not self.flatten_tuple_params:
raise RuntimeError("flatten_tuple_params=False is not supported yet")
def get_callee_struct_for_param_tensor(
self,
param: spec.Tensor,
@@ -157,8 +177,15 @@ class DynamicParamPackCallProvider(CallProvider, TVMFFIBuilder):
self, current_block: ir.Block, context: CallContext
) -> list[tuple[ir.Type, ir.Value]]:
"""Pack a parameter to a struct."""
# Flatten TupleParam into list of params if enabled
if self.flatten_tuple_params:
flattened_params = _flatten_tuple_params(context.params)
else:
flattened_params = context.params
# Pack each parameter
packed_params = []
for param in context.params:
for param in flattened_params:
if isinstance(param, spec.Tensor):
packed_params.append(
self.pack_param_tensor(current_block, context, param)
@@ -177,6 +204,9 @@ class DynamicParamPackCallProvider(CallProvider, TVMFFIBuilder):
packed_params.append(
self.pack_param_var(current_block, context, param.var)
)
elif isinstance(param, spec.ConstNone):
# const none is not packed
continue
else:
raise NotImplementedError(f"Unsupported parameter type: {type(param)}")
return packed_params

View File

@@ -302,7 +302,10 @@ class MLIRBuilder(MLIRTypeBuilder):
cond: ir.Value,
true_block: ir.Block,
false_block: ir.Block,
branch_weights=None
*,
branch_weights=None,
true_dest_operands: Sequence[ir.Value] = (),
false_dest_operands: Sequence[ir.Value] = (),
) -> None:
"""Create a conditional branch.
@@ -318,6 +321,10 @@ class MLIRBuilder(MLIRTypeBuilder):
Optional branch weights [true_weight, false_weight] for optimization hints.
Higher values indicate higher probability. For example, (99, 1) indicates
the true branch is much more likely than the false branch.
true_dest_operands : Sequence[ir.Value]
Operands to pass to the true destination block.
false_dest_operands : Sequence[ir.Value]
Operands to pass to the false destination block.
"""
if branch_weights is not None:
# Branch weights should be a tuple/list of two integers [true_weight, false_weight]
@@ -325,8 +332,8 @@ class MLIRBuilder(MLIRTypeBuilder):
raise ValueError("branch_weights must have exactly 2 elements")
llvm.cond_br(
cond,
true_dest_operands=[],
false_dest_operands=[],
true_dest_operands=true_dest_operands,
false_dest_operands=false_dest_operands,
true_dest=true_block,
false_dest=false_block,
branch_weights=ir.DenseI32ArrayAttr.get(list(branch_weights))
@@ -334,8 +341,8 @@ class MLIRBuilder(MLIRTypeBuilder):
else:
llvm.cond_br(
cond,
true_dest_operands=[],
false_dest_operands=[],
true_dest_operands=true_dest_operands,
false_dest_operands=false_dest_operands,
true_dest=true_block,
false_dest=false_block,
)
@@ -401,6 +408,17 @@ class MLIRBuilder(MLIRTypeBuilder):
)
func_op.attributes["llvm.linkage"] = ir.StringAttr.get("external")
def create_alloca(self, entry_block: ir.Block, alloca_type: ir.Type, array_size: int) -> ir.Value:
"""Create an alloca operation."""
with ir.InsertionPoint(entry_block.operations[0]):
# declare the struct type
alloca = llvm.alloca(
res=self.ptr_type,
elem_type=alloca_type,
array_size=self.i32(array_size),
)
return alloca
def pack_values_to_alloca(
self,
current_block: ir.Block,
@@ -423,14 +441,11 @@ class MLIRBuilder(MLIRTypeBuilder):
tuple[ir.Type, ir.Value]
The struct type and the alloca.
"""
with ir.InsertionPoint(entry_block.operations[0]):
# declare the struct type
struct_type = self.struct_type(fields=[value.type for value in values])
alloca = llvm.alloca(
res=self.ptr_type,
elem_type=struct_type,
array_size=self.i32(1),
)
# Declare the struct type from the values
struct_type = self.struct_type(fields=[value.type for value in values])
# Create alloca using the helper method
alloca = self.create_alloca(entry_block, struct_type, array_size=1)
with ir.InsertionPoint(current_block):
for index, value in enumerate(values):

View File

@@ -136,7 +136,10 @@ class Shape(Param):
name: str
shape: list[Union[int, Var]]
def __init__(self, name: str, shape: list[Union[int, Var]]) -> None:
def __init__(
self,
name: str, shape: list[Union[int, Var]],
) -> None:
"""Initialize a Shape parameter.
Parameters
@@ -146,6 +149,9 @@ class Shape(Param):
shape : list[int | Var]
The shape of the parameter.
unpack_shape: bool
Whether to unpack the shape into list of arguments when calling
the call provider function.
"""
self.name = name
self.shape = shape
@@ -301,6 +307,106 @@ class DataPointer(Param):
self.address_space = address_space
class ConstNone(Param):
"""ConstNone parameter.
Parameters
----------
name : str
The parameter name.
"""
name: str
def __init__(self, name: str) -> None:
"""Initialize a ConstExpr parameter.
Parameters
----------
name : str
The parameter name.
"""
self.name = name
class TupleParam(Param):
"""Tuple parameter.
Parameters
----------
name : str
The parameter name.
"""
name: str
params: list[Param]
def __init__(self, name: str, params: list[Param]) -> None:
"""Initialize a TupleParam parameter.
Parameters
----------
name : str
The parameter name.
params : list[Param]
The parameters of the tuple.
"""
self.name = name
self.params = params
def format_param_type(param: Param) -> str:
"""Format a parameter type as a string, recursively handling nested types.
Parameters
----------
param : Param
The parameter to format.
Returns
-------
str
The formatted type string.
Raises
------
TypeError
If an unsupported parameter type is encountered.
"""
if isinstance(param, Var):
return str(param.dtype)
elif isinstance(param, Tensor):
# Format tensor shape
shape_strs = []
for dim in param.shape:
if isinstance(dim, Var):
shape_strs.append(dim.name)
else:
shape_strs.append(str(dim))
shape_str = "[" + ", ".join(shape_strs) + "]"
return f"Tensor({shape_str}, {param.dtype})"
elif isinstance(param, Shape):
# Format shape parameter
shape_strs = []
for dim in param.shape:
if isinstance(dim, Var):
shape_strs.append(dim.name)
else:
shape_strs.append(str(dim))
shape_str = "[" + ", ".join(shape_strs) + "]"
return f"Shape({shape_str})"
elif isinstance(param, Stream):
return "Stream"
elif isinstance(param, DataPointer):
return "DataPointer"
elif isinstance(param, ConstNone):
return "None"
elif isinstance(param, TupleParam):
# Recursively format tuple elements
element_types = [format_param_type(p) for p in param.params]
return f"Tuple[{', '.join(element_types)}]"
else:
raise TypeError(f"Unsupported parameter type: {type(param)}")
def signature(name: str, params: list[Param]) -> str:
"""Generate a function signature string from name and parameters.
@@ -325,39 +431,13 @@ def signature(name: str, params: list[Param]) -> str:
param_strs = []
for param in params:
if isinstance(param, Var):
param_str = f"{param.name}: {param.dtype}"
elif isinstance(param, Tensor):
# Format tensor shape
shape_strs = []
for dim in param.shape:
if isinstance(dim, Var):
shape_strs.append(dim.name)
else:
shape_strs.append(str(dim))
shape_str = "[" + ", ".join(shape_strs) + "]"
param_str = f"{param.name}: Tensor({shape_str}, {param.dtype})"
elif isinstance(param, Shape):
# Format shape parameter
shape_strs = []
for dim in param.shape:
if isinstance(dim, Var):
shape_strs.append(dim.name)
else:
shape_strs.append(str(dim))
shape_str = "[" + ", ".join(shape_strs) + "]"
param_str = f"{param.name}: Shape({shape_str})"
elif isinstance(param, Stream):
param_str = f"{param.name}: Stream"
elif isinstance(param, DataPointer):
param_str = f"{param.name}: DataPointer"
elif isinstance(param, EnvStream):
if isinstance(param, EnvStream):
# env stream is not part of the FFI function signature
# continue to skip append
continue
else:
raise TypeError(f"Unsupported parameter type: {type(param)}")
param_type = format_param_type(param)
param_str = f"{param.name}: {param_type}"
param_strs.append(param_str)
return f"{name}({', '.join(param_strs)})"

View File

@@ -27,6 +27,64 @@ from .mlir_builder import MLIRBuilder
from dataclasses import dataclass
@dataclass
class ArgContext:
"""Context information for parameter decoding error messages.
:ivar param_name: The name of the parameter.
:vartype param_name: str
:ivar arg_index: The index of the argument in the function call.
:vartype arg_index: int
:ivar tuple_indices: List of tuple indices for nested tuple access (e.g., [0, 1] for tuple[0][1]).
:vartype tuple_indices: list[int]
"""
param_name: str
arg_index: int
tuple_indices: list[int]
def get(self) -> list[str]:
"""Get the context as a list of strings for error messages.
:returns: Context strings like ["on argument ", "#0"] or ["on my_tuple[0][1] in argument ", "#0"].
:rtype: list[str]
"""
if not self.tuple_indices:
# Top-level argument: "on argument #0"
return ["on argument ", f"#{self.arg_index}"]
else:
# Nested tuple element: "on my_tuple[0][1] in argument #0"
indices_str = "".join(f"[{i}]" for i in self.tuple_indices)
return [f"on {self.param_name}{indices_str} in argument ", f"#{self.arg_index}"]
def get_field_name(self, field_suffix: str) -> str:
"""Get the field name with tuple indices for shape/stride access.
:param field_suffix: The field suffix (e.g., ".shape", ".strides").
:type field_suffix: str
:returns: Field name like "my_param.shape" or "my_tuple[0][1].shape".
:rtype: str
"""
if not self.tuple_indices:
return f"{self.param_name}{field_suffix}"
else:
indices_str = "".join(f"[{i}]" for i in self.tuple_indices)
return f"{self.param_name}{indices_str}{field_suffix}"
def get_element_context(self, element_index: int) -> "ArgContext":
"""Create a nested context for a tuple element.
:param element_index: The index within the tuple.
:type element_index: int
:returns: New context for the nested element.
:rtype: ArgContext
"""
return ArgContext(
param_name=self.param_name,
arg_index=self.arg_index,
tuple_indices=self.tuple_indices + [element_index],
)
@dataclass
class CallContext:
"""Call context that contains the information of the call."""
@@ -119,7 +177,7 @@ class TVMFFIBuilder(MLIRBuilder):
super().__init__()
# this is a number we can tune to minimize the register size
# it is 6 by default to minimize the register size
self.set_raised_from_cstr_parts_max_num_parts = 6
self.set_raised_from_cstr_parts_max_num_parts = 8
self.set_raised_from_cstr_parts_cache: dict[int, str] = {}
self.tvm_ffi_any_type = self.struct_type(
name="TVMFFIAny",
@@ -488,26 +546,28 @@ class TVMFFIBuilder(MLIRBuilder):
) -> ir.Value:
"""Downcast i64 to lower bits."""
overflow_flags = llvm.IntegerOverflowFlags.none
if (hasattr(tvm_ffi._dtype.DataTypeCode, "BOOL") and
target_dtype.type_code == tvm_ffi._dtype.DataTypeCode.BOOL):
# LLVM use i1 (boolean) for boolean
return llvm.icmp(llvm.ICmpPredicate.ne, v_int64, self.i64(0))
if target_dtype.bits == 64:
result = v_int64
elif target_dtype.bits == 32:
result = llvm.trunc(
return v_int64
if target_dtype.bits == 32:
return llvm.trunc(
res=self.i32_type, arg=v_int64, overflow_flags=overflow_flags
)
elif target_dtype.bits == 16:
result = llvm.trunc(
if target_dtype.bits == 16:
return llvm.trunc(
res=self.i16_type, arg=v_int64, overflow_flags=overflow_flags
)
elif target_dtype.bits == 8:
result = llvm.trunc(
if target_dtype.bits == 8:
return llvm.trunc(
res=self.i8_type, arg=v_int64, overflow_flags=overflow_flags
)
elif target_dtype.bits == 1:
if target_dtype.bits == 1:
# For i1 (boolean), convert i64 to boolean by checking if non-zero
result = llvm.icmp(llvm.ICmpPredicate.ne, v_int64, self.i64(0))
else:
raise ValueError(f"Unsupported Var dtype: {target_dtype}")
return result
return llvm.icmp(llvm.ICmpPredicate.ne, v_int64, self.i64(0))
raise ValueError(f"Unsupported Var dtype: {target_dtype}")
def is_contiguous(
self,
@@ -767,8 +827,40 @@ class TVMFFIFunctionBuilder(TVMFFIBuilder):
self.matched_var_binding = {}
self.matched_var_source = {}
def find_or_declare_extern_func(
self, name: str, params: Sequence[ir.Type], ret: ir.Type
) -> None:
"""Find an existing extern function or declare it if it doesn't exist.
This method checks if a function with the given name already exists in the module.
If it does, the method returns without doing anything. Otherwise, it declares
the function as an external function.
Parameters
----------
name : str
The name of the extern function.
params : Sequence[ir.Type]
The parameter types of the function.
ret : ir.Type
The return type of the function.
"""
# Check if the function already exists
existing_func = self.find_func_in_module(self.module, name)
if existing_func is not None:
# Function already declared, nothing to do
return
# Function doesn't exist, declare it
self.declare_extern_func(name, params, ret)
def decode_param_int(
self, current_block: ir.Block, param: spec.Var, args: ir.Value, arg_index: int
self,
current_block: ir.Block,
param: spec.Var,
args: ir.Value,
arg_index: int,
arg_context: ArgContext,
) -> ir.Block:
"""Decode the integer parameter at the given index."""
# read the type index
@@ -785,8 +877,8 @@ class TVMFFIFunctionBuilder(TVMFFIBuilder):
lambda: is_int_or_bool,
"TypeError",
[
"Mismatched type on argument ",
f"#{arg_index}",
"Mismatched type ",
*arg_context.get(),
self._fn_call_context,
", expected int",
],
@@ -801,14 +893,19 @@ class TVMFFIFunctionBuilder(TVMFFIBuilder):
param,
v_int64,
[
"value on argument ",
f"#{arg_index}",
"value ",
*arg_context.get(),
self._fn_call_context,
],
)
def decode_param_float(
self, current_block: ir.Block, param: spec.Var, args: ir.Value, arg_index: int
self,
current_block: ir.Block,
param: spec.Var,
args: ir.Value,
arg_index: int,
arg_context: ArgContext,
) -> ir.Block:
"""Decode the float parameter at the given index."""
# read the type index
@@ -891,8 +988,8 @@ class TVMFFIFunctionBuilder(TVMFFIBuilder):
self.raise_error_and_return(
"TypeError",
[
"Mismatched type on argument ",
f"#{arg_index}",
"Mismatched type ",
*arg_context.get(),
self._fn_call_context,
", expected float",
],
@@ -912,6 +1009,7 @@ class TVMFFIFunctionBuilder(TVMFFIBuilder):
param: spec.Var,
args: ir.Value,
arg_index: int,
arg_context: ArgContext,
*,
allow_int_as_ptr: bool = False,
address_space: Optional[int] = None,
@@ -941,8 +1039,8 @@ class TVMFFIFunctionBuilder(TVMFFIBuilder):
lambda: is_opaque_ptr_or_nullptr,
"TypeError",
[
"Mismatched type on argument ",
f"#{arg_index}",
"Mismatched type ",
*arg_context.get(),
self._fn_call_context,
expect_message,
],
@@ -959,6 +1057,37 @@ class TVMFFIFunctionBuilder(TVMFFIBuilder):
return current_block
def decode_param_const_none(
self,
current_block: ir.Block,
param: spec.ConstNone,
args: ir.Value,
arg_index: int,
arg_context: ArgContext,
) -> ir.Block:
"""Decode the opaque handle parameter at the given index."""
# read the type index
with ir.InsertionPoint(current_block):
type_index: ir.Value = self.load_ffi_any_array_item_type_index(args, arg_index)
# Check if type is a nullptr
is_nullptr = self.equal(type_index, self.i32(TVMFFITypeIndex.kTVMFFINone))
expect_message = ", expected None"
# Break error message into reusable parts for better string deduplication
current_block = self.check_condition(
current_block,
lambda: is_nullptr,
"TypeError",
[
"Mismatched type ",
*arg_context.get(),
self._fn_call_context,
expect_message,
],
)
return current_block
def check_int_value_dtype_bound(
self,
current_block: ir.Block,
@@ -1105,9 +1234,16 @@ class TVMFFIFunctionBuilder(TVMFFIBuilder):
*error_msg_context,
f", expected to be {var}"
]
def check_value_mismatch() -> ir.Value:
cond = self.equal(value, expected_value)
if skip_check_predicate is not None:
cond = self.or_(skip_check_predicate, cond)
return cond
return self.check_condition(
current_block,
lambda: self.equal(value, expected_value),
check_value_mismatch,
error_kind,
error_msg_mismatch
)
@@ -1117,17 +1253,18 @@ class TVMFFIFunctionBuilder(TVMFFIBuilder):
current_block: ir.Block,
var: Union[spec.Var, int],
value: ir.Value,
field: str,
arg_index: int,
field_suffix: str,
shape_index: int,
arg_context: ArgContext,
*,
skip_check_predicate: Optional[ir.Value] = None,
) -> ir.Block:
"""Load the shape value from the argument or match the shape value from the parameter."""
field_name = arg_context.get_field_name(field_suffix)
error_msg = [
field,
f"[{shape_index}] on argument ",
f"#{arg_index}",
field_name,
f"[{shape_index}] ",
*arg_context.get(),
self._fn_call_context,
]
return self.set_or_check_matched_var_binding(
@@ -1135,7 +1272,11 @@ class TVMFFIFunctionBuilder(TVMFFIBuilder):
)
def decode_param_shape_from_ffi_array(
self, current_block: ir.Block, param: spec.Shape, arg_index: int, array_cell: ir.Value
self,
current_block: ir.Block,
param: spec.Shape,
arg_context: ArgContext,
array_cell: ir.Value,
) -> tuple[ir.Block, list[ir.Value]]:
"""Decode the shape parameter from the TVMFFIArrayCell."""
with ir.InsertionPoint(current_block):
@@ -1148,8 +1289,8 @@ class TVMFFIFunctionBuilder(TVMFFIBuilder):
lambda: self.equal(array_size, self.i64(len(param.shape))),
"ValueError",
[
"Mismatched Shape on argument ",
f"#{arg_index}",
"Mismatched Shape ",
*arg_context.get(),
self._fn_call_context,
f", expected shape size={len(param.shape)}",
],
@@ -1157,35 +1298,49 @@ class TVMFFIFunctionBuilder(TVMFFIBuilder):
# Load and validate each element of the array
load_shapes = []
for i in range(len(param.shape)):
with ir.InsertionPoint(current_block):
type_index: ir.Value = self.load_ffi_any_array_item_type_index(array_data, i)
def validate_and_load_shape_element(
block: ir.Block, index: int
) -> tuple[ir.Block, ir.Value]:
"""Validate and load a single shape element from the array."""
with ir.InsertionPoint(block):
type_index: ir.Value = self.load_ffi_any_array_item_type_index(array_data, index)
# Check if type is int or bool (both use v_int64, bool can be converted to int)
is_int = self.equal(type_index, self.i32(TVMFFITypeIndex.kTVMFFIInt))
is_int_val = self.equal(type_index, self.i32(TVMFFITypeIndex.kTVMFFIInt))
# Check that the element is an integer
current_block = self.check_condition(
current_block,
lambda: is_int,
field_name = arg_context.get_field_name("")
block = self.check_condition(
block,
lambda: is_int_val,
"TypeError",
[
f"Invalid shape element type ",
f"{param.name}[{i}]",
f" on argument ",
f"#{arg_index}",
"Invalid shape element type ",
f"{field_name}[{index}]",
" ",
*arg_context.get(),
self._fn_call_context,
f", expected int",
", expected int",
],
)
with ir.InsertionPoint(current_block):
v_int64: ir.Value = self.load_ffi_any_array_item_v_int64(array_data, i)
load_shapes.append(v_int64)
with ir.InsertionPoint(block):
v_int64: ir.Value = self.load_ffi_any_array_item_v_int64(array_data, index)
return block, v_int64
for i in range(len(param.shape)):
current_block, v_int64 = validate_and_load_shape_element(current_block, i)
load_shapes.append(v_int64)
return (current_block, load_shapes)
def decode_param_shape_from_ffi_shape(
self, current_block: ir.Block, param: spec.Shape, arg_index: int, shape_cell: ir.Value
self,
current_block: ir.Block,
param: spec.Shape,
arg_context: ArgContext,
shape_cell: ir.Value,
) -> tuple[ir.Block, list[ir.Value]]:
"""Decode the shape parameter from the TVMFFIShapeCell."""
with ir.InsertionPoint(current_block):
@@ -1198,8 +1353,8 @@ class TVMFFIFunctionBuilder(TVMFFIBuilder):
lambda: self.equal(shape_size, self.i64(len(param.shape))),
"ValueError",
[
"Mismatched Shape on argument ",
f"#{arg_index}",
"Mismatched Shape ",
*arg_context.get(),
self._fn_call_context,
f", expected shape size={len(param.shape)}",
],
@@ -1211,7 +1366,12 @@ class TVMFFIFunctionBuilder(TVMFFIBuilder):
return (current_block, load_shapes)
def decode_param_shape(
self, current_block: ir.Block, param: spec.Shape, args: ir.Value, arg_index: int
self,
current_block: ir.Block,
param: spec.Shape,
args: ir.Value,
arg_index: int,
arg_context: ArgContext,
) -> ir.Block:
"""Decode the shape parameter at the given index."""
with ir.InsertionPoint(current_block):
@@ -1238,7 +1398,7 @@ class TVMFFIFunctionBuilder(TVMFFIBuilder):
self.load_ffi_any_array_item_v_ptr(args, arg_index)
)
ffi_shape_block, load_shapes = self.decode_param_shape_from_ffi_shape(
ffi_shape_block, param, arg_index, shape_cell
ffi_shape_block, param, arg_context, shape_cell
)
with ir.InsertionPoint(ffi_shape_block):
self.br(subsequent_block, args=load_shapes)
@@ -1256,7 +1416,7 @@ class TVMFFIFunctionBuilder(TVMFFIBuilder):
self.load_ffi_any_array_item_v_ptr(args, arg_index)
)
ffi_array_block, load_shapes = self.decode_param_shape_from_ffi_array(
ffi_array_block, param, arg_index, array_cell_ptr
ffi_array_block, param, arg_context, array_cell_ptr
)
with ir.InsertionPoint(ffi_array_block):
self.br(subsequent_block, args=load_shapes)
@@ -1267,8 +1427,8 @@ class TVMFFIFunctionBuilder(TVMFFIBuilder):
self.raise_error_and_return(
"TypeError",
[
"Mismatched type on argument ",
f"#{arg_index}",
"Mismatched type ",
*arg_context.get(),
self._fn_call_context,
", expected ffi.Shape or ffi.Array",
],
@@ -1279,7 +1439,7 @@ class TVMFFIFunctionBuilder(TVMFFIBuilder):
shape_values = list(subsequent_block.arguments)
for i, dim in enumerate(param.shape):
subsequent_block = self.set_or_check_matched_var_binding_from_shape(
subsequent_block, dim, shape_values[i], f"{param.name}", arg_index, i
subsequent_block, dim, shape_values[i], "", i, arg_context
)
return subsequent_block
@@ -1290,6 +1450,7 @@ class TVMFFIFunctionBuilder(TVMFFIBuilder):
param: spec.Tensor,
args: ir.Value,
arg_index: int,
arg_context: ArgContext,
) -> tuple[ir.Block, ir.Value]:
"""Decode tensor step0: check index and find out DLTensor*."""
# read the type index
@@ -1333,8 +1494,8 @@ class TVMFFIFunctionBuilder(TVMFFIBuilder):
self.raise_error_and_return(
"TypeError",
[
"Mismatched type on argument ",
f"#{arg_index}",
"Mismatched type ",
*arg_context.get(),
self._fn_call_context,
", expected Tensor",
],
@@ -1352,10 +1513,11 @@ class TVMFFIFunctionBuilder(TVMFFIBuilder):
param: spec.Tensor,
args: ir.Value,
arg_index: int,
arg_context: ArgContext,
) -> ir.Block:
"""Decode the tensor parameter at the given index."""
current_block, dl_tensor_ptr = self.decode_param_tensor_dltensor_ptr(
current_block, param, args, arg_index
current_block, param, args, arg_index, arg_context
)
with ir.InsertionPoint(current_block):
data = self.load_dltensor_data_ptr(dl_tensor_ptr)
@@ -1381,8 +1543,8 @@ class TVMFFIFunctionBuilder(TVMFFIBuilder):
check_alignment,
"ValueError",
[
"Misaligned Tensor data on argument ",
f"#{arg_index}",
"Misaligned Tensor data ",
*arg_context.get(),
self._fn_call_context,
f", expected data alignment={param.data_alignment} bytes",
],
@@ -1401,8 +1563,8 @@ class TVMFFIFunctionBuilder(TVMFFIBuilder):
lambda: self.equal(ndim, self.i32(expected_ndim)),
"ValueError",
[
"Mismatched Tensor on argument ",
f"#{arg_index}",
"Mismatched Tensor ",
*arg_context.get(),
self._fn_call_context,
f", expected ndim={expected_ndim}",
],
@@ -1414,8 +1576,8 @@ class TVMFFIFunctionBuilder(TVMFFIBuilder):
lambda: self.equal(device_type, self.i32(param.dlpack_device_type)),
"ValueError",
[
"Mismatched Tensor on argument ",
f"#{arg_index}",
"Mismatched Tensor ",
*arg_context.get(),
self._fn_call_context,
f", expected device_type={param.device_type_name}",
],
@@ -1437,8 +1599,8 @@ class TVMFFIFunctionBuilder(TVMFFIBuilder):
dtype_equal,
"ValueError",
[
"Mismatched Tensor on argument ",
f"#{arg_index}",
"Mismatched Tensor ",
*arg_context.get(),
self._fn_call_context,
f", expected dtype={param.dtype}",
],
@@ -1450,8 +1612,8 @@ class TVMFFIFunctionBuilder(TVMFFIBuilder):
lambda: self.equal(byte_offset, self.i64(0)),
"ValueError",
[
"Mismatched Tensor on argument ",
f"#{arg_index}",
"Mismatched Tensor ",
*arg_context.get(),
self._fn_call_context,
", expected byte_offset=0",
],
@@ -1474,9 +1636,9 @@ class TVMFFIFunctionBuilder(TVMFFIBuilder):
current_block,
param.shape[index],
load_shapes[index],
f"{param.name}.shape",
arg_index,
".shape",
index,
arg_context,
)
if param.strides is not None:
@@ -1490,9 +1652,9 @@ class TVMFFIFunctionBuilder(TVMFFIBuilder):
current_block,
param.strides[index],
load_strides[index],
f"{param.name}.strides",
arg_index,
".strides",
index,
arg_context,
skip_check_predicate=skip_check_predicate,
)
else:
@@ -1502,8 +1664,8 @@ class TVMFFIFunctionBuilder(TVMFFIBuilder):
lambda: self.is_contiguous(param.shape, load_shapes, load_strides),
"ValueError",
[
"Mismatched Tensor on argument ",
f"#{arg_index}",
"Mismatched Tensor ",
*arg_context.get(),
self._fn_call_context,
", expected contiguous",
],
@@ -1516,11 +1678,12 @@ class TVMFFIFunctionBuilder(TVMFFIBuilder):
param: spec.Stream,
args: ir.Value,
arg_index: int,
arg_context: ArgContext,
) -> ir.Block:
"""Decode the stream parameter at the given index."""
# stream is decoded as opaque handle
return self.decode_param_opaque_handle(
current_block, param.var, args, arg_index
current_block, param.var, args, arg_index, arg_context
)
def decode_param_data_pointer(
@@ -1529,6 +1692,7 @@ class TVMFFIFunctionBuilder(TVMFFIBuilder):
param: spec.DataPointer,
args: ir.Value,
arg_index: int,
arg_context: ArgContext,
) -> ir.Block:
"""Decode the data pointer parameter at the given index."""
# data pointer is decoded as opaque handle
@@ -1537,6 +1701,7 @@ class TVMFFIFunctionBuilder(TVMFFIBuilder):
param.var,
args,
arg_index,
arg_context,
allow_int_as_ptr=True,
address_space=param.address_space,
)
@@ -1578,35 +1743,117 @@ class TVMFFIFunctionBuilder(TVMFFIBuilder):
expected_num_args += 1
return expected_num_args
def decode_param( # noqa: PLR0911
self, current_block: ir.Block, param: spec.Param, args: ir.Value, arg_index: int
def decode_param_tuple(
self,
current_block: ir.Block,
param: spec.TupleParam,
args: ir.Value,
arg_index: int,
arg_context: ArgContext,
) -> ir.Block:
"""Decode the parameter at the given index."""
"""Decode the tuple parameter at the given index."""
# Check if type is kTVMFFIArray
with ir.InsertionPoint(current_block):
type_index: ir.Value = self.load_ffi_any_array_item_type_index(args, arg_index)
is_ffi_array = self.equal(type_index, self.i32(TVMFFITypeIndex.kTVMFFIArray))
# Check that the type is an array
current_block = self.check_condition(
current_block,
lambda: is_ffi_array,
"TypeError",
[
"Mismatched type ",
*arg_context.get(),
self._fn_call_context,
", expected ffi.Array for tuple",
],
)
# Load the array cell
with ir.InsertionPoint(current_block):
array_cell_ptr = self.get_object_cell_ptr(
self.load_ffi_any_array_item_v_ptr(args, arg_index)
)
array_data = self.load_array_cell_data_ptr(array_cell_ptr)
array_size = self.load_array_cell_size_as_i64(array_cell_ptr)
# Check that the array size matches the expected tuple size
current_block = self.check_condition(
current_block,
lambda: self.equal(array_size, self.i64(len(param.params))),
"ValueError",
[
"Mismatched tuple size ",
*arg_context.get(),
self._fn_call_context,
f", expected tuple size={len(param.params)}",
],
)
# Recursively decode each element of the tuple
for i, tuple_param in enumerate(param.params):
# Create nested context for the tuple element
nested_context = arg_context.get_element_context(i)
current_block = self.decode_param(current_block, tuple_param, array_data, i, nested_context)
return current_block
def decode_param( # noqa: PLR0911
self,
current_block: ir.Block,
param: spec.Param,
args: ir.Value,
arg_index: int,
arg_context: ArgContext,
) -> ir.Block:
"""Decode the parameter at the given index.
Parameters
----------
current_block : ir.Block
The current IR block.
param : spec.Param
The parameter specification to decode.
args : ir.Value
The FFI arguments array.
arg_index : int
The index in the args array.
arg_context : ArgContext
Context information for error messages.
"""
if isinstance(param, spec.Var):
if param.dtype.type_code == tvm_ffi._dtype.DataTypeCode.INT:
return self.decode_param_int(current_block, param, args, arg_index)
return self.decode_param_int(current_block, param, args, arg_index, arg_context)
elif param.dtype.type_code == tvm_ffi._dtype.DataTypeCode.UINT:
# UINT uses the same logic as INT since both are stored in v_int64
return self.decode_param_int(current_block, param, args, arg_index)
return self.decode_param_int(current_block, param, args, arg_index, arg_context)
elif (hasattr(tvm_ffi._dtype.DataTypeCode, "BOOL") and
param.dtype.type_code == tvm_ffi._dtype.DataTypeCode.BOOL):
return self.decode_param_int(current_block, param, args, arg_index, arg_context)
elif param.dtype.type_code == tvm_ffi._dtype.DataTypeCode.FLOAT:
return self.decode_param_float(current_block, param, args, arg_index)
return self.decode_param_float(current_block, param, args, arg_index, arg_context)
elif param.dtype.type_code == tvm_ffi._dtype.DataTypeCode.HANDLE:
return self.decode_param_opaque_handle(
current_block, param, args, arg_index
current_block, param, args, arg_index, arg_context
)
else:
raise ValueError(f"Unsupported parameter type: {param.dtype.type_code}")
elif isinstance(param, spec.Shape):
return self.decode_param_shape(current_block, param, args, arg_index)
return self.decode_param_shape(current_block, param, args, arg_index, arg_context)
elif isinstance(param, spec.Tensor):
return self.decode_param_tensor(current_block, param, args, arg_index)
return self.decode_param_tensor(current_block, param, args, arg_index, arg_context)
elif isinstance(param, spec.Stream):
return self.decode_param_stream(current_block, param, args, arg_index)
return self.decode_param_stream(current_block, param, args, arg_index, arg_context)
elif isinstance(param, spec.EnvStream):
# decode of env stream is deferred after we go through all parameters
return current_block
elif isinstance(param, spec.DataPointer):
return self.decode_param_data_pointer(current_block, param, args, arg_index)
return self.decode_param_data_pointer(current_block, param, args, arg_index, arg_context)
elif isinstance(param, spec.ConstNone):
return self.decode_param_const_none(current_block, param, args, arg_index, arg_context)
elif isinstance(param, spec.TupleParam):
return self.decode_param_tuple(current_block, param, args, arg_index, arg_context)
else:
raise ValueError(f"Unsupported parameter type: {type(param)}")
@@ -1652,20 +1899,20 @@ class TVMFFIFunctionBuilder(TVMFFIBuilder):
with ir.InsertionPoint(self.module.body): # type: ignore[union-attr]
# void TVMFFIErrorSetRaisedFromCStr(
# const char* error_kind, const char* message);
self.declare_extern_func(
self.find_or_declare_extern_func(
"TVMFFIErrorSetRaisedFromCStr",
[self.ptr_type, self.ptr_type],
self.void_type,
)
# void TVMFFIErrorSetRaisedFromCStrParts(
# const char* error_kind, const char* messages, int32_t num_parts);
self.declare_extern_func(
self.find_or_declare_extern_func(
"TVMFFIErrorSetRaisedFromCStrParts",
[self.ptr_type, self.ptr_type, self.i32_type],
self.void_type,
)
# void* TVMFFIEnvGetStream(int32_t device_type, int32_t device_id);
self.declare_extern_func(
self.find_or_declare_extern_func(
"TVMFFIEnvGetStream",
[self.i32_type, self.i32_type],
self.ptr_type,
@@ -1697,7 +1944,12 @@ class TVMFFIFunctionBuilder(TVMFFIBuilder):
# decode parameters to populate the matched var binding
for arg_index, param in enumerate(params_list):
current_block = self.decode_param(current_block, param, args, arg_index)
arg_context = ArgContext(
param_name=param.name,
arg_index=arg_index,
tuple_indices=[],
)
current_block = self.decode_param(current_block, param, args, arg_index, arg_context)
with ir.InsertionPoint(current_block):
env_stream = self.find_env_stream(params_list)

View File

@@ -239,19 +239,6 @@ def get_c_pointers(obj):
return []
def arg_compatible_with_tvm_ffi(arg):
"""
Given the `arg`, check if it is a compatible argument for TVM FFI
"""
import tvm_ffi
return (
hasattr(arg, "__tvm_ffi_object__")
or isinstance(arg, (int, float, bool))
or isinstance(arg, tvm_ffi.Shape)
)
def get_mlir_types(obj):
"""
Given the `obj`, recursively go through it to extract all contained MLIR types

View File

@@ -205,6 +205,7 @@ KeepCUBIN = _dsl.KeepCUBIN
KeepPTX = _dsl.KeepPTX
GPUArch = _dsl.GPUArch
LinkLibraries = _dsl.LinkLibraries
EnableTVMFFI = _dsl.EnableTVMFFI
# attach the TVM FFI ABI interface postprocessor to the DSL
from . import _tvm_ffi_args_spec_converter

View File

@@ -42,7 +42,7 @@ from .typing import (
)
import cuda.bindings.driver as cuda
from typing import List, Dict, Any, Optional
from typing import List, Dict, Any, Optional, Tuple, get_origin, get_args
import inspect
NumericToTVMFFIDtype = {
@@ -91,6 +91,11 @@ def _get_llvm_address_space_from_memspace(
return 1
return None
def _is_gpu_memspace(
memspace: _cute_ir.AddressSpace,
) -> bool:
return memspace != _cute_ir.AddressSpace.generic
class SymIntId:
def __init__(self, sym_int: SymInt):
@@ -103,118 +108,187 @@ class SymIntId:
return self.sym_int is other.sym_int
def _tvm_ffi_args_spec_converter(
function_name: str,
args_spec: inspect.FullArgSpec,
dynamic_args: List[Any],
dynamic_kwargs: Dict[str, Any],
):
"""Convert cute algebra args to tvm ffi spec params.
This function converts the cute arguments specs to tvm ffi spec params.
"""
exec_args = ExecutionArgs(args_spec, function_name)
rectified_args = exec_args.get_rectified_args(dynamic_args, dynamic_kwargs)
arg_names = exec_args.args_spec.args + exec_args.args_spec.kwonlyargs
class ConverterContext:
"""Context for managing variable allocation during TVM FFI args conversion."""
params = []
num_dyn_shape_vars = 0
num_dyn_stride_vars = 0
sym_int_id_mapping = {}
def __init__(self):
self.num_dyn_shape_vars = 0
self.num_dyn_stride_vars = 0
self.sym_int_id_mapping = {}
def alloc_shape_name():
nonlocal num_dyn_shape_vars
name = f"n{num_dyn_shape_vars}"
num_dyn_shape_vars += 1
def alloc_shape_name(self) -> str:
"""Allocate a new dynamic shape variable name."""
name = f"n{self.num_dyn_shape_vars}"
self.num_dyn_shape_vars += 1
return name
def alloc_stride_name():
nonlocal num_dyn_stride_vars
name = f"s{num_dyn_stride_vars}"
num_dyn_stride_vars += 1
def alloc_stride_name(self) -> str:
"""Allocate a new dynamic stride variable name."""
name = f"s{self.num_dyn_stride_vars}"
self.num_dyn_stride_vars += 1
return name
def alloc_or_reuse_symint_var(value, name_alloc_func):
nonlocal sym_int_id_mapping
def alloc_or_reuse_symint_var(self, value: SymInt, name_alloc_func):
"""Allocate or reuse a symbolic integer variable."""
sym_int_id = SymIntId(value)
if sym_int_id in sym_int_id_mapping:
return sym_int_id_mapping[sym_int_id]
if sym_int_id in self.sym_int_id_mapping:
return self.sym_int_id_mapping[sym_int_id]
name = name_alloc_func()
if value.width == 32:
dtype = NumericToTVMFFIDtype[Int32]
else:
dtype = NumericToTVMFFIDtype[Int64]
var = spec.Var(name, dtype, divisibility=value.divisibility)
sym_int_id_mapping[sym_int_id] = var
self.sym_int_id_mapping[sym_int_id] = var
return var
for arg, arg_name in zip(rectified_args, arg_names):
arg_type = args_spec.annotations.get(arg_name, None)
if isinstance(arg, Numeric) and arg.dtype in AcceptableNumericTypesForScalar:
params.append(spec.Var(arg_name, NumericToTVMFFIDtype[arg.dtype]))
elif is_cute_algebra_type(arg_type):
shape = []
for i in range(len(arg)):
if isinstance(arg[i], int):
shape.append(arg[i])
elif isinstance(arg[i], SymInt):
shape.append(alloc_or_reuse_symint_var(arg[i], alloc_shape_name))
else:
shape.append(spec.Var(alloc_shape_name(), NumericToTVMFFIDtype[arg[i].dtype]))
params.append(spec.Shape(arg_name, shape))
elif isinstance(arg, Tensor):
shapes = []
for i, dyn_mask in enumerate(arg.dynamic_shapes_mask):
if not dyn_mask:
shapes.append(arg.shape[i])
elif isinstance(arg.shape[i], SymInt):
shapes.append(alloc_or_reuse_symint_var(arg.shape[i], alloc_shape_name))
else:
shapes.append(spec.Var(alloc_shape_name(), NumericToTVMFFIDtype[Int32]))
strides = []
for i, dyn_mask in enumerate(arg.dynamic_strides_mask):
if not dyn_mask:
strides.append(arg.stride[i])
elif isinstance(arg.stride[i], SymInt):
strides.append(alloc_or_reuse_symint_var(arg.stride[i], alloc_stride_name))
else:
if hasattr(arg, "_use_32bit_stride") and arg._use_32bit_stride:
dtype = NumericToTVMFFIDtype[Int32]
else:
dtype = NumericToTVMFFIDtype[Int64]
strides.append(spec.Var(alloc_stride_name(), dtype))
def _convert_single_arg(
arg,
arg_name: str,
arg_type,
ctx: ConverterContext
) -> spec.Param:
"""Convert a single argument to a spec.Param.
Parameters
----------
arg : Any
The argument value to convert.
arg_name : str
The name of the argument.
arg_type : type
The type annotation of the argument.
ctx : ConverterContext
The converter context for managing variable allocation.
Returns
-------
spec.Param
The converted parameter specification.
"""
if arg is None:
return spec.ConstNone(arg_name)
elif (isinstance(arg, Numeric) and arg.dtype in AcceptableNumericTypesForScalar):
return spec.Var(arg_name, NumericToTVMFFIDtype[arg.dtype])
elif arg_type in AcceptableNumericTypesForScalar:
return spec.Var(arg_name, NumericToTVMFFIDtype[arg_type])
elif is_cute_algebra_type(arg_type):
shape = []
for i in range(len(arg)):
if isinstance(arg[i], int):
shape.append(arg[i])
elif isinstance(arg[i], SymInt):
shape.append(ctx.alloc_or_reuse_symint_var(arg[i], ctx.alloc_shape_name))
else:
shape.append(spec.Var(ctx.alloc_shape_name(), NumericToTVMFFIDtype[arg[i].dtype]))
return spec.Shape(arg_name, shape)
elif isinstance(arg, Tensor):
shapes = []
for i, dyn_mask in enumerate(arg.dynamic_shapes_mask):
if not dyn_mask:
shapes.append(arg.shape[i])
elif isinstance(arg.shape[i], SymInt):
shapes.append(ctx.alloc_or_reuse_symint_var(arg.shape[i], ctx.alloc_shape_name))
else:
shapes.append(spec.Var(ctx.alloc_shape_name(), NumericToTVMFFIDtype[Int32]))
strides = []
for i, dyn_mask in enumerate(arg.dynamic_strides_mask):
if not dyn_mask:
strides.append(arg.stride[i])
elif isinstance(arg.stride[i], SymInt):
strides.append(ctx.alloc_or_reuse_symint_var(arg.stride[i], ctx.alloc_stride_name))
else:
if hasattr(arg, "_use_32bit_stride") and arg._use_32bit_stride:
dtype = NumericToTVMFFIDtype[Int32]
else:
dtype = NumericToTVMFFIDtype[Int64]
strides.append(spec.Var(ctx.alloc_stride_name(), dtype))
if hasattr(arg, "_tvm_ffi_tensor"):
tvm_ffi_tensor = arg._tvm_ffi_tensor
dtype = tvm_ffi_tensor.dtype
tvm_ffi_cute_tensor = spec.Tensor(
arg_name,
shapes,
arg._tvm_ffi_tensor.dtype,
strides=strides,
data_alignment=arg._assumed_align,
device_type=tvm_ffi_tensor.device.type
)
else:
# for FakeTensor, strictly follow the shape and stride from the cute tensor
device_type = "cuda" if _is_gpu_memspace(arg.memspace) else "cpu"
tvm_ffi_cute_tensor = spec.Tensor(
arg_name,
shapes,
NumericToTVMFFIDtype[arg.element_type],
strides=strides,
data_alignment=arg._assumed_align,
device_type=device_type,
)
if arg.element_type == Float4E2M1FN:
tvm_ffi_cute_tensor = spec.create_map_tensor_dtype_f4x2_to_f4_spec(
tvm_ffi_cute_tensor
)
params.append(tvm_ffi_cute_tensor)
elif isinstance(arg, Pointer):
address_space = None
if hasattr(arg, "memspace"):
address_space = _get_llvm_address_space_from_memspace(arg.memspace)
params.append(spec.DataPointer(arg_name, address_space=address_space))
elif isinstance(arg, _FakeStream):
if arg.use_tvm_ffi_env_stream:
params.append(spec.EnvStream(arg_name))
else:
params.append(spec.Stream(arg_name))
elif isinstance(arg, cuda.CUstream):
params.append(spec.Stream(arg_name))
return tvm_ffi_cute_tensor
elif isinstance(arg, Pointer) or arg_type == Pointer:
address_space = None
if hasattr(arg, "memspace"):
address_space = _get_llvm_address_space_from_memspace(arg.memspace)
return spec.DataPointer(arg_name, address_space=address_space)
elif isinstance(arg, _FakeStream):
if arg.use_tvm_ffi_env_stream:
return spec.EnvStream(arg_name)
else:
raise DSLRuntimeError(f"Unsupported argument type: {type(arg)}")
# The following code can obtain signature of the function
# that maybe useful for future debugging and usecases.
# signature = spec.signature(function_name, params)
return spec.Stream(arg_name)
elif isinstance(arg, cuda.CUstream):
return spec.Stream(arg_name)
elif arg_type is not None and get_origin(arg_type) is tuple:
# Handle Tuple[X, Y, ...] type annotations
tuple_element_types = get_args(arg_type)
if not isinstance(arg, (tuple, list)):
raise DSLRuntimeError(f"Expected tuple for argument {arg_name}, got {type(arg)}")
if len(arg) != len(tuple_element_types):
raise DSLRuntimeError(
f"Tuple length mismatch for argument {arg_name}: "
f"expected {len(tuple_element_types)}, got {len(arg)}"
)
# Recursively convert each tuple element
tuple_params = []
for i, (elem, elem_type) in enumerate(zip(arg, tuple_element_types)):
elem_name = f"{arg_name}[{i}]"
elem_param = _convert_single_arg(elem, elem_name, elem_type, ctx)
tuple_params.append(elem_param)
return spec.TupleParam(arg_name, tuple_params)
else:
raise DSLRuntimeError(f"Unsupported argument type: {type(arg)}")
def _tvm_ffi_args_spec_converter(
function_name: str,
args_spec: inspect.FullArgSpec,
full_args: List[Any],
full_kwargs: Dict[str, Any],
):
"""Convert cute algebra args to tvm ffi spec params.
This function converts the cute arguments specs to tvm ffi spec params.
"""
exec_args = ExecutionArgs(args_spec, function_name)
rectified_args = exec_args.get_rectified_args_from_original_args(full_args, full_kwargs)
arg_names = exec_args.args_spec.args + exec_args.args_spec.kwonlyargs
params = []
ctx = ConverterContext()
for arg, arg_name in zip(rectified_args, arg_names):
arg_type = args_spec.annotations.get(arg_name, None)
param = _convert_single_arg(arg, arg_name, arg_type, ctx)
params.append(param)
return params

View File

@@ -209,6 +209,7 @@ def mbarrier_conditional_try_wait(
def mbarrier_arrive(
mbar_ptr: Pointer,
peer_cta_rank_in_cluster: Optional[Int] = None,
arrive_count: Int = 1,
*,
loc=None,
ip=None,
@@ -239,7 +240,7 @@ def mbarrier_arrive(
nvvm.mbarrier_txn(
mbar_llvm_ptr,
Int32(1).ir_value(loc=loc, ip=ip),
Int32(arrive_count).ir_value(loc=loc, ip=ip),
kind=nvvm.MBarrierTxnKind.ARRIVE,
space=space,
loc=loc,

View File

@@ -201,13 +201,13 @@ class MmaOp(Tcgen05MmaOp):
if self.cta_group == CtaGroup.ONE:
if m not in [64, 128]:
raise OpError(self, f"expects the M-mode to be 64 or 128, but got {m}")
if m == 64:
if (n < 8) or (n > 256) or (n % 8 != 0):
if self.b_dtype.width == 8 and self.b_major_mode == OperandMajorMode.MN:
if (n < 16) or (n > 256) or (n % 16 != 0):
raise OpError(
self,
f"expects the N-mode to satisfy 8 <= N <= 256 and N % 8 == 0, but got {n}",
f"expects the N-mode to satisfy 16 <= N <= 256 and N % 16 == 0, but got {n}",
)
elif m == 128:
else:
if (n < 8) or (n > 256) or (n % 8 != 0):
raise OpError(
self,
@@ -216,11 +216,18 @@ class MmaOp(Tcgen05MmaOp):
else:
if m not in [128, 256]:
raise OpError(self, f"expects the M-mode to be 128 or 256, but got {m}")
if (n < 16) or (n > 256) or (n % 16 != 0):
raise OpError(
self,
f"expects the N-mode to satisfy 16 <= N <= 256 and N % 16 == 0, but got {n}",
)
if self.b_dtype.width == 8 and self.b_major_mode == OperandMajorMode.MN:
if (n < 32) or (n > 256) or (n % 32 != 0):
raise OpError(
self,
f"expects the N-mode to satisfy 32 <= N <= 256 and N % 32 == 0, but got {n}",
)
else:
if (n < 16) or (n > 256) or (n % 16 != 0):
raise OpError(
self,
f"expects the N-mode to satisfy 16 <= N <= 256 and N % 16 == 0, but got {n}",
)
def __str__(self) -> str:
return (
@@ -302,6 +309,7 @@ class BlockScaledMmaOp(Tcgen05MmaOp):
admissible_archs = [
Arch.sm_100a,
Arch.sm_103a,
]
def __post_init__(self) -> None:

View File

@@ -10,7 +10,6 @@
# is strictly prohibited.
import ctypes
import os
import sys
from pathlib import Path
from functools import lru_cache
@@ -20,6 +19,7 @@ from typing import Union, Optional, Type, List
# MLIR modules imports
from cutlass._mlir import ir
from cutlass.base_dsl.env_manager import get_prefix_dsl_libs
import cutlass._mlir.dialects.cute as _cute_ir
import cutlass._mlir.dialects.cuda as _cuda_dialect
@@ -144,6 +144,7 @@ class _Tensor(Tensor):
self._tvm_ffi_tensor = tvm_ffi.from_dlpack(tensor)
self._dlpack_data = self._tvm_ffi_tensor.__dlpack__()
self._dltensor_wrapper = None
self._assumed_align = assumed_align
self._is_dynamic = False
@@ -387,7 +388,16 @@ class _Tensor(Tensor):
return CoreTensor(values[0].value, self._dtype)
def __tvm_ffi_object__(self):
return self._tvm_ffi_tensor
try:
return self._tvm_ffi_tensor
except AttributeError:
raise DSLRuntimeError(
(
"runtime._Tensor is not a TVM-FFI tensor. "
"Enable TVM-FFI with `from_dlpack(..., enable_tvm_ffi=True)` "
"or `CUTE_DSL_ENABLE_TVM_FFI=1`."
)
)
def _get_cute_type_str(inp):
@@ -411,7 +421,8 @@ class _FakeCompactTensor(Tensor):
self._dtype = dtype
self._shape = shape
self._stride_order = stride_order or tuple(range(len(shape)))
self._memspace = memspace or AddressSpace.gmem
# cannot use memspace or AddressSpace.gmem because AddressSpace.generic is 0
self._memspace = memspace if memspace is not None else AddressSpace.gmem
self._assumed_align = assumed_align or -(-dtype.width // 8)
self._use_32bit_stride = use_32bit_stride
@@ -510,7 +521,8 @@ class _FakeTensor(Tensor):
self._dtype = dtype
self._shape = shape
self._stride = stride
self._memspace = memspace or AddressSpace.generic
# cannot use memspace or AddressSpace.generic because AddressSpace.generic is 0
self._memspace = memspace if memspace is not None else AddressSpace.gmem
self._assumed_align = assumed_align
if assumed_align is None:
# use the bytes width of the element dtype. The alignment is at least one byte align.
@@ -605,7 +617,7 @@ def make_fake_compact_tensor(
:param shape: Shape of the tensor.
:type shape: tuple[int, ...]
:param stride_order: Order in which strides (memory layout) are assigned to the tensor dimensions.
If None, the default layout is row-major. Otherwise, it should be a permutation of the dimension indices.
If None, the default layout is col-major. Otherwise, it should be a permutation of the dimension indices.
:type stride_order: tuple[int, ...], optional
:param memspace: Memory space where the fake tensor resides. Optional.
:type memspace: str, optional
@@ -644,6 +656,7 @@ def make_fake_compact_tensor(
use_32bit_stride=use_32bit_stride,
)
def make_fake_tensor(dtype, shape, stride, *, memspace=None, assumed_align=None):
"""
Create a fake tensor with the specified element type, shape, and stride.
@@ -859,21 +872,22 @@ def find_runtime_libraries(*, enable_tvm_ffi: bool = True) -> List[str]:
"""
def _get_cuda_dialect_runtime_path():
libs = os.environ.get("CUTE_DSL_LIBS")
if libs:
sep = ";" if sys.platform.startswith("win32") else ":"
for path in libs.split(sep):
if path.endswith("libcuda_dialect_runtime.so"):
return path
try:
# find package library from wheel package
pkg_base = Path(__file__).resolve().parent.parent
lib_path = pkg_base / "lib" / "libcuda_dialect_runtime.so"
if lib_path.is_file():
return str(lib_path)
except OSError:
libs = get_prefix_dsl_libs("CUTE_DSL")
if libs is None:
return None
# check if the separator is ; for windows
if sys.platform.startswith("win32") and ";" in libs:
libs = libs.split(";")
else:
libs = libs.split(":")
for path in libs:
if path.endswith("libcuda_dialect_runtime.so"):
return path
return None
libs = []
cuda_dialect_runtime_path = _get_cuda_dialect_runtime_path()
if cuda_dialect_runtime_path:

View File

@@ -20,6 +20,7 @@ from typing import Type, Union, Callable, Optional, Dict, List, Any
import cuda.bindings.driver as cuda_driver
import cuda.bindings.runtime as cuda_runtime
import cutlass
import cutlass.base_dsl.jit_executor
from cutlass.cutlass_dsl import Constexpr, CuTeDSL, T, dsl_user_op
@@ -233,6 +234,7 @@ def convert(src: cute.Tensor, dst: cute.Tensor):
src.shape[leading_mode] % elem_per_copy == 0
and dst.shape[leading_mode] % elem_per_copy == 0
)
_convert(src, dst, leading_mode, elem_per_copy)

View File

@@ -52,6 +52,7 @@ from ..base_dsl.compiler import (
KeepPTX,
GPUArch,
LinkLibraries,
EnableTVMFFI,
)
from ..base_dsl.runtime.jit_arg_adapters import *

View File

@@ -21,7 +21,12 @@ import cuda.bindings.runtime as cuda_runtime
import cuda.bindings.driver as cuda_driver
# Local modules imports
from ..base_dsl.jit_executor import JitExecutor, ExecutionArgs, JitFunctionArtifacts
from ..base_dsl.jit_executor import (
JitExecutor,
JitCompiledFunction,
ExecutionArgs,
JitFunctionArtifacts,
)
from ..base_dsl.utils.logger import log
from ..base_dsl.common import DSLCudaRuntimeError, DSLRuntimeError
from ..base_dsl.typing import Int32
@@ -57,7 +62,7 @@ class CudaDialectJitModule:
self.unload()
class CudaDialectJitCompiledFunction:
class CudaDialectJitCompiledFunction(JitCompiledFunction):
"""Holds a compiled function and its module."""
def __init__(
@@ -103,21 +108,6 @@ class CudaDialectJitCompiledFunction:
self._executor_lock = threading.RLock()
self._default_executor = None
@property
def __ptx__(self):
"""Returns the PTX code of the JIT-compiled function."""
return self.artifacts.PTX if self.artifacts is not None else None
@property
def __cubin__(self):
"""Returns the CUBIN data of the JIT-compiled function."""
return self.artifacts.CUBIN if self.artifacts is not None else None
@property
def __mlir__(self):
"""Returns the MLIR code of the JIT-compiled function."""
return self.artifacts.MLIR if self.artifacts is not None else None
@functools.cached_property
def num_devices(self):
"""Returns the number of CUDA devices available."""
@@ -285,6 +275,7 @@ class CudaDialectJitCompiledFunction:
:return: A callable executor function.
:rtype: JitExecutor
"""
super()._validate_engine()
with self._executor_lock:
# We need to ensure that the modules are loaded if not already
if self.jit_module is None or self.jit_module.is_unloaded():
@@ -297,37 +288,3 @@ class CudaDialectJitCompiledFunction:
)
return JitExecutor(self.jit_module, None, self.jit_time_profiling)
def generate_execution_args(self, *args, **kwargs):
return self.args_spec.generate_execution_args(args, kwargs)
def set_dynamic_args(self, dynamic_args, dynamic_kwargs):
"""Sets the dynamic argument information required for export to c code generation."""
self.dynamic_args = dynamic_args
self.dynamic_kwargs = dynamic_kwargs
def __call__(self, *args, **kwargs):
"""Executes the jit-compiled function under the currently active CUDA context.
Calling this method multiple devices is not allowed and will result in unexpected
CUDA errors. If you need to call the kernel on multiple devices use `to`
to return a per-device function.
"""
exe_args, adapted_args = self.generate_execution_args(*args, **kwargs)
return self.run_compiled_program(exe_args)
def run_compiled_program(self, exe_args):
"""Executes the jit-compiled function under the currently active CUDA context.
Calling this method multiple devices is not allowed and will result in unexpected
CUDA errors. If you need to call the kernel on multiple devices use `to`
to return a per-device function.
"""
with self._executor_lock:
if self._default_executor is None:
log().debug("Creating default executor.")
# We use a weak reference here so that this instance does not keep this
# object alive as it hold a reference to self.
proxy_self = weakref.proxy(self)
self._default_executor = proxy_self.to(None)
return self._default_executor.run_compiled_program(exe_args)

View File

@@ -42,3 +42,7 @@ class CudaDialectStreamAdapter:
def __get_mlir_types__(self):
return [cuda.StreamType.get()]
def __cuda_stream__(self):
# support cuda stream protocol
return (0, int(self._arg))

View File

@@ -32,6 +32,7 @@ import pkgutil
from dataclasses import is_dataclass, fields
from math import ceil
from itertools import chain
from pathlib import Path
from collections.abc import Sequence
import builtins
import ctypes
@@ -334,8 +335,15 @@ class CutlassBaseDSL(BaseDSL):
)
try:
# update the version hash of the cutlass shared library
giant_dso_name = str(
next(
(Path(dsl_path) / "_mlir" / "_mlir_libs").glob(
"_cutlass_ir.cpython*"
)
).name
)
with open(
os.path.join(dsl_path, "_mlir/_mlir_libs/libCutlassIRPythonCAPI.so"),
os.path.join(dsl_path, f"_mlir/_mlir_libs/{giant_dso_name}"),
"rb",
) as f:
while True:
@@ -345,7 +353,7 @@ class CutlassBaseDSL(BaseDSL):
version_hash.update(chunk)
except Exception:
raise DSLRuntimeError(
"Failed to read the shared library file libCutlassIRPythonCAPI.so."
f"Failed to read the shared library file {giant_dso_name}."
"The file may not exist or may not be readable."
"Please re-install the package."
)
@@ -386,6 +394,8 @@ class CutlassBaseDSL(BaseDSL):
args_spec,
no_cache,
*,
full_args=None,
full_kwargs=None,
dynamic_args=None,
dynamic_kwargs=None,
original_function_name=None,
@@ -399,6 +409,8 @@ class CutlassBaseDSL(BaseDSL):
:param pipeline: The pipeline to use for compilation.
:param args_spec: The args spec to use for compilation.
:param no_cache: Whether to cache the result.
:param full_args: The full arguments to use for compilation.
:param full_kwargs: The full keyword arguments to use for compilation.
:param dynamic_args: The dynamic arguments to use for compilation.
:param dynamic_kwargs: The dynamic keyword arguments to use for compilation.
:param original_function_name: The name of the original function without mangling.
@@ -415,7 +427,7 @@ class CutlassBaseDSL(BaseDSL):
assert self._tvm_ffi_args_spec_converter is not None
tvm_ffi_spec_params = self._tvm_ffi_args_spec_converter(
function_name, args_spec, dynamic_args, dynamic_kwargs
function_name, args_spec, full_args, full_kwargs
)
tvm_ffi_provider = TVMFFICuteCallProvider(function_name)
@@ -445,6 +457,8 @@ class CutlassBaseDSL(BaseDSL):
args_spec,
no_cache,
TVMFFIJitCompiledFunction,
full_args=full_args,
full_kwargs=full_kwargs,
dynamic_args=dynamic_args,
dynamic_kwargs=dynamic_kwargs,
)
@@ -457,6 +471,8 @@ class CutlassBaseDSL(BaseDSL):
args_spec,
no_cache,
CudaDialectJitCompiledFunction,
full_args=full_args,
full_kwargs=full_kwargs,
dynamic_args=dynamic_args,
dynamic_kwargs=dynamic_kwargs,
original_function_name=original_function_name,

View File

@@ -17,18 +17,24 @@ from cutlass.base_dsl.tvm_ffi_builder import (
)
from cutlass._mlir import ir
from cutlass._mlir.dialects import llvm
from cutlass._mlir._mlir_libs import _execution_engine_extra
from cutlass._mlir._mlir_libs._cutlass_ir import _execution_engine_extra
from cutlass.cutlass_dsl.cuda_jit_executor import CudaDialectJitCompiledFunction
from cutlass.base_dsl.common import DSLRuntimeError
from typing import Optional
import tvm_ffi
class TVMFFICuteCallProvider(DynamicParamPackCallProvider):
"""Cute call provider that uses cute call convention."""
cuda_device_index: Optional[ir.Value]
cuda_error_handle_block: Optional[ir.Block]
def __init__(self, target_func: str):
super().__init__(target_func, struct_call=True)
self.cuda_global_state_symbol = f"__{target_func}_cuda_state"
self.cuda_device_index = None
self.cuda_error_handle_block = None
def get_callee_struct_for_param_tensor(
self,
@@ -41,7 +47,10 @@ class TVMFFICuteCallProvider(DynamicParamPackCallProvider):
) -> ir.Type:
"""Routine used to override the tensor passing struct convention"""
with ir.InsertionPoint(current_block):
data_type = self.gpu_ptr_type
if param.dlpack_device_type == tvm_ffi.DLDeviceType.kDLCPU:
data_type = self.ptr_type
else:
data_type = self.gpu_ptr_type
strides_type = (
self.struct_type(fields=[x.type for x in strides])
if len(strides) != 1
@@ -79,22 +88,27 @@ class TVMFFICuteCallProvider(DynamicParamPackCallProvider):
def declare_extern_funcs(self, current_block: ir.Block, context: CallContext):
"""Append the error handling function to the current block."""
with ir.InsertionPoint(context.module.body):
self.declare_extern_func(
context.builder.find_or_declare_extern_func(
"cuda_dialect_get_error_name",
[self.i32_type],
self.ptr_type,
)
self.declare_extern_func(
context.builder.find_or_declare_extern_func(
"_cudaGetDevice",
[self.ptr_type],
self.i32_type,
)
context.builder.find_or_declare_extern_func(
"_cudaSetDevice",
[self.i32_type],
self.i32_type,
)
self.declare_extern_func(
context.builder.find_or_declare_extern_func(
"cuda_dialect_init_library_once",
[self.ptr_type, self.ptr_type, self.ptr_type, self.ptr_type],
self.i32_type,
)
self.declare_extern_func(
context.builder.find_or_declare_extern_func(
"cuda_dialect_unload_library_once",
[self.ptr_type],
self.void_type,
@@ -116,7 +130,7 @@ class TVMFFICuteCallProvider(DynamicParamPackCallProvider):
self.cuda_global_state_symbol, self.ptr_type
)
cuda_init_ptr = self.address_of("cuda_init", self.ptr_type)
cuda_load_ptr = self.address_of("cuda_load", self.ptr_type)
cuda_load_to_device_ptr = self.address_of("cuda_load_to_device", self.ptr_type)
set_error_ptr = self.address_of(
"TVMFFIErrorSetRaisedFromCStr", self.ptr_type
)
@@ -127,7 +141,7 @@ class TVMFFICuteCallProvider(DynamicParamPackCallProvider):
callee_operands=[
cuda_global_state_ptr,
cuda_init_ptr,
cuda_load_ptr,
cuda_load_to_device_ptr,
set_error_ptr,
],
op_bundle_sizes=[],
@@ -207,34 +221,70 @@ class TVMFFICuteCallProvider(DynamicParamPackCallProvider):
def check_cuda_error(
self, code: ir.Value, current_block: ir.Block, context: CallContext
):
"""Check if the CUDA error is raised and return the error string if so."""
"""Check if the CUDA error is raised and return the error string if so.
Uses a shared error handling block to avoid code duplication. The error code
is passed as a block argument to the shared error handler.
"""
assert self.cuda_error_handle_block is not None
with ir.InsertionPoint(current_block):
# check if the call is successful
error_block = current_block.create_after()
success_block = error_block.create_after()
# Check if call is successful (non-zero return code)
success_block = current_block.create_after()
# Check if call is successful (zero return code means success)
self.cond_br(
cond=self.equal(code, self.i32(0)),
true_block=success_block,
false_block=error_block,
false_block=self.cuda_error_handle_block,
branch_weights=self.BRANCH_WEIGHTS_LIKELY,
false_dest_operands=[code], # Pass error code to shared error block
)
return success_block
def set_cuda_device_if_mismatch(
self,
current_block: ir.Block,
context: CallContext,
current_device: Optional[ir.Value],
target_device: Optional[ir.Value],
) -> ir.Block:
"""Set the CUDA device index if it differs from the target device.
"""
# If either device is None, no switching needed
if current_device is None:
assert target_device is None
return current_block
with ir.InsertionPoint(current_block):
# Check if devices are different
devices_differ = self.not_equal(current_device, target_device)
# Create blocks for conditional device switching
switch_device_block = current_block.create_after()
continuation_block = switch_device_block.create_after()
# For this specific case, avoid branch weights for now
# mainly to avoid too drastic reordering of the code
self.cond_br(
cond=devices_differ,
true_block=switch_device_block,
false_block=continuation_block
)
# Error block: raise error and return
with ir.InsertionPoint(error_block):
error_str = llvm.call(
result=self.ptr_type,
callee="cuda_dialect_get_error_name",
callee_operands=[code],
# Switch device block: call cudaSetDevice
with ir.InsertionPoint(switch_device_block):
result = llvm.call(
result=self.i32_type,
callee="_cudaSetDevice",
callee_operands=[target_device],
op_bundle_sizes=[],
op_bundle_operands=[],
)
# Raise error and return -1
context.builder.raise_error_and_return(
error_kind="RuntimeError",
error_message_parts=["CUDA Error: ", error_str],
)
return success_block
# Check for errors and branch to continuation
switch_device_block = self.check_cuda_error(result, switch_device_block, context)
with ir.InsertionPoint(switch_device_block):
self.br(continuation_block)
return continuation_block
def generate_llvm_call(
self,
@@ -243,6 +293,36 @@ class TVMFFICuteCallProvider(DynamicParamPackCallProvider):
context: CallContext,
) -> ir.Block:
"""Generate the LLVM call operation and check if the call is successful."""
old_cuda_device_index: Optional[ir.Value] = None
# If we need to manage CUDA device context
if self.cuda_device_index is not None:
# Create an alloca in the entry block to store the current device index
device_index_alloca = context.builder.create_alloca(
context.entry_block, self.i32_type, array_size=1
)
# Get the current device
with ir.InsertionPoint(current_block):
get_device_result = llvm.call(
result=self.i32_type,
callee="_cudaGetDevice",
callee_operands=[device_index_alloca],
op_bundle_sizes=[],
op_bundle_operands=[],
)
current_block = self.check_cuda_error(get_device_result, current_block, context)
# Load the current device index from the alloca
with ir.InsertionPoint(current_block):
old_cuda_device_index = llvm.load(self.i32_type, device_index_alloca)
# Switch to target device if different
current_block = self.set_cuda_device_if_mismatch(
current_block, context, old_cuda_device_index, self.cuda_device_index
)
# Execute the main call
with ir.InsertionPoint(current_block):
result = llvm.call(
result=self.i32_type,
@@ -251,42 +331,72 @@ class TVMFFICuteCallProvider(DynamicParamPackCallProvider):
op_bundle_sizes=[],
op_bundle_operands=[],
)
return self.check_cuda_error(result, current_block, context)
def insert_set_cuda_device(self, current_block: ir.Block, context: CallContext):
"""Call the _cudaSetDevice function if we can find device id from tensor parameters."""
# Restore the original device BEFORE checking for errors
# This ensures device is restored even if the main call failed
current_block = self.set_cuda_device_if_mismatch(
current_block, context, self.cuda_device_index, old_cuda_device_index
)
def find_cuda_device_index_from_params():
for param in context.params:
if (
isinstance(param, spec.Tensor)
and param.dlpack_device_type != tvm_ffi.DLDeviceType.kDLCPU
):
return context.matched_var_binding[param.device_id]
return None
# Now check for errors from the main call
current_block = self.check_cuda_error(result, current_block, context)
cuda_device_index = find_cuda_device_index_from_params()
return current_block
if cuda_device_index is None:
return current_block
with ir.InsertionPoint(current_block):
result = llvm.call(
result=self.i32_type,
callee="_cudaSetDevice",
callee_operands=[cuda_device_index],
def find_cuda_device_index_from_params(self, context: CallContext):
"""Find the CUDA device index from tensor parameters."""
for param in context.params:
if (
isinstance(param, spec.Tensor)
and param.dlpack_device_type != tvm_ffi.DLDeviceType.kDLCPU
):
return context.matched_var_binding[param.device_id]
return None
def create_shared_cuda_error_block(
self,
current_block: ir.Block,
context: CallContext
) -> ir.Block:
"""Create a shared error handling block for all CUDA errors.
"""
# Create the shared error block after the current block (setup phase)
# This block will be branched to from multiple error checking sites
# It accepts the error code as a block argument
error_block = current_block.create_after()
error_code = error_block.add_argument(self.i32_type, ir.Location.unknown())
# Populate the error block
with ir.InsertionPoint(error_block):
error_str = llvm.call(
result=self.ptr_type,
callee="cuda_dialect_get_error_name",
callee_operands=[error_code],
op_bundle_sizes=[],
op_bundle_operands=[],
)
# Raise error and return -1
context.builder.raise_error_and_return(
error_kind="RuntimeError",
error_message_parts=["CUDA Error: ", error_str],
)
return self.check_cuda_error(result, current_block, context)
return error_block
def __call__(self, current_block: ir.Block, context: CallContext) -> ir.Block:
current_block = self.declare_extern_funcs(current_block, context)
current_block = self.insert_lazy_init_cuda(current_block, context)
current_block = self.append_unload_to_global_dtors(current_block, context)
current_block = self.insert_set_cuda_device(current_block, context)
# Create shared CUDA error handling block after the setup blocks
# This reduces code duplication - all CUDA errors branch to this single block
self.cuda_error_handle_block = self.create_shared_cuda_error_block(current_block, context)
# setup device index, will be set around the call to the target function
self.cuda_device_index = self.find_cuda_device_index_from_params(context)
current_block = super().__call__(current_block, context)
self.cuda_device_index = None
self.cuda_error_handle_block = None
# reset the device index and error block
return current_block
@@ -316,12 +426,15 @@ class TVMFFIJitCompiledFunction(tvm_ffi.Function, CudaDialectJitCompiledFunction
if self.__chandle__() != 0:
raise DSLRuntimeError("TVM FFI function is already initialized")
# get the MLIR function pointer from the execution engine
tvm_ffi_function_ptr = self.engine.raw_lookup("__tvm_ffi_" + self.function_name)
tvm_ffi_function = tvm_ffi.Function.__from_mlir_packed_safe_call__(
tvm_ffi_function_ptr
)
# move the handle from the tvm_ffi.Function to the current instance
self.__move_handle_from__(tvm_ffi_function)
if self.engine is not None:
tvm_ffi_function_ptr = self.engine.raw_lookup(
"__tvm_ffi_" + self.function_name
)
tvm_ffi_function = tvm_ffi.Function.__from_mlir_packed_safe_call__(
tvm_ffi_function_ptr
)
# move the handle from the tvm_ffi.Function to the current instance
self.__move_handle_from__(tvm_ffi_function)
def to(self, device=None):
"""TVM FFI function itself is already support all devices."""

View File

@@ -165,6 +165,16 @@ def create_and_permute_torch_tensor(
return dtype_torch_tensor
def get_leading_dim(torch_tensor: torch.Tensor) -> int:
"""
Get the leading dimension of a torch tensor
"""
for i, stride in enumerate(torch_tensor.stride()):
if stride == 1:
return i
return None
def convert_cute_tensor(
f32_torch_tensor: "torch.Tensor",
cute_tensor: Tensor,
@@ -189,8 +199,10 @@ def convert_cute_tensor(
}:
fp32_cute_tensor = from_dlpack(f32_torch_tensor)
if is_dynamic_layout:
# note: dim_order to not always maps to leading dimension,
# so we need to get the leading dimension from the torch tensor strides
fp32_cute_tensor = fp32_cute_tensor.mark_layout_dynamic(
f32_torch_tensor.dim_order()[-1]
leading_dim=get_leading_dim(f32_torch_tensor)
)
# Copy and convert from f32 cute tensor to dtype cute tensor
cute.testing.convert(fp32_cute_tensor, cute_tensor)
@@ -297,11 +309,9 @@ def cute_tensor_like(
# create cute tensor using the device buffer
cute_tensor = from_dlpack(torch_tensor, assumed_align=assumed_align)
cute_tensor.element_type = cutlass_dtype
if is_dynamic_layout:
for i, stride in enumerate(torch_tensor.stride()):
if stride == 1:
leading_dim = i
break
leading_dim = get_leading_dim(torch_tensor)
cute_tensor = cute_tensor.mark_layout_dynamic(leading_dim=leading_dim)
# initialize the cute tensor data
@@ -316,5 +326,4 @@ def cute_tensor_like(
)
else:
torch_tensor.copy_(data_ref.to(dtype=torch_dtype))
return cute_tensor, torch_tensor

View File

@@ -664,6 +664,7 @@ def make_smem_layout_a(
a_dtype: Type[Numeric],
num_stages: int,
*,
is_k_major=None,
loc=None,
ip=None,
) -> Union[cute.Layout, cute.ComposedLayout]:
@@ -687,7 +688,8 @@ def make_smem_layout_a(
:rtype: Union[cute.Layout, cute.ComposedLayout]
"""
is_k_major = tiled_mma.op.a_major_mode == OperandMajorMode.K
is_k_major = (tiled_mma.op.a_major_mode == OperandMajorMode.K) if is_k_major is None else is_k_major
a_major_mode = OperandMajorMode.K if is_k_major else OperandMajorMode.MN
a_smem_shape = tiled_mma.partition_shape_A(
cute.dice(mma_tiler_mnk, (1, None, 1), loc=loc, ip=ip), loc=loc, ip=ip
)
@@ -696,7 +698,7 @@ def make_smem_layout_a(
cute.size(a_smem_shape[0][1], loc=loc, ip=ip) * a_smem_shape[2],
)
smem_layout_atom_kind = get_smem_layout_atom_ab(
tiled_mma.op.a_major_mode, a_dtype, a_smem_shape_mn_k, loc=loc, ip=ip
a_major_mode, a_dtype, a_smem_shape_mn_k, loc=loc, ip=ip
)
a_smem_layout_atom = make_smem_layout_atom(
smem_layout_atom_kind, a_dtype, loc=loc, ip=ip
@@ -716,6 +718,7 @@ def make_smem_layout_b(
b_dtype: Type[Numeric],
num_stages: int,
*,
is_k_major=None,
loc=None,
ip=None,
) -> Union[cute.Layout, cute.ComposedLayout]:
@@ -739,7 +742,8 @@ def make_smem_layout_b(
:rtype: Union[cute.Layout, cute.ComposedLayout]
"""
is_k_major = tiled_mma.op.b_major_mode == OperandMajorMode.K
is_k_major = (tiled_mma.op.b_major_mode == OperandMajorMode.K) if is_k_major is None else is_k_major
b_major_mode = OperandMajorMode.K if is_k_major else OperandMajorMode.MN
b_smem_shape = tiled_mma.partition_shape_B(
cute.dice(mma_tiler_mnk, (None, 1, 1), loc=loc, ip=ip), loc=loc, ip=ip
)
@@ -749,7 +753,7 @@ def make_smem_layout_b(
)
smem_layout_atom_kind = get_smem_layout_atom_ab(
tiled_mma.op.b_major_mode, b_dtype, b_smem_shape_nk, loc=loc, ip=ip
b_major_mode, b_dtype, b_smem_shape_nk, loc=loc, ip=ip
)
b_smem_layout_atom = make_smem_layout_atom(
smem_layout_atom_kind, b_dtype, loc=loc, ip=ip

View File

@@ -8,12 +8,10 @@
# Any use, reproduction, disclosure, or distribution of this software
# and related documentation outside the scope permitted by the EULA
# is strictly prohibited.
from cuda.bindings import driver, nvrtc, runtime
from cutlass.cutlass_dsl.cuda_jit_executor import CudaDialectJitModule
from cuda.bindings import driver, runtime
from cutlass.base_dsl.common import DSLRuntimeError
import cutlass.cute as cute
from cutlass import cute
import tempfile
"""
This class is used to get the hardware info of given GPU device.
@@ -53,8 +51,8 @@ class HardwareInfo:
f"Cluster size must be between 1 and 32, {cluster_size} is not supported"
)
self._get_device_function(self.device)
# must do get kernel after set device so runtime context is set correctly
self.kernel = self._get_device_function()
max_shared_memory_per_block = self._checkCudaErrors(
driver.cuDeviceGetAttribute(
driver.CUdevice_attribute.CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN,
@@ -152,8 +150,6 @@ class HardwareInfo:
if isinstance(error, driver.CUresult):
err, name = driver.cuGetErrorName(error)
return name if err == driver.CUresult.CUDA_SUCCESS else "<unknown>"
elif isinstance(error, nvrtc.nvrtcResult):
return nvrtc.nvrtcGetErrorString(error)[1]
else:
raise RuntimeError("Unknown error type: {}".format(error))
@@ -175,14 +171,20 @@ class HardwareInfo:
)
# get a empty kernel to compute occupancy
def _get_device_function(self, device) -> driver.CUfunction:
self.compiled_kernel = cute.compile(self._host_function).to(device)
assert isinstance(self.compiled_kernel.jit_module, CudaDialectJitModule)
err, kernels = runtime.cudaLibraryEnumerateKernels(
1, self.compiled_kernel.jit_module.cuda_library[0]
)
if err is not runtime.cudaError_t.cudaSuccess:
raise DSLRuntimeError(f"Failed to enumerate kernels: {err}")
self.kernel = kernels[0]
self.kernel = self._checkCudaErrors(driver.cuKernelGetFunction(self.kernel))
return self.kernel
def _get_device_function(self) -> driver.CUfunction:
"""
Get a device function by compiling a dummy kernel using cuteDSL pipeline.
"""
# Create a temporary directory for dumping artifacts
with tempfile.TemporaryDirectory() as temp_dir:
# keep-cubin will keep the cubin in the artifacts
compiled_func = cute.compile(self._host_function, options=f"--dump-dir={temp_dir} --keep-cubin")
# Get the CUBIN from artifacts
cubin_data = compiled_func.artifacts.CUBIN
cuda_library = self._checkCudaErrors(
driver.cuLibraryLoadData(cubin_data, None, None, 0, None, None, 0)
)
# Enumerate kernels from the library
kernels = self._checkCudaErrors(driver.cuLibraryEnumerateKernels(1, cuda_library))
# Get the function from the kernel
return self._checkCudaErrors(driver.cuKernelGetFunction(kernels[0]))