From ff35fa561d673c527bfc0bfa0fe7e4377c688a32 Mon Sep 17 00:00:00 2001 From: Junkai-Wu Date: Thu, 4 Dec 2025 23:14:50 +0800 Subject: [PATCH] v4.3.2 update. (#2840) --- CHANGELOG.md | 8 + README.md | 7 +- .../python/CuTeDSL/blackwell/dense_gemm.py | 2 +- include/cutlass/version.h | 2 +- .../cute_dsl_general/compile_with_tvm_ffi.rst | 14 +- .../cute_dsl_general/dsl_jit_caching.rst | 6 + .../CuTeDSL/cutlass/base_dsl/cache_helpers.py | 27 ++- python/CuTeDSL/cutlass/base_dsl/dsl.py | 4 +- .../CuTeDSL/cutlass/base_dsl/env_manager.py | 2 + .../cutlass/base_dsl/tvm_ffi_builder/spec.py | 6 +- .../tvm_ffi_builder/tvm_ffi_builder.py | 80 +++++-- .../cute/_tvm_ffi_args_spec_converter.py | 35 ++- python/CuTeDSL/cutlass/cute/runtime.py | 1 + .../cutlass/cutlass_dsl/cuda_jit_executor.py | 2 + .../cutlass/utils/distributed_helpers.py | 208 ------------------ .../CuTeDSL/cutlass/utils/smem_allocator.py | 11 +- python/cutlass_cppgen/__init__.py | 2 +- python/setup_cutlass.py | 2 +- python/setup_pycute.py | 2 +- 19 files changed, 164 insertions(+), 257 deletions(-) delete mode 100644 python/CuTeDSL/cutlass/utils/distributed_helpers.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 1eb5a29a1..58d6c841b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,14 @@ # CUTLASS 4.x +## [4.3.2](https://github.com/NVIDIA/cutlass/releases/tag/v4.3.2) (2025-12-05) +* New features + - New env var `CUTE_DSL_CACHE_DIR` to specify the path for dumping caches + +* Bug fixing and improvements + - Fixed an issue of CUDA JitExecutor when unloading kernels + - Fixed an issue of allocating max smem when there's statically allocated smem + ## [4.3.1](https://github.com/NVIDIA/cutlass/releases/tag/v4.3.1) (2025-11-26) ### CuTe DSL diff --git a/README.md b/README.md index 6e4c30e60..ec4d0bd15 100644 --- a/README.md +++ b/README.md @@ -1,9 +1,9 @@ ![ALT](./media/images/gemm-hierarchy-with-epilogue-no-labels.png "Complete CUDA GEMM decomposition") # Overview -# CUTLASS 4.3.1 +# CUTLASS 4.3.2 -_CUTLASS 4.3.1 - Nov 2025_ +_CUTLASS 4.3.2 - Dec 2025_ CUTLASS is a collection of abstractions for implementing high-performance matrix-matrix multiplication (GEMM) and related computations at all levels and scales within CUDA. It incorporates strategies for @@ -53,6 +53,7 @@ To get started quickly - please refer : - Added l2 cache evict priority for tma related ops. Users could do fine-grain l2 cache control. - Added Blackwell SM103 support. - Multiple dependent DSOs in the wheel have been merged into one single DSO. + - New env var `CUTE_DSL_CACHE_DIR` to specify the path for dumping caches. * Debuggability improvements: - Supported source location tracking for DSL APIs (Allow tools like ``nsight`` profiling to correlate perf metrics with Python source code) - Supported dumping PTX and CUBIN code: [Hello World Example](https://github.com/NVIDIA/cutlass/blob/main/examples/python/CuTeDSL/notebooks/hello_world.ipynb) @@ -99,6 +100,8 @@ To get started quickly - please refer : - Fixed an issue with mark_compact_shape_dynamic - Fixed device reset issue with tvm-ffi - Fixed tvm-ffi export compiled function + - Fixed an issue of CUDA JitExecutor when unloading kernels + - Fixed an issue of allocating max smem when there's statically allocated smem ## CUTLASS C++ * Further enhance Blackwell SM100 Attention kernels in [example 77](https://github.com/NVIDIA/cutlass/tree/main/examples/77_blackwell_fmha/). diff --git a/examples/python/CuTeDSL/blackwell/dense_gemm.py b/examples/python/CuTeDSL/blackwell/dense_gemm.py index 4f2e93ea8..c5ff6bb19 100644 --- a/examples/python/CuTeDSL/blackwell/dense_gemm.py +++ b/examples/python/CuTeDSL/blackwell/dense_gemm.py @@ -215,7 +215,7 @@ class DenseGemmKernel: self.occupancy = 1 self.threads_per_cta = 128 - self.smem_capacity = utils.get_smem_capacity_in_bytes("sm_100") + self.smem_capacity = utils.get_smem_capacity_in_bytes() def _setup_attributes(self): """Set up configurations that are dependent on GEMM inputs diff --git a/include/cutlass/version.h b/include/cutlass/version.h index aa90b1713..ce1c9b24b 100644 --- a/include/cutlass/version.h +++ b/include/cutlass/version.h @@ -36,7 +36,7 @@ #define CUTLASS_MAJOR 4 #define CUTLASS_MINOR 3 -#define CUTLASS_PATCH 1 +#define CUTLASS_PATCH 2 #ifdef CUTLASS_VERSIONS_GENERATED #include "cutlass/version_extended.h" diff --git a/media/docs/pythonDSL/cute_dsl_general/compile_with_tvm_ffi.rst b/media/docs/pythonDSL/cute_dsl_general/compile_with_tvm_ffi.rst index 42d0fdea1..296ff8a8e 100644 --- a/media/docs/pythonDSL/cute_dsl_general/compile_with_tvm_ffi.rst +++ b/media/docs/pythonDSL/cute_dsl_general/compile_with_tvm_ffi.rst @@ -55,7 +55,7 @@ To maximize performance benefits, we recommend setting up your workflow as follo - **Declare shape constraints using fake tensors** and reuse the compiled function throughout your execution. - **Pass PyTorch tensors directly** to the compiled function to avoid explicit DLPack conversion. -- **Use the environment stream flag** to implicitly synchronize with the current PyTorch stream. +- **Use the environment stream flag** to implicitly pass the current PyTorch stream. - **Rely on compiled argument validation** instead of Python-side attribute validation, as TVM FFI functions perform fast compiled checks. @@ -210,11 +210,11 @@ The following example demonstrates this approach; the function accepts ``torch.c Using Environment Stream ~~~~~~~~~~~~~~~~~~~~~~~~ -The second option is to rely on the environment-stream flag. -Pass ``use_tvm_ffi_env_stream=True`` to ``make_fake_stream`` to mark the argument as an -environment stream so it no longer has to be provided explicitly. -TVM FFI will reuse its environment stream, synchronizing it with ``torch.cuda.current_stream()`` -before each call. The example below shows this flow: +The second option is to rely on the environment stream flag. +Pass ``use_tvm_ffi_env_stream=True`` to ``make_fake_stream`` to mark the stream argument as an +environment stream, which means it no longer needs to be provided explicitly. +TVM FFI will automatically use its environment stream (i.e., the current PyTorch stream) +as the stream argument. The example below demonstrates this flow: .. code-block:: python @@ -351,7 +351,7 @@ example error cases that can be checked: except ValueError as e: # Mismatched b.shape[0] on argument #1 when calling: # `add_one(a: Tensor([n0], float32), b: Tensor([n0], float32))`, - # symbolic constraint violated + # expected to match a.shape[0] print(f"ValueError: {e}") try: diff --git a/media/docs/pythonDSL/cute_dsl_general/dsl_jit_caching.rst b/media/docs/pythonDSL/cute_dsl_general/dsl_jit_caching.rst index ecaea52b5..214ba2804 100644 --- a/media/docs/pythonDSL/cute_dsl_general/dsl_jit_caching.rst +++ b/media/docs/pythonDSL/cute_dsl_general/dsl_jit_caching.rst @@ -130,6 +130,9 @@ After serialization, compiled MLIR bytecode is stored in file. The cache directory is ``/tmp/{current_user}/cutlass_python_cache``. The cache loads from files into memory during |DSL| initialization and saves back to files when the process exits. +Note that for efficiency, the default cache directory is located in a temporary folder. However, this location is not persistent, it may be cleared by the system (for example, during a reboot or disk space cleanup). +If you wish to preserve the cache across sessions, set the ``CUTE_DSL_CACHE_DIR`` environment variable to point to a persistent directory. + The following environment variables control file caching: .. code-block:: bash @@ -140,6 +143,9 @@ The following environment variables control file caching: # Maximum number of cache files allowed, defaults to 1000. export CUTE_DSL_FILE_CACHING_CAPACITY=1000 + # Cache directory, defaults to /tmp/{current_user}/cutlass_python_cache. + export CUTE_DSL_CACHE_DIR=/home/user/local_cutlass_python_cache/dense_gemm_cache/ + Limitations ~~~~~~~~~~~~~~~~~~~~~ diff --git a/python/CuTeDSL/cutlass/base_dsl/cache_helpers.py b/python/CuTeDSL/cutlass/base_dsl/cache_helpers.py index a272497f8..898fffb93 100644 --- a/python/CuTeDSL/cutlass/base_dsl/cache_helpers.py +++ b/python/CuTeDSL/cutlass/base_dsl/cache_helpers.py @@ -22,6 +22,7 @@ import time from pathlib import Path import hashlib from functools import lru_cache +import tempfile from .utils.logger import log from .jit_executor import JitCompiledFunction @@ -46,15 +47,23 @@ def get_current_user(): # default_generated_ir_path is the path to the cache directory. -# It is set to /tmp/{user}/cutlass_python_cache/ by default. -# If the user is not found, the default path is used or /tmp/cutlass_python_cache/ is used. -try: - default_generated_ir_path = f"/tmp/{get_current_user()}/cutlass_python_cache/" -except Exception as e: - # If all else fails, provide a default fallback path - default_generated_ir_path = "/tmp/cutlass_python_cache/" - print(f"Could not determine user, using default path. Error: {e}") +# If `CUTE_DSL_CACHE_DIR` is set, it is used as the cache directory. +# Otherwise, it is set to a directory controled by TMPDIR defaulting +# to /tmp/${USER}/cutlass_python_cache. +if not (default_generated_ir_path := os.getenv("CUTE_DSL_CACHE_DIR", None)): + tmp_dir = Path(os.environ.get("TMPDIR", tempfile.gettempdir())) + + def get_reusable_temp_dir(name): + path = tmp_dir / f"{get_current_user()}/{name}" + path.mkdir(parents=True, exist_ok=True) + return str(path) + + try: + default_generated_ir_path = get_reusable_temp_dir("cutlass_python_cache") + except Exception as e: + default_generated_ir_path = str(tmp_dir / "cutlass_python_cache") + print(f"Could not determine user, using default path. Error: {e}") @lru_cache(maxsize=1) def get_default_file_dump_root(): @@ -223,6 +232,8 @@ def dump_cache_to_path( :type bytecode_writer: callable, optional """ log().info("JIT cache : dumping [%s] items=[%s]", dsl_name, len(jit_cache)) + if not path: + path = default_generated_ir_path os.makedirs(path, exist_ok=True) try: for idx, [key, value] in enumerate(jit_cache.items()): diff --git a/python/CuTeDSL/cutlass/base_dsl/dsl.py b/python/CuTeDSL/cutlass/base_dsl/dsl.py index e19b08889..cb3496370 100644 --- a/python/CuTeDSL/cutlass/base_dsl/dsl.py +++ b/python/CuTeDSL/cutlass/base_dsl/dsl.py @@ -372,10 +372,10 @@ class BaseDSL: atexit.register(restore_excepthook, origin_excepthook) - def dump_cache(self): + def dump_cache(self, path=None): if not self.envar.disable_file_caching: dump_cache_to_path( - self.name, self.jit_cache, self.envar.file_caching_capacity + self.name, self.jit_cache, self.envar.file_caching_capacity, path=path ) @lru_cache(maxsize=1) diff --git a/python/CuTeDSL/cutlass/base_dsl/env_manager.py b/python/CuTeDSL/cutlass/base_dsl/env_manager.py index cd0aa79bb..bd170deb4 100644 --- a/python/CuTeDSL/cutlass/base_dsl/env_manager.py +++ b/python/CuTeDSL/cutlass/base_dsl/env_manager.py @@ -296,6 +296,7 @@ class EnvironmentVarManager(LogEnvironmentManager): - [DSL_NAME]_FILTER_STACKTRACE: Filter internal stacktrace (default: True) File options: - [DSL_NAME]_DUMP_DIR: Directory to dump the generated files (default: current working directory) + - [DSL_NAME]_CACHE_DIR: Cache directory (default: /tmp/{dsl_name}_python_cache_{tmpfile_suffix}) - [DSL_NAME]_KEEP_IR: Save generated IR in a file (default: False) - [DSL_NAME]_KEEP_PTX: Save generated PTX in a file (default: False) - [DSL_NAME]_KEEP_CUBIN: Save generated CUBIN in a file (default: False) @@ -333,6 +334,7 @@ class EnvironmentVarManager(LogEnvironmentManager): # File options self.keep_ir = get_bool_env_var(f"{prefix}_KEEP_IR", False) + self.cache_dir = get_str_env_var(f"{prefix}_CACHE_DIR", None) # Other options self.dryrun = get_bool_env_var(f"{prefix}_DRYRUN", False) self.arch = get_str_env_var(f"{prefix}_ARCH", detect_gpu_arch(prefix)) diff --git a/python/CuTeDSL/cutlass/base_dsl/tvm_ffi_builder/spec.py b/python/CuTeDSL/cutlass/base_dsl/tvm_ffi_builder/spec.py index be66771d2..6fc8812ff 100644 --- a/python/CuTeDSL/cutlass/base_dsl/tvm_ffi_builder/spec.py +++ b/python/CuTeDSL/cutlass/base_dsl/tvm_ffi_builder/spec.py @@ -192,6 +192,7 @@ class Tensor(Param): dtype: Union[str, "tvm_ffi.dtype"], *, device_type: Optional[str] = None, + device_id: Optional[Var] = None, strides: Optional[Sequence[Var]] = None, map_tensor_dtype_f4x2_to_f4: bool = False, data_alignment: Optional[int] = None, @@ -229,7 +230,10 @@ class Tensor(Param): example_device = tvm_ffi.device(device_type, 0) self.dlpack_device_type = example_device.dlpack_device_type() self.device_type_name = example_device.type - self.device_id = Var(name + ".device_id", tvm_ffi.dtype("int32")) + if device_id is None: + self.device_id = Var(name + ".device.index", tvm_ffi.dtype("int32")) + else: + self.device_id = device_id self.map_tensor_dtype_f4x2_to_f4 = map_tensor_dtype_f4x2_to_f4 diff --git a/python/CuTeDSL/cutlass/base_dsl/tvm_ffi_builder/tvm_ffi_builder.py b/python/CuTeDSL/cutlass/base_dsl/tvm_ffi_builder/tvm_ffi_builder.py index 7353210ca..91a4d08c4 100644 --- a/python/CuTeDSL/cutlass/base_dsl/tvm_ffi_builder/tvm_ffi_builder.py +++ b/python/CuTeDSL/cutlass/base_dsl/tvm_ffi_builder/tvm_ffi_builder.py @@ -818,6 +818,7 @@ class TVMFFIFunctionBuilder(TVMFFIBuilder): _fn_call_context: str matched_var_binding: dict[spec.Var, ir.Value] matched_var_source: dict[spec.Var, ir.Value] + matched_var_arg_field_name: dict[spec.Var, str] def __init__(self, module: ir.Module) -> None: super().__init__() @@ -826,6 +827,7 @@ class TVMFFIFunctionBuilder(TVMFFIBuilder): self._fn_call_context: str = "" self.matched_var_binding = {} self.matched_var_source = {} + self.matched_var_arg_field_name = {} def find_or_declare_extern_func( self, name: str, params: Sequence[ir.Type], ret: ir.Type @@ -897,6 +899,7 @@ class TVMFFIFunctionBuilder(TVMFFIBuilder): *arg_context.get(), self._fn_call_context, ], + arg_context.get_field_name(""), ) def decode_param_float( @@ -1000,6 +1003,7 @@ class TVMFFIFunctionBuilder(TVMFFIBuilder): result = result_block.arguments[0] self.matched_var_binding[param] = result self.matched_var_source[param] = v_float64 + self.matched_var_arg_field_name[param] = arg_context.get_field_name("") return result_block @@ -1054,6 +1058,7 @@ class TVMFFIFunctionBuilder(TVMFFIBuilder): # For opaque handles, we store the pointer directly self.matched_var_binding[param] = v_ptr self.matched_var_source[param] = v_ptr + self.matched_var_arg_field_name[param] = arg_context.get_field_name("") return current_block @@ -1191,8 +1196,10 @@ class TVMFFIFunctionBuilder(TVMFFIBuilder): var: Union[spec.Var, int], value: ir.Value, error_msg_context: list[str], + arg_field_name: str, *, skip_check_predicate: Optional[ir.Value] = None, + skip_cast_and_check: bool = False, ) -> ir.Block: """Set or check the matched var binding.""" error_kind = "ValueError" @@ -1202,33 +1209,48 @@ class TVMFFIFunctionBuilder(TVMFFIBuilder): if isinstance(var, spec.Var): # if var contains llvm_value and is not populated, populate it if var not in self.matched_var_binding: - current_block = self.check_int_value_dtype_bound( - current_block, value, var.dtype, error_msg_context - ) - # check divisibility if specified - if var.divisibility is not None: - current_block = self.check_int_value_divisibility( - current_block, value, var.divisibility, error_msg_context, - skip_check_predicate=skip_check_predicate, - ) - # store the source value with parameter info - with ir.InsertionPoint(current_block): - self.matched_var_source[var] = value - self.matched_var_binding[var] = self.downcast_i64_to_lower_bits( - value, var.dtype + if not skip_cast_and_check: + current_block = self.check_int_value_dtype_bound( + current_block, value, var.dtype, error_msg_context ) + # check divisibility if specified + if var.divisibility is not None: + current_block = self.check_int_value_divisibility( + current_block, value, var.divisibility, error_msg_context, + skip_check_predicate=skip_check_predicate, + ) + # store the source value with parameter info + with ir.InsertionPoint(current_block): + target_value = self.downcast_i64_to_lower_bits( + value, var.dtype + ) + else: + target_value = value + # store the source value + self.matched_var_source[var] = value + # store the target value (casted to target dtype aleady) + self.matched_var_binding[var] = target_value + # store arg_field_name + self.matched_var_arg_field_name[var] = arg_field_name return current_block # otherwise, it appears more than once, we need to check if the value matches expected_value = self.matched_var_source[var] + prev_arg_field_name = self.matched_var_arg_field_name[var] error_msg_mismatch = [ error_prefix_mismatch, *error_msg_context, - ", symbolic constraint violated" + f", expected to match {prev_arg_field_name}", ] else: assert isinstance(var, int) with ir.InsertionPoint(current_block): - expected_value = self.i64(var) + if not skip_cast_and_check: + expected_value = self.i64(var) + else: + expected_value = self.downcast_i64_to_lower_bits( + self.i64(var), var.dtype + ) + error_msg_mismatch = [ error_prefix_mismatch, *error_msg_context, @@ -1261,6 +1283,7 @@ class TVMFFIFunctionBuilder(TVMFFIBuilder): ) -> ir.Block: """Load the shape value from the argument or match the shape value from the parameter.""" field_name = arg_context.get_field_name(field_suffix) + arg_field_name = f"{field_name}[{shape_index}]" error_msg = [ field_name, f"[{shape_index}] ", @@ -1268,7 +1291,8 @@ class TVMFFIFunctionBuilder(TVMFFIBuilder): self._fn_call_context, ] return self.set_or_check_matched_var_binding( - current_block, var, value, error_msg, skip_check_predicate=skip_check_predicate + current_block, var, value, error_msg, arg_field_name, + skip_check_predicate=skip_check_predicate ) def decode_param_shape_from_ffi_array( @@ -1553,8 +1577,22 @@ class TVMFFIFunctionBuilder(TVMFFIBuilder): # store the matched values, these do not need constraint checks self.matched_var_binding[param.data] = data self.matched_var_source[param.data] = param.data - self.matched_var_binding[param.device_id] = device_id - self.matched_var_source[param.device_id] = param.device_id + self.matched_var_arg_field_name[param.data] = arg_context.get_field_name(".data") + + # check device_id constraint if user specifies a device_id variable + current_block = self.set_or_check_matched_var_binding( + current_block, + param.device_id, + device_id, + [ + "device index ", + *arg_context.get(), + self._fn_call_context, + ], + arg_context.get_field_name(".device.index"), + skip_cast_and_check=True, + ) + # check ndim expected_ndim = len(param.shape) # Break error message into reusable parts for better string deduplication @@ -1683,7 +1721,8 @@ class TVMFFIFunctionBuilder(TVMFFIBuilder): """Decode the stream parameter at the given index.""" # stream is decoded as opaque handle return self.decode_param_opaque_handle( - current_block, param.var, args, arg_index, arg_context + current_block, param.var, args, arg_index, arg_context, + allow_int_as_ptr=True ) def decode_param_data_pointer( @@ -1873,6 +1912,7 @@ class TVMFFIFunctionBuilder(TVMFFIBuilder): ) self.matched_var_binding[param.var] = env_stream self.matched_var_source[param.var] = env_stream + self.matched_var_arg_field_name[param.var] = param.name return current_block diff --git a/python/CuTeDSL/cutlass/cute/_tvm_ffi_args_spec_converter.py b/python/CuTeDSL/cutlass/cute/_tvm_ffi_args_spec_converter.py index d0499ab9c..02fae4825 100644 --- a/python/CuTeDSL/cutlass/cute/_tvm_ffi_args_spec_converter.py +++ b/python/CuTeDSL/cutlass/cute/_tvm_ffi_args_spec_converter.py @@ -115,7 +115,9 @@ class ConverterContext: def __init__(self): self.num_dyn_shape_vars = 0 self.num_dyn_stride_vars = 0 + self.num_device_id_vars = 0 self.sym_int_id_mapping = {} + self.vdevice_to_device_id_mapping = {} def alloc_shape_name(self) -> str: """Allocate a new dynamic shape variable name.""" @@ -143,6 +145,25 @@ class ConverterContext: self.sym_int_id_mapping[sym_int_id] = var return var + def alloc_or_reuse_device_id(self, device_type: str, vdevice_id: int) -> Optional[spec.Var]: + """Allocate or reuse a device_id variable for a given virtual device. + + This function returns None for CPU tensors. + """ + # Don't allocate device_id for CPU tensors + if device_type == "cpu": + return None + + vdevice_key = (device_type, vdevice_id) + if vdevice_key in self.vdevice_to_device_id_mapping: + return self.vdevice_to_device_id_mapping[vdevice_key] + + name = f"device_id{self.num_device_id_vars}" + self.num_device_id_vars += 1 + device_id_var = spec.Var(name, "int32") + self.vdevice_to_device_id_mapping[vdevice_key] = device_id_var + return device_id_var + def _convert_single_arg( arg, @@ -209,17 +230,28 @@ def _convert_single_arg( if hasattr(arg, "_tvm_ffi_tensor"): tvm_ffi_tensor = arg._tvm_ffi_tensor dtype = tvm_ffi_tensor.dtype + device_type = tvm_ffi_tensor.device.type + + # Allocate device_id (returns None for CPU tensors) + vdevice_id = tvm_ffi_tensor.device.index + device_id = ctx.alloc_or_reuse_device_id(device_type, vdevice_id) + tvm_ffi_cute_tensor = spec.Tensor( arg_name, shapes, arg._tvm_ffi_tensor.dtype, strides=strides, data_alignment=arg._assumed_align, - device_type=tvm_ffi_tensor.device.type + device_type=device_type, + device_id=device_id ) else: # for FakeTensor, strictly follow the shape and stride from the cute tensor device_type = "cuda" if _is_gpu_memspace(arg.memspace) else "cpu" + # Allocate device_id (returns None for CPU tensors) + vdevice_id = 0 # For now, use vdevice_id = 0 for all GPU tensors + device_id = ctx.alloc_or_reuse_device_id(device_type, vdevice_id) + tvm_ffi_cute_tensor = spec.Tensor( arg_name, shapes, @@ -227,6 +259,7 @@ def _convert_single_arg( strides=strides, data_alignment=arg._assumed_align, device_type=device_type, + device_id=device_id ) if arg.element_type == Float4E2M1FN: tvm_ffi_cute_tensor = spec.create_map_tensor_dtype_f4x2_to_f4_spec( diff --git a/python/CuTeDSL/cutlass/cute/runtime.py b/python/CuTeDSL/cutlass/cute/runtime.py index 19815b5cd..8cfbc72eb 100644 --- a/python/CuTeDSL/cutlass/cute/runtime.py +++ b/python/CuTeDSL/cutlass/cute/runtime.py @@ -515,6 +515,7 @@ class _FakeTensor(Tensor): when the dimension is dynamic. :type use_32bit_stride: bool, optional + """ def __init__(self, dtype, shape, *, stride, memspace=None, assumed_align=None): diff --git a/python/CuTeDSL/cutlass/cutlass_dsl/cuda_jit_executor.py b/python/CuTeDSL/cutlass/cutlass_dsl/cuda_jit_executor.py index 24405c866..b6a2435c9 100644 --- a/python/CuTeDSL/cutlass/cutlass_dsl/cuda_jit_executor.py +++ b/python/CuTeDSL/cutlass/cutlass_dsl/cuda_jit_executor.py @@ -55,6 +55,8 @@ class CudaDialectJitModule: for library in self.cuda_library: cuda_runtime.cudaLibraryUnload(library) self.cuda_library.clear() + except Exception as e: + pass finally: self._unloaded = True diff --git a/python/CuTeDSL/cutlass/utils/distributed_helpers.py b/python/CuTeDSL/cutlass/utils/distributed_helpers.py deleted file mode 100644 index 6e569e0c5..000000000 --- a/python/CuTeDSL/cutlass/utils/distributed_helpers.py +++ /dev/null @@ -1,208 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: LicenseRef-NvidiaProprietary -# -# Use of this software is governed by the terms and conditions of the -# NVIDIA End User License Agreement (EULA), available at: -# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html -# -# Any use, reproduction, disclosure, or distribution of this software -# and related documentation outside the scope permitted by the EULA -# is strictly prohibited. - -from functools import partial -from typing import Tuple - -import cutlass.cute as cute -from cutlass.cutlass_dsl import T, dsl_user_op - -from cutlass._mlir import ir -from cutlass._mlir.dialects import llvm, nvvm -from cutlass._mlir.dialects.nvvm import ( - MemOrderKind, - MemScopeKind, - AtomicOpKind, -) -from cutlass.cute.typing import Pointer, Int32 - - -@dsl_user_op -def atomicAdd(dst_ptr: Pointer, val: Int32, loc=None, ip=None) -> Int32: - return nvvm.atomicrmw( - T.i32(), - AtomicOpKind.ADD, - dst_ptr.llvm_ptr, - val.ir_value(loc=loc, ip=ip), - mem_order=MemOrderKind.RELAXED, - syncscope=MemScopeKind.SYS, - loc=loc, - ip=ip, - ) - - -@cute.jit -def ld_bypass(input_tensor: cute.Tensor): - fragment = cute.make_rmem_tensor(input_tensor.layout, input_tensor.element_type) - copy_atom_load = cute.make_copy_atom( - cute.nvgpu.CopyUniversalOp(), - input_tensor.element_type, - memory_order=cute.nvgpu.common.MemoryOrder.VOLATILE, - memory_scope=cute.nvgpu.common.MemoryScope.SYS, - ) - cute.copy_atom_call(copy_atom_load, input_tensor, fragment) - vals = fragment.load() - return vals - - -@cute.jit -def spin_lock_wait( - lock_ptr: Pointer, - expect_count: Int32, - mem_order: str = "relaxed", - mem_scope: str = "gpu", - loc=None, - ip=None, -) -> None: - """ - wait on a spin lock until the expected count is reached. - """ - res = 0 - while res != expect_count: - res = nvvm.atomicrmw( - T.i32(), - AtomicOpKind.CAS, - lock_ptr.llvm_ptr, - Int32(0).ir_value(loc=loc, ip=ip), - b=Int32(expect_count).ir_value(loc=loc, ip=ip), - mem_order=( - MemOrderKind.ACQUIRE if mem_order == "acquire" else MemOrderKind.RELAXED - ), - syncscope=MemScopeKind.GPU if mem_scope == "gpu" else MemScopeKind.SYS, - ) - - -@dsl_user_op -def multimem_red_add_sys_release(mc_ptr: Pointer, loc=None, ip=None) -> None: - """ - add 1 to the multimem address - """ - llvm.inline_asm( - None, - [mc_ptr.toint().ir_value(loc=loc, ip=ip)], - "multimem.red.release.sys.global.add.u32 [$0], 1;", - "l", - has_side_effects=True, - asm_dialect=0, - loc=loc, - ip=ip, - ) - - -@dsl_user_op -def multimem_red_add_gpu_relaxed(mc_ptr: Pointer, loc=None, ip=None) -> None: - """ - add 1 to the multimem address - """ - llvm.inline_asm( - None, - [mc_ptr.toint().ir_value(loc=loc, ip=ip)], - "multimem.red.relaxed.gpu.global.add.u32 [$0], 1;", - "l", - has_side_effects=True, - asm_dialect=0, - loc=loc, - ip=ip, - ) - - -def spin_lock_multimem_arrive(lock_ptr: Pointer, loc=None, ip=None) -> None: - """ - arrive a spin lock when the lock_ptr is a multimem address. - """ - multimem_red_add_gpu_relaxed(lock_ptr, loc=loc, ip=ip) - - -def sm_wise_inter_gpu_multimem_barrier( - barrier: Pointer, barrier_mc: Pointer, num_ranks, loc=None, ip=None -) -> None: - """ - barrier for inter-gpu sm-wise - """ - bidx, bidy, bidz = cute.arch.block_idx() - bdimx, bdimy, _ = cute.arch.grid_dim() - pid = bidx + bidy * bdimx + bidz * bdimx * bdimy - multimem_red_add_sys_release(barrier_mc + pid, loc=loc, ip=ip) - cute.arch.fence_proxy(cute.arch.ProxyKind.alias) - spin_lock_wait( - barrier + pid, num_ranks, mem_order="acquire", mem_scope="sys", loc=loc, ip=ip - ) - - -@dsl_user_op -def multimem_ld_reduce_base( - mc_ptr: Pointer, - *, - ptx_string: str = "", - loc=None, - ip=None, -) -> Tuple[Int32, Int32, Int32, Int32]: - # ld reduce 8xf16 elts - mc_ptr_int = mc_ptr.toint(loc=loc, ip=ip).ir_value(loc=loc, ip=ip) - return_struct = llvm.inline_asm( - ir.Type.parse("!llvm.struct<(i32,i32,i32,i32)>"), - [mc_ptr_int], - ptx_string, - "=r,=r,=r,=r,l", - has_side_effects=True, - asm_dialect=0, - loc=loc, - ip=ip, - ) - return_regs = [llvm.extractvalue(T.i32(), return_struct, [i]) for i in range(4)] - return return_regs[0], return_regs[1], return_regs[2], return_regs[3] - - -multimem_ld_reduce_8xf16 = partial( - multimem_ld_reduce_base, - ptx_string="multimem.ld_reduce.sys.relaxed.global.add.acc::f32.v4.f16x2 {$0, $1, $2, $3}, [$4];", -) -multimem_ld_reduce_4xf32 = partial( - multimem_ld_reduce_base, - ptx_string="multimem.ld_reduce.sys.relaxed.global.add.v4.f32 {$0, $1, $2, $3}, [$4];", -) -multimem_ld_reduce_8xbf16 = partial( - multimem_ld_reduce_base, - ptx_string="multimem.ld_reduce.sys.relaxed.global.add.acc::f32.v4.bf16x2 {$0, $1, $2, $3}, [$4];", -) -multimem_ld_reduce_16xe4m3 = partial( - multimem_ld_reduce_base, - ptx_string="multimem.ld_reduce.sys.relaxed.global.add.acc::f16.v4.e4m3x4 {$0, $1, $2, $3}, [$4];", -) -multimem_ld_reduce_16xe5m2 = partial( - multimem_ld_reduce_base, - ptx_string="multimem.ld_reduce.sys.relaxed.global.add.acc::f16.v4.e5m2x4 {$0, $1, $2, $3}, [$4];", -) - - -@dsl_user_op -def multimem_st_4xb32( - mc_ptr: Pointer, - x: Int32, - y: Int32, - z: Int32, - w: Int32, - *, - loc=None, - ip=None, -) -> None: - # st 4x32 bits of data - mc_ptr_int = mc_ptr.toint(loc=loc, ip=ip).ir_value(loc=loc, ip=ip) - llvm.inline_asm( - T.i32(), - [mc_ptr_int, x, y, z, w], - "multimem.st.sys.relaxed.global.v4.f32 [$1], {$2, $3, $4, $5};", - "=r,l,r,r,r,r", - has_side_effects=True, - asm_dialect=0, - loc=loc, - ip=ip, - ) diff --git a/python/CuTeDSL/cutlass/utils/smem_allocator.py b/python/CuTeDSL/cutlass/utils/smem_allocator.py index 7e801ddd3..fb140dbe0 100644 --- a/python/CuTeDSL/cutlass/utils/smem_allocator.py +++ b/python/CuTeDSL/cutlass/utils/smem_allocator.py @@ -14,10 +14,12 @@ import inspect import cutlass.cute as cute from cutlass.cute.arch import get_dyn_smem, get_dyn_smem_size -from cutlass.cutlass_dsl import CutlassBaseDSL, Int8, Numeric, NumericMeta, dsl_user_op +from cutlass.cutlass_dsl import CuTeDSL, Int8, Numeric, NumericMeta, dsl_user_op + SMEM_CAPACITY_MAP = { "sm_120": (100 - 1) * 1024, + "sm_103": (228 - 1) * 1024, "sm_100": (228 - 1) * 1024, "sm_90": (228 - 1) * 1024, "sm_80": (164 - 1) * 1024, @@ -71,7 +73,7 @@ class SmemAllocator: """ @staticmethod - def capacity_in_bytes(compute_capability: str) -> int: + def capacity_in_bytes(compute_capability: Optional[str] = None) -> int: """Get the shared memory capacity in bytes for a given compute capability. Returns the maximum shared memory capacity in bytes available for the specified @@ -83,6 +85,9 @@ class SmemAllocator: :rtype: int :raises ValueError: If the compute capability is not supported """ + if compute_capability is None: + arch = CuTeDSL._get_dsl().get_arch_enum() + compute_capability = f"sm_{arch.major}{arch.minor}" if compute_capability not in SMEM_CAPACITY_MAP: raise ValueError(f"Unsupported compute capability: {compute_capability}") return SMEM_CAPACITY_MAP[compute_capability] @@ -101,7 +106,7 @@ class SmemAllocator: """ self._base = get_dyn_smem(Int8, alignment=1024, loc=loc, ip=ip) self._allocated_bytes = 0 - CutlassBaseDSL.track_smem_allocator(self, lambda cls: cls._allocated_bytes) + CuTeDSL.track_smem_allocator(self, lambda cls: cls._allocated_bytes) @overload def allocate( diff --git a/python/cutlass_cppgen/__init__.py b/python/cutlass_cppgen/__init__.py index a7b9eb3d3..63cece3d1 100644 --- a/python/cutlass_cppgen/__init__.py +++ b/python/cutlass_cppgen/__init__.py @@ -133,7 +133,7 @@ def get_option_registry(): this._option_registry = OptionRegistry(device_cc()) return this._option_registry -this.__version__ = '4.3.1' +this.__version__ = '4.3.2' from cutlass_cppgen.backend import create_memory_pool from cutlass_cppgen.emit.pytorch import pytorch diff --git a/python/setup_cutlass.py b/python/setup_cutlass.py index 097d89fdd..faa20bfb2 100644 --- a/python/setup_cutlass.py +++ b/python/setup_cutlass.py @@ -51,7 +51,7 @@ setup_pycute.perform_setup() setup( name='cutlass_cppgen', - version='4.3.1', + version='4.3.2', description='CUTLASS Pythonic Interface', package_dir={'': '.'}, packages=[ diff --git a/python/setup_pycute.py b/python/setup_pycute.py index d642b0afa..671afe050 100644 --- a/python/setup_pycute.py +++ b/python/setup_pycute.py @@ -36,7 +36,7 @@ from setuptools import setup def perform_setup(): setup( name='pycute', - version='4.3.1', + version='4.3.2', description='Python implementation of CuTe', packages=['pycute'], )