mirror of
https://github.com/NVIDIA/cutlass.git
synced 2026-04-20 06:48:59 +00:00
v4.3.1 update. (#2817)
This commit is contained in:
@@ -9,6 +9,10 @@
|
||||
# and related documentation outside the scope permitted by the EULA
|
||||
# is strictly prohibited.
|
||||
|
||||
from ._mlir._mlir_libs import _cutlass_ir
|
||||
|
||||
_cutlass_ir.populate(_cutlass_ir)
|
||||
|
||||
from .cutlass_dsl import (
|
||||
Constexpr,
|
||||
dsl_user_op,
|
||||
|
||||
@@ -32,6 +32,9 @@ class Arch(Enum):
|
||||
sm_101 = (10, 1, "")
|
||||
sm_101a = (10, 1, "a")
|
||||
sm_101f = (10, 1, "f")
|
||||
sm_103 = (10, 3, "")
|
||||
sm_103a = (10, 3, "a")
|
||||
sm_103f = (10, 3, "f")
|
||||
sm_110 = (11, 0, "")
|
||||
sm_110a = (11, 0, "a")
|
||||
sm_110f = (11, 0, "f")
|
||||
@@ -80,6 +83,9 @@ class Arch(Enum):
|
||||
Arch.sm_101,
|
||||
Arch.sm_101a,
|
||||
Arch.sm_101f,
|
||||
Arch.sm_103,
|
||||
Arch.sm_103a,
|
||||
Arch.sm_103f,
|
||||
Arch.sm_110,
|
||||
Arch.sm_110a,
|
||||
Arch.sm_110f,
|
||||
|
||||
@@ -529,6 +529,7 @@ def _get_self_module():
|
||||
return inspect.getmodule(_get_self_module)
|
||||
|
||||
|
||||
@lru_cache(maxsize=16)
|
||||
def cf_symbol_check(symbol):
|
||||
"""
|
||||
Check if the symbol is control flow symbol from current module.
|
||||
|
||||
@@ -1344,6 +1344,25 @@ class DSLPreprocessor(ast.NodeTransformer):
|
||||
),
|
||||
node,
|
||||
)
|
||||
elif func.id == "super" and node.args == [] and node.keywords == []:
|
||||
# If it's a Python3 argument free super(), rewrite to old style super with args
|
||||
# So if this call is under dynamic control flow, it still works.
|
||||
return ast.copy_location(
|
||||
ast.Call(
|
||||
func=func,
|
||||
args=node.args
|
||||
+ [
|
||||
ast.Attribute(
|
||||
value=ast.Name(id="self", ctx=ast.Load()),
|
||||
attr="__class__",
|
||||
ctx=ast.Load(),
|
||||
),
|
||||
ast.Name(id="self", ctx=ast.Load()),
|
||||
],
|
||||
keywords=node.keywords,
|
||||
),
|
||||
node,
|
||||
)
|
||||
elif isinstance(func, ast.Attribute) and isinstance(func.value, ast.Name):
|
||||
|
||||
def create_downcast_call(arg):
|
||||
@@ -1853,21 +1872,41 @@ class DSLPreprocessor(ast.NodeTransformer):
|
||||
# Handle elif case
|
||||
elif_node = node.orelse[0]
|
||||
nested_if_name = elif_region_name
|
||||
# Recursion for nested elif
|
||||
nested_if = self.create_if_function(
|
||||
nested_if_name, elif_node, write_args, full_write_args_count
|
||||
)
|
||||
else_block = ast.FunctionDef(
|
||||
name=else_block_name,
|
||||
args=func_then_else_arguments,
|
||||
body=[
|
||||
nested_if,
|
||||
ast.Return(
|
||||
value=ast.Name(id=nested_if_name, ctx=ast.Load())
|
||||
),
|
||||
],
|
||||
decorator_list=[],
|
||||
)
|
||||
# AST cannot distinguish between the following two cases:
|
||||
# elif pred:
|
||||
# and
|
||||
# else:
|
||||
# if pred:
|
||||
# And under both cases, the `pred` can be a const_expr, so we need to handle it here.
|
||||
if self.is_node_constexpr(elif_node):
|
||||
self.generic_visit(elif_node)
|
||||
check = self._insert_cf_symbol_check(elif_node.test.func)
|
||||
else_block = ast.FunctionDef(
|
||||
name=else_block_name,
|
||||
args=func_then_else_arguments,
|
||||
body=[
|
||||
check,
|
||||
elif_node,
|
||||
ast.Return(value=return_list),
|
||||
],
|
||||
decorator_list=[],
|
||||
)
|
||||
else:
|
||||
# Recursion for nested elif
|
||||
nested_if = self.create_if_function(
|
||||
nested_if_name, elif_node, write_args, full_write_args_count
|
||||
)
|
||||
else_block = ast.FunctionDef(
|
||||
name=else_block_name,
|
||||
args=func_then_else_arguments,
|
||||
body=[
|
||||
nested_if,
|
||||
ast.Return(
|
||||
value=ast.Name(id=nested_if_name, ctx=ast.Load())
|
||||
),
|
||||
],
|
||||
decorator_list=[],
|
||||
)
|
||||
else:
|
||||
else_body = []
|
||||
for stmt in node.orelse:
|
||||
|
||||
@@ -356,7 +356,11 @@ class GPUArch(StringCompileOption):
|
||||
|
||||
|
||||
class EnableTVMFFI(EmptyCompileOption):
|
||||
option_name = "enable-tvm-ffi"
|
||||
pass
|
||||
|
||||
|
||||
class DumpDir(EmptyCompileOption):
|
||||
option_name = "dump-dir"
|
||||
|
||||
|
||||
class CompileOptions:
|
||||
@@ -380,6 +384,7 @@ class CompileOptions:
|
||||
GPUArch: GPUArch(""),
|
||||
LinkLibraries: LinkLibraries(""),
|
||||
EnableTVMFFI: EnableTVMFFI(False),
|
||||
DumpDir: DumpDir(""),
|
||||
}
|
||||
|
||||
if options is not None:
|
||||
@@ -416,19 +421,24 @@ class CompileOptions:
|
||||
if self.options[GPUArch].value == ""
|
||||
else self.options[GPUArch].value
|
||||
)
|
||||
dump_dir = (
|
||||
envar.dump_dir
|
||||
if self.options[DumpDir].value == ""
|
||||
else self.options[DumpDir].value
|
||||
)
|
||||
if self.options[KeepPTX].value:
|
||||
self.options[KeepPTX].dump_path = os.path.join(
|
||||
envar.dump_dir, f"{function_name}"
|
||||
dump_dir, f"{function_name}"
|
||||
)
|
||||
self.options[KeepPTX].full_ptx_path = os.path.join(
|
||||
envar.dump_dir, f"{function_name}.{arch}.ptx"
|
||||
dump_dir, f"{function_name}.{arch}.ptx"
|
||||
)
|
||||
if self.options[KeepCUBIN].value:
|
||||
self.options[KeepCUBIN].dump_path = os.path.join(
|
||||
envar.dump_dir, f"{function_name}"
|
||||
dump_dir, f"{function_name}"
|
||||
)
|
||||
self.options[KeepCUBIN].full_cubin_path = os.path.join(
|
||||
envar.dump_dir, f"{function_name}.{arch}.cubin"
|
||||
dump_dir, f"{function_name}.{arch}.cubin"
|
||||
)
|
||||
|
||||
@property
|
||||
@@ -504,6 +514,7 @@ def _parse_compile_options_from_str(options: str) -> CompileOptions:
|
||||
"keep_ptx": KeepPTX,
|
||||
"gpu_arch": GPUArch,
|
||||
"enable_tvm_ffi": EnableTVMFFI,
|
||||
"dump_dir": DumpDir,
|
||||
}
|
||||
return mapping[option_str]
|
||||
|
||||
@@ -520,6 +531,7 @@ def _parse_compile_options_from_str(options: str) -> CompileOptions:
|
||||
parser.add_argument("--ptxas-options", type=str, default="")
|
||||
parser.add_argument("--gpu-arch", type=str, default="")
|
||||
parser.add_argument("--enable-tvm-ffi", action="store_true", default=False)
|
||||
parser.add_argument("--dump-dir", type=str, default="")
|
||||
compile_options = CompileOptions()
|
||||
try:
|
||||
# Use shlex to properly handle options with spaces
|
||||
@@ -545,7 +557,7 @@ class CompileCallable:
|
||||
def __init__(self, options=None):
|
||||
def preprocess_options(option):
|
||||
if type(option) is type and issubclass(
|
||||
option, (BooleanCompileOption, BooleanBasedFileDumpOption)
|
||||
option, (BooleanCompileOption, BooleanBasedFileDumpOption, EnableTVMFFI)
|
||||
):
|
||||
# Automatically creates a True instance of the option
|
||||
return option(True)
|
||||
|
||||
@@ -53,7 +53,7 @@ from .runtime.jit_arg_adapters import is_argument_constexpr, JitArgAdapterRegist
|
||||
|
||||
from .ast_preprocessor import DSLPreprocessor
|
||||
from .common import *
|
||||
from .typing import get_c_pointers, get_mlir_types, Integer, arg_compatible_with_tvm_ffi
|
||||
from .typing import get_c_pointers, get_mlir_types, Integer
|
||||
from .arch import Arch
|
||||
|
||||
# =============================================================================
|
||||
@@ -810,10 +810,6 @@ class BaseDSL:
|
||||
|
||||
if is_host:
|
||||
if self.envar.enable_tvm_ffi:
|
||||
if not arg_compatible_with_tvm_ffi(arg):
|
||||
raise DSLRuntimeError(
|
||||
f"Argument #{i + 1} ({arg_name}) is not a TVM FFI argument."
|
||||
)
|
||||
jit_exec_arg.extend([arg])
|
||||
else:
|
||||
jit_exec_arg.extend(get_c_pointers(arg))
|
||||
@@ -1218,6 +1214,8 @@ class BaseDSL:
|
||||
no_cache,
|
||||
func_type=JitCompiledFunction,
|
||||
*,
|
||||
full_args=None,
|
||||
full_kwargs=None,
|
||||
dynamic_args=None,
|
||||
dynamic_kwargs=None,
|
||||
original_function_name=None,
|
||||
@@ -1388,6 +1386,8 @@ class BaseDSL:
|
||||
pipeline,
|
||||
args_spec,
|
||||
no_cache,
|
||||
full_args=args,
|
||||
full_kwargs=kwargs,
|
||||
dynamic_args=dynamic_args,
|
||||
dynamic_kwargs=dynamic_kwargs,
|
||||
original_function_name=original_function_name,
|
||||
@@ -1400,7 +1400,7 @@ class BaseDSL:
|
||||
module_hash,
|
||||
)
|
||||
jit_function = self.jit_cache[module_hash]
|
||||
|
||||
|
||||
finally:
|
||||
self.post_compilation_cleanup()
|
||||
|
||||
|
||||
@@ -235,27 +235,20 @@ def get_prefix_dsl_libs(prefix: str):
|
||||
return prefix_libs_existing
|
||||
|
||||
def get_libs_cand(start):
|
||||
target_libs = {
|
||||
"mlir_c_runner_utils",
|
||||
"mlir_runner_utils",
|
||||
"mlir_cuda_runtime",
|
||||
target_cuda_dialect_libs = {
|
||||
"cuda_dialect_runtime",
|
||||
}
|
||||
lib_folder_guesses = [
|
||||
"lib",
|
||||
]
|
||||
|
||||
libs_cand = find_libs_in_ancestors(start, target_libs, lib_folder_guesses)
|
||||
|
||||
optional_libs_cand = find_libs_in_ancestors(
|
||||
start, {"cuda_dialect_runtime"}, lib_folder_guesses
|
||||
)
|
||||
|
||||
if libs_cand:
|
||||
dsl_libs = ":".join(libs_cand)
|
||||
if optional_libs_cand:
|
||||
dsl_libs += ":" + ":".join(optional_libs_cand)
|
||||
return dsl_libs
|
||||
|
||||
for target_libs in [
|
||||
target_cuda_dialect_libs,
|
||||
]:
|
||||
libs_cand = find_libs_in_ancestors(start, target_libs, lib_folder_guesses)
|
||||
if libs_cand:
|
||||
dsl_libs = ":".join(libs_cand)
|
||||
return dsl_libs
|
||||
return None
|
||||
|
||||
# find from install folder
|
||||
|
||||
@@ -216,6 +216,69 @@ class ExecutionArgs:
|
||||
|
||||
return exe_args, adapted_args
|
||||
|
||||
def get_rectified_args_from_original_args(self, full_args, full_kwargs):
|
||||
"""
|
||||
This function is used to rectify the original arguments to the runtime
|
||||
arguments that matched the original args_spec.
|
||||
|
||||
:param full_args: The original full arguments to filter.
|
||||
:param full_kwargs: The original full keyword arguments to filter.
|
||||
:return: The filtered arguments and keyword arguments.
|
||||
"""
|
||||
arg_spec = self.original_args_spec
|
||||
|
||||
if arg_spec.defaults:
|
||||
defaults_start_idx = len(arg_spec.args) - len(arg_spec.defaults)
|
||||
else:
|
||||
defaults_start_idx = len(arg_spec.args)
|
||||
|
||||
runtime_args = []
|
||||
# Filter arguments and maintain their properties
|
||||
for i, arg_name in enumerate(arg_spec.args):
|
||||
arg_type = arg_spec.annotations.get(arg_name, None)
|
||||
|
||||
# Skip compile-time arguments
|
||||
if is_arg_spec_constexpr(arg_type, arg_name, i, self.function_name):
|
||||
continue
|
||||
|
||||
# Keep corresponding default if it exists
|
||||
if i >= defaults_start_idx:
|
||||
default_idx = i - defaults_start_idx
|
||||
runtime_args.append(arg_spec.defaults[default_idx])
|
||||
else:
|
||||
runtime_args.append(full_args[i])
|
||||
|
||||
# Filter keyword-only arguments
|
||||
runtime_kwargs = {}
|
||||
if arg_spec.kwonlyargs:
|
||||
for i, kwarg in enumerate(arg_spec.kwonlyargs):
|
||||
arg_type = arg_spec.annotations.get(kwarg, None)
|
||||
|
||||
# Skip compile-time arguments
|
||||
if is_arg_spec_constexpr(arg_type, kwarg, i, self.function_name):
|
||||
continue
|
||||
|
||||
# Keep runtime keyword-only arguments
|
||||
if kwarg in full_kwargs:
|
||||
runtime_kwargs[kwarg] = full_kwargs[kwarg]
|
||||
elif arg_spec.kwonlydefaults and kwarg in arg_spec.kwonlydefaults:
|
||||
runtime_kwargs[kwarg] = arg_spec.kwonlydefaults[kwarg]
|
||||
|
||||
if (len(runtime_args) != len(self.args_spec.args) or
|
||||
len(runtime_kwargs) != len(self.args_spec.kwonlyargs)):
|
||||
raise DSLRuntimeError(
|
||||
"input args/kwargs length does not match runtime function signature!",
|
||||
context={
|
||||
"input args length": len(runtime_args),
|
||||
"input kwargs length": len(runtime_kwargs),
|
||||
"function signature args length": len(self.args_spec.args),
|
||||
"function signature kwonlyargs length": len(self.args_spec.kwonlyargs),
|
||||
},
|
||||
)
|
||||
|
||||
return runtime_args + list(runtime_kwargs.values())
|
||||
|
||||
|
||||
def filter_runtime_arg_spec(self, arg_spec: inspect.FullArgSpec):
|
||||
runtime_args = []
|
||||
runtime_annotations = {}
|
||||
@@ -562,6 +625,13 @@ class JitCompiledFunction:
|
||||
kernel_modules[sym] = CudaModuleAndKernel(sym, cubin_module, kernel, attrs)
|
||||
return list(kernel_modules.values())
|
||||
|
||||
def _validate_engine(self):
|
||||
if self.engine is None:
|
||||
raise DSLRuntimeError(
|
||||
"The compiled function does not have a valid execution engine.",
|
||||
suggestion="For cross-compilation, please use `cute.export.export_to_c` to serialize the compiled function and load/execute it on target device.",
|
||||
)
|
||||
|
||||
def to(self, device=None) -> JitExecutor:
|
||||
"""Returns an executable function bound to the given device.
|
||||
|
||||
@@ -573,6 +643,7 @@ class JitCompiledFunction:
|
||||
:return: A callable executor function.
|
||||
:rtype: JitExecutor
|
||||
"""
|
||||
self._validate_engine()
|
||||
with self._executor_lock:
|
||||
# We need to ensure that the modules are loaded if not already
|
||||
if self.jit_module is None:
|
||||
@@ -621,4 +692,4 @@ class JitCompiledFunction:
|
||||
# object alive as it hold a reference to self.
|
||||
proxy_self = weakref.proxy(self)
|
||||
self._default_executor = proxy_self.to(None)
|
||||
self._default_executor.run_compiled_program(exe_args)
|
||||
return self._default_executor.run_compiled_program(exe_args)
|
||||
|
||||
@@ -20,6 +20,18 @@ from ..._mlir.dialects import llvm
|
||||
from .tvm_ffi_builder import CallContext, CallProvider, TVMFFIBuilder
|
||||
|
||||
|
||||
def _flatten_tuple_params(params: list[spec.Param]) -> list[spec.Param]:
|
||||
"""Recursively flatten TupleParam into list of params."""
|
||||
flattened = []
|
||||
for param in params:
|
||||
if isinstance(param, spec.TupleParam):
|
||||
# Recursively flatten nested tuples
|
||||
flattened.extend(_flatten_tuple_params(param.params))
|
||||
else:
|
||||
flattened.append(param)
|
||||
return flattened
|
||||
|
||||
|
||||
class NopCallProvider(CallProvider):
|
||||
"""No-op call provider for testing purposes."""
|
||||
|
||||
@@ -53,7 +65,11 @@ class DynamicParamPackCallProvider(CallProvider, TVMFFIBuilder):
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, target_func: str, include_num_args: bool = False, struct_call: bool = False,
|
||||
self,
|
||||
target_func: str,
|
||||
include_num_args: bool = False,
|
||||
struct_call: bool = False,
|
||||
flatten_tuple_params: bool = True,
|
||||
) -> None:
|
||||
import tvm_ffi
|
||||
|
||||
@@ -61,8 +77,12 @@ class DynamicParamPackCallProvider(CallProvider, TVMFFIBuilder):
|
||||
self.target_func = target_func
|
||||
self.include_num_args = include_num_args
|
||||
self.struct_call = struct_call
|
||||
self.flatten_tuple_params = flatten_tuple_params
|
||||
self.float4x2_dtype = tvm_ffi.dtype("float4_e2m1fnx2")
|
||||
|
||||
if not self.flatten_tuple_params:
|
||||
raise RuntimeError("flatten_tuple_params=False is not supported yet")
|
||||
|
||||
def get_callee_struct_for_param_tensor(
|
||||
self,
|
||||
param: spec.Tensor,
|
||||
@@ -157,8 +177,15 @@ class DynamicParamPackCallProvider(CallProvider, TVMFFIBuilder):
|
||||
self, current_block: ir.Block, context: CallContext
|
||||
) -> list[tuple[ir.Type, ir.Value]]:
|
||||
"""Pack a parameter to a struct."""
|
||||
# Flatten TupleParam into list of params if enabled
|
||||
if self.flatten_tuple_params:
|
||||
flattened_params = _flatten_tuple_params(context.params)
|
||||
else:
|
||||
flattened_params = context.params
|
||||
|
||||
# Pack each parameter
|
||||
packed_params = []
|
||||
for param in context.params:
|
||||
for param in flattened_params:
|
||||
if isinstance(param, spec.Tensor):
|
||||
packed_params.append(
|
||||
self.pack_param_tensor(current_block, context, param)
|
||||
@@ -177,6 +204,9 @@ class DynamicParamPackCallProvider(CallProvider, TVMFFIBuilder):
|
||||
packed_params.append(
|
||||
self.pack_param_var(current_block, context, param.var)
|
||||
)
|
||||
elif isinstance(param, spec.ConstNone):
|
||||
# const none is not packed
|
||||
continue
|
||||
else:
|
||||
raise NotImplementedError(f"Unsupported parameter type: {type(param)}")
|
||||
return packed_params
|
||||
|
||||
@@ -302,7 +302,10 @@ class MLIRBuilder(MLIRTypeBuilder):
|
||||
cond: ir.Value,
|
||||
true_block: ir.Block,
|
||||
false_block: ir.Block,
|
||||
branch_weights=None
|
||||
*,
|
||||
branch_weights=None,
|
||||
true_dest_operands: Sequence[ir.Value] = (),
|
||||
false_dest_operands: Sequence[ir.Value] = (),
|
||||
) -> None:
|
||||
"""Create a conditional branch.
|
||||
|
||||
@@ -318,6 +321,10 @@ class MLIRBuilder(MLIRTypeBuilder):
|
||||
Optional branch weights [true_weight, false_weight] for optimization hints.
|
||||
Higher values indicate higher probability. For example, (99, 1) indicates
|
||||
the true branch is much more likely than the false branch.
|
||||
true_dest_operands : Sequence[ir.Value]
|
||||
Operands to pass to the true destination block.
|
||||
false_dest_operands : Sequence[ir.Value]
|
||||
Operands to pass to the false destination block.
|
||||
"""
|
||||
if branch_weights is not None:
|
||||
# Branch weights should be a tuple/list of two integers [true_weight, false_weight]
|
||||
@@ -325,8 +332,8 @@ class MLIRBuilder(MLIRTypeBuilder):
|
||||
raise ValueError("branch_weights must have exactly 2 elements")
|
||||
llvm.cond_br(
|
||||
cond,
|
||||
true_dest_operands=[],
|
||||
false_dest_operands=[],
|
||||
true_dest_operands=true_dest_operands,
|
||||
false_dest_operands=false_dest_operands,
|
||||
true_dest=true_block,
|
||||
false_dest=false_block,
|
||||
branch_weights=ir.DenseI32ArrayAttr.get(list(branch_weights))
|
||||
@@ -334,8 +341,8 @@ class MLIRBuilder(MLIRTypeBuilder):
|
||||
else:
|
||||
llvm.cond_br(
|
||||
cond,
|
||||
true_dest_operands=[],
|
||||
false_dest_operands=[],
|
||||
true_dest_operands=true_dest_operands,
|
||||
false_dest_operands=false_dest_operands,
|
||||
true_dest=true_block,
|
||||
false_dest=false_block,
|
||||
)
|
||||
@@ -401,6 +408,17 @@ class MLIRBuilder(MLIRTypeBuilder):
|
||||
)
|
||||
func_op.attributes["llvm.linkage"] = ir.StringAttr.get("external")
|
||||
|
||||
def create_alloca(self, entry_block: ir.Block, alloca_type: ir.Type, array_size: int) -> ir.Value:
|
||||
"""Create an alloca operation."""
|
||||
with ir.InsertionPoint(entry_block.operations[0]):
|
||||
# declare the struct type
|
||||
alloca = llvm.alloca(
|
||||
res=self.ptr_type,
|
||||
elem_type=alloca_type,
|
||||
array_size=self.i32(array_size),
|
||||
)
|
||||
return alloca
|
||||
|
||||
def pack_values_to_alloca(
|
||||
self,
|
||||
current_block: ir.Block,
|
||||
@@ -423,14 +441,11 @@ class MLIRBuilder(MLIRTypeBuilder):
|
||||
tuple[ir.Type, ir.Value]
|
||||
The struct type and the alloca.
|
||||
"""
|
||||
with ir.InsertionPoint(entry_block.operations[0]):
|
||||
# declare the struct type
|
||||
struct_type = self.struct_type(fields=[value.type for value in values])
|
||||
alloca = llvm.alloca(
|
||||
res=self.ptr_type,
|
||||
elem_type=struct_type,
|
||||
array_size=self.i32(1),
|
||||
)
|
||||
# Declare the struct type from the values
|
||||
struct_type = self.struct_type(fields=[value.type for value in values])
|
||||
|
||||
# Create alloca using the helper method
|
||||
alloca = self.create_alloca(entry_block, struct_type, array_size=1)
|
||||
|
||||
with ir.InsertionPoint(current_block):
|
||||
for index, value in enumerate(values):
|
||||
|
||||
@@ -136,7 +136,10 @@ class Shape(Param):
|
||||
name: str
|
||||
shape: list[Union[int, Var]]
|
||||
|
||||
def __init__(self, name: str, shape: list[Union[int, Var]]) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
name: str, shape: list[Union[int, Var]],
|
||||
) -> None:
|
||||
"""Initialize a Shape parameter.
|
||||
|
||||
Parameters
|
||||
@@ -146,6 +149,9 @@ class Shape(Param):
|
||||
shape : list[int | Var]
|
||||
The shape of the parameter.
|
||||
|
||||
unpack_shape: bool
|
||||
Whether to unpack the shape into list of arguments when calling
|
||||
the call provider function.
|
||||
"""
|
||||
self.name = name
|
||||
self.shape = shape
|
||||
@@ -301,6 +307,106 @@ class DataPointer(Param):
|
||||
self.address_space = address_space
|
||||
|
||||
|
||||
class ConstNone(Param):
|
||||
"""ConstNone parameter.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
name : str
|
||||
The parameter name.
|
||||
"""
|
||||
name: str
|
||||
|
||||
def __init__(self, name: str) -> None:
|
||||
"""Initialize a ConstExpr parameter.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
name : str
|
||||
The parameter name.
|
||||
"""
|
||||
self.name = name
|
||||
|
||||
|
||||
class TupleParam(Param):
|
||||
"""Tuple parameter.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
name : str
|
||||
The parameter name.
|
||||
"""
|
||||
name: str
|
||||
params: list[Param]
|
||||
|
||||
def __init__(self, name: str, params: list[Param]) -> None:
|
||||
"""Initialize a TupleParam parameter.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
name : str
|
||||
The parameter name.
|
||||
params : list[Param]
|
||||
The parameters of the tuple.
|
||||
"""
|
||||
self.name = name
|
||||
self.params = params
|
||||
|
||||
|
||||
def format_param_type(param: Param) -> str:
|
||||
"""Format a parameter type as a string, recursively handling nested types.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
param : Param
|
||||
The parameter to format.
|
||||
|
||||
Returns
|
||||
-------
|
||||
str
|
||||
The formatted type string.
|
||||
|
||||
Raises
|
||||
------
|
||||
TypeError
|
||||
If an unsupported parameter type is encountered.
|
||||
"""
|
||||
if isinstance(param, Var):
|
||||
return str(param.dtype)
|
||||
elif isinstance(param, Tensor):
|
||||
# Format tensor shape
|
||||
shape_strs = []
|
||||
for dim in param.shape:
|
||||
if isinstance(dim, Var):
|
||||
shape_strs.append(dim.name)
|
||||
else:
|
||||
shape_strs.append(str(dim))
|
||||
shape_str = "[" + ", ".join(shape_strs) + "]"
|
||||
return f"Tensor({shape_str}, {param.dtype})"
|
||||
elif isinstance(param, Shape):
|
||||
# Format shape parameter
|
||||
shape_strs = []
|
||||
for dim in param.shape:
|
||||
if isinstance(dim, Var):
|
||||
shape_strs.append(dim.name)
|
||||
else:
|
||||
shape_strs.append(str(dim))
|
||||
shape_str = "[" + ", ".join(shape_strs) + "]"
|
||||
return f"Shape({shape_str})"
|
||||
elif isinstance(param, Stream):
|
||||
return "Stream"
|
||||
elif isinstance(param, DataPointer):
|
||||
return "DataPointer"
|
||||
elif isinstance(param, ConstNone):
|
||||
return "None"
|
||||
elif isinstance(param, TupleParam):
|
||||
# Recursively format tuple elements
|
||||
element_types = [format_param_type(p) for p in param.params]
|
||||
return f"Tuple[{', '.join(element_types)}]"
|
||||
else:
|
||||
raise TypeError(f"Unsupported parameter type: {type(param)}")
|
||||
|
||||
|
||||
def signature(name: str, params: list[Param]) -> str:
|
||||
"""Generate a function signature string from name and parameters.
|
||||
|
||||
@@ -325,39 +431,13 @@ def signature(name: str, params: list[Param]) -> str:
|
||||
param_strs = []
|
||||
|
||||
for param in params:
|
||||
if isinstance(param, Var):
|
||||
param_str = f"{param.name}: {param.dtype}"
|
||||
elif isinstance(param, Tensor):
|
||||
# Format tensor shape
|
||||
shape_strs = []
|
||||
for dim in param.shape:
|
||||
if isinstance(dim, Var):
|
||||
shape_strs.append(dim.name)
|
||||
else:
|
||||
shape_strs.append(str(dim))
|
||||
shape_str = "[" + ", ".join(shape_strs) + "]"
|
||||
param_str = f"{param.name}: Tensor({shape_str}, {param.dtype})"
|
||||
elif isinstance(param, Shape):
|
||||
# Format shape parameter
|
||||
shape_strs = []
|
||||
for dim in param.shape:
|
||||
if isinstance(dim, Var):
|
||||
shape_strs.append(dim.name)
|
||||
else:
|
||||
shape_strs.append(str(dim))
|
||||
shape_str = "[" + ", ".join(shape_strs) + "]"
|
||||
param_str = f"{param.name}: Shape({shape_str})"
|
||||
elif isinstance(param, Stream):
|
||||
param_str = f"{param.name}: Stream"
|
||||
elif isinstance(param, DataPointer):
|
||||
param_str = f"{param.name}: DataPointer"
|
||||
elif isinstance(param, EnvStream):
|
||||
if isinstance(param, EnvStream):
|
||||
# env stream is not part of the FFI function signature
|
||||
# continue to skip append
|
||||
continue
|
||||
else:
|
||||
raise TypeError(f"Unsupported parameter type: {type(param)}")
|
||||
|
||||
param_type = format_param_type(param)
|
||||
param_str = f"{param.name}: {param_type}"
|
||||
param_strs.append(param_str)
|
||||
|
||||
return f"{name}({', '.join(param_strs)})"
|
||||
|
||||
@@ -27,6 +27,64 @@ from .mlir_builder import MLIRBuilder
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass
|
||||
class ArgContext:
|
||||
"""Context information for parameter decoding error messages.
|
||||
|
||||
:ivar param_name: The name of the parameter.
|
||||
:vartype param_name: str
|
||||
:ivar arg_index: The index of the argument in the function call.
|
||||
:vartype arg_index: int
|
||||
:ivar tuple_indices: List of tuple indices for nested tuple access (e.g., [0, 1] for tuple[0][1]).
|
||||
:vartype tuple_indices: list[int]
|
||||
"""
|
||||
param_name: str
|
||||
arg_index: int
|
||||
tuple_indices: list[int]
|
||||
|
||||
def get(self) -> list[str]:
|
||||
"""Get the context as a list of strings for error messages.
|
||||
|
||||
:returns: Context strings like ["on argument ", "#0"] or ["on my_tuple[0][1] in argument ", "#0"].
|
||||
:rtype: list[str]
|
||||
"""
|
||||
if not self.tuple_indices:
|
||||
# Top-level argument: "on argument #0"
|
||||
return ["on argument ", f"#{self.arg_index}"]
|
||||
else:
|
||||
# Nested tuple element: "on my_tuple[0][1] in argument #0"
|
||||
indices_str = "".join(f"[{i}]" for i in self.tuple_indices)
|
||||
return [f"on {self.param_name}{indices_str} in argument ", f"#{self.arg_index}"]
|
||||
|
||||
def get_field_name(self, field_suffix: str) -> str:
|
||||
"""Get the field name with tuple indices for shape/stride access.
|
||||
|
||||
:param field_suffix: The field suffix (e.g., ".shape", ".strides").
|
||||
:type field_suffix: str
|
||||
:returns: Field name like "my_param.shape" or "my_tuple[0][1].shape".
|
||||
:rtype: str
|
||||
"""
|
||||
if not self.tuple_indices:
|
||||
return f"{self.param_name}{field_suffix}"
|
||||
else:
|
||||
indices_str = "".join(f"[{i}]" for i in self.tuple_indices)
|
||||
return f"{self.param_name}{indices_str}{field_suffix}"
|
||||
|
||||
def get_element_context(self, element_index: int) -> "ArgContext":
|
||||
"""Create a nested context for a tuple element.
|
||||
|
||||
:param element_index: The index within the tuple.
|
||||
:type element_index: int
|
||||
:returns: New context for the nested element.
|
||||
:rtype: ArgContext
|
||||
"""
|
||||
return ArgContext(
|
||||
param_name=self.param_name,
|
||||
arg_index=self.arg_index,
|
||||
tuple_indices=self.tuple_indices + [element_index],
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class CallContext:
|
||||
"""Call context that contains the information of the call."""
|
||||
@@ -119,7 +177,7 @@ class TVMFFIBuilder(MLIRBuilder):
|
||||
super().__init__()
|
||||
# this is a number we can tune to minimize the register size
|
||||
# it is 6 by default to minimize the register size
|
||||
self.set_raised_from_cstr_parts_max_num_parts = 6
|
||||
self.set_raised_from_cstr_parts_max_num_parts = 8
|
||||
self.set_raised_from_cstr_parts_cache: dict[int, str] = {}
|
||||
self.tvm_ffi_any_type = self.struct_type(
|
||||
name="TVMFFIAny",
|
||||
@@ -488,26 +546,28 @@ class TVMFFIBuilder(MLIRBuilder):
|
||||
) -> ir.Value:
|
||||
"""Downcast i64 to lower bits."""
|
||||
overflow_flags = llvm.IntegerOverflowFlags.none
|
||||
if (hasattr(tvm_ffi._dtype.DataTypeCode, "BOOL") and
|
||||
target_dtype.type_code == tvm_ffi._dtype.DataTypeCode.BOOL):
|
||||
# LLVM use i1 (boolean) for boolean
|
||||
return llvm.icmp(llvm.ICmpPredicate.ne, v_int64, self.i64(0))
|
||||
if target_dtype.bits == 64:
|
||||
result = v_int64
|
||||
elif target_dtype.bits == 32:
|
||||
result = llvm.trunc(
|
||||
return v_int64
|
||||
if target_dtype.bits == 32:
|
||||
return llvm.trunc(
|
||||
res=self.i32_type, arg=v_int64, overflow_flags=overflow_flags
|
||||
)
|
||||
elif target_dtype.bits == 16:
|
||||
result = llvm.trunc(
|
||||
if target_dtype.bits == 16:
|
||||
return llvm.trunc(
|
||||
res=self.i16_type, arg=v_int64, overflow_flags=overflow_flags
|
||||
)
|
||||
elif target_dtype.bits == 8:
|
||||
result = llvm.trunc(
|
||||
if target_dtype.bits == 8:
|
||||
return llvm.trunc(
|
||||
res=self.i8_type, arg=v_int64, overflow_flags=overflow_flags
|
||||
)
|
||||
elif target_dtype.bits == 1:
|
||||
if target_dtype.bits == 1:
|
||||
# For i1 (boolean), convert i64 to boolean by checking if non-zero
|
||||
result = llvm.icmp(llvm.ICmpPredicate.ne, v_int64, self.i64(0))
|
||||
else:
|
||||
raise ValueError(f"Unsupported Var dtype: {target_dtype}")
|
||||
return result
|
||||
return llvm.icmp(llvm.ICmpPredicate.ne, v_int64, self.i64(0))
|
||||
raise ValueError(f"Unsupported Var dtype: {target_dtype}")
|
||||
|
||||
def is_contiguous(
|
||||
self,
|
||||
@@ -767,8 +827,40 @@ class TVMFFIFunctionBuilder(TVMFFIBuilder):
|
||||
self.matched_var_binding = {}
|
||||
self.matched_var_source = {}
|
||||
|
||||
def find_or_declare_extern_func(
|
||||
self, name: str, params: Sequence[ir.Type], ret: ir.Type
|
||||
) -> None:
|
||||
"""Find an existing extern function or declare it if it doesn't exist.
|
||||
|
||||
This method checks if a function with the given name already exists in the module.
|
||||
If it does, the method returns without doing anything. Otherwise, it declares
|
||||
the function as an external function.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
name : str
|
||||
The name of the extern function.
|
||||
params : Sequence[ir.Type]
|
||||
The parameter types of the function.
|
||||
ret : ir.Type
|
||||
The return type of the function.
|
||||
"""
|
||||
# Check if the function already exists
|
||||
existing_func = self.find_func_in_module(self.module, name)
|
||||
if existing_func is not None:
|
||||
# Function already declared, nothing to do
|
||||
return
|
||||
|
||||
# Function doesn't exist, declare it
|
||||
self.declare_extern_func(name, params, ret)
|
||||
|
||||
def decode_param_int(
|
||||
self, current_block: ir.Block, param: spec.Var, args: ir.Value, arg_index: int
|
||||
self,
|
||||
current_block: ir.Block,
|
||||
param: spec.Var,
|
||||
args: ir.Value,
|
||||
arg_index: int,
|
||||
arg_context: ArgContext,
|
||||
) -> ir.Block:
|
||||
"""Decode the integer parameter at the given index."""
|
||||
# read the type index
|
||||
@@ -785,8 +877,8 @@ class TVMFFIFunctionBuilder(TVMFFIBuilder):
|
||||
lambda: is_int_or_bool,
|
||||
"TypeError",
|
||||
[
|
||||
"Mismatched type on argument ",
|
||||
f"#{arg_index}",
|
||||
"Mismatched type ",
|
||||
*arg_context.get(),
|
||||
self._fn_call_context,
|
||||
", expected int",
|
||||
],
|
||||
@@ -801,14 +893,19 @@ class TVMFFIFunctionBuilder(TVMFFIBuilder):
|
||||
param,
|
||||
v_int64,
|
||||
[
|
||||
"value on argument ",
|
||||
f"#{arg_index}",
|
||||
"value ",
|
||||
*arg_context.get(),
|
||||
self._fn_call_context,
|
||||
],
|
||||
)
|
||||
|
||||
def decode_param_float(
|
||||
self, current_block: ir.Block, param: spec.Var, args: ir.Value, arg_index: int
|
||||
self,
|
||||
current_block: ir.Block,
|
||||
param: spec.Var,
|
||||
args: ir.Value,
|
||||
arg_index: int,
|
||||
arg_context: ArgContext,
|
||||
) -> ir.Block:
|
||||
"""Decode the float parameter at the given index."""
|
||||
# read the type index
|
||||
@@ -891,8 +988,8 @@ class TVMFFIFunctionBuilder(TVMFFIBuilder):
|
||||
self.raise_error_and_return(
|
||||
"TypeError",
|
||||
[
|
||||
"Mismatched type on argument ",
|
||||
f"#{arg_index}",
|
||||
"Mismatched type ",
|
||||
*arg_context.get(),
|
||||
self._fn_call_context,
|
||||
", expected float",
|
||||
],
|
||||
@@ -912,6 +1009,7 @@ class TVMFFIFunctionBuilder(TVMFFIBuilder):
|
||||
param: spec.Var,
|
||||
args: ir.Value,
|
||||
arg_index: int,
|
||||
arg_context: ArgContext,
|
||||
*,
|
||||
allow_int_as_ptr: bool = False,
|
||||
address_space: Optional[int] = None,
|
||||
@@ -941,8 +1039,8 @@ class TVMFFIFunctionBuilder(TVMFFIBuilder):
|
||||
lambda: is_opaque_ptr_or_nullptr,
|
||||
"TypeError",
|
||||
[
|
||||
"Mismatched type on argument ",
|
||||
f"#{arg_index}",
|
||||
"Mismatched type ",
|
||||
*arg_context.get(),
|
||||
self._fn_call_context,
|
||||
expect_message,
|
||||
],
|
||||
@@ -959,6 +1057,37 @@ class TVMFFIFunctionBuilder(TVMFFIBuilder):
|
||||
|
||||
return current_block
|
||||
|
||||
def decode_param_const_none(
|
||||
self,
|
||||
current_block: ir.Block,
|
||||
param: spec.ConstNone,
|
||||
args: ir.Value,
|
||||
arg_index: int,
|
||||
arg_context: ArgContext,
|
||||
) -> ir.Block:
|
||||
"""Decode the opaque handle parameter at the given index."""
|
||||
# read the type index
|
||||
with ir.InsertionPoint(current_block):
|
||||
type_index: ir.Value = self.load_ffi_any_array_item_type_index(args, arg_index)
|
||||
# Check if type is a nullptr
|
||||
is_nullptr = self.equal(type_index, self.i32(TVMFFITypeIndex.kTVMFFINone))
|
||||
|
||||
expect_message = ", expected None"
|
||||
|
||||
# Break error message into reusable parts for better string deduplication
|
||||
current_block = self.check_condition(
|
||||
current_block,
|
||||
lambda: is_nullptr,
|
||||
"TypeError",
|
||||
[
|
||||
"Mismatched type ",
|
||||
*arg_context.get(),
|
||||
self._fn_call_context,
|
||||
expect_message,
|
||||
],
|
||||
)
|
||||
return current_block
|
||||
|
||||
def check_int_value_dtype_bound(
|
||||
self,
|
||||
current_block: ir.Block,
|
||||
@@ -1105,9 +1234,16 @@ class TVMFFIFunctionBuilder(TVMFFIBuilder):
|
||||
*error_msg_context,
|
||||
f", expected to be {var}"
|
||||
]
|
||||
|
||||
def check_value_mismatch() -> ir.Value:
|
||||
cond = self.equal(value, expected_value)
|
||||
if skip_check_predicate is not None:
|
||||
cond = self.or_(skip_check_predicate, cond)
|
||||
return cond
|
||||
|
||||
return self.check_condition(
|
||||
current_block,
|
||||
lambda: self.equal(value, expected_value),
|
||||
check_value_mismatch,
|
||||
error_kind,
|
||||
error_msg_mismatch
|
||||
)
|
||||
@@ -1117,17 +1253,18 @@ class TVMFFIFunctionBuilder(TVMFFIBuilder):
|
||||
current_block: ir.Block,
|
||||
var: Union[spec.Var, int],
|
||||
value: ir.Value,
|
||||
field: str,
|
||||
arg_index: int,
|
||||
field_suffix: str,
|
||||
shape_index: int,
|
||||
arg_context: ArgContext,
|
||||
*,
|
||||
skip_check_predicate: Optional[ir.Value] = None,
|
||||
) -> ir.Block:
|
||||
"""Load the shape value from the argument or match the shape value from the parameter."""
|
||||
field_name = arg_context.get_field_name(field_suffix)
|
||||
error_msg = [
|
||||
field,
|
||||
f"[{shape_index}] on argument ",
|
||||
f"#{arg_index}",
|
||||
field_name,
|
||||
f"[{shape_index}] ",
|
||||
*arg_context.get(),
|
||||
self._fn_call_context,
|
||||
]
|
||||
return self.set_or_check_matched_var_binding(
|
||||
@@ -1135,7 +1272,11 @@ class TVMFFIFunctionBuilder(TVMFFIBuilder):
|
||||
)
|
||||
|
||||
def decode_param_shape_from_ffi_array(
|
||||
self, current_block: ir.Block, param: spec.Shape, arg_index: int, array_cell: ir.Value
|
||||
self,
|
||||
current_block: ir.Block,
|
||||
param: spec.Shape,
|
||||
arg_context: ArgContext,
|
||||
array_cell: ir.Value,
|
||||
) -> tuple[ir.Block, list[ir.Value]]:
|
||||
"""Decode the shape parameter from the TVMFFIArrayCell."""
|
||||
with ir.InsertionPoint(current_block):
|
||||
@@ -1148,8 +1289,8 @@ class TVMFFIFunctionBuilder(TVMFFIBuilder):
|
||||
lambda: self.equal(array_size, self.i64(len(param.shape))),
|
||||
"ValueError",
|
||||
[
|
||||
"Mismatched Shape on argument ",
|
||||
f"#{arg_index}",
|
||||
"Mismatched Shape ",
|
||||
*arg_context.get(),
|
||||
self._fn_call_context,
|
||||
f", expected shape size={len(param.shape)}",
|
||||
],
|
||||
@@ -1157,35 +1298,49 @@ class TVMFFIFunctionBuilder(TVMFFIBuilder):
|
||||
|
||||
# Load and validate each element of the array
|
||||
load_shapes = []
|
||||
for i in range(len(param.shape)):
|
||||
with ir.InsertionPoint(current_block):
|
||||
type_index: ir.Value = self.load_ffi_any_array_item_type_index(array_data, i)
|
||||
|
||||
def validate_and_load_shape_element(
|
||||
block: ir.Block, index: int
|
||||
) -> tuple[ir.Block, ir.Value]:
|
||||
"""Validate and load a single shape element from the array."""
|
||||
with ir.InsertionPoint(block):
|
||||
type_index: ir.Value = self.load_ffi_any_array_item_type_index(array_data, index)
|
||||
# Check if type is int or bool (both use v_int64, bool can be converted to int)
|
||||
is_int = self.equal(type_index, self.i32(TVMFFITypeIndex.kTVMFFIInt))
|
||||
is_int_val = self.equal(type_index, self.i32(TVMFFITypeIndex.kTVMFFIInt))
|
||||
|
||||
# Check that the element is an integer
|
||||
current_block = self.check_condition(
|
||||
current_block,
|
||||
lambda: is_int,
|
||||
field_name = arg_context.get_field_name("")
|
||||
block = self.check_condition(
|
||||
block,
|
||||
lambda: is_int_val,
|
||||
"TypeError",
|
||||
[
|
||||
f"Invalid shape element type ",
|
||||
f"{param.name}[{i}]",
|
||||
f" on argument ",
|
||||
f"#{arg_index}",
|
||||
"Invalid shape element type ",
|
||||
f"{field_name}[{index}]",
|
||||
" ",
|
||||
*arg_context.get(),
|
||||
self._fn_call_context,
|
||||
f", expected int",
|
||||
", expected int",
|
||||
],
|
||||
)
|
||||
|
||||
with ir.InsertionPoint(current_block):
|
||||
v_int64: ir.Value = self.load_ffi_any_array_item_v_int64(array_data, i)
|
||||
load_shapes.append(v_int64)
|
||||
with ir.InsertionPoint(block):
|
||||
v_int64: ir.Value = self.load_ffi_any_array_item_v_int64(array_data, index)
|
||||
|
||||
return block, v_int64
|
||||
|
||||
for i in range(len(param.shape)):
|
||||
current_block, v_int64 = validate_and_load_shape_element(current_block, i)
|
||||
load_shapes.append(v_int64)
|
||||
|
||||
return (current_block, load_shapes)
|
||||
|
||||
def decode_param_shape_from_ffi_shape(
|
||||
self, current_block: ir.Block, param: spec.Shape, arg_index: int, shape_cell: ir.Value
|
||||
self,
|
||||
current_block: ir.Block,
|
||||
param: spec.Shape,
|
||||
arg_context: ArgContext,
|
||||
shape_cell: ir.Value,
|
||||
) -> tuple[ir.Block, list[ir.Value]]:
|
||||
"""Decode the shape parameter from the TVMFFIShapeCell."""
|
||||
with ir.InsertionPoint(current_block):
|
||||
@@ -1198,8 +1353,8 @@ class TVMFFIFunctionBuilder(TVMFFIBuilder):
|
||||
lambda: self.equal(shape_size, self.i64(len(param.shape))),
|
||||
"ValueError",
|
||||
[
|
||||
"Mismatched Shape on argument ",
|
||||
f"#{arg_index}",
|
||||
"Mismatched Shape ",
|
||||
*arg_context.get(),
|
||||
self._fn_call_context,
|
||||
f", expected shape size={len(param.shape)}",
|
||||
],
|
||||
@@ -1211,7 +1366,12 @@ class TVMFFIFunctionBuilder(TVMFFIBuilder):
|
||||
return (current_block, load_shapes)
|
||||
|
||||
def decode_param_shape(
|
||||
self, current_block: ir.Block, param: spec.Shape, args: ir.Value, arg_index: int
|
||||
self,
|
||||
current_block: ir.Block,
|
||||
param: spec.Shape,
|
||||
args: ir.Value,
|
||||
arg_index: int,
|
||||
arg_context: ArgContext,
|
||||
) -> ir.Block:
|
||||
"""Decode the shape parameter at the given index."""
|
||||
with ir.InsertionPoint(current_block):
|
||||
@@ -1238,7 +1398,7 @@ class TVMFFIFunctionBuilder(TVMFFIBuilder):
|
||||
self.load_ffi_any_array_item_v_ptr(args, arg_index)
|
||||
)
|
||||
ffi_shape_block, load_shapes = self.decode_param_shape_from_ffi_shape(
|
||||
ffi_shape_block, param, arg_index, shape_cell
|
||||
ffi_shape_block, param, arg_context, shape_cell
|
||||
)
|
||||
with ir.InsertionPoint(ffi_shape_block):
|
||||
self.br(subsequent_block, args=load_shapes)
|
||||
@@ -1256,7 +1416,7 @@ class TVMFFIFunctionBuilder(TVMFFIBuilder):
|
||||
self.load_ffi_any_array_item_v_ptr(args, arg_index)
|
||||
)
|
||||
ffi_array_block, load_shapes = self.decode_param_shape_from_ffi_array(
|
||||
ffi_array_block, param, arg_index, array_cell_ptr
|
||||
ffi_array_block, param, arg_context, array_cell_ptr
|
||||
)
|
||||
with ir.InsertionPoint(ffi_array_block):
|
||||
self.br(subsequent_block, args=load_shapes)
|
||||
@@ -1267,8 +1427,8 @@ class TVMFFIFunctionBuilder(TVMFFIBuilder):
|
||||
self.raise_error_and_return(
|
||||
"TypeError",
|
||||
[
|
||||
"Mismatched type on argument ",
|
||||
f"#{arg_index}",
|
||||
"Mismatched type ",
|
||||
*arg_context.get(),
|
||||
self._fn_call_context,
|
||||
", expected ffi.Shape or ffi.Array",
|
||||
],
|
||||
@@ -1279,7 +1439,7 @@ class TVMFFIFunctionBuilder(TVMFFIBuilder):
|
||||
shape_values = list(subsequent_block.arguments)
|
||||
for i, dim in enumerate(param.shape):
|
||||
subsequent_block = self.set_or_check_matched_var_binding_from_shape(
|
||||
subsequent_block, dim, shape_values[i], f"{param.name}", arg_index, i
|
||||
subsequent_block, dim, shape_values[i], "", i, arg_context
|
||||
)
|
||||
|
||||
return subsequent_block
|
||||
@@ -1290,6 +1450,7 @@ class TVMFFIFunctionBuilder(TVMFFIBuilder):
|
||||
param: spec.Tensor,
|
||||
args: ir.Value,
|
||||
arg_index: int,
|
||||
arg_context: ArgContext,
|
||||
) -> tuple[ir.Block, ir.Value]:
|
||||
"""Decode tensor step0: check index and find out DLTensor*."""
|
||||
# read the type index
|
||||
@@ -1333,8 +1494,8 @@ class TVMFFIFunctionBuilder(TVMFFIBuilder):
|
||||
self.raise_error_and_return(
|
||||
"TypeError",
|
||||
[
|
||||
"Mismatched type on argument ",
|
||||
f"#{arg_index}",
|
||||
"Mismatched type ",
|
||||
*arg_context.get(),
|
||||
self._fn_call_context,
|
||||
", expected Tensor",
|
||||
],
|
||||
@@ -1352,10 +1513,11 @@ class TVMFFIFunctionBuilder(TVMFFIBuilder):
|
||||
param: spec.Tensor,
|
||||
args: ir.Value,
|
||||
arg_index: int,
|
||||
arg_context: ArgContext,
|
||||
) -> ir.Block:
|
||||
"""Decode the tensor parameter at the given index."""
|
||||
current_block, dl_tensor_ptr = self.decode_param_tensor_dltensor_ptr(
|
||||
current_block, param, args, arg_index
|
||||
current_block, param, args, arg_index, arg_context
|
||||
)
|
||||
with ir.InsertionPoint(current_block):
|
||||
data = self.load_dltensor_data_ptr(dl_tensor_ptr)
|
||||
@@ -1381,8 +1543,8 @@ class TVMFFIFunctionBuilder(TVMFFIBuilder):
|
||||
check_alignment,
|
||||
"ValueError",
|
||||
[
|
||||
"Misaligned Tensor data on argument ",
|
||||
f"#{arg_index}",
|
||||
"Misaligned Tensor data ",
|
||||
*arg_context.get(),
|
||||
self._fn_call_context,
|
||||
f", expected data alignment={param.data_alignment} bytes",
|
||||
],
|
||||
@@ -1401,8 +1563,8 @@ class TVMFFIFunctionBuilder(TVMFFIBuilder):
|
||||
lambda: self.equal(ndim, self.i32(expected_ndim)),
|
||||
"ValueError",
|
||||
[
|
||||
"Mismatched Tensor on argument ",
|
||||
f"#{arg_index}",
|
||||
"Mismatched Tensor ",
|
||||
*arg_context.get(),
|
||||
self._fn_call_context,
|
||||
f", expected ndim={expected_ndim}",
|
||||
],
|
||||
@@ -1414,8 +1576,8 @@ class TVMFFIFunctionBuilder(TVMFFIBuilder):
|
||||
lambda: self.equal(device_type, self.i32(param.dlpack_device_type)),
|
||||
"ValueError",
|
||||
[
|
||||
"Mismatched Tensor on argument ",
|
||||
f"#{arg_index}",
|
||||
"Mismatched Tensor ",
|
||||
*arg_context.get(),
|
||||
self._fn_call_context,
|
||||
f", expected device_type={param.device_type_name}",
|
||||
],
|
||||
@@ -1437,8 +1599,8 @@ class TVMFFIFunctionBuilder(TVMFFIBuilder):
|
||||
dtype_equal,
|
||||
"ValueError",
|
||||
[
|
||||
"Mismatched Tensor on argument ",
|
||||
f"#{arg_index}",
|
||||
"Mismatched Tensor ",
|
||||
*arg_context.get(),
|
||||
self._fn_call_context,
|
||||
f", expected dtype={param.dtype}",
|
||||
],
|
||||
@@ -1450,8 +1612,8 @@ class TVMFFIFunctionBuilder(TVMFFIBuilder):
|
||||
lambda: self.equal(byte_offset, self.i64(0)),
|
||||
"ValueError",
|
||||
[
|
||||
"Mismatched Tensor on argument ",
|
||||
f"#{arg_index}",
|
||||
"Mismatched Tensor ",
|
||||
*arg_context.get(),
|
||||
self._fn_call_context,
|
||||
", expected byte_offset=0",
|
||||
],
|
||||
@@ -1474,9 +1636,9 @@ class TVMFFIFunctionBuilder(TVMFFIBuilder):
|
||||
current_block,
|
||||
param.shape[index],
|
||||
load_shapes[index],
|
||||
f"{param.name}.shape",
|
||||
arg_index,
|
||||
".shape",
|
||||
index,
|
||||
arg_context,
|
||||
)
|
||||
|
||||
if param.strides is not None:
|
||||
@@ -1490,9 +1652,9 @@ class TVMFFIFunctionBuilder(TVMFFIBuilder):
|
||||
current_block,
|
||||
param.strides[index],
|
||||
load_strides[index],
|
||||
f"{param.name}.strides",
|
||||
arg_index,
|
||||
".strides",
|
||||
index,
|
||||
arg_context,
|
||||
skip_check_predicate=skip_check_predicate,
|
||||
)
|
||||
else:
|
||||
@@ -1502,8 +1664,8 @@ class TVMFFIFunctionBuilder(TVMFFIBuilder):
|
||||
lambda: self.is_contiguous(param.shape, load_shapes, load_strides),
|
||||
"ValueError",
|
||||
[
|
||||
"Mismatched Tensor on argument ",
|
||||
f"#{arg_index}",
|
||||
"Mismatched Tensor ",
|
||||
*arg_context.get(),
|
||||
self._fn_call_context,
|
||||
", expected contiguous",
|
||||
],
|
||||
@@ -1516,11 +1678,12 @@ class TVMFFIFunctionBuilder(TVMFFIBuilder):
|
||||
param: spec.Stream,
|
||||
args: ir.Value,
|
||||
arg_index: int,
|
||||
arg_context: ArgContext,
|
||||
) -> ir.Block:
|
||||
"""Decode the stream parameter at the given index."""
|
||||
# stream is decoded as opaque handle
|
||||
return self.decode_param_opaque_handle(
|
||||
current_block, param.var, args, arg_index
|
||||
current_block, param.var, args, arg_index, arg_context
|
||||
)
|
||||
|
||||
def decode_param_data_pointer(
|
||||
@@ -1529,6 +1692,7 @@ class TVMFFIFunctionBuilder(TVMFFIBuilder):
|
||||
param: spec.DataPointer,
|
||||
args: ir.Value,
|
||||
arg_index: int,
|
||||
arg_context: ArgContext,
|
||||
) -> ir.Block:
|
||||
"""Decode the data pointer parameter at the given index."""
|
||||
# data pointer is decoded as opaque handle
|
||||
@@ -1537,6 +1701,7 @@ class TVMFFIFunctionBuilder(TVMFFIBuilder):
|
||||
param.var,
|
||||
args,
|
||||
arg_index,
|
||||
arg_context,
|
||||
allow_int_as_ptr=True,
|
||||
address_space=param.address_space,
|
||||
)
|
||||
@@ -1578,35 +1743,117 @@ class TVMFFIFunctionBuilder(TVMFFIBuilder):
|
||||
expected_num_args += 1
|
||||
return expected_num_args
|
||||
|
||||
def decode_param( # noqa: PLR0911
|
||||
self, current_block: ir.Block, param: spec.Param, args: ir.Value, arg_index: int
|
||||
def decode_param_tuple(
|
||||
self,
|
||||
current_block: ir.Block,
|
||||
param: spec.TupleParam,
|
||||
args: ir.Value,
|
||||
arg_index: int,
|
||||
arg_context: ArgContext,
|
||||
) -> ir.Block:
|
||||
"""Decode the parameter at the given index."""
|
||||
"""Decode the tuple parameter at the given index."""
|
||||
# Check if type is kTVMFFIArray
|
||||
with ir.InsertionPoint(current_block):
|
||||
type_index: ir.Value = self.load_ffi_any_array_item_type_index(args, arg_index)
|
||||
is_ffi_array = self.equal(type_index, self.i32(TVMFFITypeIndex.kTVMFFIArray))
|
||||
|
||||
# Check that the type is an array
|
||||
current_block = self.check_condition(
|
||||
current_block,
|
||||
lambda: is_ffi_array,
|
||||
"TypeError",
|
||||
[
|
||||
"Mismatched type ",
|
||||
*arg_context.get(),
|
||||
self._fn_call_context,
|
||||
", expected ffi.Array for tuple",
|
||||
],
|
||||
)
|
||||
|
||||
# Load the array cell
|
||||
with ir.InsertionPoint(current_block):
|
||||
array_cell_ptr = self.get_object_cell_ptr(
|
||||
self.load_ffi_any_array_item_v_ptr(args, arg_index)
|
||||
)
|
||||
array_data = self.load_array_cell_data_ptr(array_cell_ptr)
|
||||
array_size = self.load_array_cell_size_as_i64(array_cell_ptr)
|
||||
|
||||
# Check that the array size matches the expected tuple size
|
||||
current_block = self.check_condition(
|
||||
current_block,
|
||||
lambda: self.equal(array_size, self.i64(len(param.params))),
|
||||
"ValueError",
|
||||
[
|
||||
"Mismatched tuple size ",
|
||||
*arg_context.get(),
|
||||
self._fn_call_context,
|
||||
f", expected tuple size={len(param.params)}",
|
||||
],
|
||||
)
|
||||
|
||||
# Recursively decode each element of the tuple
|
||||
for i, tuple_param in enumerate(param.params):
|
||||
# Create nested context for the tuple element
|
||||
nested_context = arg_context.get_element_context(i)
|
||||
current_block = self.decode_param(current_block, tuple_param, array_data, i, nested_context)
|
||||
|
||||
return current_block
|
||||
|
||||
def decode_param( # noqa: PLR0911
|
||||
self,
|
||||
current_block: ir.Block,
|
||||
param: spec.Param,
|
||||
args: ir.Value,
|
||||
arg_index: int,
|
||||
arg_context: ArgContext,
|
||||
) -> ir.Block:
|
||||
"""Decode the parameter at the given index.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
current_block : ir.Block
|
||||
The current IR block.
|
||||
param : spec.Param
|
||||
The parameter specification to decode.
|
||||
args : ir.Value
|
||||
The FFI arguments array.
|
||||
arg_index : int
|
||||
The index in the args array.
|
||||
arg_context : ArgContext
|
||||
Context information for error messages.
|
||||
"""
|
||||
if isinstance(param, spec.Var):
|
||||
if param.dtype.type_code == tvm_ffi._dtype.DataTypeCode.INT:
|
||||
return self.decode_param_int(current_block, param, args, arg_index)
|
||||
return self.decode_param_int(current_block, param, args, arg_index, arg_context)
|
||||
elif param.dtype.type_code == tvm_ffi._dtype.DataTypeCode.UINT:
|
||||
# UINT uses the same logic as INT since both are stored in v_int64
|
||||
return self.decode_param_int(current_block, param, args, arg_index)
|
||||
return self.decode_param_int(current_block, param, args, arg_index, arg_context)
|
||||
elif (hasattr(tvm_ffi._dtype.DataTypeCode, "BOOL") and
|
||||
param.dtype.type_code == tvm_ffi._dtype.DataTypeCode.BOOL):
|
||||
return self.decode_param_int(current_block, param, args, arg_index, arg_context)
|
||||
elif param.dtype.type_code == tvm_ffi._dtype.DataTypeCode.FLOAT:
|
||||
return self.decode_param_float(current_block, param, args, arg_index)
|
||||
return self.decode_param_float(current_block, param, args, arg_index, arg_context)
|
||||
elif param.dtype.type_code == tvm_ffi._dtype.DataTypeCode.HANDLE:
|
||||
return self.decode_param_opaque_handle(
|
||||
current_block, param, args, arg_index
|
||||
current_block, param, args, arg_index, arg_context
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported parameter type: {param.dtype.type_code}")
|
||||
elif isinstance(param, spec.Shape):
|
||||
return self.decode_param_shape(current_block, param, args, arg_index)
|
||||
return self.decode_param_shape(current_block, param, args, arg_index, arg_context)
|
||||
elif isinstance(param, spec.Tensor):
|
||||
return self.decode_param_tensor(current_block, param, args, arg_index)
|
||||
return self.decode_param_tensor(current_block, param, args, arg_index, arg_context)
|
||||
elif isinstance(param, spec.Stream):
|
||||
return self.decode_param_stream(current_block, param, args, arg_index)
|
||||
return self.decode_param_stream(current_block, param, args, arg_index, arg_context)
|
||||
elif isinstance(param, spec.EnvStream):
|
||||
# decode of env stream is deferred after we go through all parameters
|
||||
return current_block
|
||||
elif isinstance(param, spec.DataPointer):
|
||||
return self.decode_param_data_pointer(current_block, param, args, arg_index)
|
||||
return self.decode_param_data_pointer(current_block, param, args, arg_index, arg_context)
|
||||
elif isinstance(param, spec.ConstNone):
|
||||
return self.decode_param_const_none(current_block, param, args, arg_index, arg_context)
|
||||
elif isinstance(param, spec.TupleParam):
|
||||
return self.decode_param_tuple(current_block, param, args, arg_index, arg_context)
|
||||
else:
|
||||
raise ValueError(f"Unsupported parameter type: {type(param)}")
|
||||
|
||||
@@ -1652,20 +1899,20 @@ class TVMFFIFunctionBuilder(TVMFFIBuilder):
|
||||
with ir.InsertionPoint(self.module.body): # type: ignore[union-attr]
|
||||
# void TVMFFIErrorSetRaisedFromCStr(
|
||||
# const char* error_kind, const char* message);
|
||||
self.declare_extern_func(
|
||||
self.find_or_declare_extern_func(
|
||||
"TVMFFIErrorSetRaisedFromCStr",
|
||||
[self.ptr_type, self.ptr_type],
|
||||
self.void_type,
|
||||
)
|
||||
# void TVMFFIErrorSetRaisedFromCStrParts(
|
||||
# const char* error_kind, const char* messages, int32_t num_parts);
|
||||
self.declare_extern_func(
|
||||
self.find_or_declare_extern_func(
|
||||
"TVMFFIErrorSetRaisedFromCStrParts",
|
||||
[self.ptr_type, self.ptr_type, self.i32_type],
|
||||
self.void_type,
|
||||
)
|
||||
# void* TVMFFIEnvGetStream(int32_t device_type, int32_t device_id);
|
||||
self.declare_extern_func(
|
||||
self.find_or_declare_extern_func(
|
||||
"TVMFFIEnvGetStream",
|
||||
[self.i32_type, self.i32_type],
|
||||
self.ptr_type,
|
||||
@@ -1697,7 +1944,12 @@ class TVMFFIFunctionBuilder(TVMFFIBuilder):
|
||||
|
||||
# decode parameters to populate the matched var binding
|
||||
for arg_index, param in enumerate(params_list):
|
||||
current_block = self.decode_param(current_block, param, args, arg_index)
|
||||
arg_context = ArgContext(
|
||||
param_name=param.name,
|
||||
arg_index=arg_index,
|
||||
tuple_indices=[],
|
||||
)
|
||||
current_block = self.decode_param(current_block, param, args, arg_index, arg_context)
|
||||
|
||||
with ir.InsertionPoint(current_block):
|
||||
env_stream = self.find_env_stream(params_list)
|
||||
|
||||
@@ -239,19 +239,6 @@ def get_c_pointers(obj):
|
||||
return []
|
||||
|
||||
|
||||
def arg_compatible_with_tvm_ffi(arg):
|
||||
"""
|
||||
Given the `arg`, check if it is a compatible argument for TVM FFI
|
||||
"""
|
||||
import tvm_ffi
|
||||
|
||||
return (
|
||||
hasattr(arg, "__tvm_ffi_object__")
|
||||
or isinstance(arg, (int, float, bool))
|
||||
or isinstance(arg, tvm_ffi.Shape)
|
||||
)
|
||||
|
||||
|
||||
def get_mlir_types(obj):
|
||||
"""
|
||||
Given the `obj`, recursively go through it to extract all contained MLIR types
|
||||
|
||||
@@ -205,6 +205,7 @@ KeepCUBIN = _dsl.KeepCUBIN
|
||||
KeepPTX = _dsl.KeepPTX
|
||||
GPUArch = _dsl.GPUArch
|
||||
LinkLibraries = _dsl.LinkLibraries
|
||||
EnableTVMFFI = _dsl.EnableTVMFFI
|
||||
|
||||
# attach the TVM FFI ABI interface postprocessor to the DSL
|
||||
from . import _tvm_ffi_args_spec_converter
|
||||
|
||||
@@ -42,7 +42,7 @@ from .typing import (
|
||||
)
|
||||
import cuda.bindings.driver as cuda
|
||||
|
||||
from typing import List, Dict, Any, Optional
|
||||
from typing import List, Dict, Any, Optional, Tuple, get_origin, get_args
|
||||
import inspect
|
||||
|
||||
NumericToTVMFFIDtype = {
|
||||
@@ -91,6 +91,11 @@ def _get_llvm_address_space_from_memspace(
|
||||
return 1
|
||||
return None
|
||||
|
||||
def _is_gpu_memspace(
|
||||
memspace: _cute_ir.AddressSpace,
|
||||
) -> bool:
|
||||
return memspace != _cute_ir.AddressSpace.generic
|
||||
|
||||
|
||||
class SymIntId:
|
||||
def __init__(self, sym_int: SymInt):
|
||||
@@ -103,118 +108,187 @@ class SymIntId:
|
||||
return self.sym_int is other.sym_int
|
||||
|
||||
|
||||
def _tvm_ffi_args_spec_converter(
|
||||
function_name: str,
|
||||
args_spec: inspect.FullArgSpec,
|
||||
dynamic_args: List[Any],
|
||||
dynamic_kwargs: Dict[str, Any],
|
||||
):
|
||||
"""Convert cute algebra args to tvm ffi spec params.
|
||||
|
||||
This function converts the cute arguments specs to tvm ffi spec params.
|
||||
"""
|
||||
exec_args = ExecutionArgs(args_spec, function_name)
|
||||
rectified_args = exec_args.get_rectified_args(dynamic_args, dynamic_kwargs)
|
||||
arg_names = exec_args.args_spec.args + exec_args.args_spec.kwonlyargs
|
||||
class ConverterContext:
|
||||
"""Context for managing variable allocation during TVM FFI args conversion."""
|
||||
|
||||
params = []
|
||||
num_dyn_shape_vars = 0
|
||||
num_dyn_stride_vars = 0
|
||||
sym_int_id_mapping = {}
|
||||
def __init__(self):
|
||||
self.num_dyn_shape_vars = 0
|
||||
self.num_dyn_stride_vars = 0
|
||||
self.sym_int_id_mapping = {}
|
||||
|
||||
|
||||
def alloc_shape_name():
|
||||
nonlocal num_dyn_shape_vars
|
||||
name = f"n{num_dyn_shape_vars}"
|
||||
num_dyn_shape_vars += 1
|
||||
def alloc_shape_name(self) -> str:
|
||||
"""Allocate a new dynamic shape variable name."""
|
||||
name = f"n{self.num_dyn_shape_vars}"
|
||||
self.num_dyn_shape_vars += 1
|
||||
return name
|
||||
|
||||
def alloc_stride_name():
|
||||
nonlocal num_dyn_stride_vars
|
||||
name = f"s{num_dyn_stride_vars}"
|
||||
num_dyn_stride_vars += 1
|
||||
def alloc_stride_name(self) -> str:
|
||||
"""Allocate a new dynamic stride variable name."""
|
||||
name = f"s{self.num_dyn_stride_vars}"
|
||||
self.num_dyn_stride_vars += 1
|
||||
return name
|
||||
|
||||
def alloc_or_reuse_symint_var(value, name_alloc_func):
|
||||
nonlocal sym_int_id_mapping
|
||||
def alloc_or_reuse_symint_var(self, value: SymInt, name_alloc_func):
|
||||
"""Allocate or reuse a symbolic integer variable."""
|
||||
sym_int_id = SymIntId(value)
|
||||
if sym_int_id in sym_int_id_mapping:
|
||||
return sym_int_id_mapping[sym_int_id]
|
||||
if sym_int_id in self.sym_int_id_mapping:
|
||||
return self.sym_int_id_mapping[sym_int_id]
|
||||
name = name_alloc_func()
|
||||
if value.width == 32:
|
||||
dtype = NumericToTVMFFIDtype[Int32]
|
||||
else:
|
||||
dtype = NumericToTVMFFIDtype[Int64]
|
||||
var = spec.Var(name, dtype, divisibility=value.divisibility)
|
||||
sym_int_id_mapping[sym_int_id] = var
|
||||
self.sym_int_id_mapping[sym_int_id] = var
|
||||
return var
|
||||
|
||||
for arg, arg_name in zip(rectified_args, arg_names):
|
||||
arg_type = args_spec.annotations.get(arg_name, None)
|
||||
if isinstance(arg, Numeric) and arg.dtype in AcceptableNumericTypesForScalar:
|
||||
params.append(spec.Var(arg_name, NumericToTVMFFIDtype[arg.dtype]))
|
||||
elif is_cute_algebra_type(arg_type):
|
||||
shape = []
|
||||
for i in range(len(arg)):
|
||||
if isinstance(arg[i], int):
|
||||
shape.append(arg[i])
|
||||
elif isinstance(arg[i], SymInt):
|
||||
shape.append(alloc_or_reuse_symint_var(arg[i], alloc_shape_name))
|
||||
else:
|
||||
shape.append(spec.Var(alloc_shape_name(), NumericToTVMFFIDtype[arg[i].dtype]))
|
||||
params.append(spec.Shape(arg_name, shape))
|
||||
elif isinstance(arg, Tensor):
|
||||
shapes = []
|
||||
for i, dyn_mask in enumerate(arg.dynamic_shapes_mask):
|
||||
if not dyn_mask:
|
||||
shapes.append(arg.shape[i])
|
||||
elif isinstance(arg.shape[i], SymInt):
|
||||
shapes.append(alloc_or_reuse_symint_var(arg.shape[i], alloc_shape_name))
|
||||
else:
|
||||
shapes.append(spec.Var(alloc_shape_name(), NumericToTVMFFIDtype[Int32]))
|
||||
strides = []
|
||||
|
||||
for i, dyn_mask in enumerate(arg.dynamic_strides_mask):
|
||||
if not dyn_mask:
|
||||
strides.append(arg.stride[i])
|
||||
elif isinstance(arg.stride[i], SymInt):
|
||||
strides.append(alloc_or_reuse_symint_var(arg.stride[i], alloc_stride_name))
|
||||
else:
|
||||
if hasattr(arg, "_use_32bit_stride") and arg._use_32bit_stride:
|
||||
dtype = NumericToTVMFFIDtype[Int32]
|
||||
else:
|
||||
dtype = NumericToTVMFFIDtype[Int64]
|
||||
strides.append(spec.Var(alloc_stride_name(), dtype))
|
||||
def _convert_single_arg(
|
||||
arg,
|
||||
arg_name: str,
|
||||
arg_type,
|
||||
ctx: ConverterContext
|
||||
) -> spec.Param:
|
||||
"""Convert a single argument to a spec.Param.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
arg : Any
|
||||
The argument value to convert.
|
||||
arg_name : str
|
||||
The name of the argument.
|
||||
arg_type : type
|
||||
The type annotation of the argument.
|
||||
ctx : ConverterContext
|
||||
The converter context for managing variable allocation.
|
||||
|
||||
Returns
|
||||
-------
|
||||
spec.Param
|
||||
The converted parameter specification.
|
||||
"""
|
||||
if arg is None:
|
||||
return spec.ConstNone(arg_name)
|
||||
elif (isinstance(arg, Numeric) and arg.dtype in AcceptableNumericTypesForScalar):
|
||||
return spec.Var(arg_name, NumericToTVMFFIDtype[arg.dtype])
|
||||
elif arg_type in AcceptableNumericTypesForScalar:
|
||||
return spec.Var(arg_name, NumericToTVMFFIDtype[arg_type])
|
||||
elif is_cute_algebra_type(arg_type):
|
||||
shape = []
|
||||
for i in range(len(arg)):
|
||||
if isinstance(arg[i], int):
|
||||
shape.append(arg[i])
|
||||
elif isinstance(arg[i], SymInt):
|
||||
shape.append(ctx.alloc_or_reuse_symint_var(arg[i], ctx.alloc_shape_name))
|
||||
else:
|
||||
shape.append(spec.Var(ctx.alloc_shape_name(), NumericToTVMFFIDtype[arg[i].dtype]))
|
||||
return spec.Shape(arg_name, shape)
|
||||
elif isinstance(arg, Tensor):
|
||||
shapes = []
|
||||
for i, dyn_mask in enumerate(arg.dynamic_shapes_mask):
|
||||
if not dyn_mask:
|
||||
shapes.append(arg.shape[i])
|
||||
elif isinstance(arg.shape[i], SymInt):
|
||||
shapes.append(ctx.alloc_or_reuse_symint_var(arg.shape[i], ctx.alloc_shape_name))
|
||||
else:
|
||||
shapes.append(spec.Var(ctx.alloc_shape_name(), NumericToTVMFFIDtype[Int32]))
|
||||
strides = []
|
||||
|
||||
for i, dyn_mask in enumerate(arg.dynamic_strides_mask):
|
||||
if not dyn_mask:
|
||||
strides.append(arg.stride[i])
|
||||
elif isinstance(arg.stride[i], SymInt):
|
||||
strides.append(ctx.alloc_or_reuse_symint_var(arg.stride[i], ctx.alloc_stride_name))
|
||||
else:
|
||||
if hasattr(arg, "_use_32bit_stride") and arg._use_32bit_stride:
|
||||
dtype = NumericToTVMFFIDtype[Int32]
|
||||
else:
|
||||
dtype = NumericToTVMFFIDtype[Int64]
|
||||
strides.append(spec.Var(ctx.alloc_stride_name(), dtype))
|
||||
if hasattr(arg, "_tvm_ffi_tensor"):
|
||||
tvm_ffi_tensor = arg._tvm_ffi_tensor
|
||||
dtype = tvm_ffi_tensor.dtype
|
||||
tvm_ffi_cute_tensor = spec.Tensor(
|
||||
arg_name,
|
||||
shapes,
|
||||
arg._tvm_ffi_tensor.dtype,
|
||||
strides=strides,
|
||||
data_alignment=arg._assumed_align,
|
||||
device_type=tvm_ffi_tensor.device.type
|
||||
)
|
||||
else:
|
||||
# for FakeTensor, strictly follow the shape and stride from the cute tensor
|
||||
device_type = "cuda" if _is_gpu_memspace(arg.memspace) else "cpu"
|
||||
tvm_ffi_cute_tensor = spec.Tensor(
|
||||
arg_name,
|
||||
shapes,
|
||||
NumericToTVMFFIDtype[arg.element_type],
|
||||
strides=strides,
|
||||
data_alignment=arg._assumed_align,
|
||||
device_type=device_type,
|
||||
)
|
||||
if arg.element_type == Float4E2M1FN:
|
||||
tvm_ffi_cute_tensor = spec.create_map_tensor_dtype_f4x2_to_f4_spec(
|
||||
tvm_ffi_cute_tensor
|
||||
)
|
||||
params.append(tvm_ffi_cute_tensor)
|
||||
elif isinstance(arg, Pointer):
|
||||
address_space = None
|
||||
if hasattr(arg, "memspace"):
|
||||
address_space = _get_llvm_address_space_from_memspace(arg.memspace)
|
||||
params.append(spec.DataPointer(arg_name, address_space=address_space))
|
||||
elif isinstance(arg, _FakeStream):
|
||||
if arg.use_tvm_ffi_env_stream:
|
||||
params.append(spec.EnvStream(arg_name))
|
||||
else:
|
||||
params.append(spec.Stream(arg_name))
|
||||
elif isinstance(arg, cuda.CUstream):
|
||||
params.append(spec.Stream(arg_name))
|
||||
return tvm_ffi_cute_tensor
|
||||
elif isinstance(arg, Pointer) or arg_type == Pointer:
|
||||
address_space = None
|
||||
if hasattr(arg, "memspace"):
|
||||
address_space = _get_llvm_address_space_from_memspace(arg.memspace)
|
||||
return spec.DataPointer(arg_name, address_space=address_space)
|
||||
elif isinstance(arg, _FakeStream):
|
||||
if arg.use_tvm_ffi_env_stream:
|
||||
return spec.EnvStream(arg_name)
|
||||
else:
|
||||
raise DSLRuntimeError(f"Unsupported argument type: {type(arg)}")
|
||||
# The following code can obtain signature of the function
|
||||
# that maybe useful for future debugging and usecases.
|
||||
# signature = spec.signature(function_name, params)
|
||||
return spec.Stream(arg_name)
|
||||
elif isinstance(arg, cuda.CUstream):
|
||||
return spec.Stream(arg_name)
|
||||
elif arg_type is not None and get_origin(arg_type) is tuple:
|
||||
# Handle Tuple[X, Y, ...] type annotations
|
||||
tuple_element_types = get_args(arg_type)
|
||||
if not isinstance(arg, (tuple, list)):
|
||||
raise DSLRuntimeError(f"Expected tuple for argument {arg_name}, got {type(arg)}")
|
||||
if len(arg) != len(tuple_element_types):
|
||||
raise DSLRuntimeError(
|
||||
f"Tuple length mismatch for argument {arg_name}: "
|
||||
f"expected {len(tuple_element_types)}, got {len(arg)}"
|
||||
)
|
||||
|
||||
# Recursively convert each tuple element
|
||||
tuple_params = []
|
||||
for i, (elem, elem_type) in enumerate(zip(arg, tuple_element_types)):
|
||||
elem_name = f"{arg_name}[{i}]"
|
||||
elem_param = _convert_single_arg(elem, elem_name, elem_type, ctx)
|
||||
tuple_params.append(elem_param)
|
||||
|
||||
return spec.TupleParam(arg_name, tuple_params)
|
||||
else:
|
||||
raise DSLRuntimeError(f"Unsupported argument type: {type(arg)}")
|
||||
|
||||
|
||||
def _tvm_ffi_args_spec_converter(
|
||||
function_name: str,
|
||||
args_spec: inspect.FullArgSpec,
|
||||
full_args: List[Any],
|
||||
full_kwargs: Dict[str, Any],
|
||||
):
|
||||
"""Convert cute algebra args to tvm ffi spec params.
|
||||
|
||||
This function converts the cute arguments specs to tvm ffi spec params.
|
||||
"""
|
||||
exec_args = ExecutionArgs(args_spec, function_name)
|
||||
rectified_args = exec_args.get_rectified_args_from_original_args(full_args, full_kwargs)
|
||||
arg_names = exec_args.args_spec.args + exec_args.args_spec.kwonlyargs
|
||||
params = []
|
||||
ctx = ConverterContext()
|
||||
|
||||
for arg, arg_name in zip(rectified_args, arg_names):
|
||||
arg_type = args_spec.annotations.get(arg_name, None)
|
||||
param = _convert_single_arg(arg, arg_name, arg_type, ctx)
|
||||
params.append(param)
|
||||
|
||||
return params
|
||||
|
||||
|
||||
|
||||
@@ -209,6 +209,7 @@ def mbarrier_conditional_try_wait(
|
||||
def mbarrier_arrive(
|
||||
mbar_ptr: Pointer,
|
||||
peer_cta_rank_in_cluster: Optional[Int] = None,
|
||||
arrive_count: Int = 1,
|
||||
*,
|
||||
loc=None,
|
||||
ip=None,
|
||||
@@ -239,7 +240,7 @@ def mbarrier_arrive(
|
||||
|
||||
nvvm.mbarrier_txn(
|
||||
mbar_llvm_ptr,
|
||||
Int32(1).ir_value(loc=loc, ip=ip),
|
||||
Int32(arrive_count).ir_value(loc=loc, ip=ip),
|
||||
kind=nvvm.MBarrierTxnKind.ARRIVE,
|
||||
space=space,
|
||||
loc=loc,
|
||||
|
||||
@@ -201,13 +201,13 @@ class MmaOp(Tcgen05MmaOp):
|
||||
if self.cta_group == CtaGroup.ONE:
|
||||
if m not in [64, 128]:
|
||||
raise OpError(self, f"expects the M-mode to be 64 or 128, but got {m}")
|
||||
if m == 64:
|
||||
if (n < 8) or (n > 256) or (n % 8 != 0):
|
||||
if self.b_dtype.width == 8 and self.b_major_mode == OperandMajorMode.MN:
|
||||
if (n < 16) or (n > 256) or (n % 16 != 0):
|
||||
raise OpError(
|
||||
self,
|
||||
f"expects the N-mode to satisfy 8 <= N <= 256 and N % 8 == 0, but got {n}",
|
||||
f"expects the N-mode to satisfy 16 <= N <= 256 and N % 16 == 0, but got {n}",
|
||||
)
|
||||
elif m == 128:
|
||||
else:
|
||||
if (n < 8) or (n > 256) or (n % 8 != 0):
|
||||
raise OpError(
|
||||
self,
|
||||
@@ -216,11 +216,18 @@ class MmaOp(Tcgen05MmaOp):
|
||||
else:
|
||||
if m not in [128, 256]:
|
||||
raise OpError(self, f"expects the M-mode to be 128 or 256, but got {m}")
|
||||
if (n < 16) or (n > 256) or (n % 16 != 0):
|
||||
raise OpError(
|
||||
self,
|
||||
f"expects the N-mode to satisfy 16 <= N <= 256 and N % 16 == 0, but got {n}",
|
||||
)
|
||||
if self.b_dtype.width == 8 and self.b_major_mode == OperandMajorMode.MN:
|
||||
if (n < 32) or (n > 256) or (n % 32 != 0):
|
||||
raise OpError(
|
||||
self,
|
||||
f"expects the N-mode to satisfy 32 <= N <= 256 and N % 32 == 0, but got {n}",
|
||||
)
|
||||
else:
|
||||
if (n < 16) or (n > 256) or (n % 16 != 0):
|
||||
raise OpError(
|
||||
self,
|
||||
f"expects the N-mode to satisfy 16 <= N <= 256 and N % 16 == 0, but got {n}",
|
||||
)
|
||||
|
||||
def __str__(self) -> str:
|
||||
return (
|
||||
@@ -302,6 +309,7 @@ class BlockScaledMmaOp(Tcgen05MmaOp):
|
||||
|
||||
admissible_archs = [
|
||||
Arch.sm_100a,
|
||||
Arch.sm_103a,
|
||||
]
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
|
||||
@@ -10,7 +10,6 @@
|
||||
# is strictly prohibited.
|
||||
|
||||
import ctypes
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from functools import lru_cache
|
||||
@@ -20,6 +19,7 @@ from typing import Union, Optional, Type, List
|
||||
|
||||
# MLIR modules imports
|
||||
from cutlass._mlir import ir
|
||||
from cutlass.base_dsl.env_manager import get_prefix_dsl_libs
|
||||
import cutlass._mlir.dialects.cute as _cute_ir
|
||||
import cutlass._mlir.dialects.cuda as _cuda_dialect
|
||||
|
||||
@@ -144,6 +144,7 @@ class _Tensor(Tensor):
|
||||
|
||||
self._tvm_ffi_tensor = tvm_ffi.from_dlpack(tensor)
|
||||
self._dlpack_data = self._tvm_ffi_tensor.__dlpack__()
|
||||
|
||||
self._dltensor_wrapper = None
|
||||
self._assumed_align = assumed_align
|
||||
self._is_dynamic = False
|
||||
@@ -387,7 +388,16 @@ class _Tensor(Tensor):
|
||||
return CoreTensor(values[0].value, self._dtype)
|
||||
|
||||
def __tvm_ffi_object__(self):
|
||||
return self._tvm_ffi_tensor
|
||||
try:
|
||||
return self._tvm_ffi_tensor
|
||||
except AttributeError:
|
||||
raise DSLRuntimeError(
|
||||
(
|
||||
"runtime._Tensor is not a TVM-FFI tensor. "
|
||||
"Enable TVM-FFI with `from_dlpack(..., enable_tvm_ffi=True)` "
|
||||
"or `CUTE_DSL_ENABLE_TVM_FFI=1`."
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def _get_cute_type_str(inp):
|
||||
@@ -411,7 +421,8 @@ class _FakeCompactTensor(Tensor):
|
||||
self._dtype = dtype
|
||||
self._shape = shape
|
||||
self._stride_order = stride_order or tuple(range(len(shape)))
|
||||
self._memspace = memspace or AddressSpace.gmem
|
||||
# cannot use memspace or AddressSpace.gmem because AddressSpace.generic is 0
|
||||
self._memspace = memspace if memspace is not None else AddressSpace.gmem
|
||||
self._assumed_align = assumed_align or -(-dtype.width // 8)
|
||||
self._use_32bit_stride = use_32bit_stride
|
||||
|
||||
@@ -510,7 +521,8 @@ class _FakeTensor(Tensor):
|
||||
self._dtype = dtype
|
||||
self._shape = shape
|
||||
self._stride = stride
|
||||
self._memspace = memspace or AddressSpace.generic
|
||||
# cannot use memspace or AddressSpace.generic because AddressSpace.generic is 0
|
||||
self._memspace = memspace if memspace is not None else AddressSpace.gmem
|
||||
self._assumed_align = assumed_align
|
||||
if assumed_align is None:
|
||||
# use the bytes width of the element dtype. The alignment is at least one byte align.
|
||||
@@ -605,7 +617,7 @@ def make_fake_compact_tensor(
|
||||
:param shape: Shape of the tensor.
|
||||
:type shape: tuple[int, ...]
|
||||
:param stride_order: Order in which strides (memory layout) are assigned to the tensor dimensions.
|
||||
If None, the default layout is row-major. Otherwise, it should be a permutation of the dimension indices.
|
||||
If None, the default layout is col-major. Otherwise, it should be a permutation of the dimension indices.
|
||||
:type stride_order: tuple[int, ...], optional
|
||||
:param memspace: Memory space where the fake tensor resides. Optional.
|
||||
:type memspace: str, optional
|
||||
@@ -644,6 +656,7 @@ def make_fake_compact_tensor(
|
||||
use_32bit_stride=use_32bit_stride,
|
||||
)
|
||||
|
||||
|
||||
def make_fake_tensor(dtype, shape, stride, *, memspace=None, assumed_align=None):
|
||||
"""
|
||||
Create a fake tensor with the specified element type, shape, and stride.
|
||||
@@ -859,21 +872,22 @@ def find_runtime_libraries(*, enable_tvm_ffi: bool = True) -> List[str]:
|
||||
"""
|
||||
|
||||
def _get_cuda_dialect_runtime_path():
|
||||
libs = os.environ.get("CUTE_DSL_LIBS")
|
||||
if libs:
|
||||
sep = ";" if sys.platform.startswith("win32") else ":"
|
||||
for path in libs.split(sep):
|
||||
if path.endswith("libcuda_dialect_runtime.so"):
|
||||
return path
|
||||
try:
|
||||
# find package library from wheel package
|
||||
pkg_base = Path(__file__).resolve().parent.parent
|
||||
lib_path = pkg_base / "lib" / "libcuda_dialect_runtime.so"
|
||||
if lib_path.is_file():
|
||||
return str(lib_path)
|
||||
except OSError:
|
||||
libs = get_prefix_dsl_libs("CUTE_DSL")
|
||||
if libs is None:
|
||||
return None
|
||||
|
||||
# check if the separator is ; for windows
|
||||
if sys.platform.startswith("win32") and ";" in libs:
|
||||
libs = libs.split(";")
|
||||
else:
|
||||
libs = libs.split(":")
|
||||
|
||||
for path in libs:
|
||||
if path.endswith("libcuda_dialect_runtime.so"):
|
||||
return path
|
||||
|
||||
return None
|
||||
|
||||
libs = []
|
||||
cuda_dialect_runtime_path = _get_cuda_dialect_runtime_path()
|
||||
if cuda_dialect_runtime_path:
|
||||
|
||||
@@ -20,6 +20,7 @@ from typing import Type, Union, Callable, Optional, Dict, List, Any
|
||||
import cuda.bindings.driver as cuda_driver
|
||||
import cuda.bindings.runtime as cuda_runtime
|
||||
|
||||
import cutlass
|
||||
import cutlass.base_dsl.jit_executor
|
||||
from cutlass.cutlass_dsl import Constexpr, CuTeDSL, T, dsl_user_op
|
||||
|
||||
@@ -233,6 +234,7 @@ def convert(src: cute.Tensor, dst: cute.Tensor):
|
||||
src.shape[leading_mode] % elem_per_copy == 0
|
||||
and dst.shape[leading_mode] % elem_per_copy == 0
|
||||
)
|
||||
|
||||
_convert(src, dst, leading_mode, elem_per_copy)
|
||||
|
||||
|
||||
|
||||
@@ -52,6 +52,7 @@ from ..base_dsl.compiler import (
|
||||
KeepPTX,
|
||||
GPUArch,
|
||||
LinkLibraries,
|
||||
EnableTVMFFI,
|
||||
)
|
||||
from ..base_dsl.runtime.jit_arg_adapters import *
|
||||
|
||||
|
||||
@@ -21,7 +21,12 @@ import cuda.bindings.runtime as cuda_runtime
|
||||
import cuda.bindings.driver as cuda_driver
|
||||
|
||||
# Local modules imports
|
||||
from ..base_dsl.jit_executor import JitExecutor, ExecutionArgs, JitFunctionArtifacts
|
||||
from ..base_dsl.jit_executor import (
|
||||
JitExecutor,
|
||||
JitCompiledFunction,
|
||||
ExecutionArgs,
|
||||
JitFunctionArtifacts,
|
||||
)
|
||||
from ..base_dsl.utils.logger import log
|
||||
from ..base_dsl.common import DSLCudaRuntimeError, DSLRuntimeError
|
||||
from ..base_dsl.typing import Int32
|
||||
@@ -57,7 +62,7 @@ class CudaDialectJitModule:
|
||||
self.unload()
|
||||
|
||||
|
||||
class CudaDialectJitCompiledFunction:
|
||||
class CudaDialectJitCompiledFunction(JitCompiledFunction):
|
||||
"""Holds a compiled function and its module."""
|
||||
|
||||
def __init__(
|
||||
@@ -103,21 +108,6 @@ class CudaDialectJitCompiledFunction:
|
||||
self._executor_lock = threading.RLock()
|
||||
self._default_executor = None
|
||||
|
||||
@property
|
||||
def __ptx__(self):
|
||||
"""Returns the PTX code of the JIT-compiled function."""
|
||||
return self.artifacts.PTX if self.artifacts is not None else None
|
||||
|
||||
@property
|
||||
def __cubin__(self):
|
||||
"""Returns the CUBIN data of the JIT-compiled function."""
|
||||
return self.artifacts.CUBIN if self.artifacts is not None else None
|
||||
|
||||
@property
|
||||
def __mlir__(self):
|
||||
"""Returns the MLIR code of the JIT-compiled function."""
|
||||
return self.artifacts.MLIR if self.artifacts is not None else None
|
||||
|
||||
@functools.cached_property
|
||||
def num_devices(self):
|
||||
"""Returns the number of CUDA devices available."""
|
||||
@@ -285,6 +275,7 @@ class CudaDialectJitCompiledFunction:
|
||||
:return: A callable executor function.
|
||||
:rtype: JitExecutor
|
||||
"""
|
||||
super()._validate_engine()
|
||||
with self._executor_lock:
|
||||
# We need to ensure that the modules are loaded if not already
|
||||
if self.jit_module is None or self.jit_module.is_unloaded():
|
||||
@@ -297,37 +288,3 @@ class CudaDialectJitCompiledFunction:
|
||||
)
|
||||
|
||||
return JitExecutor(self.jit_module, None, self.jit_time_profiling)
|
||||
|
||||
def generate_execution_args(self, *args, **kwargs):
|
||||
return self.args_spec.generate_execution_args(args, kwargs)
|
||||
|
||||
def set_dynamic_args(self, dynamic_args, dynamic_kwargs):
|
||||
"""Sets the dynamic argument information required for export to c code generation."""
|
||||
self.dynamic_args = dynamic_args
|
||||
self.dynamic_kwargs = dynamic_kwargs
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
"""Executes the jit-compiled function under the currently active CUDA context.
|
||||
|
||||
Calling this method multiple devices is not allowed and will result in unexpected
|
||||
CUDA errors. If you need to call the kernel on multiple devices use `to`
|
||||
to return a per-device function.
|
||||
"""
|
||||
exe_args, adapted_args = self.generate_execution_args(*args, **kwargs)
|
||||
return self.run_compiled_program(exe_args)
|
||||
|
||||
def run_compiled_program(self, exe_args):
|
||||
"""Executes the jit-compiled function under the currently active CUDA context.
|
||||
|
||||
Calling this method multiple devices is not allowed and will result in unexpected
|
||||
CUDA errors. If you need to call the kernel on multiple devices use `to`
|
||||
to return a per-device function.
|
||||
"""
|
||||
with self._executor_lock:
|
||||
if self._default_executor is None:
|
||||
log().debug("Creating default executor.")
|
||||
# We use a weak reference here so that this instance does not keep this
|
||||
# object alive as it hold a reference to self.
|
||||
proxy_self = weakref.proxy(self)
|
||||
self._default_executor = proxy_self.to(None)
|
||||
return self._default_executor.run_compiled_program(exe_args)
|
||||
|
||||
@@ -42,3 +42,7 @@ class CudaDialectStreamAdapter:
|
||||
|
||||
def __get_mlir_types__(self):
|
||||
return [cuda.StreamType.get()]
|
||||
|
||||
def __cuda_stream__(self):
|
||||
# support cuda stream protocol
|
||||
return (0, int(self._arg))
|
||||
|
||||
@@ -32,6 +32,7 @@ import pkgutil
|
||||
from dataclasses import is_dataclass, fields
|
||||
from math import ceil
|
||||
from itertools import chain
|
||||
from pathlib import Path
|
||||
from collections.abc import Sequence
|
||||
import builtins
|
||||
import ctypes
|
||||
@@ -334,8 +335,15 @@ class CutlassBaseDSL(BaseDSL):
|
||||
)
|
||||
try:
|
||||
# update the version hash of the cutlass shared library
|
||||
giant_dso_name = str(
|
||||
next(
|
||||
(Path(dsl_path) / "_mlir" / "_mlir_libs").glob(
|
||||
"_cutlass_ir.cpython*"
|
||||
)
|
||||
).name
|
||||
)
|
||||
with open(
|
||||
os.path.join(dsl_path, "_mlir/_mlir_libs/libCutlassIRPythonCAPI.so"),
|
||||
os.path.join(dsl_path, f"_mlir/_mlir_libs/{giant_dso_name}"),
|
||||
"rb",
|
||||
) as f:
|
||||
while True:
|
||||
@@ -345,7 +353,7 @@ class CutlassBaseDSL(BaseDSL):
|
||||
version_hash.update(chunk)
|
||||
except Exception:
|
||||
raise DSLRuntimeError(
|
||||
"Failed to read the shared library file libCutlassIRPythonCAPI.so."
|
||||
f"Failed to read the shared library file {giant_dso_name}."
|
||||
"The file may not exist or may not be readable."
|
||||
"Please re-install the package."
|
||||
)
|
||||
@@ -386,6 +394,8 @@ class CutlassBaseDSL(BaseDSL):
|
||||
args_spec,
|
||||
no_cache,
|
||||
*,
|
||||
full_args=None,
|
||||
full_kwargs=None,
|
||||
dynamic_args=None,
|
||||
dynamic_kwargs=None,
|
||||
original_function_name=None,
|
||||
@@ -399,6 +409,8 @@ class CutlassBaseDSL(BaseDSL):
|
||||
:param pipeline: The pipeline to use for compilation.
|
||||
:param args_spec: The args spec to use for compilation.
|
||||
:param no_cache: Whether to cache the result.
|
||||
:param full_args: The full arguments to use for compilation.
|
||||
:param full_kwargs: The full keyword arguments to use for compilation.
|
||||
:param dynamic_args: The dynamic arguments to use for compilation.
|
||||
:param dynamic_kwargs: The dynamic keyword arguments to use for compilation.
|
||||
:param original_function_name: The name of the original function without mangling.
|
||||
@@ -415,7 +427,7 @@ class CutlassBaseDSL(BaseDSL):
|
||||
|
||||
assert self._tvm_ffi_args_spec_converter is not None
|
||||
tvm_ffi_spec_params = self._tvm_ffi_args_spec_converter(
|
||||
function_name, args_spec, dynamic_args, dynamic_kwargs
|
||||
function_name, args_spec, full_args, full_kwargs
|
||||
)
|
||||
tvm_ffi_provider = TVMFFICuteCallProvider(function_name)
|
||||
|
||||
@@ -445,6 +457,8 @@ class CutlassBaseDSL(BaseDSL):
|
||||
args_spec,
|
||||
no_cache,
|
||||
TVMFFIJitCompiledFunction,
|
||||
full_args=full_args,
|
||||
full_kwargs=full_kwargs,
|
||||
dynamic_args=dynamic_args,
|
||||
dynamic_kwargs=dynamic_kwargs,
|
||||
)
|
||||
@@ -457,6 +471,8 @@ class CutlassBaseDSL(BaseDSL):
|
||||
args_spec,
|
||||
no_cache,
|
||||
CudaDialectJitCompiledFunction,
|
||||
full_args=full_args,
|
||||
full_kwargs=full_kwargs,
|
||||
dynamic_args=dynamic_args,
|
||||
dynamic_kwargs=dynamic_kwargs,
|
||||
original_function_name=original_function_name,
|
||||
|
||||
@@ -17,18 +17,24 @@ from cutlass.base_dsl.tvm_ffi_builder import (
|
||||
)
|
||||
from cutlass._mlir import ir
|
||||
from cutlass._mlir.dialects import llvm
|
||||
from cutlass._mlir._mlir_libs import _execution_engine_extra
|
||||
from cutlass._mlir._mlir_libs._cutlass_ir import _execution_engine_extra
|
||||
from cutlass.cutlass_dsl.cuda_jit_executor import CudaDialectJitCompiledFunction
|
||||
from cutlass.base_dsl.common import DSLRuntimeError
|
||||
from typing import Optional
|
||||
import tvm_ffi
|
||||
|
||||
|
||||
class TVMFFICuteCallProvider(DynamicParamPackCallProvider):
|
||||
"""Cute call provider that uses cute call convention."""
|
||||
|
||||
cuda_device_index: Optional[ir.Value]
|
||||
cuda_error_handle_block: Optional[ir.Block]
|
||||
|
||||
def __init__(self, target_func: str):
|
||||
super().__init__(target_func, struct_call=True)
|
||||
self.cuda_global_state_symbol = f"__{target_func}_cuda_state"
|
||||
self.cuda_device_index = None
|
||||
self.cuda_error_handle_block = None
|
||||
|
||||
def get_callee_struct_for_param_tensor(
|
||||
self,
|
||||
@@ -41,7 +47,10 @@ class TVMFFICuteCallProvider(DynamicParamPackCallProvider):
|
||||
) -> ir.Type:
|
||||
"""Routine used to override the tensor passing struct convention"""
|
||||
with ir.InsertionPoint(current_block):
|
||||
data_type = self.gpu_ptr_type
|
||||
if param.dlpack_device_type == tvm_ffi.DLDeviceType.kDLCPU:
|
||||
data_type = self.ptr_type
|
||||
else:
|
||||
data_type = self.gpu_ptr_type
|
||||
strides_type = (
|
||||
self.struct_type(fields=[x.type for x in strides])
|
||||
if len(strides) != 1
|
||||
@@ -79,22 +88,27 @@ class TVMFFICuteCallProvider(DynamicParamPackCallProvider):
|
||||
def declare_extern_funcs(self, current_block: ir.Block, context: CallContext):
|
||||
"""Append the error handling function to the current block."""
|
||||
with ir.InsertionPoint(context.module.body):
|
||||
self.declare_extern_func(
|
||||
context.builder.find_or_declare_extern_func(
|
||||
"cuda_dialect_get_error_name",
|
||||
[self.i32_type],
|
||||
self.ptr_type,
|
||||
)
|
||||
self.declare_extern_func(
|
||||
context.builder.find_or_declare_extern_func(
|
||||
"_cudaGetDevice",
|
||||
[self.ptr_type],
|
||||
self.i32_type,
|
||||
)
|
||||
context.builder.find_or_declare_extern_func(
|
||||
"_cudaSetDevice",
|
||||
[self.i32_type],
|
||||
self.i32_type,
|
||||
)
|
||||
self.declare_extern_func(
|
||||
context.builder.find_or_declare_extern_func(
|
||||
"cuda_dialect_init_library_once",
|
||||
[self.ptr_type, self.ptr_type, self.ptr_type, self.ptr_type],
|
||||
self.i32_type,
|
||||
)
|
||||
self.declare_extern_func(
|
||||
context.builder.find_or_declare_extern_func(
|
||||
"cuda_dialect_unload_library_once",
|
||||
[self.ptr_type],
|
||||
self.void_type,
|
||||
@@ -116,7 +130,7 @@ class TVMFFICuteCallProvider(DynamicParamPackCallProvider):
|
||||
self.cuda_global_state_symbol, self.ptr_type
|
||||
)
|
||||
cuda_init_ptr = self.address_of("cuda_init", self.ptr_type)
|
||||
cuda_load_ptr = self.address_of("cuda_load", self.ptr_type)
|
||||
cuda_load_to_device_ptr = self.address_of("cuda_load_to_device", self.ptr_type)
|
||||
set_error_ptr = self.address_of(
|
||||
"TVMFFIErrorSetRaisedFromCStr", self.ptr_type
|
||||
)
|
||||
@@ -127,7 +141,7 @@ class TVMFFICuteCallProvider(DynamicParamPackCallProvider):
|
||||
callee_operands=[
|
||||
cuda_global_state_ptr,
|
||||
cuda_init_ptr,
|
||||
cuda_load_ptr,
|
||||
cuda_load_to_device_ptr,
|
||||
set_error_ptr,
|
||||
],
|
||||
op_bundle_sizes=[],
|
||||
@@ -207,34 +221,70 @@ class TVMFFICuteCallProvider(DynamicParamPackCallProvider):
|
||||
def check_cuda_error(
|
||||
self, code: ir.Value, current_block: ir.Block, context: CallContext
|
||||
):
|
||||
"""Check if the CUDA error is raised and return the error string if so."""
|
||||
"""Check if the CUDA error is raised and return the error string if so.
|
||||
|
||||
Uses a shared error handling block to avoid code duplication. The error code
|
||||
is passed as a block argument to the shared error handler.
|
||||
"""
|
||||
assert self.cuda_error_handle_block is not None
|
||||
with ir.InsertionPoint(current_block):
|
||||
# check if the call is successful
|
||||
error_block = current_block.create_after()
|
||||
success_block = error_block.create_after()
|
||||
# Check if call is successful (non-zero return code)
|
||||
success_block = current_block.create_after()
|
||||
# Check if call is successful (zero return code means success)
|
||||
self.cond_br(
|
||||
cond=self.equal(code, self.i32(0)),
|
||||
true_block=success_block,
|
||||
false_block=error_block,
|
||||
false_block=self.cuda_error_handle_block,
|
||||
branch_weights=self.BRANCH_WEIGHTS_LIKELY,
|
||||
false_dest_operands=[code], # Pass error code to shared error block
|
||||
)
|
||||
return success_block
|
||||
|
||||
def set_cuda_device_if_mismatch(
|
||||
self,
|
||||
current_block: ir.Block,
|
||||
context: CallContext,
|
||||
current_device: Optional[ir.Value],
|
||||
target_device: Optional[ir.Value],
|
||||
) -> ir.Block:
|
||||
"""Set the CUDA device index if it differs from the target device.
|
||||
"""
|
||||
# If either device is None, no switching needed
|
||||
if current_device is None:
|
||||
assert target_device is None
|
||||
return current_block
|
||||
|
||||
with ir.InsertionPoint(current_block):
|
||||
# Check if devices are different
|
||||
devices_differ = self.not_equal(current_device, target_device)
|
||||
|
||||
# Create blocks for conditional device switching
|
||||
switch_device_block = current_block.create_after()
|
||||
continuation_block = switch_device_block.create_after()
|
||||
# For this specific case, avoid branch weights for now
|
||||
# mainly to avoid too drastic reordering of the code
|
||||
self.cond_br(
|
||||
cond=devices_differ,
|
||||
true_block=switch_device_block,
|
||||
false_block=continuation_block
|
||||
)
|
||||
|
||||
# Error block: raise error and return
|
||||
with ir.InsertionPoint(error_block):
|
||||
error_str = llvm.call(
|
||||
result=self.ptr_type,
|
||||
callee="cuda_dialect_get_error_name",
|
||||
callee_operands=[code],
|
||||
# Switch device block: call cudaSetDevice
|
||||
with ir.InsertionPoint(switch_device_block):
|
||||
result = llvm.call(
|
||||
result=self.i32_type,
|
||||
callee="_cudaSetDevice",
|
||||
callee_operands=[target_device],
|
||||
op_bundle_sizes=[],
|
||||
op_bundle_operands=[],
|
||||
)
|
||||
# Raise error and return -1
|
||||
context.builder.raise_error_and_return(
|
||||
error_kind="RuntimeError",
|
||||
error_message_parts=["CUDA Error: ", error_str],
|
||||
)
|
||||
return success_block
|
||||
|
||||
# Check for errors and branch to continuation
|
||||
switch_device_block = self.check_cuda_error(result, switch_device_block, context)
|
||||
with ir.InsertionPoint(switch_device_block):
|
||||
self.br(continuation_block)
|
||||
|
||||
return continuation_block
|
||||
|
||||
def generate_llvm_call(
|
||||
self,
|
||||
@@ -243,6 +293,36 @@ class TVMFFICuteCallProvider(DynamicParamPackCallProvider):
|
||||
context: CallContext,
|
||||
) -> ir.Block:
|
||||
"""Generate the LLVM call operation and check if the call is successful."""
|
||||
old_cuda_device_index: Optional[ir.Value] = None
|
||||
|
||||
# If we need to manage CUDA device context
|
||||
if self.cuda_device_index is not None:
|
||||
# Create an alloca in the entry block to store the current device index
|
||||
device_index_alloca = context.builder.create_alloca(
|
||||
context.entry_block, self.i32_type, array_size=1
|
||||
)
|
||||
|
||||
# Get the current device
|
||||
with ir.InsertionPoint(current_block):
|
||||
get_device_result = llvm.call(
|
||||
result=self.i32_type,
|
||||
callee="_cudaGetDevice",
|
||||
callee_operands=[device_index_alloca],
|
||||
op_bundle_sizes=[],
|
||||
op_bundle_operands=[],
|
||||
)
|
||||
current_block = self.check_cuda_error(get_device_result, current_block, context)
|
||||
|
||||
# Load the current device index from the alloca
|
||||
with ir.InsertionPoint(current_block):
|
||||
old_cuda_device_index = llvm.load(self.i32_type, device_index_alloca)
|
||||
|
||||
# Switch to target device if different
|
||||
current_block = self.set_cuda_device_if_mismatch(
|
||||
current_block, context, old_cuda_device_index, self.cuda_device_index
|
||||
)
|
||||
|
||||
# Execute the main call
|
||||
with ir.InsertionPoint(current_block):
|
||||
result = llvm.call(
|
||||
result=self.i32_type,
|
||||
@@ -251,42 +331,72 @@ class TVMFFICuteCallProvider(DynamicParamPackCallProvider):
|
||||
op_bundle_sizes=[],
|
||||
op_bundle_operands=[],
|
||||
)
|
||||
return self.check_cuda_error(result, current_block, context)
|
||||
|
||||
def insert_set_cuda_device(self, current_block: ir.Block, context: CallContext):
|
||||
"""Call the _cudaSetDevice function if we can find device id from tensor parameters."""
|
||||
# Restore the original device BEFORE checking for errors
|
||||
# This ensures device is restored even if the main call failed
|
||||
current_block = self.set_cuda_device_if_mismatch(
|
||||
current_block, context, self.cuda_device_index, old_cuda_device_index
|
||||
)
|
||||
|
||||
def find_cuda_device_index_from_params():
|
||||
for param in context.params:
|
||||
if (
|
||||
isinstance(param, spec.Tensor)
|
||||
and param.dlpack_device_type != tvm_ffi.DLDeviceType.kDLCPU
|
||||
):
|
||||
return context.matched_var_binding[param.device_id]
|
||||
return None
|
||||
# Now check for errors from the main call
|
||||
current_block = self.check_cuda_error(result, current_block, context)
|
||||
|
||||
cuda_device_index = find_cuda_device_index_from_params()
|
||||
return current_block
|
||||
|
||||
if cuda_device_index is None:
|
||||
return current_block
|
||||
|
||||
with ir.InsertionPoint(current_block):
|
||||
result = llvm.call(
|
||||
result=self.i32_type,
|
||||
callee="_cudaSetDevice",
|
||||
callee_operands=[cuda_device_index],
|
||||
def find_cuda_device_index_from_params(self, context: CallContext):
|
||||
"""Find the CUDA device index from tensor parameters."""
|
||||
for param in context.params:
|
||||
if (
|
||||
isinstance(param, spec.Tensor)
|
||||
and param.dlpack_device_type != tvm_ffi.DLDeviceType.kDLCPU
|
||||
):
|
||||
return context.matched_var_binding[param.device_id]
|
||||
return None
|
||||
|
||||
def create_shared_cuda_error_block(
|
||||
self,
|
||||
current_block: ir.Block,
|
||||
context: CallContext
|
||||
) -> ir.Block:
|
||||
"""Create a shared error handling block for all CUDA errors.
|
||||
"""
|
||||
# Create the shared error block after the current block (setup phase)
|
||||
# This block will be branched to from multiple error checking sites
|
||||
# It accepts the error code as a block argument
|
||||
error_block = current_block.create_after()
|
||||
error_code = error_block.add_argument(self.i32_type, ir.Location.unknown())
|
||||
|
||||
# Populate the error block
|
||||
with ir.InsertionPoint(error_block):
|
||||
error_str = llvm.call(
|
||||
result=self.ptr_type,
|
||||
callee="cuda_dialect_get_error_name",
|
||||
callee_operands=[error_code],
|
||||
op_bundle_sizes=[],
|
||||
op_bundle_operands=[],
|
||||
)
|
||||
# Raise error and return -1
|
||||
context.builder.raise_error_and_return(
|
||||
error_kind="RuntimeError",
|
||||
error_message_parts=["CUDA Error: ", error_str],
|
||||
)
|
||||
|
||||
return self.check_cuda_error(result, current_block, context)
|
||||
return error_block
|
||||
|
||||
def __call__(self, current_block: ir.Block, context: CallContext) -> ir.Block:
|
||||
current_block = self.declare_extern_funcs(current_block, context)
|
||||
current_block = self.insert_lazy_init_cuda(current_block, context)
|
||||
current_block = self.append_unload_to_global_dtors(current_block, context)
|
||||
current_block = self.insert_set_cuda_device(current_block, context)
|
||||
# Create shared CUDA error handling block after the setup blocks
|
||||
# This reduces code duplication - all CUDA errors branch to this single block
|
||||
self.cuda_error_handle_block = self.create_shared_cuda_error_block(current_block, context)
|
||||
# setup device index, will be set around the call to the target function
|
||||
self.cuda_device_index = self.find_cuda_device_index_from_params(context)
|
||||
current_block = super().__call__(current_block, context)
|
||||
self.cuda_device_index = None
|
||||
self.cuda_error_handle_block = None
|
||||
# reset the device index and error block
|
||||
return current_block
|
||||
|
||||
|
||||
@@ -316,12 +426,15 @@ class TVMFFIJitCompiledFunction(tvm_ffi.Function, CudaDialectJitCompiledFunction
|
||||
if self.__chandle__() != 0:
|
||||
raise DSLRuntimeError("TVM FFI function is already initialized")
|
||||
# get the MLIR function pointer from the execution engine
|
||||
tvm_ffi_function_ptr = self.engine.raw_lookup("__tvm_ffi_" + self.function_name)
|
||||
tvm_ffi_function = tvm_ffi.Function.__from_mlir_packed_safe_call__(
|
||||
tvm_ffi_function_ptr
|
||||
)
|
||||
# move the handle from the tvm_ffi.Function to the current instance
|
||||
self.__move_handle_from__(tvm_ffi_function)
|
||||
if self.engine is not None:
|
||||
tvm_ffi_function_ptr = self.engine.raw_lookup(
|
||||
"__tvm_ffi_" + self.function_name
|
||||
)
|
||||
tvm_ffi_function = tvm_ffi.Function.__from_mlir_packed_safe_call__(
|
||||
tvm_ffi_function_ptr
|
||||
)
|
||||
# move the handle from the tvm_ffi.Function to the current instance
|
||||
self.__move_handle_from__(tvm_ffi_function)
|
||||
|
||||
def to(self, device=None):
|
||||
"""TVM FFI function itself is already support all devices."""
|
||||
|
||||
@@ -165,6 +165,16 @@ def create_and_permute_torch_tensor(
|
||||
return dtype_torch_tensor
|
||||
|
||||
|
||||
def get_leading_dim(torch_tensor: torch.Tensor) -> int:
|
||||
"""
|
||||
Get the leading dimension of a torch tensor
|
||||
"""
|
||||
for i, stride in enumerate(torch_tensor.stride()):
|
||||
if stride == 1:
|
||||
return i
|
||||
return None
|
||||
|
||||
|
||||
def convert_cute_tensor(
|
||||
f32_torch_tensor: "torch.Tensor",
|
||||
cute_tensor: Tensor,
|
||||
@@ -189,8 +199,10 @@ def convert_cute_tensor(
|
||||
}:
|
||||
fp32_cute_tensor = from_dlpack(f32_torch_tensor)
|
||||
if is_dynamic_layout:
|
||||
# note: dim_order to not always maps to leading dimension,
|
||||
# so we need to get the leading dimension from the torch tensor strides
|
||||
fp32_cute_tensor = fp32_cute_tensor.mark_layout_dynamic(
|
||||
f32_torch_tensor.dim_order()[-1]
|
||||
leading_dim=get_leading_dim(f32_torch_tensor)
|
||||
)
|
||||
# Copy and convert from f32 cute tensor to dtype cute tensor
|
||||
cute.testing.convert(fp32_cute_tensor, cute_tensor)
|
||||
@@ -297,11 +309,9 @@ def cute_tensor_like(
|
||||
# create cute tensor using the device buffer
|
||||
cute_tensor = from_dlpack(torch_tensor, assumed_align=assumed_align)
|
||||
cute_tensor.element_type = cutlass_dtype
|
||||
|
||||
if is_dynamic_layout:
|
||||
for i, stride in enumerate(torch_tensor.stride()):
|
||||
if stride == 1:
|
||||
leading_dim = i
|
||||
break
|
||||
leading_dim = get_leading_dim(torch_tensor)
|
||||
cute_tensor = cute_tensor.mark_layout_dynamic(leading_dim=leading_dim)
|
||||
|
||||
# initialize the cute tensor data
|
||||
@@ -316,5 +326,4 @@ def cute_tensor_like(
|
||||
)
|
||||
else:
|
||||
torch_tensor.copy_(data_ref.to(dtype=torch_dtype))
|
||||
|
||||
return cute_tensor, torch_tensor
|
||||
|
||||
@@ -664,6 +664,7 @@ def make_smem_layout_a(
|
||||
a_dtype: Type[Numeric],
|
||||
num_stages: int,
|
||||
*,
|
||||
is_k_major=None,
|
||||
loc=None,
|
||||
ip=None,
|
||||
) -> Union[cute.Layout, cute.ComposedLayout]:
|
||||
@@ -687,7 +688,8 @@ def make_smem_layout_a(
|
||||
:rtype: Union[cute.Layout, cute.ComposedLayout]
|
||||
"""
|
||||
|
||||
is_k_major = tiled_mma.op.a_major_mode == OperandMajorMode.K
|
||||
is_k_major = (tiled_mma.op.a_major_mode == OperandMajorMode.K) if is_k_major is None else is_k_major
|
||||
a_major_mode = OperandMajorMode.K if is_k_major else OperandMajorMode.MN
|
||||
a_smem_shape = tiled_mma.partition_shape_A(
|
||||
cute.dice(mma_tiler_mnk, (1, None, 1), loc=loc, ip=ip), loc=loc, ip=ip
|
||||
)
|
||||
@@ -696,7 +698,7 @@ def make_smem_layout_a(
|
||||
cute.size(a_smem_shape[0][1], loc=loc, ip=ip) * a_smem_shape[2],
|
||||
)
|
||||
smem_layout_atom_kind = get_smem_layout_atom_ab(
|
||||
tiled_mma.op.a_major_mode, a_dtype, a_smem_shape_mn_k, loc=loc, ip=ip
|
||||
a_major_mode, a_dtype, a_smem_shape_mn_k, loc=loc, ip=ip
|
||||
)
|
||||
a_smem_layout_atom = make_smem_layout_atom(
|
||||
smem_layout_atom_kind, a_dtype, loc=loc, ip=ip
|
||||
@@ -716,6 +718,7 @@ def make_smem_layout_b(
|
||||
b_dtype: Type[Numeric],
|
||||
num_stages: int,
|
||||
*,
|
||||
is_k_major=None,
|
||||
loc=None,
|
||||
ip=None,
|
||||
) -> Union[cute.Layout, cute.ComposedLayout]:
|
||||
@@ -739,7 +742,8 @@ def make_smem_layout_b(
|
||||
:rtype: Union[cute.Layout, cute.ComposedLayout]
|
||||
"""
|
||||
|
||||
is_k_major = tiled_mma.op.b_major_mode == OperandMajorMode.K
|
||||
is_k_major = (tiled_mma.op.b_major_mode == OperandMajorMode.K) if is_k_major is None else is_k_major
|
||||
b_major_mode = OperandMajorMode.K if is_k_major else OperandMajorMode.MN
|
||||
b_smem_shape = tiled_mma.partition_shape_B(
|
||||
cute.dice(mma_tiler_mnk, (None, 1, 1), loc=loc, ip=ip), loc=loc, ip=ip
|
||||
)
|
||||
@@ -749,7 +753,7 @@ def make_smem_layout_b(
|
||||
)
|
||||
|
||||
smem_layout_atom_kind = get_smem_layout_atom_ab(
|
||||
tiled_mma.op.b_major_mode, b_dtype, b_smem_shape_nk, loc=loc, ip=ip
|
||||
b_major_mode, b_dtype, b_smem_shape_nk, loc=loc, ip=ip
|
||||
)
|
||||
b_smem_layout_atom = make_smem_layout_atom(
|
||||
smem_layout_atom_kind, b_dtype, loc=loc, ip=ip
|
||||
|
||||
@@ -8,12 +8,10 @@
|
||||
# Any use, reproduction, disclosure, or distribution of this software
|
||||
# and related documentation outside the scope permitted by the EULA
|
||||
# is strictly prohibited.
|
||||
|
||||
from cuda.bindings import driver, nvrtc, runtime
|
||||
from cutlass.cutlass_dsl.cuda_jit_executor import CudaDialectJitModule
|
||||
from cuda.bindings import driver, runtime
|
||||
from cutlass.base_dsl.common import DSLRuntimeError
|
||||
|
||||
import cutlass.cute as cute
|
||||
from cutlass import cute
|
||||
import tempfile
|
||||
|
||||
"""
|
||||
This class is used to get the hardware info of given GPU device.
|
||||
@@ -53,8 +51,8 @@ class HardwareInfo:
|
||||
f"Cluster size must be between 1 and 32, {cluster_size} is not supported"
|
||||
)
|
||||
|
||||
self._get_device_function(self.device)
|
||||
|
||||
# must do get kernel after set device so runtime context is set correctly
|
||||
self.kernel = self._get_device_function()
|
||||
max_shared_memory_per_block = self._checkCudaErrors(
|
||||
driver.cuDeviceGetAttribute(
|
||||
driver.CUdevice_attribute.CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN,
|
||||
@@ -152,8 +150,6 @@ class HardwareInfo:
|
||||
if isinstance(error, driver.CUresult):
|
||||
err, name = driver.cuGetErrorName(error)
|
||||
return name if err == driver.CUresult.CUDA_SUCCESS else "<unknown>"
|
||||
elif isinstance(error, nvrtc.nvrtcResult):
|
||||
return nvrtc.nvrtcGetErrorString(error)[1]
|
||||
else:
|
||||
raise RuntimeError("Unknown error type: {}".format(error))
|
||||
|
||||
@@ -175,14 +171,20 @@ class HardwareInfo:
|
||||
)
|
||||
|
||||
# get a empty kernel to compute occupancy
|
||||
def _get_device_function(self, device) -> driver.CUfunction:
|
||||
self.compiled_kernel = cute.compile(self._host_function).to(device)
|
||||
assert isinstance(self.compiled_kernel.jit_module, CudaDialectJitModule)
|
||||
err, kernels = runtime.cudaLibraryEnumerateKernels(
|
||||
1, self.compiled_kernel.jit_module.cuda_library[0]
|
||||
)
|
||||
if err is not runtime.cudaError_t.cudaSuccess:
|
||||
raise DSLRuntimeError(f"Failed to enumerate kernels: {err}")
|
||||
self.kernel = kernels[0]
|
||||
self.kernel = self._checkCudaErrors(driver.cuKernelGetFunction(self.kernel))
|
||||
return self.kernel
|
||||
def _get_device_function(self) -> driver.CUfunction:
|
||||
"""
|
||||
Get a device function by compiling a dummy kernel using cuteDSL pipeline.
|
||||
"""
|
||||
# Create a temporary directory for dumping artifacts
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
# keep-cubin will keep the cubin in the artifacts
|
||||
compiled_func = cute.compile(self._host_function, options=f"--dump-dir={temp_dir} --keep-cubin")
|
||||
# Get the CUBIN from artifacts
|
||||
cubin_data = compiled_func.artifacts.CUBIN
|
||||
cuda_library = self._checkCudaErrors(
|
||||
driver.cuLibraryLoadData(cubin_data, None, None, 0, None, None, 0)
|
||||
)
|
||||
# Enumerate kernels from the library
|
||||
kernels = self._checkCudaErrors(driver.cuLibraryEnumerateKernels(1, cuda_library))
|
||||
# Get the function from the kernel
|
||||
return self._checkCudaErrors(driver.cuKernelGetFunction(kernels[0]))
|
||||
|
||||
Reference in New Issue
Block a user