mirror of
https://github.com/NVIDIA/cutlass.git
synced 2026-04-20 06:48:59 +00:00
v4.3.3 update. (#2868)
This commit is contained in:
@@ -14,6 +14,8 @@ This module provides jit cache load/dump helper functions
|
||||
"""
|
||||
|
||||
import os
|
||||
import io
|
||||
import sys
|
||||
import uuid
|
||||
import random
|
||||
import tempfile
|
||||
@@ -22,7 +24,7 @@ import time
|
||||
from pathlib import Path
|
||||
import hashlib
|
||||
from functools import lru_cache
|
||||
import tempfile
|
||||
import zlib
|
||||
|
||||
from .utils.logger import log
|
||||
from .jit_executor import JitCompiledFunction
|
||||
@@ -74,13 +76,58 @@ def get_default_file_dump_root():
|
||||
return dump_root
|
||||
|
||||
|
||||
def load_ir(file, asBytecode=False):
|
||||
def write_bytecode_with_crc32(f, module):
|
||||
"""Write the bytecode to the file and calculate the crc32 checksum.
|
||||
|
||||
:param f: The file to write the bytecode to.
|
||||
:type f: file
|
||||
:param module: The IR module to write the bytecode to.
|
||||
:type module: object
|
||||
"""
|
||||
s = io.BytesIO()
|
||||
module.operation.write_bytecode(s)
|
||||
content = s.getvalue()
|
||||
crc = zlib.crc32(content)
|
||||
s.write(crc.to_bytes(4, sys.byteorder))
|
||||
f.write(s.getvalue())
|
||||
return
|
||||
|
||||
|
||||
def read_bytecode_and_check_crc32(f):
|
||||
"""
|
||||
Read the bytecode from the file and check the crc32 checksum.
|
||||
|
||||
:param f: The file to read the bytecode with appended CRC32 from.
|
||||
:type f: file
|
||||
:return: The bytecode content if checksum matches.
|
||||
:rtype: bytes
|
||||
:raises DSLRuntimeError: If checksum does not match.
|
||||
"""
|
||||
content = f.read()
|
||||
if len(content) < 4:
|
||||
raise DSLRuntimeError(
|
||||
f"File {f.name} does not contain enough data for CRC32 checksum."
|
||||
)
|
||||
bytecode = content[:-4]
|
||||
crc_appended = content[-4:]
|
||||
crc_appended_int = int.from_bytes(crc_appended, sys.byteorder)
|
||||
crc_computed = zlib.crc32(bytecode)
|
||||
if crc_appended_int != crc_computed:
|
||||
raise DSLRuntimeError(
|
||||
f"CRC32 checksum mismatch! Expected {crc_computed}, got {crc_appended_int}"
|
||||
)
|
||||
return ir.Module.parse(bytecode)
|
||||
|
||||
|
||||
def load_ir(file, asBytecode=False, bytecode_reader=None):
|
||||
"""Load generated IR from a file.
|
||||
|
||||
:param file: The path to the file to load.
|
||||
:type file: str
|
||||
:param asBytecode: Whether to load the IR as bytecode, defaults to False
|
||||
:type asBytecode: bool, optional
|
||||
:param bytecode_reader: The bytecode reader to use, defaults to None
|
||||
:type bytecode_reader: callable, optional
|
||||
:return: The function name and the IR module
|
||||
:rtype: tuple[str, ir.Module]
|
||||
"""
|
||||
@@ -88,8 +135,10 @@ def load_ir(file, asBytecode=False):
|
||||
func_name = file.split(".mlir")[0].split("dsl_")[-1]
|
||||
with ir.Context() as ctx:
|
||||
with open(file, "rb" if asBytecode else "r") as f:
|
||||
module = ir.Module.parse(f.read())
|
||||
|
||||
if bytecode_reader:
|
||||
module = bytecode_reader(f)
|
||||
else:
|
||||
module = ir.Module.parse(f.read())
|
||||
return func_name, module
|
||||
|
||||
|
||||
@@ -128,6 +177,14 @@ def save_ir(
|
||||
:type module: object
|
||||
:param fname: The name of the file to save.
|
||||
:type fname: str
|
||||
:param output_dir: The path to the output directory, defaults to None
|
||||
:type output_dir: str, optional
|
||||
:param as_bytecode: Whether to save the IR as bytecode, defaults to False
|
||||
:type as_bytecode: bool, optional
|
||||
:param bytecode_writer: The bytecode writer to use, defaults to None
|
||||
:type bytecode_writer: callable, optional
|
||||
:return: The path to the saved file
|
||||
:rtype: str
|
||||
"""
|
||||
initial_name = f"{dsl_name.lower()}_{fname}.mlir"
|
||||
save_path = Path(output_dir if output_dir else tempfile.gettempdir())
|
||||
@@ -158,63 +215,45 @@ def save_ir(
|
||||
return save_fname
|
||||
|
||||
|
||||
def check_func_name(jit_cache, func_name):
|
||||
"""Check if the function name is in the cache.
|
||||
If not, create a new JitCompiledFunction object and add it to the cache.
|
||||
|
||||
:param jit_cache: The cache to check.
|
||||
:type jit_cache: dict
|
||||
:param func_name: The name of the function to check.
|
||||
:type func_name: str
|
||||
:return: The cache
|
||||
:rtype: dict
|
||||
"""
|
||||
if not func_name in jit_cache:
|
||||
jit_cache[func_name] = JitCompiledFunction(
|
||||
None, None, None, None, None, [], False, None
|
||||
)
|
||||
return jit_cache
|
||||
|
||||
|
||||
def load_cache_from_path(dsl_name, cache_limit, path=default_generated_ir_path):
|
||||
def load_cache_from_path(
|
||||
dsl_name, file, path=default_generated_ir_path, bytecode_reader=None
|
||||
):
|
||||
"""Load cache from a directory path.
|
||||
|
||||
:param dsl_name: The name of the DSL.
|
||||
:type dsl_name: str
|
||||
:param cache_limit: The limit of the cache.
|
||||
:type cache_limit: int
|
||||
:param file: The name of the file to load.
|
||||
:type file: str
|
||||
:param path: The path to the cache directory, defaults to default_generated_ir_path
|
||||
:type path: str, optional
|
||||
:param bytecode_reader: The bytecode reader to use, defaults to None
|
||||
:type bytecode_reader: callable, optional
|
||||
:return: The cache
|
||||
:rtype: dict
|
||||
"""
|
||||
if not os.path.exists(path):
|
||||
return dict()
|
||||
files = os.listdir(path)
|
||||
jit_cache = dict()
|
||||
return None
|
||||
ret = None
|
||||
try:
|
||||
for idx, file in enumerate(files):
|
||||
if idx >= int(cache_limit):
|
||||
break
|
||||
# identify dsl prefix
|
||||
if not file.startswith(f"{dsl_name.lower()}"):
|
||||
continue
|
||||
if ".mlir" in file:
|
||||
func_name, ir_module = load_ir(
|
||||
os.path.join(path, file), asBytecode=True
|
||||
)
|
||||
jit_cache = check_func_name(jit_cache, func_name)
|
||||
jit_cache[func_name].ir_module = ir_module
|
||||
file = f"{dsl_name.lower()}_{file}.mlir"
|
||||
if os.path.exists(os.path.join(path, file)):
|
||||
_, module = load_ir(
|
||||
os.path.join(path, file),
|
||||
asBytecode=True,
|
||||
bytecode_reader=bytecode_reader,
|
||||
)
|
||||
ret = JitCompiledFunction(module, None, None, None, None, [], False, None)
|
||||
except Exception as e:
|
||||
print(f"{dsl_name} failed with loading generated IR cache.", e)
|
||||
jit_cache = dict()
|
||||
return jit_cache
|
||||
log().warning(
|
||||
f"{dsl_name} failed with loading generated IR cache for {file}.", e
|
||||
)
|
||||
return ret
|
||||
|
||||
|
||||
def dump_cache_to_path(
|
||||
dsl_name,
|
||||
jit_cache,
|
||||
cache_limit,
|
||||
jit_function,
|
||||
file,
|
||||
path=default_generated_ir_path,
|
||||
bytecode_writer=None,
|
||||
):
|
||||
@@ -222,30 +261,29 @@ def dump_cache_to_path(
|
||||
|
||||
:param dsl_name: The name of the DSL.
|
||||
:type dsl_name: str
|
||||
:param jit_cache: The cache to dump.
|
||||
:type jit_cache: dict
|
||||
:param cache_limit: The limit of the cache.
|
||||
:type cache_limit: int
|
||||
:param jit_function: The JitCompiledFunction to dump.
|
||||
:type jit_function: JitCompiledFunction
|
||||
:param file: The name of the file to dump.
|
||||
:type file: str
|
||||
:param path: The path to the cache directory, defaults to default_generated_ir_path
|
||||
:type path: str, optional
|
||||
:param bytecode_writer: The bytecode writer to use, defaults to None
|
||||
:type bytecode_writer: callable, optional
|
||||
"""
|
||||
log().info("JIT cache : dumping [%s] items=[%s]", dsl_name, len(jit_cache))
|
||||
log().info("JIT cache : dumping [%s] file=[%s]", dsl_name, file)
|
||||
if not path:
|
||||
path = default_generated_ir_path
|
||||
os.makedirs(path, exist_ok=True)
|
||||
try:
|
||||
for idx, [key, value] in enumerate(jit_cache.items()):
|
||||
if idx >= int(cache_limit):
|
||||
break
|
||||
save_ir(
|
||||
dsl_name,
|
||||
value.ir_module,
|
||||
key,
|
||||
output_dir=path,
|
||||
as_bytecode=True,
|
||||
bytecode_writer=bytecode_writer,
|
||||
)
|
||||
save_ir(
|
||||
dsl_name,
|
||||
jit_function.ir_module,
|
||||
file,
|
||||
output_dir=path,
|
||||
as_bytecode=True,
|
||||
bytecode_writer=bytecode_writer,
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"{dsl_name} failed with caching generated IR", e)
|
||||
log().warning(
|
||||
f"{dsl_name} failed with dumping generated IR cache for {file}: {e}"
|
||||
)
|
||||
|
||||
@@ -195,6 +195,10 @@ def _get_friendly_cuda_error_message(error_code, error_name):
|
||||
f"2. SM ARCH setting",
|
||||
f"3. Steps to reproduce",
|
||||
),
|
||||
"cudaErrorInsufficientDriver": (
|
||||
f"1. Run nvidia-smi to confirm CUDA driver version",
|
||||
f"2. Ensure the CUDA driver version meets the requirement of the installed cuda-python package",
|
||||
),
|
||||
}
|
||||
|
||||
message = f"{error_name} (error code: {error_code}) \n" \
|
||||
|
||||
@@ -318,11 +318,8 @@ class BaseDSL:
|
||||
self.envar = self._env_class(self.name)
|
||||
self.enable_preprocessor = preprocess
|
||||
# This cache uses hash of original ir and env as key, allows dump/load to/from file. Enabled by default
|
||||
self.jit_cache = (
|
||||
dict()
|
||||
if self.envar.disable_file_caching
|
||||
else load_cache_from_path(self.name, self.envar.file_caching_capacity)
|
||||
)
|
||||
self.jit_cache = dict()
|
||||
|
||||
self.host_jit_decorator_name = f"@{BaseDSL.jit.__name__}"
|
||||
self.device_jit_decorator_name = f"@{BaseDSL.kernel.__name__}"
|
||||
|
||||
@@ -372,12 +369,6 @@ class BaseDSL:
|
||||
|
||||
atexit.register(restore_excepthook, origin_excepthook)
|
||||
|
||||
def dump_cache(self, path=None):
|
||||
if not self.envar.disable_file_caching:
|
||||
dump_cache_to_path(
|
||||
self.name, self.jit_cache, self.envar.file_caching_capacity, path=path
|
||||
)
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def print_warning_once(self, message):
|
||||
log().warning(f"Warning: {message}")
|
||||
@@ -392,9 +383,6 @@ class BaseDSL:
|
||||
def _get_dsl(cls):
|
||||
# Instantiate the DSL Class once
|
||||
main_dsl = cls()
|
||||
if not main_dsl.no_cache:
|
||||
# register atexit callback
|
||||
atexit.register(main_dsl.dump_cache)
|
||||
return main_dsl
|
||||
|
||||
@staticmethod
|
||||
@@ -1235,6 +1223,16 @@ class BaseDSL:
|
||||
log().debug(f"Using pipeline = {pipeline}")
|
||||
shared_libs = self.get_shared_libs()
|
||||
profiler = timer(enable=self.envar.jit_time_profiling)
|
||||
# try load the file cache
|
||||
load_from_file_cache = False
|
||||
if not no_cache:
|
||||
fn = load_cache_from_path(
|
||||
self.name, module_hash, bytecode_reader=read_bytecode_and_check_crc32
|
||||
)
|
||||
if fn is not None:
|
||||
load_from_file_cache = True
|
||||
self.jit_cache[module_hash] = fn
|
||||
|
||||
if (
|
||||
no_cache
|
||||
or module_hash not in self.jit_cache
|
||||
@@ -1288,6 +1286,16 @@ class BaseDSL:
|
||||
if not no_cache:
|
||||
# module stored in cache is compiled.
|
||||
self.jit_cache[module_hash] = fn
|
||||
# write through the file cache if enabled.
|
||||
if not self.envar.disable_file_caching and not load_from_file_cache:
|
||||
dump_cache_to_path(
|
||||
self.name,
|
||||
fn,
|
||||
module_hash,
|
||||
bytecode_writer=lambda f: write_bytecode_with_crc32(
|
||||
f, fn.ir_module
|
||||
),
|
||||
)
|
||||
|
||||
return fn
|
||||
|
||||
|
||||
@@ -311,7 +311,6 @@ class EnvironmentVarManager(LogEnvironmentManager):
|
||||
- [DSL_NAME]_ENABLE_OPTIMIZATION_WARNINGS: Enable warnings of optimization warnings (default: False)
|
||||
- [DSL_NAME]_JIT_TIME_PROFILING: Whether or not to profile the IR generation/compilation/execution time (default: False)
|
||||
- [DSL_NAME]_DISABLE_FILE_CACHING: Disable file caching (default: False)
|
||||
- [DSL_NAME]_FILE_CACHING_CAPACITY: Limits the number of the cache save/load files (default: 1000)
|
||||
- [DSL_NAME]_LIBS: Path to dependent shared libraries (default: None)
|
||||
- [DSL_NAME]_ENABLE_TVM_FFI: Enable TVM-FFI or not (default: False)
|
||||
"""
|
||||
@@ -350,9 +349,6 @@ class EnvironmentVarManager(LogEnvironmentManager):
|
||||
self.disable_file_caching = get_bool_env_var(
|
||||
f"{prefix}_DISABLE_FILE_CACHING", False
|
||||
)
|
||||
self.file_caching_capacity = get_int_env_var(
|
||||
f"{prefix}_FILE_CACHING_CAPACITY", 1000
|
||||
)
|
||||
# set cuda
|
||||
self.cuda_toolkit = get_cuda_toolkit_path()
|
||||
|
||||
|
||||
@@ -27,17 +27,21 @@ from typing import Union
|
||||
|
||||
cubin_suffix = "cubin"
|
||||
|
||||
def get_export_module(ir_module: ir.Module, symbol_prefix: str):
|
||||
def get_export_module(ir_module: ir.Module, symbol_prefix: str, *, preserve_symbols = None):
|
||||
"""Get the export module which is cloned from the original compiled ir module, and add the prefix
|
||||
to avoid the symbol conflict.
|
||||
|
||||
@param ir_module: The original compiled ir module. Comes from the JitCompiledFunction.ir_module.
|
||||
@param symbol_prefix: The prefix name of the function. This is the unique identifier name of the function to avoid symbol conflict in the generated object file.
|
||||
@param preserve_symbols: Optional symbols to preserve in the export module.
|
||||
@return: The export module of the function.
|
||||
"""
|
||||
# Add prefix for symbol names to avoid conflict with other functions
|
||||
defined_symbols = set()
|
||||
|
||||
if preserve_symbols is None:
|
||||
preserve_symbols = set()
|
||||
|
||||
def walk_llvm_func_op(op):
|
||||
# not a declaration
|
||||
if (
|
||||
@@ -45,23 +49,61 @@ def get_export_module(ir_module: ir.Module, symbol_prefix: str):
|
||||
and len(op.opview.operation.regions) > 0
|
||||
and len(op.opview.operation.regions[0].blocks) > 0
|
||||
):
|
||||
defined_symbols.add(op.attributes["sym_name"].value)
|
||||
func_name = op.attributes["sym_name"].value
|
||||
# skip preserving symbols
|
||||
if func_name in preserve_symbols:
|
||||
return ir.WalkResult.ADVANCE
|
||||
defined_symbols.add(func_name)
|
||||
op.attributes["sym_name"] = ir.StringAttr.get(
|
||||
symbol_prefix + "_" + op.attributes["sym_name"].value
|
||||
symbol_prefix + "_" + func_name
|
||||
)
|
||||
return ir.WalkResult.ADVANCE
|
||||
|
||||
def walk_llvm_call_op(op):
|
||||
def walk_llvm_references(op):
|
||||
# Rename function calls
|
||||
if op.name == "llvm.call" and op.attributes["callee"].value in defined_symbols:
|
||||
op.attributes["callee"] = ir.FlatSymbolRefAttr.get(
|
||||
symbol_prefix + "_" + op.attributes["callee"].value
|
||||
)
|
||||
# Rename addressof references
|
||||
elif op.name == "llvm.mlir.addressof" and op.attributes["global_name"].value in defined_symbols:
|
||||
op.attributes["global_name"] = ir.FlatSymbolRefAttr.get(
|
||||
symbol_prefix + "_" + op.attributes["global_name"].value
|
||||
)
|
||||
# Rename global_ctors references
|
||||
elif op.name == "llvm.mlir.global_ctors" and "ctors" in op.attributes:
|
||||
ctors = list(op.attributes["ctors"])
|
||||
renamed_ctors = []
|
||||
for ctor in ctors:
|
||||
if ctor.value in defined_symbols:
|
||||
renamed_ctors.append(ir.FlatSymbolRefAttr.get(
|
||||
symbol_prefix + "_" + ctor.value
|
||||
))
|
||||
else:
|
||||
renamed_ctors.append(ctor)
|
||||
if renamed_ctors:
|
||||
op.attributes["ctors"] = ir.ArrayAttr.get(renamed_ctors)
|
||||
# Rename global_dtors references
|
||||
elif op.name == "llvm.mlir.global_dtors" and "dtors" in op.attributes:
|
||||
dtors = list(op.attributes["dtors"])
|
||||
renamed_dtors = []
|
||||
for dtor in dtors:
|
||||
if dtor.value in defined_symbols:
|
||||
renamed_dtors.append(ir.FlatSymbolRefAttr.get(
|
||||
symbol_prefix + "_" + dtor.value
|
||||
))
|
||||
else:
|
||||
renamed_dtors.append(dtor)
|
||||
if renamed_dtors:
|
||||
op.attributes["dtors"] = ir.ArrayAttr.get(renamed_dtors)
|
||||
return ir.WalkResult.ADVANCE
|
||||
|
||||
with ir.Context():
|
||||
export_module = ir.Module.parse(str(ir_module))
|
||||
# First pass: collect and rename function definitions
|
||||
export_module.operation.walk(walk_llvm_func_op)
|
||||
export_module.operation.walk(walk_llvm_call_op)
|
||||
# Second pass: rename call and addressof references
|
||||
export_module.operation.walk(walk_llvm_references)
|
||||
return export_module
|
||||
|
||||
|
||||
|
||||
@@ -16,7 +16,7 @@ This module provides jit executor related classes
|
||||
import ctypes
|
||||
import inspect
|
||||
import io
|
||||
from typing import Union, Optional
|
||||
from typing import Union, Optional, NamedTuple, Any, Sequence
|
||||
import weakref
|
||||
import threading
|
||||
import collections
|
||||
@@ -132,6 +132,15 @@ def load_kernels_from_ir_module(module, kernel_info) -> list[CudaModuleAndKernel
|
||||
return list(kernel_modules.values())
|
||||
|
||||
|
||||
|
||||
class KwargsWrapperSpec(NamedTuple):
|
||||
"""A specification for keyword arguments wrapper."""
|
||||
arg_names: list[str]
|
||||
arg_defaults: tuple[Any, ...]
|
||||
kwonly_names: list[str]
|
||||
kwonly_defaults: dict[str, Any]
|
||||
|
||||
|
||||
class ExecutionArgs:
|
||||
"""Helper that wraps the function signature spec to filter exeuction and compile time arguments."""
|
||||
|
||||
@@ -216,6 +225,59 @@ class ExecutionArgs:
|
||||
|
||||
return exe_args, adapted_args
|
||||
|
||||
def get_kwargs_wrapper_spec(self, exclude_arg_names: Sequence[str] = ()) -> KwargsWrapperSpec:
|
||||
"""
|
||||
This function is used to get the kwargs wrapper spec from the original args_spec.
|
||||
"""
|
||||
excluded_arg_names = set(exclude_arg_names)
|
||||
arg_spec = self.original_args_spec
|
||||
|
||||
if arg_spec.defaults:
|
||||
defaults_start_idx = len(arg_spec.args) - len(arg_spec.defaults)
|
||||
else:
|
||||
defaults_start_idx = len(arg_spec.args)
|
||||
|
||||
arg_names = []
|
||||
arg_defaults = []
|
||||
kwonly_names = []
|
||||
kwonly_defaults = {}
|
||||
|
||||
# Filter arguments and maintain their properties
|
||||
for i, arg_name in enumerate(arg_spec.args):
|
||||
arg_type = arg_spec.annotations.get(arg_name, None)
|
||||
|
||||
# Skip compile-time arguments
|
||||
if is_arg_spec_constexpr(arg_type, arg_name, i, self.function_name):
|
||||
continue
|
||||
if arg_name in excluded_arg_names:
|
||||
continue
|
||||
arg_names.append(arg_name)
|
||||
|
||||
if i >= defaults_start_idx:
|
||||
arg_defaults.append(arg_spec.defaults[i - defaults_start_idx])
|
||||
|
||||
if arg_spec.kwonlyargs:
|
||||
for i, kwarg in enumerate(arg_spec.kwonlyargs):
|
||||
arg_type = arg_spec.annotations.get(kwarg, None)
|
||||
|
||||
# Skip compile-time arguments
|
||||
if is_arg_spec_constexpr(arg_type, kwarg, i, self.function_name):
|
||||
continue
|
||||
|
||||
if kwarg in excluded_arg_names:
|
||||
continue
|
||||
|
||||
kwonly_names.append(kwarg)
|
||||
if arg_spec.kwonlydefaults and kwarg in arg_spec.kwonlydefaults:
|
||||
kwonly_defaults[kwarg] = arg_spec.kwonlydefaults[kwarg]
|
||||
|
||||
return KwargsWrapperSpec(
|
||||
arg_names=arg_names,
|
||||
arg_defaults=tuple(arg_defaults),
|
||||
kwonly_names=kwonly_names,
|
||||
kwonly_defaults=kwonly_defaults,
|
||||
)
|
||||
|
||||
def get_rectified_args_from_original_args(self, full_args, full_kwargs):
|
||||
"""
|
||||
This function is used to rectify the original arguments to the runtime
|
||||
@@ -233,6 +295,7 @@ class ExecutionArgs:
|
||||
defaults_start_idx = len(arg_spec.args)
|
||||
|
||||
runtime_args = []
|
||||
|
||||
# Filter arguments and maintain their properties
|
||||
for i, arg_name in enumerate(arg_spec.args):
|
||||
arg_type = arg_spec.annotations.get(arg_name, None)
|
||||
@@ -241,12 +304,24 @@ class ExecutionArgs:
|
||||
if is_arg_spec_constexpr(arg_type, arg_name, i, self.function_name):
|
||||
continue
|
||||
|
||||
# Keep corresponding default if it exists
|
||||
if i >= defaults_start_idx:
|
||||
# Check if argument was provided by user, otherwise use default
|
||||
if i < len(full_args):
|
||||
# User provided this argument - use it
|
||||
runtime_args.append(full_args[i])
|
||||
elif i >= defaults_start_idx:
|
||||
# Argument not provided, but has default - use default
|
||||
default_idx = i - defaults_start_idx
|
||||
runtime_args.append(arg_spec.defaults[default_idx])
|
||||
else:
|
||||
runtime_args.append(full_args[i])
|
||||
# Required argument missing
|
||||
raise DSLRuntimeError(
|
||||
f"Missing required argument '{arg_name}' at position {i}",
|
||||
context={
|
||||
"function_name": self.function_name,
|
||||
"expected_args": len(arg_spec.args),
|
||||
"provided_args": len(full_args),
|
||||
}
|
||||
)
|
||||
|
||||
# Filter keyword-only arguments
|
||||
runtime_kwargs = {}
|
||||
|
||||
@@ -21,6 +21,7 @@ import os
|
||||
import ctypes
|
||||
|
||||
import cuda.bindings.driver as cuda
|
||||
import cuda.bindings.runtime as cudart
|
||||
import cuda.bindings.nvrtc as nvrtc
|
||||
|
||||
# Local module imports
|
||||
@@ -44,6 +45,8 @@ def _cudaGetErrorEnum(error):
|
||||
if isinstance(error, cuda.CUresult):
|
||||
err, name = cuda.cuGetErrorName(error)
|
||||
return name if err == cuda.CUresult.CUDA_SUCCESS else "<unknown>"
|
||||
elif isinstance(error, cudart.cudaError_t):
|
||||
return cudart.cudaGetErrorName(error)[1]
|
||||
elif isinstance(error, nvrtc.nvrtcResult):
|
||||
return nvrtc.nvrtcGetErrorString(error)[1]
|
||||
else:
|
||||
|
||||
@@ -38,6 +38,7 @@ class MLIRTypeBuilder:
|
||||
self.gpu_ptr_type = llvm.PointerType.get(address_space=1)
|
||||
# did not find a programmatic way to get the void type
|
||||
self.void_type = ir.Type.parse("!llvm.void")
|
||||
self.llvm_internal_linkage = ir.Attribute.parse("#llvm.linkage<internal>")
|
||||
|
||||
def ptr_type_with_address_space(
|
||||
self, address_space: Optional[int] = None
|
||||
@@ -374,8 +375,10 @@ class MLIRBuilder(MLIRTypeBuilder):
|
||||
params_type: Sequence[ir.Type],
|
||||
ret_type: ir.Type,
|
||||
internal: bool = False,
|
||||
llvm_func_attrs: Sequence[str] = (),
|
||||
) -> tuple[list[ir.Value], ir.Block]:
|
||||
"""Create a function with the given signature."""
|
||||
"""Create a function with the given signature.
|
||||
"""
|
||||
func_op = llvm.func(
|
||||
name,
|
||||
function_type=self.as_attr(
|
||||
@@ -383,10 +386,16 @@ class MLIRBuilder(MLIRTypeBuilder):
|
||||
),
|
||||
)
|
||||
if internal:
|
||||
func_op.attributes["llvm.linkage"] = ir.StringAttr.get("private")
|
||||
func_op.attributes["linkage"] = self.llvm_internal_linkage
|
||||
else:
|
||||
func_op.attributes["llvm.emit_c_interface"] = ir.UnitAttr.get()
|
||||
|
||||
# Add LLVM function attributes via passthrough
|
||||
if llvm_func_attrs:
|
||||
func_op.attributes["passthrough"] = ir.ArrayAttr.get(
|
||||
[ir.StringAttr.get(attr) for attr in llvm_func_attrs]
|
||||
)
|
||||
|
||||
params = []
|
||||
func_body: Any = func_op.body
|
||||
if func_body is not None:
|
||||
|
||||
@@ -668,12 +668,14 @@ class TVMFFIBuilder(MLIRBuilder):
|
||||
param_types.append(self.ptr_type) # p0, p1, ..., pN-1
|
||||
|
||||
# Create the helper function
|
||||
# Mark as noinline since error handling is a slow path and benefits from not inlining
|
||||
with ir.InsertionPoint(self.module.body): # type: ignore[union-attr]
|
||||
params, entry_block = self.function(
|
||||
name=helper_name,
|
||||
params_type=param_types,
|
||||
ret_type=self.void_type,
|
||||
internal=True,
|
||||
llvm_func_attrs=["noinline"],
|
||||
)
|
||||
|
||||
kind_param = params[0]
|
||||
@@ -1244,12 +1246,7 @@ class TVMFFIFunctionBuilder(TVMFFIBuilder):
|
||||
else:
|
||||
assert isinstance(var, int)
|
||||
with ir.InsertionPoint(current_block):
|
||||
if not skip_cast_and_check:
|
||||
expected_value = self.i64(var)
|
||||
else:
|
||||
expected_value = self.downcast_i64_to_lower_bits(
|
||||
self.i64(var), var.dtype
|
||||
)
|
||||
expected_value = self.i64(var)
|
||||
|
||||
error_msg_mismatch = [
|
||||
error_prefix_mismatch,
|
||||
@@ -1983,13 +1980,21 @@ class TVMFFIFunctionBuilder(TVMFFIBuilder):
|
||||
)
|
||||
|
||||
# decode parameters to populate the matched var binding
|
||||
for arg_index, param in enumerate(params_list):
|
||||
# Track the actual FFI argument index separately from parameter index
|
||||
# since some parameters (like EnvStream) are not passed as FFI arguments
|
||||
ffi_arg_index = 0
|
||||
for param in params_list:
|
||||
# Skip EnvStream parameters as they are not in the FFI args array
|
||||
if isinstance(param, spec.EnvStream):
|
||||
continue
|
||||
|
||||
arg_context = ArgContext(
|
||||
param_name=param.name,
|
||||
arg_index=arg_index,
|
||||
arg_index=ffi_arg_index,
|
||||
tuple_indices=[],
|
||||
)
|
||||
current_block = self.decode_param(current_block, param, args, arg_index, arg_context)
|
||||
current_block = self.decode_param(current_block, param, args, ffi_arg_index, arg_context)
|
||||
ffi_arg_index += 1
|
||||
|
||||
with ir.InsertionPoint(current_block):
|
||||
env_stream = self.find_env_stream(params_list)
|
||||
@@ -2047,7 +2052,9 @@ def attach_ffi_func(
|
||||
builder.attach_ffi_func(symbol_name, params, call_provider, fn_display_name)
|
||||
|
||||
|
||||
def rename_tvm_ffi_function(module: ir.Module, old_name: str, new_name: str) -> None:
|
||||
def rename_tvm_ffi_function(
|
||||
module: ir.Module, old_name: str, new_name: str,
|
||||
) -> None:
|
||||
"""Rename the TVM FFI function in the module.
|
||||
|
||||
Parameters
|
||||
|
||||
@@ -19,6 +19,7 @@ from .typing import Tensor, Pointer, SymInt
|
||||
from .typing import (
|
||||
Numeric,
|
||||
Boolean,
|
||||
Integer,
|
||||
Int4,
|
||||
Int8,
|
||||
Uint8,
|
||||
@@ -42,7 +43,17 @@ from .typing import (
|
||||
)
|
||||
import cuda.bindings.driver as cuda
|
||||
|
||||
from typing import List, Dict, Any, Optional, Tuple, get_origin, get_args
|
||||
from typing import (
|
||||
List,
|
||||
Dict,
|
||||
Any,
|
||||
Optional,
|
||||
Tuple,
|
||||
get_origin,
|
||||
get_args,
|
||||
get_type_hints,
|
||||
)
|
||||
from types import UnionType
|
||||
import inspect
|
||||
|
||||
NumericToTVMFFIDtype = {
|
||||
@@ -91,6 +102,7 @@ def _get_llvm_address_space_from_memspace(
|
||||
return 1
|
||||
return None
|
||||
|
||||
|
||||
def _is_gpu_memspace(
|
||||
memspace: _cute_ir.AddressSpace,
|
||||
) -> bool:
|
||||
@@ -108,7 +120,6 @@ class SymIntId:
|
||||
return self.sym_int is other.sym_int
|
||||
|
||||
|
||||
|
||||
class ConverterContext:
|
||||
"""Context for managing variable allocation during TVM FFI args conversion."""
|
||||
|
||||
@@ -145,7 +156,9 @@ class ConverterContext:
|
||||
self.sym_int_id_mapping[sym_int_id] = var
|
||||
return var
|
||||
|
||||
def alloc_or_reuse_device_id(self, device_type: str, vdevice_id: int) -> Optional[spec.Var]:
|
||||
def alloc_or_reuse_device_id(
|
||||
self, device_type: str, vdevice_id: int
|
||||
) -> Optional[spec.Var]:
|
||||
"""Allocate or reuse a device_id variable for a given virtual device.
|
||||
|
||||
This function returns None for CPU tensors.
|
||||
@@ -166,10 +179,7 @@ class ConverterContext:
|
||||
|
||||
|
||||
def _convert_single_arg(
|
||||
arg,
|
||||
arg_name: str,
|
||||
arg_type,
|
||||
ctx: ConverterContext
|
||||
arg, arg_name: str, arg_type, ctx: ConverterContext
|
||||
) -> spec.Param:
|
||||
"""Convert a single argument to a spec.Param.
|
||||
|
||||
@@ -191,7 +201,7 @@ def _convert_single_arg(
|
||||
"""
|
||||
if arg is None:
|
||||
return spec.ConstNone(arg_name)
|
||||
elif (isinstance(arg, Numeric) and arg.dtype in AcceptableNumericTypesForScalar):
|
||||
elif isinstance(arg, Numeric) and arg.dtype in AcceptableNumericTypesForScalar:
|
||||
return spec.Var(arg_name, NumericToTVMFFIDtype[arg.dtype])
|
||||
elif arg_type in AcceptableNumericTypesForScalar:
|
||||
return spec.Var(arg_name, NumericToTVMFFIDtype[arg_type])
|
||||
@@ -201,9 +211,13 @@ def _convert_single_arg(
|
||||
if isinstance(arg[i], int):
|
||||
shape.append(arg[i])
|
||||
elif isinstance(arg[i], SymInt):
|
||||
shape.append(ctx.alloc_or_reuse_symint_var(arg[i], ctx.alloc_shape_name))
|
||||
shape.append(
|
||||
ctx.alloc_or_reuse_symint_var(arg[i], ctx.alloc_shape_name)
|
||||
)
|
||||
else:
|
||||
shape.append(spec.Var(ctx.alloc_shape_name(), NumericToTVMFFIDtype[arg[i].dtype]))
|
||||
shape.append(
|
||||
spec.Var(ctx.alloc_shape_name(), NumericToTVMFFIDtype[arg[i].dtype])
|
||||
)
|
||||
return spec.Shape(arg_name, shape)
|
||||
elif isinstance(arg, Tensor):
|
||||
shapes = []
|
||||
@@ -211,16 +225,22 @@ def _convert_single_arg(
|
||||
if not dyn_mask:
|
||||
shapes.append(arg.shape[i])
|
||||
elif isinstance(arg.shape[i], SymInt):
|
||||
shapes.append(ctx.alloc_or_reuse_symint_var(arg.shape[i], ctx.alloc_shape_name))
|
||||
shapes.append(
|
||||
ctx.alloc_or_reuse_symint_var(arg.shape[i], ctx.alloc_shape_name)
|
||||
)
|
||||
else:
|
||||
shapes.append(spec.Var(ctx.alloc_shape_name(), NumericToTVMFFIDtype[Int32]))
|
||||
shapes.append(
|
||||
spec.Var(ctx.alloc_shape_name(), NumericToTVMFFIDtype[Int32])
|
||||
)
|
||||
strides = []
|
||||
|
||||
for i, dyn_mask in enumerate(arg.dynamic_strides_mask):
|
||||
if not dyn_mask:
|
||||
strides.append(arg.stride[i])
|
||||
elif isinstance(arg.stride[i], SymInt):
|
||||
strides.append(ctx.alloc_or_reuse_symint_var(arg.stride[i], ctx.alloc_stride_name))
|
||||
strides.append(
|
||||
ctx.alloc_or_reuse_symint_var(arg.stride[i], ctx.alloc_stride_name)
|
||||
)
|
||||
else:
|
||||
if hasattr(arg, "_use_32bit_stride") and arg._use_32bit_stride:
|
||||
dtype = NumericToTVMFFIDtype[Int32]
|
||||
@@ -243,7 +263,7 @@ def _convert_single_arg(
|
||||
strides=strides,
|
||||
data_alignment=arg._assumed_align,
|
||||
device_type=device_type,
|
||||
device_id=device_id
|
||||
device_id=device_id,
|
||||
)
|
||||
else:
|
||||
# for FakeTensor, strictly follow the shape and stride from the cute tensor
|
||||
@@ -259,7 +279,7 @@ def _convert_single_arg(
|
||||
strides=strides,
|
||||
data_alignment=arg._assumed_align,
|
||||
device_type=device_type,
|
||||
device_id=device_id
|
||||
device_id=device_id,
|
||||
)
|
||||
if arg.element_type == Float4E2M1FN:
|
||||
tvm_ffi_cute_tensor = spec.create_map_tensor_dtype_f4x2_to_f4_spec(
|
||||
@@ -278,11 +298,38 @@ def _convert_single_arg(
|
||||
return spec.Stream(arg_name)
|
||||
elif isinstance(arg, cuda.CUstream):
|
||||
return spec.Stream(arg_name)
|
||||
elif arg_type is not None and hasattr(arg_type, "_fields"):
|
||||
# Handle NamedTuple - normalize to Tuple by order of fields, ignoring defaults
|
||||
# Get field types from annotations
|
||||
type_hints = get_type_hints(arg_type)
|
||||
tuple_element_types = [type_hints[field] for field in arg_type._fields]
|
||||
|
||||
# NamedTuples inherit from tuple, so we can check with isinstance(arg, tuple)
|
||||
if not isinstance(arg, tuple):
|
||||
raise DSLRuntimeError(
|
||||
f"Expected namedtuple for argument {arg_name}, got {type(arg)}"
|
||||
)
|
||||
if len(arg) != len(tuple_element_types):
|
||||
raise DSLRuntimeError(
|
||||
f"NamedTuple length mismatch for argument {arg_name}: "
|
||||
f"expected {len(tuple_element_types)}, got {len(arg)}"
|
||||
)
|
||||
|
||||
# Recursively convert each tuple element
|
||||
tuple_params = []
|
||||
for i, (elem, elem_type) in enumerate(zip(arg, tuple_element_types)):
|
||||
elem_name = f"{arg_name}[{i}]"
|
||||
elem_param = _convert_single_arg(elem, elem_name, elem_type, ctx)
|
||||
tuple_params.append(elem_param)
|
||||
|
||||
return spec.TupleParam(arg_name, tuple_params)
|
||||
elif arg_type is not None and get_origin(arg_type) is tuple:
|
||||
# Handle Tuple[X, Y, ...] type annotations
|
||||
tuple_element_types = get_args(arg_type)
|
||||
if not isinstance(arg, (tuple, list)):
|
||||
raise DSLRuntimeError(f"Expected tuple for argument {arg_name}, got {type(arg)}")
|
||||
raise DSLRuntimeError(
|
||||
f"Expected tuple for argument {arg_name}, got {type(arg)}"
|
||||
)
|
||||
if len(arg) != len(tuple_element_types):
|
||||
raise DSLRuntimeError(
|
||||
f"Tuple length mismatch for argument {arg_name}: "
|
||||
@@ -297,8 +344,24 @@ def _convert_single_arg(
|
||||
tuple_params.append(elem_param)
|
||||
|
||||
return spec.TupleParam(arg_name, tuple_params)
|
||||
elif isinstance(arg, (tuple, list)):
|
||||
# Handle plain tuple type annotation without explicit element types
|
||||
# Recursively convert each tuple element with None as elem_type (un-annotated)
|
||||
tuple_params = []
|
||||
for i, elem in enumerate(arg):
|
||||
elem_name = f"{arg_name}[{i}]"
|
||||
elem_param = _convert_single_arg(elem, elem_name, None, ctx)
|
||||
tuple_params.append(elem_param)
|
||||
return spec.TupleParam(arg_name, tuple_params)
|
||||
elif isinstance(arg, int):
|
||||
# in cute.compile, unannotated const int is converted to int32
|
||||
return spec.Var(arg_name, NumericToTVMFFIDtype[Int32])
|
||||
elif isinstance(arg, float):
|
||||
return spec.Var(arg_name, NumericToTVMFFIDtype[Float32])
|
||||
else:
|
||||
raise DSLRuntimeError(f"Unsupported argument type: {type(arg)}")
|
||||
raise DSLRuntimeError(
|
||||
f"Unsupported argument type: {type(arg)} for annotated type: {get_origin(arg_type)}"
|
||||
)
|
||||
|
||||
|
||||
def _tvm_ffi_args_spec_converter(
|
||||
@@ -312,17 +375,24 @@ def _tvm_ffi_args_spec_converter(
|
||||
This function converts the cute arguments specs to tvm ffi spec params.
|
||||
"""
|
||||
exec_args = ExecutionArgs(args_spec, function_name)
|
||||
rectified_args = exec_args.get_rectified_args_from_original_args(full_args, full_kwargs)
|
||||
rectified_args = exec_args.get_rectified_args_from_original_args(
|
||||
full_args, full_kwargs
|
||||
)
|
||||
arg_names = exec_args.args_spec.args + exec_args.args_spec.kwonlyargs
|
||||
params = []
|
||||
ctx = ConverterContext()
|
||||
wrapper_extra_exclude_arg_names = []
|
||||
|
||||
for arg, arg_name in zip(rectified_args, arg_names):
|
||||
arg_type = args_spec.annotations.get(arg_name, None)
|
||||
param = _convert_single_arg(arg, arg_name, arg_type, ctx)
|
||||
params.append(param)
|
||||
|
||||
return params
|
||||
if isinstance(param, spec.EnvStream):
|
||||
wrapper_extra_exclude_arg_names.append(arg_name)
|
||||
kwargs_wrapper_spec = exec_args.get_kwargs_wrapper_spec(
|
||||
wrapper_extra_exclude_arg_names
|
||||
)
|
||||
return params, kwargs_wrapper_spec
|
||||
|
||||
|
||||
def attach_args_spec_converter():
|
||||
|
||||
@@ -28,8 +28,9 @@ from ..base_dsl.jit_executor import (
|
||||
JitFunctionArtifacts,
|
||||
)
|
||||
from ..base_dsl.utils.logger import log
|
||||
from ..base_dsl.common import DSLCudaRuntimeError, DSLRuntimeError
|
||||
from ..base_dsl.common import DSLRuntimeError
|
||||
from ..base_dsl.typing import Int32
|
||||
from ..base_dsl.runtime.cuda import checkCudaErrors
|
||||
|
||||
class CudaDialectJitModule:
|
||||
"""Holds the execution engine and cuda libraries."""
|
||||
@@ -113,10 +114,7 @@ class CudaDialectJitCompiledFunction(JitCompiledFunction):
|
||||
@functools.cached_property
|
||||
def num_devices(self):
|
||||
"""Returns the number of CUDA devices available."""
|
||||
dev_err, devs = cuda_runtime.cudaGetDeviceCount()
|
||||
if dev_err != cuda_runtime.cudaError_t.cudaSuccess:
|
||||
raise DSLCudaRuntimeError(dev_err, cuda_runtime.cudaGetErrorName(dev_err))
|
||||
return devs
|
||||
return checkCudaErrors(cuda_runtime.cudaGetDeviceCount())
|
||||
|
||||
def _deserializer(self):
|
||||
"""Load the cuda library from the binary execution engine.
|
||||
@@ -148,12 +146,7 @@ class CudaDialectJitCompiledFunction(JitCompiledFunction):
|
||||
packed_args[i] = ctypes.cast(cuda_init_args[i], ctypes.c_void_p)
|
||||
cuda_init(packed_args)
|
||||
|
||||
if err.value != 0:
|
||||
error_code = err.value
|
||||
error_name = cuda_runtime.cudaGetErrorName(
|
||||
cuda_runtime.cudaError_t(error_code)
|
||||
)
|
||||
raise DSLCudaRuntimeError(error_code, error_name)
|
||||
checkCudaErrors((cuda_runtime.cudaError_t(err.value),))
|
||||
|
||||
cuda_load_args = [pointer_to_library, pointer_to_err]
|
||||
packed_args = (ctypes.c_void_p * len(cuda_load_args))()
|
||||
@@ -161,12 +154,7 @@ class CudaDialectJitCompiledFunction(JitCompiledFunction):
|
||||
packed_args[i] = ctypes.cast(cuda_load_args[i], ctypes.c_void_p)
|
||||
cuda_load(packed_args)
|
||||
|
||||
if err.value != 0:
|
||||
error_code = err.value
|
||||
error_name = cuda_runtime.cudaGetErrorName(
|
||||
cuda_runtime.cudaError_t(error_code)
|
||||
)
|
||||
raise DSLCudaRuntimeError(error_code, error_name)
|
||||
checkCudaErrors((cuda_runtime.cudaError_t(err.value),))
|
||||
|
||||
return [cuda_runtime.cudaLibrary_t(library.value)]
|
||||
|
||||
@@ -229,12 +217,7 @@ class CudaDialectJitCompiledFunction(JitCompiledFunction):
|
||||
packed_args[i] = ctypes.cast(cuda_init_args[i], ctypes.c_void_p)
|
||||
cuda_init(packed_args)
|
||||
|
||||
if err.value != 0:
|
||||
error_code = err.value
|
||||
error_name = cuda_runtime.cudaGetErrorName(
|
||||
cuda_runtime.cudaError_t(error_code)
|
||||
)
|
||||
raise DSLCudaRuntimeError(error_code, error_name)
|
||||
checkCudaErrors((cuda_runtime.cudaError_t(err.value),))
|
||||
|
||||
device_id = ctypes.c_int32(0)
|
||||
pointer_to_device_id = ctypes.pointer(device_id)
|
||||
@@ -247,18 +230,9 @@ class CudaDialectJitCompiledFunction(JitCompiledFunction):
|
||||
for dev in range(self.num_devices):
|
||||
device_id.value = dev
|
||||
cuda_load_to_device(packed_args)
|
||||
if err.value != 0:
|
||||
raise DSLCudaRuntimeError(
|
||||
err.value,
|
||||
cuda_runtime.cudaGetErrorName(cuda_runtime.cudaError_t(err.value)),
|
||||
)
|
||||
checkCudaErrors((cuda_runtime.cudaError_t(err.value),))
|
||||
|
||||
if err.value != 0:
|
||||
error_code = err.value
|
||||
error_name = cuda_runtime.cudaGetErrorName(
|
||||
cuda_runtime.cudaError_t(error_code)
|
||||
)
|
||||
raise DSLCudaRuntimeError(error_code, error_name)
|
||||
checkCudaErrors((cuda_runtime.cudaError_t(err.value),))
|
||||
|
||||
return [cuda_runtime.cudaLibrary_t(library.value)]
|
||||
|
||||
|
||||
@@ -43,9 +43,9 @@ from ..base_dsl.dsl import is_dynamic_expression, extract_mlir_values
|
||||
from ..base_dsl.typing import *
|
||||
from ..base_dsl.typing import DynamicExpression, get_mlir_types
|
||||
from ..base_dsl.runtime.jit_arg_adapters import is_arg_spec_constexpr
|
||||
from ..base_dsl.jit_executor import ExecutionArgs
|
||||
from ..base_dsl.runtime import cuda as cuda_helpers
|
||||
from .cuda_stream_adapter import CudaDialectStreamAdapter
|
||||
|
||||
from .cuda_jit_executor import CudaDialectJitCompiledFunction
|
||||
|
||||
# MLIR Imports
|
||||
@@ -421,12 +421,13 @@ class CutlassBaseDSL(BaseDSL):
|
||||
# attach extra ABI function to the MLIR module
|
||||
from .tvm_ffi_provider import (
|
||||
TVMFFIJitCompiledFunction,
|
||||
TVMFFIJitCompiledFunctionWithKwargs,
|
||||
TVMFFICuteCallProvider,
|
||||
)
|
||||
from cutlass.base_dsl.tvm_ffi_builder import attach_ffi_func
|
||||
|
||||
assert self._tvm_ffi_args_spec_converter is not None
|
||||
tvm_ffi_spec_params = self._tvm_ffi_args_spec_converter(
|
||||
tvm_ffi_spec_params, kwargs_wrapper_spec = self._tvm_ffi_args_spec_converter(
|
||||
function_name, args_spec, full_args, full_kwargs
|
||||
)
|
||||
tvm_ffi_provider = TVMFFICuteCallProvider(function_name)
|
||||
@@ -444,6 +445,15 @@ class CutlassBaseDSL(BaseDSL):
|
||||
)
|
||||
module.operation.verify()
|
||||
|
||||
def _make_compiled_func(*args, **kwargs):
|
||||
if kwargs_wrapper_spec.kwonly_names or kwargs_wrapper_spec.arg_defaults:
|
||||
return TVMFFIJitCompiledFunctionWithKwargs(
|
||||
*args, **kwargs,
|
||||
kwargs_wrapper_spec=kwargs_wrapper_spec
|
||||
)
|
||||
else:
|
||||
return TVMFFIJitCompiledFunction(*args, **kwargs)
|
||||
|
||||
# ensure the compiler can run post-compile hook after its passes
|
||||
# the context will restore the previous post-compile hook after it exits
|
||||
with compiler.PostCompileHookContext(
|
||||
@@ -456,7 +466,7 @@ class CutlassBaseDSL(BaseDSL):
|
||||
pipeline,
|
||||
args_spec,
|
||||
no_cache,
|
||||
TVMFFIJitCompiledFunction,
|
||||
_make_compiled_func,
|
||||
full_args=full_args,
|
||||
full_kwargs=full_kwargs,
|
||||
dynamic_args=dynamic_args,
|
||||
|
||||
@@ -15,12 +15,14 @@ from cutlass.base_dsl.tvm_ffi_builder import (
|
||||
rename_tvm_ffi_function,
|
||||
spec,
|
||||
)
|
||||
from cutlass.base_dsl.export import get_export_module
|
||||
from cutlass._mlir import ir
|
||||
from cutlass._mlir.dialects import llvm
|
||||
from cutlass._mlir._mlir_libs._cutlass_ir import _execution_engine_extra
|
||||
from cutlass._mlir._mlir_libs._cutlass_ir import _aot_support
|
||||
from cutlass.cutlass_dsl.cuda_jit_executor import CudaDialectJitCompiledFunction
|
||||
from cutlass.base_dsl.common import DSLRuntimeError
|
||||
from typing import Optional
|
||||
from cutlass.base_dsl.jit_executor import ExecutionArgs
|
||||
from typing import Optional, Callable
|
||||
import tvm_ffi
|
||||
|
||||
|
||||
@@ -400,41 +402,52 @@ class TVMFFICuteCallProvider(DynamicParamPackCallProvider):
|
||||
return current_block
|
||||
|
||||
|
||||
class TVMFFIJitCompiledFunction(tvm_ffi.Function, CudaDialectJitCompiledFunction):
|
||||
"""TVM FFI Function that contains metadata of the compiled function and interface to the FFI layer.
|
||||
def _inplace_hide_symbols(ir_module: ir.Module, hide_check: Callable[[str], bool]):
|
||||
"""Walk through the IRModule, hide functions that do not yet have linkage set.
|
||||
|
||||
This function should not be directly used after
|
||||
@param ir_module: The ir module to hide the symbols.
|
||||
@param hide_check: The callback to check if the symbol should be hidden.
|
||||
@return: The ir module with the symbols hidden.
|
||||
"""
|
||||
defined_symbols = set()
|
||||
def walk_llvm_func_op(op):
|
||||
# not a declaration
|
||||
if (
|
||||
op.name == "llvm.func"
|
||||
and len(op.opview.operation.regions) > 0
|
||||
and len(op.opview.operation.regions[0].blocks) > 0
|
||||
):
|
||||
func_name = op.attributes["sym_name"].value
|
||||
defined_symbols.add(func_name)
|
||||
|
||||
return ir.WalkResult.ADVANCE
|
||||
|
||||
def walk_and_hide_symbols(op):
|
||||
# Handle llvm.func operations
|
||||
if op.name == "llvm.func":
|
||||
func_name = op.attributes["sym_name"].value
|
||||
# Only set linkage if it doesn't already have one
|
||||
if func_name in defined_symbols and hide_check(func_name):
|
||||
# Set to internal linkage to hide the symbol
|
||||
op.attributes["linkage"] = ir.Attribute.parse("#llvm.linkage<internal>")
|
||||
return ir.WalkResult.ADVANCE
|
||||
|
||||
with ir_module.context:
|
||||
ir_module.operation.walk(walk_llvm_func_op)
|
||||
ir_module.operation.walk(walk_and_hide_symbols)
|
||||
|
||||
|
||||
def _get_format_from_object_file_path(object_file_path: str) -> str:
|
||||
format = object_file_path.split(".")[-1]
|
||||
if format not in ("o", "ll", "bc"):
|
||||
return "o"
|
||||
return format
|
||||
|
||||
|
||||
class TVMFFIJitCompiledFunctionBase(CudaDialectJitCompiledFunction):
|
||||
"""Base class for TVM FFI compiled function."""
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
# initialize the tvm_ffi.Function from the current execution engine
|
||||
self._init_ffi_function()
|
||||
|
||||
# use direct call to the tvm_ffi.Function.__call__
|
||||
# to avoid most of python overhead
|
||||
__call__ = tvm_ffi.Function.__call__
|
||||
|
||||
def _init_ffi_function(self):
|
||||
"""Initialize the tvm_ffi.Function from the current execution engine.
|
||||
|
||||
This function must be called at once during compilation time.
|
||||
The reason why it is not called during init is because the original
|
||||
flow may already created an execution engine and the function is not
|
||||
guaranteed to be initialized at that time.
|
||||
"""
|
||||
if self.__chandle__() != 0:
|
||||
raise DSLRuntimeError("TVM FFI function is already initialized")
|
||||
# get the MLIR function pointer from the execution engine
|
||||
if self.engine is not None:
|
||||
tvm_ffi_function_ptr = self.engine.raw_lookup(
|
||||
"__tvm_ffi_" + self.function_name
|
||||
)
|
||||
tvm_ffi_function = tvm_ffi.Function.__from_mlir_packed_safe_call__(
|
||||
tvm_ffi_function_ptr
|
||||
)
|
||||
# move the handle from the tvm_ffi.Function to the current instance
|
||||
self.__move_handle_from__(tvm_ffi_function)
|
||||
|
||||
def to(self, device=None):
|
||||
"""TVM FFI function itself is already support all devices."""
|
||||
@@ -444,18 +457,111 @@ class TVMFFIJitCompiledFunction(tvm_ffi.Function, CudaDialectJitCompiledFunction
|
||||
"""Run the compiled program. This override is needed for implicit compile and execution."""
|
||||
return self.__call__(*exe_args)
|
||||
|
||||
def export_to_c(self, object_file_path: str, function_name: str = None):
|
||||
def export_to_c(
|
||||
self, object_file_path: str, function_name: str = None,
|
||||
*,
|
||||
enable_pic: bool = True,
|
||||
export_only_tvm_ffi_symbols: bool = False
|
||||
):
|
||||
"""Export the TVM FFI function to an object file.
|
||||
|
||||
:param object_file_path: The path to the object file.
|
||||
:param function_name: The name of the function to export.
|
||||
:param enable_pic: Whether to enable PIC relocation needed for shared library loading.
|
||||
:param export_only_tvm_ffi_symbols: Only export TVM FFI symbols (hide all others).
|
||||
:param host_target_triple: If not provided, the current host target is used.
|
||||
"""
|
||||
if function_name is not None and function_name != self.function_name:
|
||||
mod = self.ir_module
|
||||
rename_tvm_ffi_function(mod, self.function_name, function_name)
|
||||
else:
|
||||
mod = self.ir_module
|
||||
|
||||
_execution_engine_extra.dump_object_file_pic(
|
||||
mod, object_file_path, "__tvm_ffi_" + function_name, 2
|
||||
# prefix internal function by function name
|
||||
internal_symbol_prefix = "__cute_internal_" + function_name
|
||||
mod = self.ir_module
|
||||
mod = get_export_module(
|
||||
self.ir_module, internal_symbol_prefix,
|
||||
preserve_symbols=[f"__tvm_ffi_{self.function_name}"]
|
||||
)
|
||||
|
||||
rename_tvm_ffi_function(mod, self.function_name, function_name)
|
||||
if export_only_tvm_ffi_symbols:
|
||||
_inplace_hide_symbols(mod, lambda x: not x.startswith("__tvm_ffi"))
|
||||
|
||||
format = _get_format_from_object_file_path(object_file_path)
|
||||
out_bytes = _aot_support.export_module_to_bytes(
|
||||
mod, format=format, opt_level=3, enable_pic=enable_pic
|
||||
)
|
||||
|
||||
with open(object_file_path, "wb") as f:
|
||||
f.write(out_bytes)
|
||||
|
||||
def _create_tvm_ffi_function(self):
|
||||
"""Create the tvm_ffi.Function from the current execution engine.
|
||||
"""
|
||||
if self.engine is not None:
|
||||
tvm_ffi_function_ptr = self.engine.raw_lookup(
|
||||
"__tvm_ffi_" + self.function_name
|
||||
)
|
||||
tvm_ffi_function = tvm_ffi.Function.__from_mlir_packed_safe_call__(
|
||||
tvm_ffi_function_ptr, keep_alive_object=self.engine)
|
||||
return tvm_ffi_function
|
||||
return None
|
||||
|
||||
|
||||
class TVMFFIJitCompiledFunction(tvm_ffi.Function, TVMFFIJitCompiledFunctionBase):
|
||||
"""TVM FFI Function that directly subclasses the tvm_ffi.Function for pos only arguments.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
# initialize the tvm_ffi.Function from the current execution engine
|
||||
if self.__chandle__() != 0:
|
||||
raise DSLRuntimeError("TVM FFI function is already initialized")
|
||||
tvm_ffi_function = self._create_tvm_ffi_function()
|
||||
if tvm_ffi_function is not None:
|
||||
# move the handle from the tvm_ffi.Function to the current instance
|
||||
self.__move_handle_from__(tvm_ffi_function)
|
||||
|
||||
# use direct call to the tvm_ffi.Function.__call__
|
||||
# to avoid most of python overhead
|
||||
__call__ = tvm_ffi.Function.__call__
|
||||
|
||||
|
||||
class TVMFFIJitCompiledFunctionWithKwargs(TVMFFIJitCompiledFunctionBase):
|
||||
"""TVM FFI Function with kwargs wrapper support
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
assert "kwargs_wrapper_spec" in kwargs, "kwargs_wrapper_spec is required"
|
||||
kwargs_wrapper_spec = kwargs.pop("kwargs_wrapper_spec")
|
||||
super().__init__(*args, **kwargs)
|
||||
# initialize the tvm_ffi.Function from the current execution engine
|
||||
self._tvm_ffi_function = self._create_tvm_ffi_function()
|
||||
if kwargs_wrapper_spec.kwonly_names or kwargs_wrapper_spec.arg_defaults:
|
||||
try:
|
||||
from tvm_ffi.utils import kwargs_wrapper # type: ignore
|
||||
self._kwargs_wrapper = kwargs_wrapper.make_kwargs_wrapper(
|
||||
self._tvm_ffi_function,
|
||||
arg_names=kwargs_wrapper_spec.arg_names,
|
||||
arg_defaults=kwargs_wrapper_spec.arg_defaults,
|
||||
kwonly_names=kwargs_wrapper_spec.kwonly_names,
|
||||
kwonly_defaults=kwargs_wrapper_spec.kwonly_defaults,
|
||||
)
|
||||
except ImportError:
|
||||
raise DSLRuntimeError("install apache-tvm-ffi>=0.1.5 to enable kwargs/defaults")
|
||||
else:
|
||||
# positional only is probably fine
|
||||
self._kwargs_wrapper = self._tvm_ffi_function
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
"""Call the TVM FFI function with kwargs wrapper.
|
||||
"""
|
||||
return self._kwargs_wrapper(*args, **kwargs)
|
||||
|
||||
def __tvm_ffi_object__(self):
|
||||
return self._tvm_ffi_function
|
||||
|
||||
|
||||
def supports_kwargs_wrapper() -> bool:
|
||||
"""Check if the kwargs wrapper is supported."""
|
||||
try:
|
||||
from tvm_ffi.utils import kwargs_wrapper # type: ignore
|
||||
return True
|
||||
except ImportError:
|
||||
return False
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
# Use `pip install -r requirements.txt` with the present file to install a
|
||||
# wheel consistent with the present state of the github repository
|
||||
nvidia-cutlass-dsl==4.3.0
|
||||
nvidia-cutlass-dsl==4.3.3
|
||||
|
||||
Reference in New Issue
Block a user