v4.3.3 update. (#2868)

This commit is contained in:
Junkai-Wu
2025-12-11 13:26:58 +08:00
committed by GitHub
parent 49bd6bf1ba
commit d3a5492381
24 changed files with 789 additions and 211 deletions

View File

@@ -14,6 +14,8 @@ This module provides jit cache load/dump helper functions
"""
import os
import io
import sys
import uuid
import random
import tempfile
@@ -22,7 +24,7 @@ import time
from pathlib import Path
import hashlib
from functools import lru_cache
import tempfile
import zlib
from .utils.logger import log
from .jit_executor import JitCompiledFunction
@@ -74,13 +76,58 @@ def get_default_file_dump_root():
return dump_root
def load_ir(file, asBytecode=False):
def write_bytecode_with_crc32(f, module):
"""Write the bytecode to the file and calculate the crc32 checksum.
:param f: The file to write the bytecode to.
:type f: file
:param module: The IR module to write the bytecode to.
:type module: object
"""
s = io.BytesIO()
module.operation.write_bytecode(s)
content = s.getvalue()
crc = zlib.crc32(content)
s.write(crc.to_bytes(4, sys.byteorder))
f.write(s.getvalue())
return
def read_bytecode_and_check_crc32(f):
"""
Read the bytecode from the file and check the crc32 checksum.
:param f: The file to read the bytecode with appended CRC32 from.
:type f: file
:return: The bytecode content if checksum matches.
:rtype: bytes
:raises DSLRuntimeError: If checksum does not match.
"""
content = f.read()
if len(content) < 4:
raise DSLRuntimeError(
f"File {f.name} does not contain enough data for CRC32 checksum."
)
bytecode = content[:-4]
crc_appended = content[-4:]
crc_appended_int = int.from_bytes(crc_appended, sys.byteorder)
crc_computed = zlib.crc32(bytecode)
if crc_appended_int != crc_computed:
raise DSLRuntimeError(
f"CRC32 checksum mismatch! Expected {crc_computed}, got {crc_appended_int}"
)
return ir.Module.parse(bytecode)
def load_ir(file, asBytecode=False, bytecode_reader=None):
"""Load generated IR from a file.
:param file: The path to the file to load.
:type file: str
:param asBytecode: Whether to load the IR as bytecode, defaults to False
:type asBytecode: bool, optional
:param bytecode_reader: The bytecode reader to use, defaults to None
:type bytecode_reader: callable, optional
:return: The function name and the IR module
:rtype: tuple[str, ir.Module]
"""
@@ -88,8 +135,10 @@ def load_ir(file, asBytecode=False):
func_name = file.split(".mlir")[0].split("dsl_")[-1]
with ir.Context() as ctx:
with open(file, "rb" if asBytecode else "r") as f:
module = ir.Module.parse(f.read())
if bytecode_reader:
module = bytecode_reader(f)
else:
module = ir.Module.parse(f.read())
return func_name, module
@@ -128,6 +177,14 @@ def save_ir(
:type module: object
:param fname: The name of the file to save.
:type fname: str
:param output_dir: The path to the output directory, defaults to None
:type output_dir: str, optional
:param as_bytecode: Whether to save the IR as bytecode, defaults to False
:type as_bytecode: bool, optional
:param bytecode_writer: The bytecode writer to use, defaults to None
:type bytecode_writer: callable, optional
:return: The path to the saved file
:rtype: str
"""
initial_name = f"{dsl_name.lower()}_{fname}.mlir"
save_path = Path(output_dir if output_dir else tempfile.gettempdir())
@@ -158,63 +215,45 @@ def save_ir(
return save_fname
def check_func_name(jit_cache, func_name):
"""Check if the function name is in the cache.
If not, create a new JitCompiledFunction object and add it to the cache.
:param jit_cache: The cache to check.
:type jit_cache: dict
:param func_name: The name of the function to check.
:type func_name: str
:return: The cache
:rtype: dict
"""
if not func_name in jit_cache:
jit_cache[func_name] = JitCompiledFunction(
None, None, None, None, None, [], False, None
)
return jit_cache
def load_cache_from_path(dsl_name, cache_limit, path=default_generated_ir_path):
def load_cache_from_path(
dsl_name, file, path=default_generated_ir_path, bytecode_reader=None
):
"""Load cache from a directory path.
:param dsl_name: The name of the DSL.
:type dsl_name: str
:param cache_limit: The limit of the cache.
:type cache_limit: int
:param file: The name of the file to load.
:type file: str
:param path: The path to the cache directory, defaults to default_generated_ir_path
:type path: str, optional
:param bytecode_reader: The bytecode reader to use, defaults to None
:type bytecode_reader: callable, optional
:return: The cache
:rtype: dict
"""
if not os.path.exists(path):
return dict()
files = os.listdir(path)
jit_cache = dict()
return None
ret = None
try:
for idx, file in enumerate(files):
if idx >= int(cache_limit):
break
# identify dsl prefix
if not file.startswith(f"{dsl_name.lower()}"):
continue
if ".mlir" in file:
func_name, ir_module = load_ir(
os.path.join(path, file), asBytecode=True
)
jit_cache = check_func_name(jit_cache, func_name)
jit_cache[func_name].ir_module = ir_module
file = f"{dsl_name.lower()}_{file}.mlir"
if os.path.exists(os.path.join(path, file)):
_, module = load_ir(
os.path.join(path, file),
asBytecode=True,
bytecode_reader=bytecode_reader,
)
ret = JitCompiledFunction(module, None, None, None, None, [], False, None)
except Exception as e:
print(f"{dsl_name} failed with loading generated IR cache.", e)
jit_cache = dict()
return jit_cache
log().warning(
f"{dsl_name} failed with loading generated IR cache for {file}.", e
)
return ret
def dump_cache_to_path(
dsl_name,
jit_cache,
cache_limit,
jit_function,
file,
path=default_generated_ir_path,
bytecode_writer=None,
):
@@ -222,30 +261,29 @@ def dump_cache_to_path(
:param dsl_name: The name of the DSL.
:type dsl_name: str
:param jit_cache: The cache to dump.
:type jit_cache: dict
:param cache_limit: The limit of the cache.
:type cache_limit: int
:param jit_function: The JitCompiledFunction to dump.
:type jit_function: JitCompiledFunction
:param file: The name of the file to dump.
:type file: str
:param path: The path to the cache directory, defaults to default_generated_ir_path
:type path: str, optional
:param bytecode_writer: The bytecode writer to use, defaults to None
:type bytecode_writer: callable, optional
"""
log().info("JIT cache : dumping [%s] items=[%s]", dsl_name, len(jit_cache))
log().info("JIT cache : dumping [%s] file=[%s]", dsl_name, file)
if not path:
path = default_generated_ir_path
os.makedirs(path, exist_ok=True)
try:
for idx, [key, value] in enumerate(jit_cache.items()):
if idx >= int(cache_limit):
break
save_ir(
dsl_name,
value.ir_module,
key,
output_dir=path,
as_bytecode=True,
bytecode_writer=bytecode_writer,
)
save_ir(
dsl_name,
jit_function.ir_module,
file,
output_dir=path,
as_bytecode=True,
bytecode_writer=bytecode_writer,
)
except Exception as e:
print(f"{dsl_name} failed with caching generated IR", e)
log().warning(
f"{dsl_name} failed with dumping generated IR cache for {file}: {e}"
)

View File

@@ -195,6 +195,10 @@ def _get_friendly_cuda_error_message(error_code, error_name):
f"2. SM ARCH setting",
f"3. Steps to reproduce",
),
"cudaErrorInsufficientDriver": (
f"1. Run nvidia-smi to confirm CUDA driver version",
f"2. Ensure the CUDA driver version meets the requirement of the installed cuda-python package",
),
}
message = f"{error_name} (error code: {error_code}) \n" \

View File

@@ -318,11 +318,8 @@ class BaseDSL:
self.envar = self._env_class(self.name)
self.enable_preprocessor = preprocess
# This cache uses hash of original ir and env as key, allows dump/load to/from file. Enabled by default
self.jit_cache = (
dict()
if self.envar.disable_file_caching
else load_cache_from_path(self.name, self.envar.file_caching_capacity)
)
self.jit_cache = dict()
self.host_jit_decorator_name = f"@{BaseDSL.jit.__name__}"
self.device_jit_decorator_name = f"@{BaseDSL.kernel.__name__}"
@@ -372,12 +369,6 @@ class BaseDSL:
atexit.register(restore_excepthook, origin_excepthook)
def dump_cache(self, path=None):
if not self.envar.disable_file_caching:
dump_cache_to_path(
self.name, self.jit_cache, self.envar.file_caching_capacity, path=path
)
@lru_cache(maxsize=1)
def print_warning_once(self, message):
log().warning(f"Warning: {message}")
@@ -392,9 +383,6 @@ class BaseDSL:
def _get_dsl(cls):
# Instantiate the DSL Class once
main_dsl = cls()
if not main_dsl.no_cache:
# register atexit callback
atexit.register(main_dsl.dump_cache)
return main_dsl
@staticmethod
@@ -1235,6 +1223,16 @@ class BaseDSL:
log().debug(f"Using pipeline = {pipeline}")
shared_libs = self.get_shared_libs()
profiler = timer(enable=self.envar.jit_time_profiling)
# try load the file cache
load_from_file_cache = False
if not no_cache:
fn = load_cache_from_path(
self.name, module_hash, bytecode_reader=read_bytecode_and_check_crc32
)
if fn is not None:
load_from_file_cache = True
self.jit_cache[module_hash] = fn
if (
no_cache
or module_hash not in self.jit_cache
@@ -1288,6 +1286,16 @@ class BaseDSL:
if not no_cache:
# module stored in cache is compiled.
self.jit_cache[module_hash] = fn
# write through the file cache if enabled.
if not self.envar.disable_file_caching and not load_from_file_cache:
dump_cache_to_path(
self.name,
fn,
module_hash,
bytecode_writer=lambda f: write_bytecode_with_crc32(
f, fn.ir_module
),
)
return fn

View File

@@ -311,7 +311,6 @@ class EnvironmentVarManager(LogEnvironmentManager):
- [DSL_NAME]_ENABLE_OPTIMIZATION_WARNINGS: Enable warnings of optimization warnings (default: False)
- [DSL_NAME]_JIT_TIME_PROFILING: Whether or not to profile the IR generation/compilation/execution time (default: False)
- [DSL_NAME]_DISABLE_FILE_CACHING: Disable file caching (default: False)
- [DSL_NAME]_FILE_CACHING_CAPACITY: Limits the number of the cache save/load files (default: 1000)
- [DSL_NAME]_LIBS: Path to dependent shared libraries (default: None)
- [DSL_NAME]_ENABLE_TVM_FFI: Enable TVM-FFI or not (default: False)
"""
@@ -350,9 +349,6 @@ class EnvironmentVarManager(LogEnvironmentManager):
self.disable_file_caching = get_bool_env_var(
f"{prefix}_DISABLE_FILE_CACHING", False
)
self.file_caching_capacity = get_int_env_var(
f"{prefix}_FILE_CACHING_CAPACITY", 1000
)
# set cuda
self.cuda_toolkit = get_cuda_toolkit_path()

View File

@@ -27,17 +27,21 @@ from typing import Union
cubin_suffix = "cubin"
def get_export_module(ir_module: ir.Module, symbol_prefix: str):
def get_export_module(ir_module: ir.Module, symbol_prefix: str, *, preserve_symbols = None):
"""Get the export module which is cloned from the original compiled ir module, and add the prefix
to avoid the symbol conflict.
@param ir_module: The original compiled ir module. Comes from the JitCompiledFunction.ir_module.
@param symbol_prefix: The prefix name of the function. This is the unique identifier name of the function to avoid symbol conflict in the generated object file.
@param preserve_symbols: Optional symbols to preserve in the export module.
@return: The export module of the function.
"""
# Add prefix for symbol names to avoid conflict with other functions
defined_symbols = set()
if preserve_symbols is None:
preserve_symbols = set()
def walk_llvm_func_op(op):
# not a declaration
if (
@@ -45,23 +49,61 @@ def get_export_module(ir_module: ir.Module, symbol_prefix: str):
and len(op.opview.operation.regions) > 0
and len(op.opview.operation.regions[0].blocks) > 0
):
defined_symbols.add(op.attributes["sym_name"].value)
func_name = op.attributes["sym_name"].value
# skip preserving symbols
if func_name in preserve_symbols:
return ir.WalkResult.ADVANCE
defined_symbols.add(func_name)
op.attributes["sym_name"] = ir.StringAttr.get(
symbol_prefix + "_" + op.attributes["sym_name"].value
symbol_prefix + "_" + func_name
)
return ir.WalkResult.ADVANCE
def walk_llvm_call_op(op):
def walk_llvm_references(op):
# Rename function calls
if op.name == "llvm.call" and op.attributes["callee"].value in defined_symbols:
op.attributes["callee"] = ir.FlatSymbolRefAttr.get(
symbol_prefix + "_" + op.attributes["callee"].value
)
# Rename addressof references
elif op.name == "llvm.mlir.addressof" and op.attributes["global_name"].value in defined_symbols:
op.attributes["global_name"] = ir.FlatSymbolRefAttr.get(
symbol_prefix + "_" + op.attributes["global_name"].value
)
# Rename global_ctors references
elif op.name == "llvm.mlir.global_ctors" and "ctors" in op.attributes:
ctors = list(op.attributes["ctors"])
renamed_ctors = []
for ctor in ctors:
if ctor.value in defined_symbols:
renamed_ctors.append(ir.FlatSymbolRefAttr.get(
symbol_prefix + "_" + ctor.value
))
else:
renamed_ctors.append(ctor)
if renamed_ctors:
op.attributes["ctors"] = ir.ArrayAttr.get(renamed_ctors)
# Rename global_dtors references
elif op.name == "llvm.mlir.global_dtors" and "dtors" in op.attributes:
dtors = list(op.attributes["dtors"])
renamed_dtors = []
for dtor in dtors:
if dtor.value in defined_symbols:
renamed_dtors.append(ir.FlatSymbolRefAttr.get(
symbol_prefix + "_" + dtor.value
))
else:
renamed_dtors.append(dtor)
if renamed_dtors:
op.attributes["dtors"] = ir.ArrayAttr.get(renamed_dtors)
return ir.WalkResult.ADVANCE
with ir.Context():
export_module = ir.Module.parse(str(ir_module))
# First pass: collect and rename function definitions
export_module.operation.walk(walk_llvm_func_op)
export_module.operation.walk(walk_llvm_call_op)
# Second pass: rename call and addressof references
export_module.operation.walk(walk_llvm_references)
return export_module

View File

@@ -16,7 +16,7 @@ This module provides jit executor related classes
import ctypes
import inspect
import io
from typing import Union, Optional
from typing import Union, Optional, NamedTuple, Any, Sequence
import weakref
import threading
import collections
@@ -132,6 +132,15 @@ def load_kernels_from_ir_module(module, kernel_info) -> list[CudaModuleAndKernel
return list(kernel_modules.values())
class KwargsWrapperSpec(NamedTuple):
"""A specification for keyword arguments wrapper."""
arg_names: list[str]
arg_defaults: tuple[Any, ...]
kwonly_names: list[str]
kwonly_defaults: dict[str, Any]
class ExecutionArgs:
"""Helper that wraps the function signature spec to filter exeuction and compile time arguments."""
@@ -216,6 +225,59 @@ class ExecutionArgs:
return exe_args, adapted_args
def get_kwargs_wrapper_spec(self, exclude_arg_names: Sequence[str] = ()) -> KwargsWrapperSpec:
"""
This function is used to get the kwargs wrapper spec from the original args_spec.
"""
excluded_arg_names = set(exclude_arg_names)
arg_spec = self.original_args_spec
if arg_spec.defaults:
defaults_start_idx = len(arg_spec.args) - len(arg_spec.defaults)
else:
defaults_start_idx = len(arg_spec.args)
arg_names = []
arg_defaults = []
kwonly_names = []
kwonly_defaults = {}
# Filter arguments and maintain their properties
for i, arg_name in enumerate(arg_spec.args):
arg_type = arg_spec.annotations.get(arg_name, None)
# Skip compile-time arguments
if is_arg_spec_constexpr(arg_type, arg_name, i, self.function_name):
continue
if arg_name in excluded_arg_names:
continue
arg_names.append(arg_name)
if i >= defaults_start_idx:
arg_defaults.append(arg_spec.defaults[i - defaults_start_idx])
if arg_spec.kwonlyargs:
for i, kwarg in enumerate(arg_spec.kwonlyargs):
arg_type = arg_spec.annotations.get(kwarg, None)
# Skip compile-time arguments
if is_arg_spec_constexpr(arg_type, kwarg, i, self.function_name):
continue
if kwarg in excluded_arg_names:
continue
kwonly_names.append(kwarg)
if arg_spec.kwonlydefaults and kwarg in arg_spec.kwonlydefaults:
kwonly_defaults[kwarg] = arg_spec.kwonlydefaults[kwarg]
return KwargsWrapperSpec(
arg_names=arg_names,
arg_defaults=tuple(arg_defaults),
kwonly_names=kwonly_names,
kwonly_defaults=kwonly_defaults,
)
def get_rectified_args_from_original_args(self, full_args, full_kwargs):
"""
This function is used to rectify the original arguments to the runtime
@@ -233,6 +295,7 @@ class ExecutionArgs:
defaults_start_idx = len(arg_spec.args)
runtime_args = []
# Filter arguments and maintain their properties
for i, arg_name in enumerate(arg_spec.args):
arg_type = arg_spec.annotations.get(arg_name, None)
@@ -241,12 +304,24 @@ class ExecutionArgs:
if is_arg_spec_constexpr(arg_type, arg_name, i, self.function_name):
continue
# Keep corresponding default if it exists
if i >= defaults_start_idx:
# Check if argument was provided by user, otherwise use default
if i < len(full_args):
# User provided this argument - use it
runtime_args.append(full_args[i])
elif i >= defaults_start_idx:
# Argument not provided, but has default - use default
default_idx = i - defaults_start_idx
runtime_args.append(arg_spec.defaults[default_idx])
else:
runtime_args.append(full_args[i])
# Required argument missing
raise DSLRuntimeError(
f"Missing required argument '{arg_name}' at position {i}",
context={
"function_name": self.function_name,
"expected_args": len(arg_spec.args),
"provided_args": len(full_args),
}
)
# Filter keyword-only arguments
runtime_kwargs = {}

View File

@@ -21,6 +21,7 @@ import os
import ctypes
import cuda.bindings.driver as cuda
import cuda.bindings.runtime as cudart
import cuda.bindings.nvrtc as nvrtc
# Local module imports
@@ -44,6 +45,8 @@ def _cudaGetErrorEnum(error):
if isinstance(error, cuda.CUresult):
err, name = cuda.cuGetErrorName(error)
return name if err == cuda.CUresult.CUDA_SUCCESS else "<unknown>"
elif isinstance(error, cudart.cudaError_t):
return cudart.cudaGetErrorName(error)[1]
elif isinstance(error, nvrtc.nvrtcResult):
return nvrtc.nvrtcGetErrorString(error)[1]
else:

View File

@@ -38,6 +38,7 @@ class MLIRTypeBuilder:
self.gpu_ptr_type = llvm.PointerType.get(address_space=1)
# did not find a programmatic way to get the void type
self.void_type = ir.Type.parse("!llvm.void")
self.llvm_internal_linkage = ir.Attribute.parse("#llvm.linkage<internal>")
def ptr_type_with_address_space(
self, address_space: Optional[int] = None
@@ -374,8 +375,10 @@ class MLIRBuilder(MLIRTypeBuilder):
params_type: Sequence[ir.Type],
ret_type: ir.Type,
internal: bool = False,
llvm_func_attrs: Sequence[str] = (),
) -> tuple[list[ir.Value], ir.Block]:
"""Create a function with the given signature."""
"""Create a function with the given signature.
"""
func_op = llvm.func(
name,
function_type=self.as_attr(
@@ -383,10 +386,16 @@ class MLIRBuilder(MLIRTypeBuilder):
),
)
if internal:
func_op.attributes["llvm.linkage"] = ir.StringAttr.get("private")
func_op.attributes["linkage"] = self.llvm_internal_linkage
else:
func_op.attributes["llvm.emit_c_interface"] = ir.UnitAttr.get()
# Add LLVM function attributes via passthrough
if llvm_func_attrs:
func_op.attributes["passthrough"] = ir.ArrayAttr.get(
[ir.StringAttr.get(attr) for attr in llvm_func_attrs]
)
params = []
func_body: Any = func_op.body
if func_body is not None:

View File

@@ -668,12 +668,14 @@ class TVMFFIBuilder(MLIRBuilder):
param_types.append(self.ptr_type) # p0, p1, ..., pN-1
# Create the helper function
# Mark as noinline since error handling is a slow path and benefits from not inlining
with ir.InsertionPoint(self.module.body): # type: ignore[union-attr]
params, entry_block = self.function(
name=helper_name,
params_type=param_types,
ret_type=self.void_type,
internal=True,
llvm_func_attrs=["noinline"],
)
kind_param = params[0]
@@ -1244,12 +1246,7 @@ class TVMFFIFunctionBuilder(TVMFFIBuilder):
else:
assert isinstance(var, int)
with ir.InsertionPoint(current_block):
if not skip_cast_and_check:
expected_value = self.i64(var)
else:
expected_value = self.downcast_i64_to_lower_bits(
self.i64(var), var.dtype
)
expected_value = self.i64(var)
error_msg_mismatch = [
error_prefix_mismatch,
@@ -1983,13 +1980,21 @@ class TVMFFIFunctionBuilder(TVMFFIBuilder):
)
# decode parameters to populate the matched var binding
for arg_index, param in enumerate(params_list):
# Track the actual FFI argument index separately from parameter index
# since some parameters (like EnvStream) are not passed as FFI arguments
ffi_arg_index = 0
for param in params_list:
# Skip EnvStream parameters as they are not in the FFI args array
if isinstance(param, spec.EnvStream):
continue
arg_context = ArgContext(
param_name=param.name,
arg_index=arg_index,
arg_index=ffi_arg_index,
tuple_indices=[],
)
current_block = self.decode_param(current_block, param, args, arg_index, arg_context)
current_block = self.decode_param(current_block, param, args, ffi_arg_index, arg_context)
ffi_arg_index += 1
with ir.InsertionPoint(current_block):
env_stream = self.find_env_stream(params_list)
@@ -2047,7 +2052,9 @@ def attach_ffi_func(
builder.attach_ffi_func(symbol_name, params, call_provider, fn_display_name)
def rename_tvm_ffi_function(module: ir.Module, old_name: str, new_name: str) -> None:
def rename_tvm_ffi_function(
module: ir.Module, old_name: str, new_name: str,
) -> None:
"""Rename the TVM FFI function in the module.
Parameters

View File

@@ -19,6 +19,7 @@ from .typing import Tensor, Pointer, SymInt
from .typing import (
Numeric,
Boolean,
Integer,
Int4,
Int8,
Uint8,
@@ -42,7 +43,17 @@ from .typing import (
)
import cuda.bindings.driver as cuda
from typing import List, Dict, Any, Optional, Tuple, get_origin, get_args
from typing import (
List,
Dict,
Any,
Optional,
Tuple,
get_origin,
get_args,
get_type_hints,
)
from types import UnionType
import inspect
NumericToTVMFFIDtype = {
@@ -91,6 +102,7 @@ def _get_llvm_address_space_from_memspace(
return 1
return None
def _is_gpu_memspace(
memspace: _cute_ir.AddressSpace,
) -> bool:
@@ -108,7 +120,6 @@ class SymIntId:
return self.sym_int is other.sym_int
class ConverterContext:
"""Context for managing variable allocation during TVM FFI args conversion."""
@@ -145,7 +156,9 @@ class ConverterContext:
self.sym_int_id_mapping[sym_int_id] = var
return var
def alloc_or_reuse_device_id(self, device_type: str, vdevice_id: int) -> Optional[spec.Var]:
def alloc_or_reuse_device_id(
self, device_type: str, vdevice_id: int
) -> Optional[spec.Var]:
"""Allocate or reuse a device_id variable for a given virtual device.
This function returns None for CPU tensors.
@@ -166,10 +179,7 @@ class ConverterContext:
def _convert_single_arg(
arg,
arg_name: str,
arg_type,
ctx: ConverterContext
arg, arg_name: str, arg_type, ctx: ConverterContext
) -> spec.Param:
"""Convert a single argument to a spec.Param.
@@ -191,7 +201,7 @@ def _convert_single_arg(
"""
if arg is None:
return spec.ConstNone(arg_name)
elif (isinstance(arg, Numeric) and arg.dtype in AcceptableNumericTypesForScalar):
elif isinstance(arg, Numeric) and arg.dtype in AcceptableNumericTypesForScalar:
return spec.Var(arg_name, NumericToTVMFFIDtype[arg.dtype])
elif arg_type in AcceptableNumericTypesForScalar:
return spec.Var(arg_name, NumericToTVMFFIDtype[arg_type])
@@ -201,9 +211,13 @@ def _convert_single_arg(
if isinstance(arg[i], int):
shape.append(arg[i])
elif isinstance(arg[i], SymInt):
shape.append(ctx.alloc_or_reuse_symint_var(arg[i], ctx.alloc_shape_name))
shape.append(
ctx.alloc_or_reuse_symint_var(arg[i], ctx.alloc_shape_name)
)
else:
shape.append(spec.Var(ctx.alloc_shape_name(), NumericToTVMFFIDtype[arg[i].dtype]))
shape.append(
spec.Var(ctx.alloc_shape_name(), NumericToTVMFFIDtype[arg[i].dtype])
)
return spec.Shape(arg_name, shape)
elif isinstance(arg, Tensor):
shapes = []
@@ -211,16 +225,22 @@ def _convert_single_arg(
if not dyn_mask:
shapes.append(arg.shape[i])
elif isinstance(arg.shape[i], SymInt):
shapes.append(ctx.alloc_or_reuse_symint_var(arg.shape[i], ctx.alloc_shape_name))
shapes.append(
ctx.alloc_or_reuse_symint_var(arg.shape[i], ctx.alloc_shape_name)
)
else:
shapes.append(spec.Var(ctx.alloc_shape_name(), NumericToTVMFFIDtype[Int32]))
shapes.append(
spec.Var(ctx.alloc_shape_name(), NumericToTVMFFIDtype[Int32])
)
strides = []
for i, dyn_mask in enumerate(arg.dynamic_strides_mask):
if not dyn_mask:
strides.append(arg.stride[i])
elif isinstance(arg.stride[i], SymInt):
strides.append(ctx.alloc_or_reuse_symint_var(arg.stride[i], ctx.alloc_stride_name))
strides.append(
ctx.alloc_or_reuse_symint_var(arg.stride[i], ctx.alloc_stride_name)
)
else:
if hasattr(arg, "_use_32bit_stride") and arg._use_32bit_stride:
dtype = NumericToTVMFFIDtype[Int32]
@@ -243,7 +263,7 @@ def _convert_single_arg(
strides=strides,
data_alignment=arg._assumed_align,
device_type=device_type,
device_id=device_id
device_id=device_id,
)
else:
# for FakeTensor, strictly follow the shape and stride from the cute tensor
@@ -259,7 +279,7 @@ def _convert_single_arg(
strides=strides,
data_alignment=arg._assumed_align,
device_type=device_type,
device_id=device_id
device_id=device_id,
)
if arg.element_type == Float4E2M1FN:
tvm_ffi_cute_tensor = spec.create_map_tensor_dtype_f4x2_to_f4_spec(
@@ -278,11 +298,38 @@ def _convert_single_arg(
return spec.Stream(arg_name)
elif isinstance(arg, cuda.CUstream):
return spec.Stream(arg_name)
elif arg_type is not None and hasattr(arg_type, "_fields"):
# Handle NamedTuple - normalize to Tuple by order of fields, ignoring defaults
# Get field types from annotations
type_hints = get_type_hints(arg_type)
tuple_element_types = [type_hints[field] for field in arg_type._fields]
# NamedTuples inherit from tuple, so we can check with isinstance(arg, tuple)
if not isinstance(arg, tuple):
raise DSLRuntimeError(
f"Expected namedtuple for argument {arg_name}, got {type(arg)}"
)
if len(arg) != len(tuple_element_types):
raise DSLRuntimeError(
f"NamedTuple length mismatch for argument {arg_name}: "
f"expected {len(tuple_element_types)}, got {len(arg)}"
)
# Recursively convert each tuple element
tuple_params = []
for i, (elem, elem_type) in enumerate(zip(arg, tuple_element_types)):
elem_name = f"{arg_name}[{i}]"
elem_param = _convert_single_arg(elem, elem_name, elem_type, ctx)
tuple_params.append(elem_param)
return spec.TupleParam(arg_name, tuple_params)
elif arg_type is not None and get_origin(arg_type) is tuple:
# Handle Tuple[X, Y, ...] type annotations
tuple_element_types = get_args(arg_type)
if not isinstance(arg, (tuple, list)):
raise DSLRuntimeError(f"Expected tuple for argument {arg_name}, got {type(arg)}")
raise DSLRuntimeError(
f"Expected tuple for argument {arg_name}, got {type(arg)}"
)
if len(arg) != len(tuple_element_types):
raise DSLRuntimeError(
f"Tuple length mismatch for argument {arg_name}: "
@@ -297,8 +344,24 @@ def _convert_single_arg(
tuple_params.append(elem_param)
return spec.TupleParam(arg_name, tuple_params)
elif isinstance(arg, (tuple, list)):
# Handle plain tuple type annotation without explicit element types
# Recursively convert each tuple element with None as elem_type (un-annotated)
tuple_params = []
for i, elem in enumerate(arg):
elem_name = f"{arg_name}[{i}]"
elem_param = _convert_single_arg(elem, elem_name, None, ctx)
tuple_params.append(elem_param)
return spec.TupleParam(arg_name, tuple_params)
elif isinstance(arg, int):
# in cute.compile, unannotated const int is converted to int32
return spec.Var(arg_name, NumericToTVMFFIDtype[Int32])
elif isinstance(arg, float):
return spec.Var(arg_name, NumericToTVMFFIDtype[Float32])
else:
raise DSLRuntimeError(f"Unsupported argument type: {type(arg)}")
raise DSLRuntimeError(
f"Unsupported argument type: {type(arg)} for annotated type: {get_origin(arg_type)}"
)
def _tvm_ffi_args_spec_converter(
@@ -312,17 +375,24 @@ def _tvm_ffi_args_spec_converter(
This function converts the cute arguments specs to tvm ffi spec params.
"""
exec_args = ExecutionArgs(args_spec, function_name)
rectified_args = exec_args.get_rectified_args_from_original_args(full_args, full_kwargs)
rectified_args = exec_args.get_rectified_args_from_original_args(
full_args, full_kwargs
)
arg_names = exec_args.args_spec.args + exec_args.args_spec.kwonlyargs
params = []
ctx = ConverterContext()
wrapper_extra_exclude_arg_names = []
for arg, arg_name in zip(rectified_args, arg_names):
arg_type = args_spec.annotations.get(arg_name, None)
param = _convert_single_arg(arg, arg_name, arg_type, ctx)
params.append(param)
return params
if isinstance(param, spec.EnvStream):
wrapper_extra_exclude_arg_names.append(arg_name)
kwargs_wrapper_spec = exec_args.get_kwargs_wrapper_spec(
wrapper_extra_exclude_arg_names
)
return params, kwargs_wrapper_spec
def attach_args_spec_converter():

View File

@@ -28,8 +28,9 @@ from ..base_dsl.jit_executor import (
JitFunctionArtifacts,
)
from ..base_dsl.utils.logger import log
from ..base_dsl.common import DSLCudaRuntimeError, DSLRuntimeError
from ..base_dsl.common import DSLRuntimeError
from ..base_dsl.typing import Int32
from ..base_dsl.runtime.cuda import checkCudaErrors
class CudaDialectJitModule:
"""Holds the execution engine and cuda libraries."""
@@ -113,10 +114,7 @@ class CudaDialectJitCompiledFunction(JitCompiledFunction):
@functools.cached_property
def num_devices(self):
"""Returns the number of CUDA devices available."""
dev_err, devs = cuda_runtime.cudaGetDeviceCount()
if dev_err != cuda_runtime.cudaError_t.cudaSuccess:
raise DSLCudaRuntimeError(dev_err, cuda_runtime.cudaGetErrorName(dev_err))
return devs
return checkCudaErrors(cuda_runtime.cudaGetDeviceCount())
def _deserializer(self):
"""Load the cuda library from the binary execution engine.
@@ -148,12 +146,7 @@ class CudaDialectJitCompiledFunction(JitCompiledFunction):
packed_args[i] = ctypes.cast(cuda_init_args[i], ctypes.c_void_p)
cuda_init(packed_args)
if err.value != 0:
error_code = err.value
error_name = cuda_runtime.cudaGetErrorName(
cuda_runtime.cudaError_t(error_code)
)
raise DSLCudaRuntimeError(error_code, error_name)
checkCudaErrors((cuda_runtime.cudaError_t(err.value),))
cuda_load_args = [pointer_to_library, pointer_to_err]
packed_args = (ctypes.c_void_p * len(cuda_load_args))()
@@ -161,12 +154,7 @@ class CudaDialectJitCompiledFunction(JitCompiledFunction):
packed_args[i] = ctypes.cast(cuda_load_args[i], ctypes.c_void_p)
cuda_load(packed_args)
if err.value != 0:
error_code = err.value
error_name = cuda_runtime.cudaGetErrorName(
cuda_runtime.cudaError_t(error_code)
)
raise DSLCudaRuntimeError(error_code, error_name)
checkCudaErrors((cuda_runtime.cudaError_t(err.value),))
return [cuda_runtime.cudaLibrary_t(library.value)]
@@ -229,12 +217,7 @@ class CudaDialectJitCompiledFunction(JitCompiledFunction):
packed_args[i] = ctypes.cast(cuda_init_args[i], ctypes.c_void_p)
cuda_init(packed_args)
if err.value != 0:
error_code = err.value
error_name = cuda_runtime.cudaGetErrorName(
cuda_runtime.cudaError_t(error_code)
)
raise DSLCudaRuntimeError(error_code, error_name)
checkCudaErrors((cuda_runtime.cudaError_t(err.value),))
device_id = ctypes.c_int32(0)
pointer_to_device_id = ctypes.pointer(device_id)
@@ -247,18 +230,9 @@ class CudaDialectJitCompiledFunction(JitCompiledFunction):
for dev in range(self.num_devices):
device_id.value = dev
cuda_load_to_device(packed_args)
if err.value != 0:
raise DSLCudaRuntimeError(
err.value,
cuda_runtime.cudaGetErrorName(cuda_runtime.cudaError_t(err.value)),
)
checkCudaErrors((cuda_runtime.cudaError_t(err.value),))
if err.value != 0:
error_code = err.value
error_name = cuda_runtime.cudaGetErrorName(
cuda_runtime.cudaError_t(error_code)
)
raise DSLCudaRuntimeError(error_code, error_name)
checkCudaErrors((cuda_runtime.cudaError_t(err.value),))
return [cuda_runtime.cudaLibrary_t(library.value)]

View File

@@ -43,9 +43,9 @@ from ..base_dsl.dsl import is_dynamic_expression, extract_mlir_values
from ..base_dsl.typing import *
from ..base_dsl.typing import DynamicExpression, get_mlir_types
from ..base_dsl.runtime.jit_arg_adapters import is_arg_spec_constexpr
from ..base_dsl.jit_executor import ExecutionArgs
from ..base_dsl.runtime import cuda as cuda_helpers
from .cuda_stream_adapter import CudaDialectStreamAdapter
from .cuda_jit_executor import CudaDialectJitCompiledFunction
# MLIR Imports
@@ -421,12 +421,13 @@ class CutlassBaseDSL(BaseDSL):
# attach extra ABI function to the MLIR module
from .tvm_ffi_provider import (
TVMFFIJitCompiledFunction,
TVMFFIJitCompiledFunctionWithKwargs,
TVMFFICuteCallProvider,
)
from cutlass.base_dsl.tvm_ffi_builder import attach_ffi_func
assert self._tvm_ffi_args_spec_converter is not None
tvm_ffi_spec_params = self._tvm_ffi_args_spec_converter(
tvm_ffi_spec_params, kwargs_wrapper_spec = self._tvm_ffi_args_spec_converter(
function_name, args_spec, full_args, full_kwargs
)
tvm_ffi_provider = TVMFFICuteCallProvider(function_name)
@@ -444,6 +445,15 @@ class CutlassBaseDSL(BaseDSL):
)
module.operation.verify()
def _make_compiled_func(*args, **kwargs):
if kwargs_wrapper_spec.kwonly_names or kwargs_wrapper_spec.arg_defaults:
return TVMFFIJitCompiledFunctionWithKwargs(
*args, **kwargs,
kwargs_wrapper_spec=kwargs_wrapper_spec
)
else:
return TVMFFIJitCompiledFunction(*args, **kwargs)
# ensure the compiler can run post-compile hook after its passes
# the context will restore the previous post-compile hook after it exits
with compiler.PostCompileHookContext(
@@ -456,7 +466,7 @@ class CutlassBaseDSL(BaseDSL):
pipeline,
args_spec,
no_cache,
TVMFFIJitCompiledFunction,
_make_compiled_func,
full_args=full_args,
full_kwargs=full_kwargs,
dynamic_args=dynamic_args,

View File

@@ -15,12 +15,14 @@ from cutlass.base_dsl.tvm_ffi_builder import (
rename_tvm_ffi_function,
spec,
)
from cutlass.base_dsl.export import get_export_module
from cutlass._mlir import ir
from cutlass._mlir.dialects import llvm
from cutlass._mlir._mlir_libs._cutlass_ir import _execution_engine_extra
from cutlass._mlir._mlir_libs._cutlass_ir import _aot_support
from cutlass.cutlass_dsl.cuda_jit_executor import CudaDialectJitCompiledFunction
from cutlass.base_dsl.common import DSLRuntimeError
from typing import Optional
from cutlass.base_dsl.jit_executor import ExecutionArgs
from typing import Optional, Callable
import tvm_ffi
@@ -400,41 +402,52 @@ class TVMFFICuteCallProvider(DynamicParamPackCallProvider):
return current_block
class TVMFFIJitCompiledFunction(tvm_ffi.Function, CudaDialectJitCompiledFunction):
"""TVM FFI Function that contains metadata of the compiled function and interface to the FFI layer.
def _inplace_hide_symbols(ir_module: ir.Module, hide_check: Callable[[str], bool]):
"""Walk through the IRModule, hide functions that do not yet have linkage set.
This function should not be directly used after
@param ir_module: The ir module to hide the symbols.
@param hide_check: The callback to check if the symbol should be hidden.
@return: The ir module with the symbols hidden.
"""
defined_symbols = set()
def walk_llvm_func_op(op):
# not a declaration
if (
op.name == "llvm.func"
and len(op.opview.operation.regions) > 0
and len(op.opview.operation.regions[0].blocks) > 0
):
func_name = op.attributes["sym_name"].value
defined_symbols.add(func_name)
return ir.WalkResult.ADVANCE
def walk_and_hide_symbols(op):
# Handle llvm.func operations
if op.name == "llvm.func":
func_name = op.attributes["sym_name"].value
# Only set linkage if it doesn't already have one
if func_name in defined_symbols and hide_check(func_name):
# Set to internal linkage to hide the symbol
op.attributes["linkage"] = ir.Attribute.parse("#llvm.linkage<internal>")
return ir.WalkResult.ADVANCE
with ir_module.context:
ir_module.operation.walk(walk_llvm_func_op)
ir_module.operation.walk(walk_and_hide_symbols)
def _get_format_from_object_file_path(object_file_path: str) -> str:
format = object_file_path.split(".")[-1]
if format not in ("o", "ll", "bc"):
return "o"
return format
class TVMFFIJitCompiledFunctionBase(CudaDialectJitCompiledFunction):
"""Base class for TVM FFI compiled function."""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# initialize the tvm_ffi.Function from the current execution engine
self._init_ffi_function()
# use direct call to the tvm_ffi.Function.__call__
# to avoid most of python overhead
__call__ = tvm_ffi.Function.__call__
def _init_ffi_function(self):
"""Initialize the tvm_ffi.Function from the current execution engine.
This function must be called at once during compilation time.
The reason why it is not called during init is because the original
flow may already created an execution engine and the function is not
guaranteed to be initialized at that time.
"""
if self.__chandle__() != 0:
raise DSLRuntimeError("TVM FFI function is already initialized")
# get the MLIR function pointer from the execution engine
if self.engine is not None:
tvm_ffi_function_ptr = self.engine.raw_lookup(
"__tvm_ffi_" + self.function_name
)
tvm_ffi_function = tvm_ffi.Function.__from_mlir_packed_safe_call__(
tvm_ffi_function_ptr
)
# move the handle from the tvm_ffi.Function to the current instance
self.__move_handle_from__(tvm_ffi_function)
def to(self, device=None):
"""TVM FFI function itself is already support all devices."""
@@ -444,18 +457,111 @@ class TVMFFIJitCompiledFunction(tvm_ffi.Function, CudaDialectJitCompiledFunction
"""Run the compiled program. This override is needed for implicit compile and execution."""
return self.__call__(*exe_args)
def export_to_c(self, object_file_path: str, function_name: str = None):
def export_to_c(
self, object_file_path: str, function_name: str = None,
*,
enable_pic: bool = True,
export_only_tvm_ffi_symbols: bool = False
):
"""Export the TVM FFI function to an object file.
:param object_file_path: The path to the object file.
:param function_name: The name of the function to export.
:param enable_pic: Whether to enable PIC relocation needed for shared library loading.
:param export_only_tvm_ffi_symbols: Only export TVM FFI symbols (hide all others).
:param host_target_triple: If not provided, the current host target is used.
"""
if function_name is not None and function_name != self.function_name:
mod = self.ir_module
rename_tvm_ffi_function(mod, self.function_name, function_name)
else:
mod = self.ir_module
_execution_engine_extra.dump_object_file_pic(
mod, object_file_path, "__tvm_ffi_" + function_name, 2
# prefix internal function by function name
internal_symbol_prefix = "__cute_internal_" + function_name
mod = self.ir_module
mod = get_export_module(
self.ir_module, internal_symbol_prefix,
preserve_symbols=[f"__tvm_ffi_{self.function_name}"]
)
rename_tvm_ffi_function(mod, self.function_name, function_name)
if export_only_tvm_ffi_symbols:
_inplace_hide_symbols(mod, lambda x: not x.startswith("__tvm_ffi"))
format = _get_format_from_object_file_path(object_file_path)
out_bytes = _aot_support.export_module_to_bytes(
mod, format=format, opt_level=3, enable_pic=enable_pic
)
with open(object_file_path, "wb") as f:
f.write(out_bytes)
def _create_tvm_ffi_function(self):
"""Create the tvm_ffi.Function from the current execution engine.
"""
if self.engine is not None:
tvm_ffi_function_ptr = self.engine.raw_lookup(
"__tvm_ffi_" + self.function_name
)
tvm_ffi_function = tvm_ffi.Function.__from_mlir_packed_safe_call__(
tvm_ffi_function_ptr, keep_alive_object=self.engine)
return tvm_ffi_function
return None
class TVMFFIJitCompiledFunction(tvm_ffi.Function, TVMFFIJitCompiledFunctionBase):
"""TVM FFI Function that directly subclasses the tvm_ffi.Function for pos only arguments.
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# initialize the tvm_ffi.Function from the current execution engine
if self.__chandle__() != 0:
raise DSLRuntimeError("TVM FFI function is already initialized")
tvm_ffi_function = self._create_tvm_ffi_function()
if tvm_ffi_function is not None:
# move the handle from the tvm_ffi.Function to the current instance
self.__move_handle_from__(tvm_ffi_function)
# use direct call to the tvm_ffi.Function.__call__
# to avoid most of python overhead
__call__ = tvm_ffi.Function.__call__
class TVMFFIJitCompiledFunctionWithKwargs(TVMFFIJitCompiledFunctionBase):
"""TVM FFI Function with kwargs wrapper support
"""
def __init__(self, *args, **kwargs):
assert "kwargs_wrapper_spec" in kwargs, "kwargs_wrapper_spec is required"
kwargs_wrapper_spec = kwargs.pop("kwargs_wrapper_spec")
super().__init__(*args, **kwargs)
# initialize the tvm_ffi.Function from the current execution engine
self._tvm_ffi_function = self._create_tvm_ffi_function()
if kwargs_wrapper_spec.kwonly_names or kwargs_wrapper_spec.arg_defaults:
try:
from tvm_ffi.utils import kwargs_wrapper # type: ignore
self._kwargs_wrapper = kwargs_wrapper.make_kwargs_wrapper(
self._tvm_ffi_function,
arg_names=kwargs_wrapper_spec.arg_names,
arg_defaults=kwargs_wrapper_spec.arg_defaults,
kwonly_names=kwargs_wrapper_spec.kwonly_names,
kwonly_defaults=kwargs_wrapper_spec.kwonly_defaults,
)
except ImportError:
raise DSLRuntimeError("install apache-tvm-ffi>=0.1.5 to enable kwargs/defaults")
else:
# positional only is probably fine
self._kwargs_wrapper = self._tvm_ffi_function
def __call__(self, *args, **kwargs):
"""Call the TVM FFI function with kwargs wrapper.
"""
return self._kwargs_wrapper(*args, **kwargs)
def __tvm_ffi_object__(self):
return self._tvm_ffi_function
def supports_kwargs_wrapper() -> bool:
"""Check if the kwargs wrapper is supported."""
try:
from tvm_ffi.utils import kwargs_wrapper # type: ignore
return True
except ImportError:
return False

View File

@@ -1,3 +1,3 @@
# Use `pip install -r requirements.txt` with the present file to install a
# wheel consistent with the present state of the github repository
nvidia-cutlass-dsl==4.3.0
nvidia-cutlass-dsl==4.3.3