Files
cutlass/python/CuTeDSL/cutlass/base_dsl/jit_executor.py
2026-04-07 12:16:05 -04:00

1184 lines
46 KiB
Python

# SPDX-FileCopyrightText: Copyright (c) 2025 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
#
# Use of this software is governed by the terms and conditions of the
# NVIDIA End User License Agreement (EULA), available at:
# https://docs.nvidia.com/cutlass/latest/media/docs/pythonDSL/license.html
#
# Any use, reproduction, disclosure, or distribution of this software
# and related documentation outside the scope permitted by the EULA
# is strictly prohibited.
"""
This module provides jit executor related classes
"""
import array
import ctypes
import inspect
import io
from typing import Union, Optional, NamedTuple, Any, Sequence
import weakref
import threading
import collections
import os
import tempfile
from dataclasses import dataclass
# MLIR modules imports
from .._mlir import ir
from .._mlir.dialects import llvm
# Local modules imports
from . import typing as t
from .common import DSLRuntimeError, DSLCudaRuntimeError
from .runtime import cuda as cuda_helpers
from .runtime.jit_arg_adapters import JitArgAdapterRegistry, is_arg_spec_constexpr
from .typing import get_c_pointers
from .utils.logger import log
from .utils.timer import timer
class CudaModuleAndKernel:
"""A loaded CUDA kernel and its metadata."""
def __init__(self, sym, cuda_module, kernel, attrs):
self.sym = sym
self.cuda_module = cuda_module
self.kernel = kernel
self.attrs = attrs
def get_escaped_cubin_bytes(cubin_data):
"""This function escapes cubin data from mlir raw bytecode to executable binary bytes"""
def ishex(inp):
return (0x30 <= inp < 0x3A) or (0x41 <= inp < 0x47) or (0x61 <= inp < 0x67)
converted = bytearray()
idx = 0
while idx < len(cubin_data):
# escape the original bytes
if cubin_data[idx] == 0x5C:
# if data of idx is b'\\'
if ishex(cubin_data[idx + 1]) and ishex(cubin_data[idx + 2]):
converted += bytearray.fromhex(cubin_data[idx + 1 : idx + 3].decode())
idx += 3
elif cubin_data[idx + 1] == 0x5C:
converted.append(cubin_data[idx])
idx += 2
else:
# no escape, directly write
converted.append(cubin_data[idx])
idx += 1
return bytes(converted)
def walk_module_and_get_cubin_data(module, sym, callback):
"""This function is used to walk gpu binary op, extract the cubin inside, and process cubin data with callback."""
def walk_gpu_binary_op(op):
if op.name != "gpu.binary":
return ir.WalkResult.ADVANCE
s = io.BytesIO()
op.write_bytecode(s)
cubin_data = s.getvalue()
if sym.encode() not in cubin_data:
return ir.WalkResult.ADVANCE
if "kernels" != op.opview.sym_name.value and sym != op.opview.sym_name.value:
return ir.WalkResult.ADVANCE
# function symbol of kernel(gpu.launch_func) is equal to sym name in mlir
func_sym = sym
if sym == op.opview.sym_name.value and not sym.endswith("_kernel"):
func_sym = sym.rsplit("_", 1)[0]
cubin_data = cubin_data.split(b'bin = "')[1].split(b'">')[0]
cubin_data = get_escaped_cubin_bytes(cubin_data)
callback(sym, func_sym, cubin_data)
return ir.WalkResult.ADVANCE
module.operation.walk(walk_gpu_binary_op)
def load_kernels_from_ir_module(module, kernel_info) -> list[CudaModuleAndKernel]:
"""Loads all kernels from the IR module that match the given set of symbols."""
if not kernel_info:
return [] # no modules
# don't sort because the external kernel pointers are recorded in the order called in ir module.
kernel_symbols = tuple(kernel_info.keys())
# load cuda module/get function pointer from module and cache
kernel_modules = collections.OrderedDict()
for sym in kernel_symbols:
log().debug(f"Loading CUDA module for symbol: {sym}")
def walk_callback(sym, func_sym, cubin_data):
if sym in kernel_modules:
log().debug(f"Skipping already loaded symbol: {sym}")
cubin_module = cuda_helpers.load_library_data(cubin_data)
kernel = cuda_helpers.get_library_kernel(cubin_module, func_sym)
# Setup attributes we want applied to the loaded kernel functions.
# A copy is made so we can update one of the attributes.
attrs = dict(kernel_info[sym])
if cuda_helpers.get_driver_version() >= 11080:
attrs[
cuda_helpers.cuda.CUfunction_attribute.CU_FUNC_ATTRIBUTE_NON_PORTABLE_CLUSTER_SIZE_ALLOWED
] = 1
kernel_modules[sym] = CudaModuleAndKernel(sym, cubin_module, kernel, attrs)
walk_module_and_get_cubin_data(module, sym, walk_callback)
return list(kernel_modules.values())
class KwargsWrapperSpec(NamedTuple):
"""A specification for keyword arguments wrapper."""
arg_names: list[str]
arg_defaults: tuple[Any, ...]
kwonly_names: list[str]
kwonly_defaults: dict[str, Any]
@dataclass
class ArgMeta:
"""Metadata for function arguments."""
pos_names: list[str]
kwonly_names: list[str]
all_names: list[str]
annotated_types: list[object]
numeric_flags: list[bool]
name_to_index: dict[str, int]
pos_defaults: list[object]
kwonly_defaults: list[object]
arg_count: int
class ExecutionArgs:
"""Helper that wraps the function signature spec to filter execution and compile time arguments."""
def __init__(self, spec, function_name):
self.function_name = function_name
self.args_spec = spec
if spec is not None:
self.args_spec = self.filter_runtime_arg_spec(spec)
self.original_args_spec = spec
self._missing = object()
self._meta = self._build_meta()
self._tls = threading.local()
def _build_meta(self):
"""
Precompute metadata for the fast-path execution.
This metadata is static per function signature.
"""
spec = self.args_spec
if spec is None:
return ArgMeta(
pos_names=[],
kwonly_names=[],
all_names=[],
annotated_types=[],
numeric_flags=[],
name_to_index={},
pos_defaults=[],
kwonly_defaults=[],
arg_count=0,
)
pos_names = list(spec.args)
kwonly_names = list(spec.kwonlyargs)
all_names = pos_names + kwonly_names
annotated_types = [spec.annotations.get(n) for n in all_names]
numeric_flags = [isinstance(typ, t.NumericMeta) for typ in annotated_types]
name_to_index = {n: i for i, n in enumerate(all_names)}
pos_defaults = [self._missing] * len(pos_names)
if spec.defaults:
start = len(pos_names) - len(spec.defaults)
for i, d in enumerate(spec.defaults):
pos_defaults[start + i] = d
kwonly_defaults = [self._missing] * len(kwonly_names)
if spec.kwonlydefaults:
for i, n in enumerate(kwonly_names):
if n in spec.kwonlydefaults:
kwonly_defaults[i] = spec.kwonlydefaults[n]
return ArgMeta(
pos_names=pos_names,
kwonly_names=kwonly_names,
all_names=all_names,
annotated_types=annotated_types,
numeric_flags=numeric_flags,
name_to_index=name_to_index,
pos_defaults=pos_defaults,
kwonly_defaults=kwonly_defaults,
arg_count=len(all_names),
)
def generate_execution_args_positional(self, *args):
"""Fast execution for positional-only arguments with exact count match.
This method is optimized for the common case where:
- All arguments are positional (no kwargs)
- Number of arguments matches the function signature exactly
:param args: The positional arguments tuple
:return: (exe_args, adapted_args) tuple
"""
exe_arg_chunks = [None] * self._meta.arg_count
adapted_args = []
tls = self._tls
adapter_caches = getattr(tls, "adapter_caches", None)
if adapter_caches is None:
adapter_caches = [dict() for _ in range(self._meta.arg_count)]
tls.adapter_caches = adapter_caches
annotated_types = self._meta.annotated_types
numeric_flags = self._meta.numeric_flags
for index in range(self._meta.arg_count):
arg = args[index]
cptr_method = getattr(arg, "__c_pointers__", None)
if cptr_method is not None:
exe_arg_chunks[index] = cptr_method()
continue
if numeric_flags[index]:
arg = t.cast(arg, annotated_types[index])
exe_arg_chunks[index] = get_c_pointers(arg)
else:
arg_type = type(arg)
cache = adapter_caches[index]
adapter = cache.get(arg_type)
if adapter is None:
adapter = JitArgAdapterRegistry.get_registered_adapter(arg_type)
if adapter is not None:
cache[arg_type] = adapter
if adapter is not None:
arg = adapter(arg)
adapted_args.append(arg)
exe_arg_chunks[index] = get_c_pointers(arg)
exe_args = [p for chunk in exe_arg_chunks for p in chunk]
return exe_args, adapted_args
def get_rectified_args(self, args, kwargs):
"""
This function is used to rectify the args and kwargs to a final runtime argument list according to the args_spec.
"""
pos_count = len(self._meta.pos_names)
if len(args) > pos_count:
raise DSLRuntimeError(
"input args/kwargs length does not match runtime function signature!",
context={
"input args length": len(args),
"input kwargs length": len(kwargs),
"function signature args length": len(self._meta.pos_names),
"function signature kwonlyargs length": len(
self._meta.kwonly_names
),
},
)
# Start with every slot marked missing, we overwrite as values/defaults bind
rectified = [self._missing] * self._meta.arg_count
pos_len = len(args)
# Fill positional slots with the values from the caller
rectified[:pos_len] = args
# Fill positional slots the caller skipped with the defaults
for i in range(pos_len, pos_count):
default = self._meta.pos_defaults[i]
if default is not self._missing:
rectified[i] = default
# Fill keyword-only slots with the defaults before user kwargs
for j, default in enumerate(self._meta.kwonly_defaults):
idx = pos_count + j
if default is not self._missing:
rectified[idx] = default
# Fill keyword slots with the values from the caller
for name, value in kwargs.items():
idx = self._meta.name_to_index.get(name)
if idx is None:
raise DSLRuntimeError(
"unexpected keyword argument",
context={"argument_name": name},
)
if idx < pos_len:
raise DSLRuntimeError(
"multiple values for argument",
context={"argument_name": name},
)
rectified[idx] = value
if self._missing in rectified:
missing_args = [
name
for i, name in enumerate(self._meta.all_names)
if rectified[i] is self._missing
]
raise DSLRuntimeError(
"input args/kwargs length does not match runtime function signature!",
context={
"missing": missing_args,
"function signature args length": len(self._meta.pos_names),
"function signature kwonlyargs length": len(
self._meta.kwonly_names
),
},
)
return rectified
def generate_execution_args(self, args, kwargs):
"""
This function is the prune version of `generate_mlir_function_types` which only generates execution args
to get rid of mlir context.
"""
if not kwargs and len(args) == self._meta.arg_count:
return self.generate_execution_args_positional(*args)
exe_arg_chunks = [None] * self._meta.arg_count
adapted_args = []
tls = self._tls
adapter_caches = getattr(tls, "adapter_caches", None)
if adapter_caches is None:
adapter_caches = [dict() for _ in range(self._meta.arg_count)]
tls.adapter_caches = adapter_caches
input_args = self.get_rectified_args(args, kwargs)
for index, arg in enumerate(input_args):
cptr_method = getattr(arg, "__c_pointers__", None)
if cptr_method is not None:
exe_arg_chunks[index] = cptr_method()
continue
arg_type_anno = self._meta.annotated_types[index]
if self._meta.numeric_flags[index]:
arg = t.cast(arg, arg_type_anno)
exe_arg_chunks[index] = get_c_pointers(arg)
else:
arg_type = type(arg)
cache = adapter_caches[index]
adapter = cache.get(arg_type)
if adapter is None:
adapter = JitArgAdapterRegistry.get_registered_adapter(arg_type)
if adapter is not None:
cache[arg_type] = adapter
if adapter is not None:
arg = adapter(arg)
adapted_args.append(arg)
exe_arg_chunks[index] = get_c_pointers(arg)
exe_args = [p for chunk in exe_arg_chunks for p in chunk]
return exe_args, adapted_args
def get_kwargs_wrapper_spec(
self, exclude_arg_names: Sequence[str] = ()
) -> KwargsWrapperSpec:
"""
This function is used to get the kwargs wrapper spec from the original args_spec.
"""
excluded_arg_names = set(exclude_arg_names)
arg_spec = self.original_args_spec
if arg_spec.defaults:
defaults_start_idx = len(arg_spec.args) - len(arg_spec.defaults)
else:
defaults_start_idx = len(arg_spec.args)
arg_names = []
arg_defaults = []
kwonly_names = []
kwonly_defaults = {}
# Filter arguments and maintain their properties
for i, arg_name in enumerate(arg_spec.args):
arg_type = arg_spec.annotations.get(arg_name, None)
# Skip compile-time arguments
if is_arg_spec_constexpr(arg_type, arg_name, i, self.function_name):
continue
if arg_name in excluded_arg_names:
continue
arg_names.append(arg_name)
if i >= defaults_start_idx:
arg_defaults.append(arg_spec.defaults[i - defaults_start_idx])
if arg_spec.kwonlyargs:
for i, kwarg in enumerate(arg_spec.kwonlyargs):
arg_type = arg_spec.annotations.get(kwarg, None)
# Skip compile-time arguments
if is_arg_spec_constexpr(arg_type, kwarg, i, self.function_name):
continue
if kwarg in excluded_arg_names:
continue
kwonly_names.append(kwarg)
if arg_spec.kwonlydefaults and kwarg in arg_spec.kwonlydefaults:
kwonly_defaults[kwarg] = arg_spec.kwonlydefaults[kwarg]
return KwargsWrapperSpec(
arg_names=arg_names,
arg_defaults=tuple(arg_defaults),
kwonly_names=kwonly_names,
kwonly_defaults=kwonly_defaults,
)
def get_rectified_args_from_original_args(self, full_args, full_kwargs):
"""
This function is used to rectify the original arguments to the runtime
arguments that matched the original args_spec.
:param full_args: The original full arguments to filter.
:param full_kwargs: The original full keyword arguments to filter.
:return: The filtered arguments and keyword arguments.
"""
arg_spec = self.original_args_spec
if arg_spec.defaults:
defaults_start_idx = len(arg_spec.args) - len(arg_spec.defaults)
else:
defaults_start_idx = len(arg_spec.args)
runtime_args = []
# Filter arguments and maintain their properties
for i, arg_name in enumerate(arg_spec.args):
arg_type = arg_spec.annotations.get(arg_name, None)
# Skip compile-time arguments
if is_arg_spec_constexpr(arg_type, arg_name, i, self.function_name):
continue
# Check if argument was provided by user, otherwise use default
if i < len(full_args):
# User provided this argument - use it
runtime_args.append(full_args[i])
elif i >= defaults_start_idx:
# Argument not provided, but has default - use default
default_idx = i - defaults_start_idx
runtime_args.append(arg_spec.defaults[default_idx])
else:
# Required argument missing
raise DSLRuntimeError(
f"Missing required argument '{arg_name}' at position {i}",
context={
"function_name": self.function_name,
"expected_args": len(arg_spec.args),
"provided_args": len(full_args),
},
)
# Filter keyword-only arguments
runtime_kwargs = {}
if arg_spec.kwonlyargs:
for i, kwarg in enumerate(arg_spec.kwonlyargs):
arg_type = arg_spec.annotations.get(kwarg, None)
# Skip compile-time arguments
if is_arg_spec_constexpr(arg_type, kwarg, i, self.function_name):
continue
# Keep runtime keyword-only arguments
if kwarg in full_kwargs:
runtime_kwargs[kwarg] = full_kwargs[kwarg]
elif arg_spec.kwonlydefaults and kwarg in arg_spec.kwonlydefaults:
runtime_kwargs[kwarg] = arg_spec.kwonlydefaults[kwarg]
if len(runtime_args) != len(self.args_spec.args) or len(runtime_kwargs) != len(
self.args_spec.kwonlyargs
):
raise DSLRuntimeError(
"input args/kwargs length does not match runtime function signature!",
context={
"input args length": len(runtime_args),
"input kwargs length": len(runtime_kwargs),
"function signature args length": len(self.args_spec.args),
"function signature kwonlyargs length": len(
self.args_spec.kwonlyargs
),
},
)
return runtime_args + list(runtime_kwargs.values())
def filter_runtime_arg_spec(self, arg_spec: inspect.FullArgSpec):
runtime_args = []
runtime_annotations = {}
runtime_defaults = []
# Calculate the offset where defaults start in the original args
if arg_spec.defaults:
defaults_start_idx = len(arg_spec.args) - len(arg_spec.defaults)
else:
defaults_start_idx = len(arg_spec.args)
# Filter arguments and maintain their properties
for i, arg_name in enumerate(arg_spec.args):
arg_type = arg_spec.annotations.get(arg_name, None)
# Skip compile-time arguments
if is_arg_spec_constexpr(arg_type, arg_name, i, self.function_name):
continue
# Keep runtime arguments
runtime_args.append(arg_name)
if arg_name in arg_spec.annotations:
runtime_annotations[arg_name] = arg_type
# Keep corresponding default if it exists
if i >= defaults_start_idx:
default_idx = i - defaults_start_idx
runtime_defaults.append(arg_spec.defaults[default_idx])
# Filter kwonlyargs and their defaults
runtime_kwonlyargs = []
runtime_kwonlydefaults = {}
if arg_spec.kwonlyargs:
for i, kwarg in enumerate(arg_spec.kwonlyargs):
arg_type = arg_spec.annotations.get(kwarg, None)
# Apply same filtering logic
if is_arg_spec_constexpr(arg_type, kwarg, i, self.function_name):
continue
runtime_kwonlyargs.append(kwarg)
if kwarg in arg_spec.annotations:
runtime_annotations[kwarg] = arg_type
if arg_spec.kwonlydefaults and kwarg in arg_spec.kwonlydefaults:
runtime_kwonlydefaults[kwarg] = arg_spec.kwonlydefaults[kwarg]
# Convert runtime_defaults to tuple if not empty (as expected by FullArgSpec)
runtime_defaults = tuple(runtime_defaults) if runtime_defaults else None
return inspect.FullArgSpec(
args=runtime_args,
varargs=arg_spec.varargs, # Keep original varargs
varkw=arg_spec.varkw, # Keep original varkw
defaults=runtime_defaults,
kwonlyargs=runtime_kwonlyargs,
kwonlydefaults=runtime_kwonlydefaults if runtime_kwonlydefaults else None,
annotations=runtime_annotations,
)
def get_constexpr_args(self) -> list[dict[str, Union[int, str]]]:
"""
This function returns the constexpr args that have been pruned from the original function signature.
The return type is a list of dicts, each dict contains the argument index (argument_index) and argument name (argument_name).
:return: list of dicts, each dict contains the argument index (argument_index) and argument name (argument_name).
:rtype: list[dict[str, Union[int, str]]]
"""
if self.original_args_spec is None:
return list()
constexpr_args = list()
for i, arg_name in enumerate(self.original_args_spec.args):
if arg_name not in self.args_spec.args:
constexpr_args.append({"argument_index": i, "argument_name": arg_name})
if self.original_args_spec.kwonlyargs:
for kwarg in self.original_args_spec.kwonlyargs:
if kwarg not in self.args_spec.kwonlyargs:
constexpr_args.append(
{"argument_index": None, "argument_name": kwarg}
)
return constexpr_args
class JitExecuteContext:
"""Holds device specific context for execution."""
def __init__(
self,
module: "JitModule",
kernel_fns=None,
context: Optional[cuda_helpers.DevicePrimaryContext] = None,
):
if kernel_fns is None:
kernel_fns = []
self.module = module
self.kernel_functions = kernel_fns
self.kernel_functions_ptrs = [ctypes.c_void_p(k.getPtr()) for k in kernel_fns]
self.context = context
class JitModule:
"""Holds the execution engine and cuda modules."""
def __init__(
self,
engine,
capi_func,
args_spec: ExecutionArgs,
modules: list[CudaModuleAndKernel],
):
self.engine = engine
self.capi_func = capi_func
self.args_spec = args_spec
self.cuda_modules = modules
self._unloaded = False
def get_device_execute_context(self, device=None) -> JitExecuteContext:
if self._unloaded:
raise RuntimeError(f"Can not get executor for unloaded module.")
# Host only code no need to setup kernels
if not self.cuda_modules:
return JitExecuteContext(self)
# We need a device at this point so get one if not provided.
if device is None:
device = cuda_helpers.get_current_device()
elif isinstance(device, int):
device = cuda_helpers.get_device(device)
# Activate a primary context for the device:
context = cuda_helpers.DevicePrimaryContext(device)
# Get kernel functions from the kernels
kernel_fns = []
for m in self.cuda_modules:
fn = cuda_helpers.get_function_from_kernel(m.kernel)
kernel_fns.append(fn)
# Set attributes for the kernel function
for attr, val in m.attrs.items():
cuda_helpers.set_kernel_attribute(fn, attr, val)
# This instance will "own" a reference to the primary device context.
# It will release the the reference once its no longer alive or
# an explicit call to unload is made.
#
# The default module loading mode is CU_MODULE_LAZY_LOADING so
# the module will not be loaded to the device until the first call
# to execute it. # This can be modified using CUDA_MODULE_LOADING
# environment variable.
return JitExecuteContext(self, kernel_fns, context)
def unload(self):
try:
for m in set([m.cuda_module for m in self.cuda_modules]):
cuda_helpers.unload_library(m)
self.cuda_modules.clear()
except Exception as e:
pass
finally:
self._unloaded = True
def __del__(self):
self.unload()
class JitExecutor:
"""An executable function that can be called to launch a device kernel.
JitExecutor is tired to a specific device context and should only be called
in a context on that device.
"""
def __init__(
self,
jit_module: Union[JitModule, "CudaDialectJitModule"],
exec_context: Optional[JitExecuteContext],
jit_time_profiling: bool,
):
# JitExecutor will keep JitCompiledFunction alive so that the underlying
# ExecutionEngine and module data is not discarded until runtime callables
# are garbage collected.
self.jit_module = jit_module
self.exec_context = exec_context
self.profiler = timer(enable=jit_time_profiling) if jit_time_profiling else None
# Get the cuda result type from the capi function.
# This is only set to i32 if CudaDialectJitModule is used.
cuda_result_type = self.jit_module.capi_func.restype
self.cuda_result = cuda_result_type() if cuda_result_type is not None else None
# Pre-compute flags
self._has_cuda_result = self.cuda_result is not None
self._has_profiler = self.profiler is not None
self._cuda_result_addr = (
ctypes.addressof(self.cuda_result) if self._has_cuda_result else None
)
if jit_time_profiling:
self._get_invoke_packed_args_func = self.profiler(
self._get_invoke_packed_args
)
self.capi_func = self.profiler(self.jit_module.capi_func)
else:
self._get_invoke_packed_args_func = self._get_invoke_packed_args
self.capi_func = self.jit_module.capi_func
# Pre-compute packed args metadata
self._num_extra_args = 0
if self._has_cuda_result:
self._num_extra_args += 1
if self.exec_context is not None:
self._kernel_ptrs = self.exec_context.kernel_functions_ptrs
self._num_extra_args += len(self._kernel_ptrs)
else:
self._kernel_ptrs = None
self._tls = threading.local()
# Assume each execution args has type `c_void_p` to reduce the overhead of `ctypes.cast`.
def _get_invoke_packed_args(self, exe_args):
# Pre-calculate sizes once during init and cache
num_base_args = len(exe_args)
total_args = num_base_args + self._num_extra_args
# Re-use packed args buffer if possible
tls = self._tls
packed_args = getattr(tls, "packed_args", None)
capacity = getattr(tls, "capacity", 0)
if packed_args is None or capacity < total_args:
packed_args = (ctypes.c_void_p * total_args)()
tls.packed_args = packed_args
tls.capacity = total_args
idx = num_base_args
if self._has_cuda_result:
packed_args[idx] = self._cuda_result_addr
idx += 1
if self._kernel_ptrs is not None:
for ptr in self._kernel_ptrs:
packed_args[idx] = ptr
idx += 1
for i in range(num_base_args):
arg = exe_args[i]
packed_args[i] = (
arg if type(arg) is ctypes.c_void_p else ctypes.c_void_p(arg).value
)
return packed_args
def generate_execution_args(self, *args, **kwargs):
return self.jit_module.args_spec.generate_execution_args(args, kwargs)
def run_compiled_program(self, exe_args):
try:
packed_args = self._get_invoke_packed_args_func(exe_args)
self.capi_func(packed_args)
if not self._has_cuda_result:
return None
error_code = self.cuda_result.value
if error_code == 0:
return error_code
error_name = cuda_helpers._cudaGetErrorEnum(
cuda_helpers.cuda.CUresult(error_code)
)
raise DSLCudaRuntimeError(error_code, error_name)
except DSLCudaRuntimeError as e:
raise e
except Exception as e:
raise DSLRuntimeError(f"💥💥💥 Runtime Crash 💥💥💥", cause=e)
def __call__(self, *args, **kwargs):
exe_args, adapted_args = self.generate_execution_args(*args, **kwargs)
return self.run_compiled_program(exe_args)
@dataclass
class JitFunctionArtifacts:
"""Holds artifacts for a JIT-compiled function."""
PTX: str
CUBIN: str
MLIR: str
def __post_init__(self):
if self.PTX is not None and os.path.exists(self.PTX):
try:
with open(self.PTX, "r") as f:
self.PTX = f.read()
except (IOError, OSError) as e:
raise DSLRuntimeError(f"Failed to read PTX file '{self.PTX}': {e}")
if self.CUBIN is not None and os.path.exists(self.CUBIN):
try:
with open(self.CUBIN, "rb") as f:
self.CUBIN = f.read()
except (IOError, OSError) as e:
raise DSLRuntimeError(f"Failed to read CUBIN file '{self.CUBIN}': {e}")
if self.MLIR is not None and os.path.exists(self.MLIR):
try:
with open(self.MLIR, "r") as f:
self.MLIR = f.read()
except (IOError, OSError) as e:
raise DSLRuntimeError(f"Failed to read MLIR file '{self.MLIR}': {e}")
class ExportProvider:
"""Holds the dsl specific settings for the export of the jit-compiled function."""
dsl: "Type[BaseDSL]" = None
arg_spec_processor: "ArgsSpecProcessor" = None
c_header_generator: "CHeaderGenerator" = None
object_file_version: str = None
def __init__(
self,
dsl: "Type[BaseDSL]",
arg_spec_processor: "ArgsSpecProcessor",
c_header_generator: "CHeaderGenerator",
object_file_version: str,
mlirExecutionEngine: "MlirExecutionEngine",
):
self.dsl = dsl
self.arg_spec_processor = arg_spec_processor
self.c_header_generator = c_header_generator
self.object_file_version = object_file_version
self.mlirExecutionEngine = mlirExecutionEngine
class JitCompiledFunction:
"""Holds a compiled function."""
export_provider: ExportProvider = None
def __init__(
self,
ir_module,
engine,
capi_func,
args_spec,
function_name,
kernel_info,
jit_time_profiling,
jit_function_artifacts,
prefix=None,
load_from_binary=False,
dynamic_args=None,
dynamic_kwargs=None,
):
self.ir_module = ir_module
self.engine = engine
self.capi_func = capi_func
self.function_name = function_name
self.kernel_info = kernel_info
if args_spec is not None:
self.args_spec = ExecutionArgs(args_spec, self.function_name)
self.jit_time_profiling = jit_time_profiling
assert (
isinstance(jit_function_artifacts, JitFunctionArtifacts)
or jit_function_artifacts is None
)
self.artifacts = jit_function_artifacts
self.prefix = prefix
self.load_from_binary = load_from_binary
# This runtime state is stored here so that we can preserve the module
# in the compiler cache. Callers can extend the lifetime of the module
# by creating and retaining the executor.
self.jit_module = None
self._executor_lock = threading.RLock()
self._default_executor = None
# This is used to do early generation of the c header arguments to release the reference to the dynamic arguments.
self._generate_c_header_arguments(dynamic_args, dynamic_kwargs)
@property
def __ptx__(self):
"""Returns the PTX code of the JIT-compiled function."""
return self.artifacts.PTX if self.artifacts is not None else None
@property
def __cubin__(self):
"""Returns the CUBIN data of the JIT-compiled function."""
return self.artifacts.CUBIN if self.artifacts is not None else None
@property
def __mlir__(self):
"""Returns the MLIR code of the JIT-compiled function."""
return self.artifacts.MLIR if self.artifacts is not None else None
def _deserializer(self):
"""Load the cuda module from the binary execution engine. This function will be injected as the
JitCompiledFunction method which will be called by the jit executor to load the cuda module by AOT flow.
@param self: The JitCompiledFunction object. This is the JitCompiledFunction object to load the cuda module.
@param name: The name of the function. This is the unique identifier name of the function to avoid symbol conflict in the generated object file.
@param execution_engine: The binary execution engine. This is the execution engine to load the cuda module.
@param kernel_info: The kernel info. This is the kernel info to load the cuda module.
@return: The list of cuda modules.
"""
cubin_suffix = "cubin"
if self.prefix is None:
raise DSLRuntimeError("prefix is required to be set for binary loading")
cubin_data = self.engine.lookup("_".join([self.prefix, cubin_suffix]))
if not cubin_data:
raise RuntimeError(
"Unknown function " + "_".join([self.prefix, cubin_suffix])
)
cubin_module = cuda_helpers.load_library_data(cubin_data)
# load cuda module/get function pointer from module and cache
kernel_modules = collections.OrderedDict()
for sym, attrs in self.kernel_info.items():
kernel = cuda_helpers.get_library_kernel(cubin_module, sym)
if cuda_helpers.get_driver_version() >= 11080:
attrs[
cuda_helpers.cuda.CUfunction_attribute.CU_FUNC_ATTRIBUTE_NON_PORTABLE_CLUSTER_SIZE_ALLOWED
] = 1
kernel_modules[sym] = CudaModuleAndKernel(sym, cubin_module, kernel, attrs)
return list(kernel_modules.values())
def _validate_engine(self):
if self.engine is None:
raise DSLRuntimeError(
"The compiled function does not have a valid execution engine.",
suggestion="For cross-compilation, please use `JitCompiledFunction.export_to_c` to serialize the compiled function and load/execute it on target device.",
)
def to(self, device=None) -> JitExecutor:
"""Returns an executable function bound to the given device.
For multi-device execution this method can be called for each device where
the kernel will run.
:param device: Specifies the device for the executor. If None the current device is used.
:type device: Optional[Union[int, CUdevice]]
:return: A callable executor function.
:rtype: JitExecutor
"""
self._validate_engine()
with self._executor_lock:
# We need to ensure that the modules are loaded if not already
if self.jit_module is None:
if self.ir_module is not None:
cuda_modules = load_kernels_from_ir_module(
self.ir_module, self.kernel_info
)
self.jit_module = JitModule(
self.engine, self.capi_func, self.args_spec, cuda_modules
)
# Create a new executor that will be tied to a device context
# n.b. host only moduels do not load device specific modules or context.
context = self.jit_module.get_device_execute_context(device)
return JitExecutor(self.jit_module, context, self.jit_time_profiling)
def generate_execution_args(self, *args, **kwargs):
return self.args_spec.generate_execution_args(args, kwargs)
def __call__(self, *args, **kwargs):
"""Executes the jit-compiled function under the currently active CUDA context.
Calling this method multiple devices is not allowed and will result in unexpected
CUDA errors. If you need to call the kernel on multiple devices use `to`
to return a per-device function.
"""
exe_args, adapted_args = self.args_spec.generate_execution_args(args, kwargs)
executor = self._default_executor
if executor is not None: # Only lock on first call
return executor.run_compiled_program(exe_args)
return self.run_compiled_program(exe_args)
def run_compiled_program(self, exe_args):
"""Executes the jit-compiled function under the currently active CUDA context.
Calling this method multiple devices is not allowed and will result in unexpected
CUDA errors. If you need to call the kernel on multiple devices use `to`
to return a per-device function.
"""
with self._executor_lock:
if self._default_executor is None:
log().debug("Creating default executor.")
# We use a weak reference here so that this instance does not keep this
# object alive as it hold a reference to self.
proxy_self = weakref.proxy(self)
self._default_executor = proxy_self.to(None)
return self._default_executor.run_compiled_program(exe_args)
def _generate_c_header_arguments(self, dynamic_args, dynamic_kwargs):
"""Generates the c header arguments for the AOT C header generation."""
self.c_header_arguments = None
from .export import CHeaderArguments
if dynamic_args is not None or dynamic_kwargs is not None:
self.dummy_prefix_name = "dummy_prefix_name"
try:
# This arguments may be generated failure due to not all the arguments (e.g. custom types) are supported by the AOT C header generator.
c_header_arguments, packed_args, declarations = (
self.export_provider.c_header_generator._generate_arguments(
self.dummy_prefix_name,
self.args_spec,
dynamic_args,
dynamic_kwargs,
)
)
self.c_header_arguments = CHeaderArguments(
self.dummy_prefix_name,
c_header_arguments,
packed_args,
declarations,
None,
)
except Exception as e:
self.c_header_arguments = CHeaderArguments(
self.dummy_prefix_name, [], [], [], str(e)
)
def dump_to_object(
self,
function_prefix: str,
) -> bytes:
"""Dump the compiled ir function to a bytes object with ELF format. The bytes object contains the host
launch entry function and cubin inside.
@param function_prefix: The prefix name of the function. This is the user provided unique identifier name of the function to avoid symbol conflict in the generated object file.
@return: The bytes object of the function.
"""
# lazy import to avoid circular dependency
from .export import get_export_module, encode_metadata_into_ir_module
assert self.export_provider is not None, (
"Export provider is not set for JitCompiledFunction."
)
export_module = get_export_module(self.ir_module, function_prefix)
export_module = encode_metadata_into_ir_module(
function_prefix,
export_module,
self.args_spec.args_spec,
self.function_name,
self.kernel_info,
self.export_provider.arg_spec_processor,
self.export_provider.object_file_version,
)
cubin_data = None
def strip_gpu_binary_op(op):
if op.name == "gpu.binary":
s = io.BytesIO()
op.operation.write_bytecode(s)
nonlocal cubin_data
cubin_data = s.getvalue()
cubin_data = cubin_data.split(b'bin = "')[1].split(b'">')[0]
cubin_data = get_escaped_cubin_bytes(cubin_data)
op.erase()
return ir.WalkResult.ADVANCE
return ir.WalkResult.ADVANCE
# Strip gpu related to avoid the object file generating builtin module load/unload functions
export_module.operation.walk(strip_gpu_binary_op)
cubin_suffix = "cubin"
if cubin_data is not None:
cubin_array = array.array("b", cubin_data)
with (
export_module.context,
ir.Location.unknown(),
ir.InsertionPoint(export_module.body),
):
new_binary_global_op = llvm.GlobalOp(
sym_name="_".join([function_prefix, cubin_suffix]),
global_type=ir.Type.parse(f"!llvm.array<{len(cubin_array)} x i8>"),
linkage=ir.Attribute.parse("#llvm.linkage<external>"),
value=ir.DenseIntElementsAttr.get(cubin_array),
constant=True,
)
if "gpu.container_module" in export_module.operation.attributes:
del export_module.operation.attributes["gpu.container_module"]
# Generate the object file
try:
with tempfile.NamedTemporaryFile() as tmp_object_file:
self.export_provider.mlirExecutionEngine.dump_object_file_pic(
export_module,
tmp_object_file.name,
"_".join([function_prefix, self.function_name]),
)
with open(tmp_object_file.name, "rb") as f:
ret = f.read()
return ret
except Exception as e:
raise DSLRuntimeError(f"Error dumping object file: {e}") from e
def export_to_c(
self,
file_path: str,
file_name: str,
function_prefix: str = "",
):
"""Exports the jit-compiled function to a C compatible files(header/library).
This is used for c/cpp AOT support.
The `file_path` will be used as the directory to save the header and object files.
The `file_name` will be used as the name of the header and object files. The same file name will always overwrite the existing file.
The `function_prefix` will be used as the symbol prefix of the generated functions, it is guaranteed by
the caller that the generated functions are unique.
The c header file is generated with following components:
1. The host launch entry function. And the structure definitions of the arguments.
2. The device metadata load/unload functions.
3. The cubin data array and len.
The library contains the binary of the underlying host launch entry function.
@param jit_function: The jit-compiled function from `cute.compile`.
@param file_path: The path to the directory where the header and object files will be saved.
@param file_name: The name of the header and object files.
@param function_prefix: The prefix of the function. This is the unique identifier name of the function to avoid symbol conflict in the generated object file. Default to the `file_name`.
"""
if function_prefix is None or function_prefix == "":
function_prefix = file_name
assert self.export_provider is not None, (
"Export provider is not set for JitCompiledFunction."
)
# lazy import to avoid circular dependency
from .export import get_export_module
export_module = get_export_module(self.ir_module, function_prefix)
# Generate the c header file
header_file_content = self.export_provider.c_header_generator(
function_prefix,
export_module,
self.args_spec,
self.function_name,
self.kernel_info,
self.c_header_arguments,
self.export_provider.dsl._get_dsl().name,
)
try:
with open(os.path.join(file_path, file_name + ".h"), "w") as f:
f.write(header_file_content)
except Exception as e:
raise DSLRuntimeError(f"Error writing header file: {e}") from e
# Generate the object file
object_file_content = self.dump_to_object(function_prefix)
try:
with open(os.path.join(file_path, file_name + ".o"), "wb") as f:
f.write(object_file_content)
except Exception as e:
raise DSLRuntimeError(f"Error writing object file: {e}") from e