diff --git a/CHANGELOG.md b/CHANGELOG.md index 6029a66ba..4cfd61e07 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,17 @@ # CUTLASS 4.x +## [4.3.3](https://github.com/NVIDIA/cutlass/releases/tag/v4.3.3) (2025-12-12) + +### CuTe DSL +* New features + - Supported namedtuple and kwargs for JIT function arguments in tvm-ffi + - Supported variadic tuples for JIT function argument in tvm-ffi + +* Bug fixing and improvements + - Fixed an issue when JIT function argument with union type annotation for tvm-ffi + - Clearer error message for the case of runtime error cudaErrorInsufficientDriver + ## [4.3.2](https://github.com/NVIDIA/cutlass/releases/tag/v4.3.2) (2025-12-05) ### CuTe DSL diff --git a/README.md b/README.md index ec4d0bd15..c87b0e784 100644 --- a/README.md +++ b/README.md @@ -1,9 +1,9 @@ ![ALT](./media/images/gemm-hierarchy-with-epilogue-no-labels.png "Complete CUDA GEMM decomposition") # Overview -# CUTLASS 4.3.2 +# CUTLASS 4.3.3 -_CUTLASS 4.3.2 - Dec 2025_ +_CUTLASS 4.3.3 - Dec 2025_ CUTLASS is a collection of abstractions for implementing high-performance matrix-matrix multiplication (GEMM) and related computations at all levels and scales within CUDA. It incorporates strategies for @@ -54,6 +54,8 @@ To get started quickly - please refer : - Added Blackwell SM103 support. - Multiple dependent DSOs in the wheel have been merged into one single DSO. - New env var `CUTE_DSL_CACHE_DIR` to specify the path for dumping caches. + - Supported namedtuple and kwargs for JIT function arguments in tvm-ffi. + - Supported variadic tuples for JIT function argument in tvm-ffi. * Debuggability improvements: - Supported source location tracking for DSL APIs (Allow tools like ``nsight`` profiling to correlate perf metrics with Python source code) - Supported dumping PTX and CUBIN code: [Hello World Example](https://github.com/NVIDIA/cutlass/blob/main/examples/python/CuTeDSL/notebooks/hello_world.ipynb) @@ -102,6 +104,8 @@ To get started quickly - please refer : - Fixed tvm-ffi export compiled function - Fixed an issue of CUDA JitExecutor when unloading kernels - Fixed an issue of allocating max smem when there's statically allocated smem + - Fixed an issue when JIT function argument with union type annotation for tvm-ffi + - Clearer error message for the case of runtime error cudaErrorInsufficientDriver ## CUTLASS C++ * Further enhance Blackwell SM100 Attention kernels in [example 77](https://github.com/NVIDIA/cutlass/tree/main/examples/77_blackwell_fmha/). diff --git a/include/cute/util/print_tensor.hpp b/include/cute/util/print_tensor.hpp index c5eb39a1d..526fa63d3 100644 --- a/include/cute/util/print_tensor.hpp +++ b/include/cute/util/print_tensor.hpp @@ -50,7 +50,6 @@ print_layout(Layout const& layout) // (m,n) -> idx CUTE_STATIC_ASSERT_V(rank(layout) == Int<2>{}); int idx_width = num_digits(cosize(layout)) + 2; - const char* delim = "+-----------------------"; print(layout); print("\n"); @@ -63,7 +62,12 @@ print_layout(Layout const& layout) // (m,n) -> idx for (int m = 0; m < size<0>(layout); ++m) { // Header print(" "); - for (int n = 0; n < size<1>(layout); ++n) { printf("%.*s", idx_width+1, delim); } + for (int n = 0; n < size<1>(layout); ++n) { + printf("+"); + for (int i = 0; i < idx_width; ++i) { + printf("-"); + } + } printf("+\n"); // Values printf("%2d ", m); // Row indices @@ -72,7 +76,12 @@ print_layout(Layout const& layout) // (m,n) -> idx } // Footer print(" "); - for (int n = 0; n < size<1>(layout); ++n) { printf("%.*s", idx_width+1, delim); } + for (int n = 0; n < size<1>(layout); ++n) { + printf("+"); + for (int i = 0; i < idx_width; ++i) { + printf("-"); + } + } printf("+\n"); } diff --git a/include/cutlass/gemm/kernel/sm90_gemm_array_tma_warpspecialized_cooperative.hpp b/include/cutlass/gemm/kernel/sm90_gemm_array_tma_warpspecialized_cooperative.hpp index c69defbfa..e758059e1 100644 --- a/include/cutlass/gemm/kernel/sm90_gemm_array_tma_warpspecialized_cooperative.hpp +++ b/include/cutlass/gemm/kernel/sm90_gemm_array_tma_warpspecialized_cooperative.hpp @@ -994,6 +994,11 @@ public: // Get next work tile auto [next_work_tile_info, increment_pipe] = scheduler.fetch_next_work(work_tile_info, tile_scheduler_pipeline, tile_scheduler_pipe_consumer_state); + + if (!next_work_tile_info.is_valid()) { + cutlass::arch::launch_dependent_grids(); + } + work_tile_info = next_work_tile_info; if (increment_pipe) { ++tile_scheduler_pipe_consumer_state; diff --git a/include/cutlass/gemm/kernel/sm90_gemm_array_tma_warpspecialized_pingpong.hpp b/include/cutlass/gemm/kernel/sm90_gemm_array_tma_warpspecialized_pingpong.hpp index d380b2cd0..7120db78d 100644 --- a/include/cutlass/gemm/kernel/sm90_gemm_array_tma_warpspecialized_pingpong.hpp +++ b/include/cutlass/gemm/kernel/sm90_gemm_array_tma_warpspecialized_pingpong.hpp @@ -1038,6 +1038,11 @@ public: // Get next work tile auto [next_work_tile_info, increment_pipe] = scheduler.fetch_next_work(work_tile_info, tile_scheduler_pipeline, tile_scheduler_pipe_consumer_state); + + if (!next_work_tile_info.is_valid()) { + cutlass::arch::launch_dependent_grids(); + } + work_tile_info = next_work_tile_info; if (increment_pipe) { ++tile_scheduler_pipe_consumer_state; diff --git a/include/cutlass/version.h b/include/cutlass/version.h index ce1c9b24b..476e3a25c 100644 --- a/include/cutlass/version.h +++ b/include/cutlass/version.h @@ -36,7 +36,7 @@ #define CUTLASS_MAJOR 4 #define CUTLASS_MINOR 3 -#define CUTLASS_PATCH 2 +#define CUTLASS_PATCH 3 #ifdef CUTLASS_VERSIONS_GENERATED #include "cutlass/version_extended.h" diff --git a/media/docs/pythonDSL/cute_dsl_general/compile_with_tvm_ffi.rst b/media/docs/pythonDSL/cute_dsl_general/compile_with_tvm_ffi.rst index 296ff8a8e..8459cbae5 100644 --- a/media/docs/pythonDSL/cute_dsl_general/compile_with_tvm_ffi.rst +++ b/media/docs/pythonDSL/cute_dsl_general/compile_with_tvm_ffi.rst @@ -288,6 +288,131 @@ composed of the types that are supported by TVM FFI. The example below shows how example_add_one_with_tuple() +Working with Variadic Tuples +---------------------------- + +Sometimes it is helpful to annotate a tuple with no explicit element types. +This can be useful to build up a generic template for a function that accepts +a variable number of elements. The compiled function's signature will be +determined by the tuple argument passed to the ``cute.compile`` function. +The following example shows how to use a variadic tuple to build such a +generic template. + +.. code-block:: python + + import cutlass + import torch + from cutlass import cute + + @cute.kernel + def device_add_one(a: cute.Tensor, b: cute.Tensor, extra_value: tuple): + threads_per_block = 128 + cta_x_, _, _ = cute.arch.block_idx() + tid_x, _, _ = cute.arch.thread_idx() + tid = cta_x_ * threads_per_block + tid_x + if tid < a.shape[0]: + if cutlass.const_expr(len(extra_value) != 0): + b[tid] = a[tid] + 1 + extra_value[0] + else: + b[tid] = a[tid] + 1 + + @cute.jit + def add_one_with_extra_value(a: cute.Tensor, b: cute.Tensor, extra_value: tuple): + n = a.shape[0] + threads_per_block = 128 + blocks = (n + threads_per_block - 1) // threads_per_block + device_add_one(a, b, extra_value).launch(grid=(blocks, 1, 1), block=(threads_per_block, 1, 1)) + + def example_add_one_with_variadic_tuple(): + n = cute.sym_int() + a_cute = cute.runtime.make_fake_compact_tensor(cute.Float32, (n,)) + b_cute = cute.runtime.make_fake_compact_tensor(cute.Float32, (n,)) + compiled_add_one_no_extra = cute.compile( + add_one_with_extra_value, a_cute, b_cute, (), + options="--enable-tvm-ffi" + ) + compiled_add_one_with_extra = cute.compile( + add_one_with_extra_value, a_cute, b_cute, (cute.Float32(4),), + options="--enable-tvm-ffi" + ) + a_torch = torch.arange(10, dtype=torch.float32, device="cuda") + b_torch = torch.empty(10, dtype=torch.float32, device="cuda") + compiled_add_one_no_extra(a_torch, b_torch, ()) + print("result of b_torch after compiled_add_one_no_extra(a_torch, b_torch, ())") + print(b_torch) + compiled_add_one_with_extra(a_torch, b_torch, (4,)) + print("result of b_torch after compiled_add_one_with_extra(a_torch, b_torch, (4,))") + print(b_torch) + + example_add_one_with_variadic_tuple() + + +Working with Named Tuples +~~~~~~~~~~~~~~~~~~~~~~~~~ + +Named tuples are also supported and help logically group related arguments together. +The example below shows how to use named tuples as arguments. Under the hood, named tuples +are passed as unnamed tuples at the ABI level. When errors occur, the function signature in +error messages will display unnamed tuple arguments. +Ensure that the compile-time CuTe named tuple type definition has the same fields +as the runtime PyTorch named tuple. +Currently, users need to explicitly unpack the named tuple outside of conditionals and then +use the unpacked variables inside the conditionals. + +.. code-block:: python + + from typing import NamedTuple + from cutlass import cute + import torch + + class CuteNamedTuple(NamedTuple): + a: cute.Tensor + b: cute.Tensor + c: cute.Float32 = cute.Float32(1) + + def __new_from_mlir_values__(self, values): + return CuteNamedTuple(*values) + + class TorchNamedTuple(NamedTuple): + a: torch.Tensor + b: torch.Tensor + c: float = 1 + + @cute.kernel + def device_add_one_named_tuple(value: CuteNamedTuple): + tid = cute.arch.block_idx()[0] * 128 + cute.arch.thread_idx()[0] + # need to unpack namedtuple outside conditionals + a = value.a + b = value.b + c = value.c + if tid < a.shape[0]: + b[tid] = a[tid] + c + + @cute.jit + def add_one_with_named_tuple(value: CuteNamedTuple): + n = value.a.shape[0] + threads_per_block = 128 + blocks = (n + threads_per_block - 1) // threads_per_block + device_add_one_named_tuple(value).launch(grid=(blocks, 1, 1), block=(threads_per_block, 1, 1)) + + def example_add_one_with_named_tuple(): + n = cute.sym_int() + a_cute = cute.runtime.make_fake_compact_tensor(cute.Float32, (n,)) + b_cute = cute.runtime.make_fake_compact_tensor(cute.Float32, (n,)) + + compiled_add_one = cute.compile( + add_one_with_named_tuple, CuteNamedTuple(a=a_cute, b=b_cute), + options="--enable-tvm-ffi" + ) + a_torch = torch.arange(10, dtype=torch.float32, device="cuda") + b_torch = torch.empty(10, dtype=torch.float32, device="cuda") + compiled_add_one(TorchNamedTuple(a=a_torch, b=b_torch)) + print("result of b_torch") + print(b_torch) + + example_add_one_with_named_tuple() + + Supported types --------------- @@ -464,3 +589,97 @@ When you build your own libraries, make sure you link against the necessary runt You can use ``cute.runtime.find_runtime_libraries(enable_tvm_ffi=True)`` to get the path to these libraries. ``cute.runtime.load_module`` will load these libraries automatically before loading an exported module. You can also manually load these libraries in advanced use cases. + + +Keyword Arguments and Defaults +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +The function returned by ``cute.compile`` supports keyword arguments and defaults. +The example below shows how to use keyword arguments and defaults: + +.. code-block:: python + + import torch + from cutlass import cute + + @cute.kernel + def device_add_scalar(a: cute.Tensor, b: cute.Tensor, offset: cutlass.Float32): + threads_per_block = 128 + cta_x_, _, _ = cute.arch.block_idx() + tid_x, _, _ = cute.arch.thread_idx() + tid = cta_x_ * threads_per_block + tid_x + if tid < a.shape[0]: + b[tid] = a[tid] + offset + + @cute.jit + def add_constant(a: cute.Tensor, b: cute.Tensor, offset: cutlass.Float32=cutlass.Float32(1)): + n = a.shape[0] + threads_per_block = 128 + blocks = (n + threads_per_block - 1) // threads_per_block + device_add_scalar(a, b, offset).launch(grid=(blocks, 1, 1), block=(threads_per_block, 1, 1)) + + def example_kwargs_and_defaults(): + n = cute.sym_int() + a_cute = cute.runtime.make_fake_compact_tensor(cute.Float32, (n,)) + b_cute = cute.runtime.make_fake_compact_tensor(cute.Float32, (n,)) + compiled_add_constant = cute.compile(add_constant, a_cute, b_cute, options="--enable-tvm-ffi") + a_torch = torch.arange(10, dtype=torch.float32, device="cuda") + b_torch = torch.empty(10, dtype=torch.float32, device="cuda") + compiled_add_constant(a_torch, b_torch) + print("result of b_torch after compiled_add_constant(a_torch, b_torch)") + print(b_torch) + compiled_add_constant(a_torch, b_torch, offset=4) + print("result of b_torch after compiled_add_constant(a_torch, b_torch, offset=4)") + print(b_torch) + +For efficiency and portability reasons, TVM FFI ABI supports functions with positional-only arguments. +If you export the compiled module to an object file and then load it back, the function +will only accept positional arguments in the order of the arguments in the function signature. +You can rewrap the function or use the TVM FFI wrapper generator to generate a kwargs wrapper. +The code block below shows how to do this: + +.. code-block:: python + + def example_kwargs_and_defaults(): + n = cute.sym_int() + a_cute = cute.runtime.make_fake_compact_tensor(cute.Float32, (n,)) + b_cute = cute.runtime.make_fake_compact_tensor(cute.Float32, (n,)) + compiled_add_constant = cute.compile(add_constant, a_cute, b_cute, options="--enable-tvm-ffi") + # export the compiled module to object file + compiled_add_constant.export_to_c("./add_constant.o", function_name="add_constant") + # obtain necessary runtime libs for loading the shared library + runtime_libs = cute.runtime.find_runtime_libraries(enable_tvm_ffi=True) + # compile the object file to a shared library + cmd = ["gcc", "-shared", "-o", "./add_constant.so", "./add_constant.o", *runtime_libs] + subprocess.run(cmd, check=True) + + a_torch = torch.arange(10, dtype=torch.float32, device="cuda") + b_torch = torch.empty(10, dtype=torch.float32, device="cuda") + + mod = cute.runtime.load_module("./add_constant.so") + try: + mod.add_constant(a_torch, b_torch) + except Exception as e: + # Raises a missing arguments error because kwargs and default information are lost + print(e) + # We rewrap the function to regain argument and kwargs support. + # Alternatively, use the TVM FFI wrapper generator to generate a kwargs wrapper function. + from tvm_ffi.utils import kwargs_wrapper + # arg_defaults are aligned to the end of the argument list + wrapped_func = kwargs_wrapper.make_kwargs_wrapper( + mod.add_constant, arg_names=["a", "b", "offset"], arg_defaults=(1,) + ) + wrapped_func(a_torch, b_torch) + print("result of b_torch after wrapped_func(a_torch, b_torch)") + print(b_torch) + # You can also use the signature of the original function + # to generate a kwargs wrapper function. Make sure to exclude + # arguments that are not included in the runtime, + # such as 'self', constexpr, and env stream arguments. + wrapped_func = kwargs_wrapper.make_kwargs_wrapper_from_signature( + mod.add_constant, signature=inspect.signature(add_constant), + exclude_arg_names=["self"] + ) + wrapped_func(a_torch, b_torch, offset=4) + print("result of b_torch after wrapped_func(a_torch, b_torch, offset=4)") + print(b_torch) diff --git a/media/docs/pythonDSL/cute_dsl_general/dsl_jit_caching.rst b/media/docs/pythonDSL/cute_dsl_general/dsl_jit_caching.rst index 214ba2804..fa5ad75cd 100644 --- a/media/docs/pythonDSL/cute_dsl_general/dsl_jit_caching.rst +++ b/media/docs/pythonDSL/cute_dsl_general/dsl_jit_caching.rst @@ -128,7 +128,7 @@ Here is an example demonstrating automatic caching of the ``add`` kernel: The cache can be serialized to files for subsequent runs. After serialization, compiled MLIR bytecode is stored in file. The cache directory is ``/tmp/{current_user}/cutlass_python_cache``. -The cache loads from files into memory during |DSL| initialization and saves back to files when the process exits. +During compilation, the cache loads the corresponding kernel from file (if it exists) into memory as needed, and after compilation, it saves any newly compiled executables back to file. Note that for efficiency, the default cache directory is located in a temporary folder. However, this location is not persistent, it may be cleared by the system (for example, during a reboot or disk space cleanup). If you wish to preserve the cache across sessions, set the ``CUTE_DSL_CACHE_DIR`` environment variable to point to a persistent directory. @@ -140,9 +140,6 @@ The following environment variables control file caching: # Disable file caching while keeping in-memory cache available, defaults to False. export CUTE_DSL_DISABLE_FILE_CACHING=True - # Maximum number of cache files allowed, defaults to 1000. - export CUTE_DSL_FILE_CACHING_CAPACITY=1000 - # Cache directory, defaults to /tmp/{current_user}/cutlass_python_cache. export CUTE_DSL_CACHE_DIR=/home/user/local_cutlass_python_cache/dense_gemm_cache/ diff --git a/media/docs/pythonDSL/cute_dsl_general/framework_integration.rst b/media/docs/pythonDSL/cute_dsl_general/framework_integration.rst index 2fdda5fa1..c45fdcb4b 100644 --- a/media/docs/pythonDSL/cute_dsl_general/framework_integration.rst +++ b/media/docs/pythonDSL/cute_dsl_general/framework_integration.rst @@ -192,6 +192,11 @@ For example: - For a tensor with layout ``(2,2):(8,2)``, since no dimension has stride 1, all dimensions are marked as dynamic: ``(?,?):(?,?)``. +The leading dimension accepts negative index which means the dimension is counted from the last dimension. For example, + +- For a tensor with layout ``(2,2,3,4):(2,1,4,12)``, if ``leading_dim`` is specified to be -1, + the layout will be marked as ``(?,?,?,?):(?,?,?,1)``. + Code Example ~~~~~~~~~~~~ diff --git a/python/CuTeDSL/cutlass/base_dsl/cache_helpers.py b/python/CuTeDSL/cutlass/base_dsl/cache_helpers.py index 898fffb93..5ef1cbeb6 100644 --- a/python/CuTeDSL/cutlass/base_dsl/cache_helpers.py +++ b/python/CuTeDSL/cutlass/base_dsl/cache_helpers.py @@ -14,6 +14,8 @@ This module provides jit cache load/dump helper functions """ import os +import io +import sys import uuid import random import tempfile @@ -22,7 +24,7 @@ import time from pathlib import Path import hashlib from functools import lru_cache -import tempfile +import zlib from .utils.logger import log from .jit_executor import JitCompiledFunction @@ -74,13 +76,58 @@ def get_default_file_dump_root(): return dump_root -def load_ir(file, asBytecode=False): +def write_bytecode_with_crc32(f, module): + """Write the bytecode to the file and calculate the crc32 checksum. + + :param f: The file to write the bytecode to. + :type f: file + :param module: The IR module to write the bytecode to. + :type module: object + """ + s = io.BytesIO() + module.operation.write_bytecode(s) + content = s.getvalue() + crc = zlib.crc32(content) + s.write(crc.to_bytes(4, sys.byteorder)) + f.write(s.getvalue()) + return + + +def read_bytecode_and_check_crc32(f): + """ + Read the bytecode from the file and check the crc32 checksum. + + :param f: The file to read the bytecode with appended CRC32 from. + :type f: file + :return: The bytecode content if checksum matches. + :rtype: bytes + :raises DSLRuntimeError: If checksum does not match. + """ + content = f.read() + if len(content) < 4: + raise DSLRuntimeError( + f"File {f.name} does not contain enough data for CRC32 checksum." + ) + bytecode = content[:-4] + crc_appended = content[-4:] + crc_appended_int = int.from_bytes(crc_appended, sys.byteorder) + crc_computed = zlib.crc32(bytecode) + if crc_appended_int != crc_computed: + raise DSLRuntimeError( + f"CRC32 checksum mismatch! Expected {crc_computed}, got {crc_appended_int}" + ) + return ir.Module.parse(bytecode) + + +def load_ir(file, asBytecode=False, bytecode_reader=None): """Load generated IR from a file. :param file: The path to the file to load. :type file: str :param asBytecode: Whether to load the IR as bytecode, defaults to False :type asBytecode: bool, optional + :param bytecode_reader: The bytecode reader to use, defaults to None + :type bytecode_reader: callable, optional :return: The function name and the IR module :rtype: tuple[str, ir.Module] """ @@ -88,8 +135,10 @@ def load_ir(file, asBytecode=False): func_name = file.split(".mlir")[0].split("dsl_")[-1] with ir.Context() as ctx: with open(file, "rb" if asBytecode else "r") as f: - module = ir.Module.parse(f.read()) - + if bytecode_reader: + module = bytecode_reader(f) + else: + module = ir.Module.parse(f.read()) return func_name, module @@ -128,6 +177,14 @@ def save_ir( :type module: object :param fname: The name of the file to save. :type fname: str + :param output_dir: The path to the output directory, defaults to None + :type output_dir: str, optional + :param as_bytecode: Whether to save the IR as bytecode, defaults to False + :type as_bytecode: bool, optional + :param bytecode_writer: The bytecode writer to use, defaults to None + :type bytecode_writer: callable, optional + :return: The path to the saved file + :rtype: str """ initial_name = f"{dsl_name.lower()}_{fname}.mlir" save_path = Path(output_dir if output_dir else tempfile.gettempdir()) @@ -158,63 +215,45 @@ def save_ir( return save_fname -def check_func_name(jit_cache, func_name): - """Check if the function name is in the cache. - If not, create a new JitCompiledFunction object and add it to the cache. - - :param jit_cache: The cache to check. - :type jit_cache: dict - :param func_name: The name of the function to check. - :type func_name: str - :return: The cache - :rtype: dict - """ - if not func_name in jit_cache: - jit_cache[func_name] = JitCompiledFunction( - None, None, None, None, None, [], False, None - ) - return jit_cache - - -def load_cache_from_path(dsl_name, cache_limit, path=default_generated_ir_path): +def load_cache_from_path( + dsl_name, file, path=default_generated_ir_path, bytecode_reader=None +): """Load cache from a directory path. :param dsl_name: The name of the DSL. :type dsl_name: str - :param cache_limit: The limit of the cache. - :type cache_limit: int + :param file: The name of the file to load. + :type file: str :param path: The path to the cache directory, defaults to default_generated_ir_path :type path: str, optional + :param bytecode_reader: The bytecode reader to use, defaults to None + :type bytecode_reader: callable, optional :return: The cache :rtype: dict """ if not os.path.exists(path): - return dict() - files = os.listdir(path) - jit_cache = dict() + return None + ret = None try: - for idx, file in enumerate(files): - if idx >= int(cache_limit): - break - # identify dsl prefix - if not file.startswith(f"{dsl_name.lower()}"): - continue - if ".mlir" in file: - func_name, ir_module = load_ir( - os.path.join(path, file), asBytecode=True - ) - jit_cache = check_func_name(jit_cache, func_name) - jit_cache[func_name].ir_module = ir_module + file = f"{dsl_name.lower()}_{file}.mlir" + if os.path.exists(os.path.join(path, file)): + _, module = load_ir( + os.path.join(path, file), + asBytecode=True, + bytecode_reader=bytecode_reader, + ) + ret = JitCompiledFunction(module, None, None, None, None, [], False, None) except Exception as e: - print(f"{dsl_name} failed with loading generated IR cache.", e) - jit_cache = dict() - return jit_cache + log().warning( + f"{dsl_name} failed with loading generated IR cache for {file}.", e + ) + return ret def dump_cache_to_path( dsl_name, - jit_cache, - cache_limit, + jit_function, + file, path=default_generated_ir_path, bytecode_writer=None, ): @@ -222,30 +261,29 @@ def dump_cache_to_path( :param dsl_name: The name of the DSL. :type dsl_name: str - :param jit_cache: The cache to dump. - :type jit_cache: dict - :param cache_limit: The limit of the cache. - :type cache_limit: int + :param jit_function: The JitCompiledFunction to dump. + :type jit_function: JitCompiledFunction + :param file: The name of the file to dump. + :type file: str :param path: The path to the cache directory, defaults to default_generated_ir_path :type path: str, optional :param bytecode_writer: The bytecode writer to use, defaults to None :type bytecode_writer: callable, optional """ - log().info("JIT cache : dumping [%s] items=[%s]", dsl_name, len(jit_cache)) + log().info("JIT cache : dumping [%s] file=[%s]", dsl_name, file) if not path: path = default_generated_ir_path os.makedirs(path, exist_ok=True) try: - for idx, [key, value] in enumerate(jit_cache.items()): - if idx >= int(cache_limit): - break - save_ir( - dsl_name, - value.ir_module, - key, - output_dir=path, - as_bytecode=True, - bytecode_writer=bytecode_writer, - ) + save_ir( + dsl_name, + jit_function.ir_module, + file, + output_dir=path, + as_bytecode=True, + bytecode_writer=bytecode_writer, + ) except Exception as e: - print(f"{dsl_name} failed with caching generated IR", e) + log().warning( + f"{dsl_name} failed with dumping generated IR cache for {file}: {e}" + ) diff --git a/python/CuTeDSL/cutlass/base_dsl/common.py b/python/CuTeDSL/cutlass/base_dsl/common.py index 6bc5fca11..3ace5e1cb 100644 --- a/python/CuTeDSL/cutlass/base_dsl/common.py +++ b/python/CuTeDSL/cutlass/base_dsl/common.py @@ -195,6 +195,10 @@ def _get_friendly_cuda_error_message(error_code, error_name): f"2. SM ARCH setting", f"3. Steps to reproduce", ), + "cudaErrorInsufficientDriver": ( + f"1. Run nvidia-smi to confirm CUDA driver version", + f"2. Ensure the CUDA driver version meets the requirement of the installed cuda-python package", + ), } message = f"{error_name} (error code: {error_code}) \n" \ diff --git a/python/CuTeDSL/cutlass/base_dsl/dsl.py b/python/CuTeDSL/cutlass/base_dsl/dsl.py index cb3496370..ce1423914 100644 --- a/python/CuTeDSL/cutlass/base_dsl/dsl.py +++ b/python/CuTeDSL/cutlass/base_dsl/dsl.py @@ -318,11 +318,8 @@ class BaseDSL: self.envar = self._env_class(self.name) self.enable_preprocessor = preprocess # This cache uses hash of original ir and env as key, allows dump/load to/from file. Enabled by default - self.jit_cache = ( - dict() - if self.envar.disable_file_caching - else load_cache_from_path(self.name, self.envar.file_caching_capacity) - ) + self.jit_cache = dict() + self.host_jit_decorator_name = f"@{BaseDSL.jit.__name__}" self.device_jit_decorator_name = f"@{BaseDSL.kernel.__name__}" @@ -372,12 +369,6 @@ class BaseDSL: atexit.register(restore_excepthook, origin_excepthook) - def dump_cache(self, path=None): - if not self.envar.disable_file_caching: - dump_cache_to_path( - self.name, self.jit_cache, self.envar.file_caching_capacity, path=path - ) - @lru_cache(maxsize=1) def print_warning_once(self, message): log().warning(f"Warning: {message}") @@ -392,9 +383,6 @@ class BaseDSL: def _get_dsl(cls): # Instantiate the DSL Class once main_dsl = cls() - if not main_dsl.no_cache: - # register atexit callback - atexit.register(main_dsl.dump_cache) return main_dsl @staticmethod @@ -1235,6 +1223,16 @@ class BaseDSL: log().debug(f"Using pipeline = {pipeline}") shared_libs = self.get_shared_libs() profiler = timer(enable=self.envar.jit_time_profiling) + # try load the file cache + load_from_file_cache = False + if not no_cache: + fn = load_cache_from_path( + self.name, module_hash, bytecode_reader=read_bytecode_and_check_crc32 + ) + if fn is not None: + load_from_file_cache = True + self.jit_cache[module_hash] = fn + if ( no_cache or module_hash not in self.jit_cache @@ -1288,6 +1286,16 @@ class BaseDSL: if not no_cache: # module stored in cache is compiled. self.jit_cache[module_hash] = fn + # write through the file cache if enabled. + if not self.envar.disable_file_caching and not load_from_file_cache: + dump_cache_to_path( + self.name, + fn, + module_hash, + bytecode_writer=lambda f: write_bytecode_with_crc32( + f, fn.ir_module + ), + ) return fn diff --git a/python/CuTeDSL/cutlass/base_dsl/env_manager.py b/python/CuTeDSL/cutlass/base_dsl/env_manager.py index bd170deb4..26419c814 100644 --- a/python/CuTeDSL/cutlass/base_dsl/env_manager.py +++ b/python/CuTeDSL/cutlass/base_dsl/env_manager.py @@ -311,7 +311,6 @@ class EnvironmentVarManager(LogEnvironmentManager): - [DSL_NAME]_ENABLE_OPTIMIZATION_WARNINGS: Enable warnings of optimization warnings (default: False) - [DSL_NAME]_JIT_TIME_PROFILING: Whether or not to profile the IR generation/compilation/execution time (default: False) - [DSL_NAME]_DISABLE_FILE_CACHING: Disable file caching (default: False) - - [DSL_NAME]_FILE_CACHING_CAPACITY: Limits the number of the cache save/load files (default: 1000) - [DSL_NAME]_LIBS: Path to dependent shared libraries (default: None) - [DSL_NAME]_ENABLE_TVM_FFI: Enable TVM-FFI or not (default: False) """ @@ -350,9 +349,6 @@ class EnvironmentVarManager(LogEnvironmentManager): self.disable_file_caching = get_bool_env_var( f"{prefix}_DISABLE_FILE_CACHING", False ) - self.file_caching_capacity = get_int_env_var( - f"{prefix}_FILE_CACHING_CAPACITY", 1000 - ) # set cuda self.cuda_toolkit = get_cuda_toolkit_path() diff --git a/python/CuTeDSL/cutlass/base_dsl/export/export.py b/python/CuTeDSL/cutlass/base_dsl/export/export.py index 836e658d0..3ccb452ee 100644 --- a/python/CuTeDSL/cutlass/base_dsl/export/export.py +++ b/python/CuTeDSL/cutlass/base_dsl/export/export.py @@ -27,17 +27,21 @@ from typing import Union cubin_suffix = "cubin" -def get_export_module(ir_module: ir.Module, symbol_prefix: str): +def get_export_module(ir_module: ir.Module, symbol_prefix: str, *, preserve_symbols = None): """Get the export module which is cloned from the original compiled ir module, and add the prefix to avoid the symbol conflict. @param ir_module: The original compiled ir module. Comes from the JitCompiledFunction.ir_module. @param symbol_prefix: The prefix name of the function. This is the unique identifier name of the function to avoid symbol conflict in the generated object file. + @param preserve_symbols: Optional symbols to preserve in the export module. @return: The export module of the function. """ # Add prefix for symbol names to avoid conflict with other functions defined_symbols = set() + if preserve_symbols is None: + preserve_symbols = set() + def walk_llvm_func_op(op): # not a declaration if ( @@ -45,23 +49,61 @@ def get_export_module(ir_module: ir.Module, symbol_prefix: str): and len(op.opview.operation.regions) > 0 and len(op.opview.operation.regions[0].blocks) > 0 ): - defined_symbols.add(op.attributes["sym_name"].value) + func_name = op.attributes["sym_name"].value + # skip preserving symbols + if func_name in preserve_symbols: + return ir.WalkResult.ADVANCE + defined_symbols.add(func_name) op.attributes["sym_name"] = ir.StringAttr.get( - symbol_prefix + "_" + op.attributes["sym_name"].value + symbol_prefix + "_" + func_name ) return ir.WalkResult.ADVANCE - def walk_llvm_call_op(op): + def walk_llvm_references(op): + # Rename function calls if op.name == "llvm.call" and op.attributes["callee"].value in defined_symbols: op.attributes["callee"] = ir.FlatSymbolRefAttr.get( symbol_prefix + "_" + op.attributes["callee"].value ) + # Rename addressof references + elif op.name == "llvm.mlir.addressof" and op.attributes["global_name"].value in defined_symbols: + op.attributes["global_name"] = ir.FlatSymbolRefAttr.get( + symbol_prefix + "_" + op.attributes["global_name"].value + ) + # Rename global_ctors references + elif op.name == "llvm.mlir.global_ctors" and "ctors" in op.attributes: + ctors = list(op.attributes["ctors"]) + renamed_ctors = [] + for ctor in ctors: + if ctor.value in defined_symbols: + renamed_ctors.append(ir.FlatSymbolRefAttr.get( + symbol_prefix + "_" + ctor.value + )) + else: + renamed_ctors.append(ctor) + if renamed_ctors: + op.attributes["ctors"] = ir.ArrayAttr.get(renamed_ctors) + # Rename global_dtors references + elif op.name == "llvm.mlir.global_dtors" and "dtors" in op.attributes: + dtors = list(op.attributes["dtors"]) + renamed_dtors = [] + for dtor in dtors: + if dtor.value in defined_symbols: + renamed_dtors.append(ir.FlatSymbolRefAttr.get( + symbol_prefix + "_" + dtor.value + )) + else: + renamed_dtors.append(dtor) + if renamed_dtors: + op.attributes["dtors"] = ir.ArrayAttr.get(renamed_dtors) return ir.WalkResult.ADVANCE with ir.Context(): export_module = ir.Module.parse(str(ir_module)) + # First pass: collect and rename function definitions export_module.operation.walk(walk_llvm_func_op) - export_module.operation.walk(walk_llvm_call_op) + # Second pass: rename call and addressof references + export_module.operation.walk(walk_llvm_references) return export_module diff --git a/python/CuTeDSL/cutlass/base_dsl/jit_executor.py b/python/CuTeDSL/cutlass/base_dsl/jit_executor.py index f36ad862a..fbe95e990 100644 --- a/python/CuTeDSL/cutlass/base_dsl/jit_executor.py +++ b/python/CuTeDSL/cutlass/base_dsl/jit_executor.py @@ -16,7 +16,7 @@ This module provides jit executor related classes import ctypes import inspect import io -from typing import Union, Optional +from typing import Union, Optional, NamedTuple, Any, Sequence import weakref import threading import collections @@ -132,6 +132,15 @@ def load_kernels_from_ir_module(module, kernel_info) -> list[CudaModuleAndKernel return list(kernel_modules.values()) + +class KwargsWrapperSpec(NamedTuple): + """A specification for keyword arguments wrapper.""" + arg_names: list[str] + arg_defaults: tuple[Any, ...] + kwonly_names: list[str] + kwonly_defaults: dict[str, Any] + + class ExecutionArgs: """Helper that wraps the function signature spec to filter exeuction and compile time arguments.""" @@ -216,6 +225,59 @@ class ExecutionArgs: return exe_args, adapted_args + def get_kwargs_wrapper_spec(self, exclude_arg_names: Sequence[str] = ()) -> KwargsWrapperSpec: + """ + This function is used to get the kwargs wrapper spec from the original args_spec. + """ + excluded_arg_names = set(exclude_arg_names) + arg_spec = self.original_args_spec + + if arg_spec.defaults: + defaults_start_idx = len(arg_spec.args) - len(arg_spec.defaults) + else: + defaults_start_idx = len(arg_spec.args) + + arg_names = [] + arg_defaults = [] + kwonly_names = [] + kwonly_defaults = {} + + # Filter arguments and maintain their properties + for i, arg_name in enumerate(arg_spec.args): + arg_type = arg_spec.annotations.get(arg_name, None) + + # Skip compile-time arguments + if is_arg_spec_constexpr(arg_type, arg_name, i, self.function_name): + continue + if arg_name in excluded_arg_names: + continue + arg_names.append(arg_name) + + if i >= defaults_start_idx: + arg_defaults.append(arg_spec.defaults[i - defaults_start_idx]) + + if arg_spec.kwonlyargs: + for i, kwarg in enumerate(arg_spec.kwonlyargs): + arg_type = arg_spec.annotations.get(kwarg, None) + + # Skip compile-time arguments + if is_arg_spec_constexpr(arg_type, kwarg, i, self.function_name): + continue + + if kwarg in excluded_arg_names: + continue + + kwonly_names.append(kwarg) + if arg_spec.kwonlydefaults and kwarg in arg_spec.kwonlydefaults: + kwonly_defaults[kwarg] = arg_spec.kwonlydefaults[kwarg] + + return KwargsWrapperSpec( + arg_names=arg_names, + arg_defaults=tuple(arg_defaults), + kwonly_names=kwonly_names, + kwonly_defaults=kwonly_defaults, + ) + def get_rectified_args_from_original_args(self, full_args, full_kwargs): """ This function is used to rectify the original arguments to the runtime @@ -233,6 +295,7 @@ class ExecutionArgs: defaults_start_idx = len(arg_spec.args) runtime_args = [] + # Filter arguments and maintain their properties for i, arg_name in enumerate(arg_spec.args): arg_type = arg_spec.annotations.get(arg_name, None) @@ -241,12 +304,24 @@ class ExecutionArgs: if is_arg_spec_constexpr(arg_type, arg_name, i, self.function_name): continue - # Keep corresponding default if it exists - if i >= defaults_start_idx: + # Check if argument was provided by user, otherwise use default + if i < len(full_args): + # User provided this argument - use it + runtime_args.append(full_args[i]) + elif i >= defaults_start_idx: + # Argument not provided, but has default - use default default_idx = i - defaults_start_idx runtime_args.append(arg_spec.defaults[default_idx]) else: - runtime_args.append(full_args[i]) + # Required argument missing + raise DSLRuntimeError( + f"Missing required argument '{arg_name}' at position {i}", + context={ + "function_name": self.function_name, + "expected_args": len(arg_spec.args), + "provided_args": len(full_args), + } + ) # Filter keyword-only arguments runtime_kwargs = {} diff --git a/python/CuTeDSL/cutlass/base_dsl/runtime/cuda.py b/python/CuTeDSL/cutlass/base_dsl/runtime/cuda.py index 64bfa2f98..d430a877f 100644 --- a/python/CuTeDSL/cutlass/base_dsl/runtime/cuda.py +++ b/python/CuTeDSL/cutlass/base_dsl/runtime/cuda.py @@ -21,6 +21,7 @@ import os import ctypes import cuda.bindings.driver as cuda +import cuda.bindings.runtime as cudart import cuda.bindings.nvrtc as nvrtc # Local module imports @@ -44,6 +45,8 @@ def _cudaGetErrorEnum(error): if isinstance(error, cuda.CUresult): err, name = cuda.cuGetErrorName(error) return name if err == cuda.CUresult.CUDA_SUCCESS else "" + elif isinstance(error, cudart.cudaError_t): + return cudart.cudaGetErrorName(error)[1] elif isinstance(error, nvrtc.nvrtcResult): return nvrtc.nvrtcGetErrorString(error)[1] else: diff --git a/python/CuTeDSL/cutlass/base_dsl/tvm_ffi_builder/mlir_builder.py b/python/CuTeDSL/cutlass/base_dsl/tvm_ffi_builder/mlir_builder.py index 36a3aadc7..4655f564c 100644 --- a/python/CuTeDSL/cutlass/base_dsl/tvm_ffi_builder/mlir_builder.py +++ b/python/CuTeDSL/cutlass/base_dsl/tvm_ffi_builder/mlir_builder.py @@ -38,6 +38,7 @@ class MLIRTypeBuilder: self.gpu_ptr_type = llvm.PointerType.get(address_space=1) # did not find a programmatic way to get the void type self.void_type = ir.Type.parse("!llvm.void") + self.llvm_internal_linkage = ir.Attribute.parse("#llvm.linkage") def ptr_type_with_address_space( self, address_space: Optional[int] = None @@ -374,8 +375,10 @@ class MLIRBuilder(MLIRTypeBuilder): params_type: Sequence[ir.Type], ret_type: ir.Type, internal: bool = False, + llvm_func_attrs: Sequence[str] = (), ) -> tuple[list[ir.Value], ir.Block]: - """Create a function with the given signature.""" + """Create a function with the given signature. + """ func_op = llvm.func( name, function_type=self.as_attr( @@ -383,10 +386,16 @@ class MLIRBuilder(MLIRTypeBuilder): ), ) if internal: - func_op.attributes["llvm.linkage"] = ir.StringAttr.get("private") + func_op.attributes["linkage"] = self.llvm_internal_linkage else: func_op.attributes["llvm.emit_c_interface"] = ir.UnitAttr.get() + # Add LLVM function attributes via passthrough + if llvm_func_attrs: + func_op.attributes["passthrough"] = ir.ArrayAttr.get( + [ir.StringAttr.get(attr) for attr in llvm_func_attrs] + ) + params = [] func_body: Any = func_op.body if func_body is not None: diff --git a/python/CuTeDSL/cutlass/base_dsl/tvm_ffi_builder/tvm_ffi_builder.py b/python/CuTeDSL/cutlass/base_dsl/tvm_ffi_builder/tvm_ffi_builder.py index 91a4d08c4..faf7eb14c 100644 --- a/python/CuTeDSL/cutlass/base_dsl/tvm_ffi_builder/tvm_ffi_builder.py +++ b/python/CuTeDSL/cutlass/base_dsl/tvm_ffi_builder/tvm_ffi_builder.py @@ -668,12 +668,14 @@ class TVMFFIBuilder(MLIRBuilder): param_types.append(self.ptr_type) # p0, p1, ..., pN-1 # Create the helper function + # Mark as noinline since error handling is a slow path and benefits from not inlining with ir.InsertionPoint(self.module.body): # type: ignore[union-attr] params, entry_block = self.function( name=helper_name, params_type=param_types, ret_type=self.void_type, internal=True, + llvm_func_attrs=["noinline"], ) kind_param = params[0] @@ -1244,12 +1246,7 @@ class TVMFFIFunctionBuilder(TVMFFIBuilder): else: assert isinstance(var, int) with ir.InsertionPoint(current_block): - if not skip_cast_and_check: - expected_value = self.i64(var) - else: - expected_value = self.downcast_i64_to_lower_bits( - self.i64(var), var.dtype - ) + expected_value = self.i64(var) error_msg_mismatch = [ error_prefix_mismatch, @@ -1983,13 +1980,21 @@ class TVMFFIFunctionBuilder(TVMFFIBuilder): ) # decode parameters to populate the matched var binding - for arg_index, param in enumerate(params_list): + # Track the actual FFI argument index separately from parameter index + # since some parameters (like EnvStream) are not passed as FFI arguments + ffi_arg_index = 0 + for param in params_list: + # Skip EnvStream parameters as they are not in the FFI args array + if isinstance(param, spec.EnvStream): + continue + arg_context = ArgContext( param_name=param.name, - arg_index=arg_index, + arg_index=ffi_arg_index, tuple_indices=[], ) - current_block = self.decode_param(current_block, param, args, arg_index, arg_context) + current_block = self.decode_param(current_block, param, args, ffi_arg_index, arg_context) + ffi_arg_index += 1 with ir.InsertionPoint(current_block): env_stream = self.find_env_stream(params_list) @@ -2047,7 +2052,9 @@ def attach_ffi_func( builder.attach_ffi_func(symbol_name, params, call_provider, fn_display_name) -def rename_tvm_ffi_function(module: ir.Module, old_name: str, new_name: str) -> None: +def rename_tvm_ffi_function( + module: ir.Module, old_name: str, new_name: str, +) -> None: """Rename the TVM FFI function in the module. Parameters diff --git a/python/CuTeDSL/cutlass/cute/_tvm_ffi_args_spec_converter.py b/python/CuTeDSL/cutlass/cute/_tvm_ffi_args_spec_converter.py index 02fae4825..6bbd9682f 100644 --- a/python/CuTeDSL/cutlass/cute/_tvm_ffi_args_spec_converter.py +++ b/python/CuTeDSL/cutlass/cute/_tvm_ffi_args_spec_converter.py @@ -19,6 +19,7 @@ from .typing import Tensor, Pointer, SymInt from .typing import ( Numeric, Boolean, + Integer, Int4, Int8, Uint8, @@ -42,7 +43,17 @@ from .typing import ( ) import cuda.bindings.driver as cuda -from typing import List, Dict, Any, Optional, Tuple, get_origin, get_args +from typing import ( + List, + Dict, + Any, + Optional, + Tuple, + get_origin, + get_args, + get_type_hints, +) +from types import UnionType import inspect NumericToTVMFFIDtype = { @@ -91,6 +102,7 @@ def _get_llvm_address_space_from_memspace( return 1 return None + def _is_gpu_memspace( memspace: _cute_ir.AddressSpace, ) -> bool: @@ -108,7 +120,6 @@ class SymIntId: return self.sym_int is other.sym_int - class ConverterContext: """Context for managing variable allocation during TVM FFI args conversion.""" @@ -145,7 +156,9 @@ class ConverterContext: self.sym_int_id_mapping[sym_int_id] = var return var - def alloc_or_reuse_device_id(self, device_type: str, vdevice_id: int) -> Optional[spec.Var]: + def alloc_or_reuse_device_id( + self, device_type: str, vdevice_id: int + ) -> Optional[spec.Var]: """Allocate or reuse a device_id variable for a given virtual device. This function returns None for CPU tensors. @@ -166,10 +179,7 @@ class ConverterContext: def _convert_single_arg( - arg, - arg_name: str, - arg_type, - ctx: ConverterContext + arg, arg_name: str, arg_type, ctx: ConverterContext ) -> spec.Param: """Convert a single argument to a spec.Param. @@ -191,7 +201,7 @@ def _convert_single_arg( """ if arg is None: return spec.ConstNone(arg_name) - elif (isinstance(arg, Numeric) and arg.dtype in AcceptableNumericTypesForScalar): + elif isinstance(arg, Numeric) and arg.dtype in AcceptableNumericTypesForScalar: return spec.Var(arg_name, NumericToTVMFFIDtype[arg.dtype]) elif arg_type in AcceptableNumericTypesForScalar: return spec.Var(arg_name, NumericToTVMFFIDtype[arg_type]) @@ -201,9 +211,13 @@ def _convert_single_arg( if isinstance(arg[i], int): shape.append(arg[i]) elif isinstance(arg[i], SymInt): - shape.append(ctx.alloc_or_reuse_symint_var(arg[i], ctx.alloc_shape_name)) + shape.append( + ctx.alloc_or_reuse_symint_var(arg[i], ctx.alloc_shape_name) + ) else: - shape.append(spec.Var(ctx.alloc_shape_name(), NumericToTVMFFIDtype[arg[i].dtype])) + shape.append( + spec.Var(ctx.alloc_shape_name(), NumericToTVMFFIDtype[arg[i].dtype]) + ) return spec.Shape(arg_name, shape) elif isinstance(arg, Tensor): shapes = [] @@ -211,16 +225,22 @@ def _convert_single_arg( if not dyn_mask: shapes.append(arg.shape[i]) elif isinstance(arg.shape[i], SymInt): - shapes.append(ctx.alloc_or_reuse_symint_var(arg.shape[i], ctx.alloc_shape_name)) + shapes.append( + ctx.alloc_or_reuse_symint_var(arg.shape[i], ctx.alloc_shape_name) + ) else: - shapes.append(spec.Var(ctx.alloc_shape_name(), NumericToTVMFFIDtype[Int32])) + shapes.append( + spec.Var(ctx.alloc_shape_name(), NumericToTVMFFIDtype[Int32]) + ) strides = [] for i, dyn_mask in enumerate(arg.dynamic_strides_mask): if not dyn_mask: strides.append(arg.stride[i]) elif isinstance(arg.stride[i], SymInt): - strides.append(ctx.alloc_or_reuse_symint_var(arg.stride[i], ctx.alloc_stride_name)) + strides.append( + ctx.alloc_or_reuse_symint_var(arg.stride[i], ctx.alloc_stride_name) + ) else: if hasattr(arg, "_use_32bit_stride") and arg._use_32bit_stride: dtype = NumericToTVMFFIDtype[Int32] @@ -243,7 +263,7 @@ def _convert_single_arg( strides=strides, data_alignment=arg._assumed_align, device_type=device_type, - device_id=device_id + device_id=device_id, ) else: # for FakeTensor, strictly follow the shape and stride from the cute tensor @@ -259,7 +279,7 @@ def _convert_single_arg( strides=strides, data_alignment=arg._assumed_align, device_type=device_type, - device_id=device_id + device_id=device_id, ) if arg.element_type == Float4E2M1FN: tvm_ffi_cute_tensor = spec.create_map_tensor_dtype_f4x2_to_f4_spec( @@ -278,11 +298,38 @@ def _convert_single_arg( return spec.Stream(arg_name) elif isinstance(arg, cuda.CUstream): return spec.Stream(arg_name) + elif arg_type is not None and hasattr(arg_type, "_fields"): + # Handle NamedTuple - normalize to Tuple by order of fields, ignoring defaults + # Get field types from annotations + type_hints = get_type_hints(arg_type) + tuple_element_types = [type_hints[field] for field in arg_type._fields] + + # NamedTuples inherit from tuple, so we can check with isinstance(arg, tuple) + if not isinstance(arg, tuple): + raise DSLRuntimeError( + f"Expected namedtuple for argument {arg_name}, got {type(arg)}" + ) + if len(arg) != len(tuple_element_types): + raise DSLRuntimeError( + f"NamedTuple length mismatch for argument {arg_name}: " + f"expected {len(tuple_element_types)}, got {len(arg)}" + ) + + # Recursively convert each tuple element + tuple_params = [] + for i, (elem, elem_type) in enumerate(zip(arg, tuple_element_types)): + elem_name = f"{arg_name}[{i}]" + elem_param = _convert_single_arg(elem, elem_name, elem_type, ctx) + tuple_params.append(elem_param) + + return spec.TupleParam(arg_name, tuple_params) elif arg_type is not None and get_origin(arg_type) is tuple: # Handle Tuple[X, Y, ...] type annotations tuple_element_types = get_args(arg_type) if not isinstance(arg, (tuple, list)): - raise DSLRuntimeError(f"Expected tuple for argument {arg_name}, got {type(arg)}") + raise DSLRuntimeError( + f"Expected tuple for argument {arg_name}, got {type(arg)}" + ) if len(arg) != len(tuple_element_types): raise DSLRuntimeError( f"Tuple length mismatch for argument {arg_name}: " @@ -297,8 +344,24 @@ def _convert_single_arg( tuple_params.append(elem_param) return spec.TupleParam(arg_name, tuple_params) + elif isinstance(arg, (tuple, list)): + # Handle plain tuple type annotation without explicit element types + # Recursively convert each tuple element with None as elem_type (un-annotated) + tuple_params = [] + for i, elem in enumerate(arg): + elem_name = f"{arg_name}[{i}]" + elem_param = _convert_single_arg(elem, elem_name, None, ctx) + tuple_params.append(elem_param) + return spec.TupleParam(arg_name, tuple_params) + elif isinstance(arg, int): + # in cute.compile, unannotated const int is converted to int32 + return spec.Var(arg_name, NumericToTVMFFIDtype[Int32]) + elif isinstance(arg, float): + return spec.Var(arg_name, NumericToTVMFFIDtype[Float32]) else: - raise DSLRuntimeError(f"Unsupported argument type: {type(arg)}") + raise DSLRuntimeError( + f"Unsupported argument type: {type(arg)} for annotated type: {get_origin(arg_type)}" + ) def _tvm_ffi_args_spec_converter( @@ -312,17 +375,24 @@ def _tvm_ffi_args_spec_converter( This function converts the cute arguments specs to tvm ffi spec params. """ exec_args = ExecutionArgs(args_spec, function_name) - rectified_args = exec_args.get_rectified_args_from_original_args(full_args, full_kwargs) + rectified_args = exec_args.get_rectified_args_from_original_args( + full_args, full_kwargs + ) arg_names = exec_args.args_spec.args + exec_args.args_spec.kwonlyargs params = [] ctx = ConverterContext() + wrapper_extra_exclude_arg_names = [] for arg, arg_name in zip(rectified_args, arg_names): arg_type = args_spec.annotations.get(arg_name, None) param = _convert_single_arg(arg, arg_name, arg_type, ctx) params.append(param) - - return params + if isinstance(param, spec.EnvStream): + wrapper_extra_exclude_arg_names.append(arg_name) + kwargs_wrapper_spec = exec_args.get_kwargs_wrapper_spec( + wrapper_extra_exclude_arg_names + ) + return params, kwargs_wrapper_spec def attach_args_spec_converter(): diff --git a/python/CuTeDSL/cutlass/cutlass_dsl/cuda_jit_executor.py b/python/CuTeDSL/cutlass/cutlass_dsl/cuda_jit_executor.py index b6a2435c9..f4c6bb474 100644 --- a/python/CuTeDSL/cutlass/cutlass_dsl/cuda_jit_executor.py +++ b/python/CuTeDSL/cutlass/cutlass_dsl/cuda_jit_executor.py @@ -28,8 +28,9 @@ from ..base_dsl.jit_executor import ( JitFunctionArtifacts, ) from ..base_dsl.utils.logger import log -from ..base_dsl.common import DSLCudaRuntimeError, DSLRuntimeError +from ..base_dsl.common import DSLRuntimeError from ..base_dsl.typing import Int32 +from ..base_dsl.runtime.cuda import checkCudaErrors class CudaDialectJitModule: """Holds the execution engine and cuda libraries.""" @@ -113,10 +114,7 @@ class CudaDialectJitCompiledFunction(JitCompiledFunction): @functools.cached_property def num_devices(self): """Returns the number of CUDA devices available.""" - dev_err, devs = cuda_runtime.cudaGetDeviceCount() - if dev_err != cuda_runtime.cudaError_t.cudaSuccess: - raise DSLCudaRuntimeError(dev_err, cuda_runtime.cudaGetErrorName(dev_err)) - return devs + return checkCudaErrors(cuda_runtime.cudaGetDeviceCount()) def _deserializer(self): """Load the cuda library from the binary execution engine. @@ -148,12 +146,7 @@ class CudaDialectJitCompiledFunction(JitCompiledFunction): packed_args[i] = ctypes.cast(cuda_init_args[i], ctypes.c_void_p) cuda_init(packed_args) - if err.value != 0: - error_code = err.value - error_name = cuda_runtime.cudaGetErrorName( - cuda_runtime.cudaError_t(error_code) - ) - raise DSLCudaRuntimeError(error_code, error_name) + checkCudaErrors((cuda_runtime.cudaError_t(err.value),)) cuda_load_args = [pointer_to_library, pointer_to_err] packed_args = (ctypes.c_void_p * len(cuda_load_args))() @@ -161,12 +154,7 @@ class CudaDialectJitCompiledFunction(JitCompiledFunction): packed_args[i] = ctypes.cast(cuda_load_args[i], ctypes.c_void_p) cuda_load(packed_args) - if err.value != 0: - error_code = err.value - error_name = cuda_runtime.cudaGetErrorName( - cuda_runtime.cudaError_t(error_code) - ) - raise DSLCudaRuntimeError(error_code, error_name) + checkCudaErrors((cuda_runtime.cudaError_t(err.value),)) return [cuda_runtime.cudaLibrary_t(library.value)] @@ -229,12 +217,7 @@ class CudaDialectJitCompiledFunction(JitCompiledFunction): packed_args[i] = ctypes.cast(cuda_init_args[i], ctypes.c_void_p) cuda_init(packed_args) - if err.value != 0: - error_code = err.value - error_name = cuda_runtime.cudaGetErrorName( - cuda_runtime.cudaError_t(error_code) - ) - raise DSLCudaRuntimeError(error_code, error_name) + checkCudaErrors((cuda_runtime.cudaError_t(err.value),)) device_id = ctypes.c_int32(0) pointer_to_device_id = ctypes.pointer(device_id) @@ -247,18 +230,9 @@ class CudaDialectJitCompiledFunction(JitCompiledFunction): for dev in range(self.num_devices): device_id.value = dev cuda_load_to_device(packed_args) - if err.value != 0: - raise DSLCudaRuntimeError( - err.value, - cuda_runtime.cudaGetErrorName(cuda_runtime.cudaError_t(err.value)), - ) + checkCudaErrors((cuda_runtime.cudaError_t(err.value),)) - if err.value != 0: - error_code = err.value - error_name = cuda_runtime.cudaGetErrorName( - cuda_runtime.cudaError_t(error_code) - ) - raise DSLCudaRuntimeError(error_code, error_name) + checkCudaErrors((cuda_runtime.cudaError_t(err.value),)) return [cuda_runtime.cudaLibrary_t(library.value)] diff --git a/python/CuTeDSL/cutlass/cutlass_dsl/cutlass.py b/python/CuTeDSL/cutlass/cutlass_dsl/cutlass.py index 477b3152a..2ed2517ce 100644 --- a/python/CuTeDSL/cutlass/cutlass_dsl/cutlass.py +++ b/python/CuTeDSL/cutlass/cutlass_dsl/cutlass.py @@ -43,9 +43,9 @@ from ..base_dsl.dsl import is_dynamic_expression, extract_mlir_values from ..base_dsl.typing import * from ..base_dsl.typing import DynamicExpression, get_mlir_types from ..base_dsl.runtime.jit_arg_adapters import is_arg_spec_constexpr +from ..base_dsl.jit_executor import ExecutionArgs from ..base_dsl.runtime import cuda as cuda_helpers from .cuda_stream_adapter import CudaDialectStreamAdapter - from .cuda_jit_executor import CudaDialectJitCompiledFunction # MLIR Imports @@ -421,12 +421,13 @@ class CutlassBaseDSL(BaseDSL): # attach extra ABI function to the MLIR module from .tvm_ffi_provider import ( TVMFFIJitCompiledFunction, + TVMFFIJitCompiledFunctionWithKwargs, TVMFFICuteCallProvider, ) from cutlass.base_dsl.tvm_ffi_builder import attach_ffi_func assert self._tvm_ffi_args_spec_converter is not None - tvm_ffi_spec_params = self._tvm_ffi_args_spec_converter( + tvm_ffi_spec_params, kwargs_wrapper_spec = self._tvm_ffi_args_spec_converter( function_name, args_spec, full_args, full_kwargs ) tvm_ffi_provider = TVMFFICuteCallProvider(function_name) @@ -444,6 +445,15 @@ class CutlassBaseDSL(BaseDSL): ) module.operation.verify() + def _make_compiled_func(*args, **kwargs): + if kwargs_wrapper_spec.kwonly_names or kwargs_wrapper_spec.arg_defaults: + return TVMFFIJitCompiledFunctionWithKwargs( + *args, **kwargs, + kwargs_wrapper_spec=kwargs_wrapper_spec + ) + else: + return TVMFFIJitCompiledFunction(*args, **kwargs) + # ensure the compiler can run post-compile hook after its passes # the context will restore the previous post-compile hook after it exits with compiler.PostCompileHookContext( @@ -456,7 +466,7 @@ class CutlassBaseDSL(BaseDSL): pipeline, args_spec, no_cache, - TVMFFIJitCompiledFunction, + _make_compiled_func, full_args=full_args, full_kwargs=full_kwargs, dynamic_args=dynamic_args, diff --git a/python/CuTeDSL/cutlass/cutlass_dsl/tvm_ffi_provider.py b/python/CuTeDSL/cutlass/cutlass_dsl/tvm_ffi_provider.py index 4636615cd..ea1674e8e 100644 --- a/python/CuTeDSL/cutlass/cutlass_dsl/tvm_ffi_provider.py +++ b/python/CuTeDSL/cutlass/cutlass_dsl/tvm_ffi_provider.py @@ -15,12 +15,14 @@ from cutlass.base_dsl.tvm_ffi_builder import ( rename_tvm_ffi_function, spec, ) +from cutlass.base_dsl.export import get_export_module from cutlass._mlir import ir from cutlass._mlir.dialects import llvm -from cutlass._mlir._mlir_libs._cutlass_ir import _execution_engine_extra +from cutlass._mlir._mlir_libs._cutlass_ir import _aot_support from cutlass.cutlass_dsl.cuda_jit_executor import CudaDialectJitCompiledFunction from cutlass.base_dsl.common import DSLRuntimeError -from typing import Optional +from cutlass.base_dsl.jit_executor import ExecutionArgs +from typing import Optional, Callable import tvm_ffi @@ -400,41 +402,52 @@ class TVMFFICuteCallProvider(DynamicParamPackCallProvider): return current_block -class TVMFFIJitCompiledFunction(tvm_ffi.Function, CudaDialectJitCompiledFunction): - """TVM FFI Function that contains metadata of the compiled function and interface to the FFI layer. +def _inplace_hide_symbols(ir_module: ir.Module, hide_check: Callable[[str], bool]): + """Walk through the IRModule, hide functions that do not yet have linkage set. - This function should not be directly used after + @param ir_module: The ir module to hide the symbols. + @param hide_check: The callback to check if the symbol should be hidden. + @return: The ir module with the symbols hidden. """ + defined_symbols = set() + def walk_llvm_func_op(op): + # not a declaration + if ( + op.name == "llvm.func" + and len(op.opview.operation.regions) > 0 + and len(op.opview.operation.regions[0].blocks) > 0 + ): + func_name = op.attributes["sym_name"].value + defined_symbols.add(func_name) + return ir.WalkResult.ADVANCE + + def walk_and_hide_symbols(op): + # Handle llvm.func operations + if op.name == "llvm.func": + func_name = op.attributes["sym_name"].value + # Only set linkage if it doesn't already have one + if func_name in defined_symbols and hide_check(func_name): + # Set to internal linkage to hide the symbol + op.attributes["linkage"] = ir.Attribute.parse("#llvm.linkage") + return ir.WalkResult.ADVANCE + + with ir_module.context: + ir_module.operation.walk(walk_llvm_func_op) + ir_module.operation.walk(walk_and_hide_symbols) + + +def _get_format_from_object_file_path(object_file_path: str) -> str: + format = object_file_path.split(".")[-1] + if format not in ("o", "ll", "bc"): + return "o" + return format + + +class TVMFFIJitCompiledFunctionBase(CudaDialectJitCompiledFunction): + """Base class for TVM FFI compiled function.""" def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - # initialize the tvm_ffi.Function from the current execution engine - self._init_ffi_function() - - # use direct call to the tvm_ffi.Function.__call__ - # to avoid most of python overhead - __call__ = tvm_ffi.Function.__call__ - - def _init_ffi_function(self): - """Initialize the tvm_ffi.Function from the current execution engine. - - This function must be called at once during compilation time. - The reason why it is not called during init is because the original - flow may already created an execution engine and the function is not - guaranteed to be initialized at that time. - """ - if self.__chandle__() != 0: - raise DSLRuntimeError("TVM FFI function is already initialized") - # get the MLIR function pointer from the execution engine - if self.engine is not None: - tvm_ffi_function_ptr = self.engine.raw_lookup( - "__tvm_ffi_" + self.function_name - ) - tvm_ffi_function = tvm_ffi.Function.__from_mlir_packed_safe_call__( - tvm_ffi_function_ptr - ) - # move the handle from the tvm_ffi.Function to the current instance - self.__move_handle_from__(tvm_ffi_function) def to(self, device=None): """TVM FFI function itself is already support all devices.""" @@ -444,18 +457,111 @@ class TVMFFIJitCompiledFunction(tvm_ffi.Function, CudaDialectJitCompiledFunction """Run the compiled program. This override is needed for implicit compile and execution.""" return self.__call__(*exe_args) - def export_to_c(self, object_file_path: str, function_name: str = None): + def export_to_c( + self, object_file_path: str, function_name: str = None, + *, + enable_pic: bool = True, + export_only_tvm_ffi_symbols: bool = False + ): """Export the TVM FFI function to an object file. :param object_file_path: The path to the object file. :param function_name: The name of the function to export. + :param enable_pic: Whether to enable PIC relocation needed for shared library loading. + :param export_only_tvm_ffi_symbols: Only export TVM FFI symbols (hide all others). + :param host_target_triple: If not provided, the current host target is used. """ - if function_name is not None and function_name != self.function_name: - mod = self.ir_module - rename_tvm_ffi_function(mod, self.function_name, function_name) - else: - mod = self.ir_module - - _execution_engine_extra.dump_object_file_pic( - mod, object_file_path, "__tvm_ffi_" + function_name, 2 + # prefix internal function by function name + internal_symbol_prefix = "__cute_internal_" + function_name + mod = self.ir_module + mod = get_export_module( + self.ir_module, internal_symbol_prefix, + preserve_symbols=[f"__tvm_ffi_{self.function_name}"] ) + + rename_tvm_ffi_function(mod, self.function_name, function_name) + if export_only_tvm_ffi_symbols: + _inplace_hide_symbols(mod, lambda x: not x.startswith("__tvm_ffi")) + + format = _get_format_from_object_file_path(object_file_path) + out_bytes = _aot_support.export_module_to_bytes( + mod, format=format, opt_level=3, enable_pic=enable_pic + ) + + with open(object_file_path, "wb") as f: + f.write(out_bytes) + + def _create_tvm_ffi_function(self): + """Create the tvm_ffi.Function from the current execution engine. + """ + if self.engine is not None: + tvm_ffi_function_ptr = self.engine.raw_lookup( + "__tvm_ffi_" + self.function_name + ) + tvm_ffi_function = tvm_ffi.Function.__from_mlir_packed_safe_call__( + tvm_ffi_function_ptr, keep_alive_object=self.engine) + return tvm_ffi_function + return None + + +class TVMFFIJitCompiledFunction(tvm_ffi.Function, TVMFFIJitCompiledFunctionBase): + """TVM FFI Function that directly subclasses the tvm_ffi.Function for pos only arguments. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + # initialize the tvm_ffi.Function from the current execution engine + if self.__chandle__() != 0: + raise DSLRuntimeError("TVM FFI function is already initialized") + tvm_ffi_function = self._create_tvm_ffi_function() + if tvm_ffi_function is not None: + # move the handle from the tvm_ffi.Function to the current instance + self.__move_handle_from__(tvm_ffi_function) + + # use direct call to the tvm_ffi.Function.__call__ + # to avoid most of python overhead + __call__ = tvm_ffi.Function.__call__ + + +class TVMFFIJitCompiledFunctionWithKwargs(TVMFFIJitCompiledFunctionBase): + """TVM FFI Function with kwargs wrapper support + """ + + def __init__(self, *args, **kwargs): + assert "kwargs_wrapper_spec" in kwargs, "kwargs_wrapper_spec is required" + kwargs_wrapper_spec = kwargs.pop("kwargs_wrapper_spec") + super().__init__(*args, **kwargs) + # initialize the tvm_ffi.Function from the current execution engine + self._tvm_ffi_function = self._create_tvm_ffi_function() + if kwargs_wrapper_spec.kwonly_names or kwargs_wrapper_spec.arg_defaults: + try: + from tvm_ffi.utils import kwargs_wrapper # type: ignore + self._kwargs_wrapper = kwargs_wrapper.make_kwargs_wrapper( + self._tvm_ffi_function, + arg_names=kwargs_wrapper_spec.arg_names, + arg_defaults=kwargs_wrapper_spec.arg_defaults, + kwonly_names=kwargs_wrapper_spec.kwonly_names, + kwonly_defaults=kwargs_wrapper_spec.kwonly_defaults, + ) + except ImportError: + raise DSLRuntimeError("install apache-tvm-ffi>=0.1.5 to enable kwargs/defaults") + else: + # positional only is probably fine + self._kwargs_wrapper = self._tvm_ffi_function + + def __call__(self, *args, **kwargs): + """Call the TVM FFI function with kwargs wrapper. + """ + return self._kwargs_wrapper(*args, **kwargs) + + def __tvm_ffi_object__(self): + return self._tvm_ffi_function + + +def supports_kwargs_wrapper() -> bool: + """Check if the kwargs wrapper is supported.""" + try: + from tvm_ffi.utils import kwargs_wrapper # type: ignore + return True + except ImportError: + return False diff --git a/python/CuTeDSL/requirements.txt b/python/CuTeDSL/requirements.txt index 19e72110b..7945c9cea 100644 --- a/python/CuTeDSL/requirements.txt +++ b/python/CuTeDSL/requirements.txt @@ -1,3 +1,3 @@ # Use `pip install -r requirements.txt` with the present file to install a # wheel consistent with the present state of the github repository -nvidia-cutlass-dsl==4.3.0 +nvidia-cutlass-dsl==4.3.3 diff --git a/python/cutlass_cppgen/__init__.py b/python/cutlass_cppgen/__init__.py index 63cece3d1..fbc26a47d 100644 --- a/python/cutlass_cppgen/__init__.py +++ b/python/cutlass_cppgen/__init__.py @@ -133,7 +133,7 @@ def get_option_registry(): this._option_registry = OptionRegistry(device_cc()) return this._option_registry -this.__version__ = '4.3.2' +this.__version__ = '4.3.3' from cutlass_cppgen.backend import create_memory_pool from cutlass_cppgen.emit.pytorch import pytorch diff --git a/python/setup_cutlass.py b/python/setup_cutlass.py index faa20bfb2..a3944c4f3 100644 --- a/python/setup_cutlass.py +++ b/python/setup_cutlass.py @@ -51,7 +51,7 @@ setup_pycute.perform_setup() setup( name='cutlass_cppgen', - version='4.3.2', + version='4.3.3', description='CUTLASS Pythonic Interface', package_dir={'': '.'}, packages=[ diff --git a/python/setup_pycute.py b/python/setup_pycute.py index 671afe050..47d8edc71 100644 --- a/python/setup_pycute.py +++ b/python/setup_pycute.py @@ -36,7 +36,7 @@ from setuptools import setup def perform_setup(): setup( name='pycute', - version='4.3.2', + version='4.3.3', description='Python implementation of CuTe', packages=['pycute'], )