v4.3.2 update. (#2840)

This commit is contained in:
Junkai-Wu
2025-12-04 23:14:50 +08:00
committed by GitHub
parent f88806b1e3
commit ff35fa561d
19 changed files with 164 additions and 257 deletions

View File

@@ -22,6 +22,7 @@ import time
from pathlib import Path
import hashlib
from functools import lru_cache
import tempfile
from .utils.logger import log
from .jit_executor import JitCompiledFunction
@@ -46,15 +47,23 @@ def get_current_user():
# default_generated_ir_path is the path to the cache directory.
# It is set to /tmp/{user}/cutlass_python_cache/ by default.
# If the user is not found, the default path is used or /tmp/cutlass_python_cache/ is used.
try:
default_generated_ir_path = f"/tmp/{get_current_user()}/cutlass_python_cache/"
except Exception as e:
# If all else fails, provide a default fallback path
default_generated_ir_path = "/tmp/cutlass_python_cache/"
print(f"Could not determine user, using default path. Error: {e}")
# If `CUTE_DSL_CACHE_DIR` is set, it is used as the cache directory.
# Otherwise, it is set to a directory controled by TMPDIR defaulting
# to /tmp/${USER}/cutlass_python_cache.
if not (default_generated_ir_path := os.getenv("CUTE_DSL_CACHE_DIR", None)):
tmp_dir = Path(os.environ.get("TMPDIR", tempfile.gettempdir()))
def get_reusable_temp_dir(name):
path = tmp_dir / f"{get_current_user()}/{name}"
path.mkdir(parents=True, exist_ok=True)
return str(path)
try:
default_generated_ir_path = get_reusable_temp_dir("cutlass_python_cache")
except Exception as e:
default_generated_ir_path = str(tmp_dir / "cutlass_python_cache")
print(f"Could not determine user, using default path. Error: {e}")
@lru_cache(maxsize=1)
def get_default_file_dump_root():
@@ -223,6 +232,8 @@ def dump_cache_to_path(
:type bytecode_writer: callable, optional
"""
log().info("JIT cache : dumping [%s] items=[%s]", dsl_name, len(jit_cache))
if not path:
path = default_generated_ir_path
os.makedirs(path, exist_ok=True)
try:
for idx, [key, value] in enumerate(jit_cache.items()):

View File

@@ -372,10 +372,10 @@ class BaseDSL:
atexit.register(restore_excepthook, origin_excepthook)
def dump_cache(self):
def dump_cache(self, path=None):
if not self.envar.disable_file_caching:
dump_cache_to_path(
self.name, self.jit_cache, self.envar.file_caching_capacity
self.name, self.jit_cache, self.envar.file_caching_capacity, path=path
)
@lru_cache(maxsize=1)

View File

@@ -296,6 +296,7 @@ class EnvironmentVarManager(LogEnvironmentManager):
- [DSL_NAME]_FILTER_STACKTRACE: Filter internal stacktrace (default: True)
File options:
- [DSL_NAME]_DUMP_DIR: Directory to dump the generated files (default: current working directory)
- [DSL_NAME]_CACHE_DIR: Cache directory (default: /tmp/{dsl_name}_python_cache_{tmpfile_suffix})
- [DSL_NAME]_KEEP_IR: Save generated IR in a file (default: False)
- [DSL_NAME]_KEEP_PTX: Save generated PTX in a file (default: False)
- [DSL_NAME]_KEEP_CUBIN: Save generated CUBIN in a file (default: False)
@@ -333,6 +334,7 @@ class EnvironmentVarManager(LogEnvironmentManager):
# File options
self.keep_ir = get_bool_env_var(f"{prefix}_KEEP_IR", False)
self.cache_dir = get_str_env_var(f"{prefix}_CACHE_DIR", None)
# Other options
self.dryrun = get_bool_env_var(f"{prefix}_DRYRUN", False)
self.arch = get_str_env_var(f"{prefix}_ARCH", detect_gpu_arch(prefix))

View File

@@ -192,6 +192,7 @@ class Tensor(Param):
dtype: Union[str, "tvm_ffi.dtype"],
*,
device_type: Optional[str] = None,
device_id: Optional[Var] = None,
strides: Optional[Sequence[Var]] = None,
map_tensor_dtype_f4x2_to_f4: bool = False,
data_alignment: Optional[int] = None,
@@ -229,7 +230,10 @@ class Tensor(Param):
example_device = tvm_ffi.device(device_type, 0)
self.dlpack_device_type = example_device.dlpack_device_type()
self.device_type_name = example_device.type
self.device_id = Var(name + ".device_id", tvm_ffi.dtype("int32"))
if device_id is None:
self.device_id = Var(name + ".device.index", tvm_ffi.dtype("int32"))
else:
self.device_id = device_id
self.map_tensor_dtype_f4x2_to_f4 = map_tensor_dtype_f4x2_to_f4

View File

@@ -818,6 +818,7 @@ class TVMFFIFunctionBuilder(TVMFFIBuilder):
_fn_call_context: str
matched_var_binding: dict[spec.Var, ir.Value]
matched_var_source: dict[spec.Var, ir.Value]
matched_var_arg_field_name: dict[spec.Var, str]
def __init__(self, module: ir.Module) -> None:
super().__init__()
@@ -826,6 +827,7 @@ class TVMFFIFunctionBuilder(TVMFFIBuilder):
self._fn_call_context: str = ""
self.matched_var_binding = {}
self.matched_var_source = {}
self.matched_var_arg_field_name = {}
def find_or_declare_extern_func(
self, name: str, params: Sequence[ir.Type], ret: ir.Type
@@ -897,6 +899,7 @@ class TVMFFIFunctionBuilder(TVMFFIBuilder):
*arg_context.get(),
self._fn_call_context,
],
arg_context.get_field_name(""),
)
def decode_param_float(
@@ -1000,6 +1003,7 @@ class TVMFFIFunctionBuilder(TVMFFIBuilder):
result = result_block.arguments[0]
self.matched_var_binding[param] = result
self.matched_var_source[param] = v_float64
self.matched_var_arg_field_name[param] = arg_context.get_field_name("")
return result_block
@@ -1054,6 +1058,7 @@ class TVMFFIFunctionBuilder(TVMFFIBuilder):
# For opaque handles, we store the pointer directly
self.matched_var_binding[param] = v_ptr
self.matched_var_source[param] = v_ptr
self.matched_var_arg_field_name[param] = arg_context.get_field_name("")
return current_block
@@ -1191,8 +1196,10 @@ class TVMFFIFunctionBuilder(TVMFFIBuilder):
var: Union[spec.Var, int],
value: ir.Value,
error_msg_context: list[str],
arg_field_name: str,
*,
skip_check_predicate: Optional[ir.Value] = None,
skip_cast_and_check: bool = False,
) -> ir.Block:
"""Set or check the matched var binding."""
error_kind = "ValueError"
@@ -1202,33 +1209,48 @@ class TVMFFIFunctionBuilder(TVMFFIBuilder):
if isinstance(var, spec.Var):
# if var contains llvm_value and is not populated, populate it
if var not in self.matched_var_binding:
current_block = self.check_int_value_dtype_bound(
current_block, value, var.dtype, error_msg_context
)
# check divisibility if specified
if var.divisibility is not None:
current_block = self.check_int_value_divisibility(
current_block, value, var.divisibility, error_msg_context,
skip_check_predicate=skip_check_predicate,
)
# store the source value with parameter info
with ir.InsertionPoint(current_block):
self.matched_var_source[var] = value
self.matched_var_binding[var] = self.downcast_i64_to_lower_bits(
value, var.dtype
if not skip_cast_and_check:
current_block = self.check_int_value_dtype_bound(
current_block, value, var.dtype, error_msg_context
)
# check divisibility if specified
if var.divisibility is not None:
current_block = self.check_int_value_divisibility(
current_block, value, var.divisibility, error_msg_context,
skip_check_predicate=skip_check_predicate,
)
# store the source value with parameter info
with ir.InsertionPoint(current_block):
target_value = self.downcast_i64_to_lower_bits(
value, var.dtype
)
else:
target_value = value
# store the source value
self.matched_var_source[var] = value
# store the target value (casted to target dtype aleady)
self.matched_var_binding[var] = target_value
# store arg_field_name
self.matched_var_arg_field_name[var] = arg_field_name
return current_block
# otherwise, it appears more than once, we need to check if the value matches
expected_value = self.matched_var_source[var]
prev_arg_field_name = self.matched_var_arg_field_name[var]
error_msg_mismatch = [
error_prefix_mismatch,
*error_msg_context,
", symbolic constraint violated"
f", expected to match {prev_arg_field_name}",
]
else:
assert isinstance(var, int)
with ir.InsertionPoint(current_block):
expected_value = self.i64(var)
if not skip_cast_and_check:
expected_value = self.i64(var)
else:
expected_value = self.downcast_i64_to_lower_bits(
self.i64(var), var.dtype
)
error_msg_mismatch = [
error_prefix_mismatch,
*error_msg_context,
@@ -1261,6 +1283,7 @@ class TVMFFIFunctionBuilder(TVMFFIBuilder):
) -> ir.Block:
"""Load the shape value from the argument or match the shape value from the parameter."""
field_name = arg_context.get_field_name(field_suffix)
arg_field_name = f"{field_name}[{shape_index}]"
error_msg = [
field_name,
f"[{shape_index}] ",
@@ -1268,7 +1291,8 @@ class TVMFFIFunctionBuilder(TVMFFIBuilder):
self._fn_call_context,
]
return self.set_or_check_matched_var_binding(
current_block, var, value, error_msg, skip_check_predicate=skip_check_predicate
current_block, var, value, error_msg, arg_field_name,
skip_check_predicate=skip_check_predicate
)
def decode_param_shape_from_ffi_array(
@@ -1553,8 +1577,22 @@ class TVMFFIFunctionBuilder(TVMFFIBuilder):
# store the matched values, these do not need constraint checks
self.matched_var_binding[param.data] = data
self.matched_var_source[param.data] = param.data
self.matched_var_binding[param.device_id] = device_id
self.matched_var_source[param.device_id] = param.device_id
self.matched_var_arg_field_name[param.data] = arg_context.get_field_name(".data")
# check device_id constraint if user specifies a device_id variable
current_block = self.set_or_check_matched_var_binding(
current_block,
param.device_id,
device_id,
[
"device index ",
*arg_context.get(),
self._fn_call_context,
],
arg_context.get_field_name(".device.index"),
skip_cast_and_check=True,
)
# check ndim
expected_ndim = len(param.shape)
# Break error message into reusable parts for better string deduplication
@@ -1683,7 +1721,8 @@ class TVMFFIFunctionBuilder(TVMFFIBuilder):
"""Decode the stream parameter at the given index."""
# stream is decoded as opaque handle
return self.decode_param_opaque_handle(
current_block, param.var, args, arg_index, arg_context
current_block, param.var, args, arg_index, arg_context,
allow_int_as_ptr=True
)
def decode_param_data_pointer(
@@ -1873,6 +1912,7 @@ class TVMFFIFunctionBuilder(TVMFFIBuilder):
)
self.matched_var_binding[param.var] = env_stream
self.matched_var_source[param.var] = env_stream
self.matched_var_arg_field_name[param.var] = param.name
return current_block

View File

@@ -115,7 +115,9 @@ class ConverterContext:
def __init__(self):
self.num_dyn_shape_vars = 0
self.num_dyn_stride_vars = 0
self.num_device_id_vars = 0
self.sym_int_id_mapping = {}
self.vdevice_to_device_id_mapping = {}
def alloc_shape_name(self) -> str:
"""Allocate a new dynamic shape variable name."""
@@ -143,6 +145,25 @@ class ConverterContext:
self.sym_int_id_mapping[sym_int_id] = var
return var
def alloc_or_reuse_device_id(self, device_type: str, vdevice_id: int) -> Optional[spec.Var]:
"""Allocate or reuse a device_id variable for a given virtual device.
This function returns None for CPU tensors.
"""
# Don't allocate device_id for CPU tensors
if device_type == "cpu":
return None
vdevice_key = (device_type, vdevice_id)
if vdevice_key in self.vdevice_to_device_id_mapping:
return self.vdevice_to_device_id_mapping[vdevice_key]
name = f"device_id{self.num_device_id_vars}"
self.num_device_id_vars += 1
device_id_var = spec.Var(name, "int32")
self.vdevice_to_device_id_mapping[vdevice_key] = device_id_var
return device_id_var
def _convert_single_arg(
arg,
@@ -209,17 +230,28 @@ def _convert_single_arg(
if hasattr(arg, "_tvm_ffi_tensor"):
tvm_ffi_tensor = arg._tvm_ffi_tensor
dtype = tvm_ffi_tensor.dtype
device_type = tvm_ffi_tensor.device.type
# Allocate device_id (returns None for CPU tensors)
vdevice_id = tvm_ffi_tensor.device.index
device_id = ctx.alloc_or_reuse_device_id(device_type, vdevice_id)
tvm_ffi_cute_tensor = spec.Tensor(
arg_name,
shapes,
arg._tvm_ffi_tensor.dtype,
strides=strides,
data_alignment=arg._assumed_align,
device_type=tvm_ffi_tensor.device.type
device_type=device_type,
device_id=device_id
)
else:
# for FakeTensor, strictly follow the shape and stride from the cute tensor
device_type = "cuda" if _is_gpu_memspace(arg.memspace) else "cpu"
# Allocate device_id (returns None for CPU tensors)
vdevice_id = 0 # For now, use vdevice_id = 0 for all GPU tensors
device_id = ctx.alloc_or_reuse_device_id(device_type, vdevice_id)
tvm_ffi_cute_tensor = spec.Tensor(
arg_name,
shapes,
@@ -227,6 +259,7 @@ def _convert_single_arg(
strides=strides,
data_alignment=arg._assumed_align,
device_type=device_type,
device_id=device_id
)
if arg.element_type == Float4E2M1FN:
tvm_ffi_cute_tensor = spec.create_map_tensor_dtype_f4x2_to_f4_spec(

View File

@@ -515,6 +515,7 @@ class _FakeTensor(Tensor):
when the dimension is dynamic.
:type use_32bit_stride: bool, optional
"""
def __init__(self, dtype, shape, *, stride, memspace=None, assumed_align=None):

View File

@@ -55,6 +55,8 @@ class CudaDialectJitModule:
for library in self.cuda_library:
cuda_runtime.cudaLibraryUnload(library)
self.cuda_library.clear()
except Exception as e:
pass
finally:
self._unloaded = True

View File

@@ -1,208 +0,0 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
#
# Use of this software is governed by the terms and conditions of the
# NVIDIA End User License Agreement (EULA), available at:
# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html
#
# Any use, reproduction, disclosure, or distribution of this software
# and related documentation outside the scope permitted by the EULA
# is strictly prohibited.
from functools import partial
from typing import Tuple
import cutlass.cute as cute
from cutlass.cutlass_dsl import T, dsl_user_op
from cutlass._mlir import ir
from cutlass._mlir.dialects import llvm, nvvm
from cutlass._mlir.dialects.nvvm import (
MemOrderKind,
MemScopeKind,
AtomicOpKind,
)
from cutlass.cute.typing import Pointer, Int32
@dsl_user_op
def atomicAdd(dst_ptr: Pointer, val: Int32, loc=None, ip=None) -> Int32:
return nvvm.atomicrmw(
T.i32(),
AtomicOpKind.ADD,
dst_ptr.llvm_ptr,
val.ir_value(loc=loc, ip=ip),
mem_order=MemOrderKind.RELAXED,
syncscope=MemScopeKind.SYS,
loc=loc,
ip=ip,
)
@cute.jit
def ld_bypass(input_tensor: cute.Tensor):
fragment = cute.make_rmem_tensor(input_tensor.layout, input_tensor.element_type)
copy_atom_load = cute.make_copy_atom(
cute.nvgpu.CopyUniversalOp(),
input_tensor.element_type,
memory_order=cute.nvgpu.common.MemoryOrder.VOLATILE,
memory_scope=cute.nvgpu.common.MemoryScope.SYS,
)
cute.copy_atom_call(copy_atom_load, input_tensor, fragment)
vals = fragment.load()
return vals
@cute.jit
def spin_lock_wait(
lock_ptr: Pointer,
expect_count: Int32,
mem_order: str = "relaxed",
mem_scope: str = "gpu",
loc=None,
ip=None,
) -> None:
"""
wait on a spin lock until the expected count is reached.
"""
res = 0
while res != expect_count:
res = nvvm.atomicrmw(
T.i32(),
AtomicOpKind.CAS,
lock_ptr.llvm_ptr,
Int32(0).ir_value(loc=loc, ip=ip),
b=Int32(expect_count).ir_value(loc=loc, ip=ip),
mem_order=(
MemOrderKind.ACQUIRE if mem_order == "acquire" else MemOrderKind.RELAXED
),
syncscope=MemScopeKind.GPU if mem_scope == "gpu" else MemScopeKind.SYS,
)
@dsl_user_op
def multimem_red_add_sys_release(mc_ptr: Pointer, loc=None, ip=None) -> None:
"""
add 1 to the multimem address
"""
llvm.inline_asm(
None,
[mc_ptr.toint().ir_value(loc=loc, ip=ip)],
"multimem.red.release.sys.global.add.u32 [$0], 1;",
"l",
has_side_effects=True,
asm_dialect=0,
loc=loc,
ip=ip,
)
@dsl_user_op
def multimem_red_add_gpu_relaxed(mc_ptr: Pointer, loc=None, ip=None) -> None:
"""
add 1 to the multimem address
"""
llvm.inline_asm(
None,
[mc_ptr.toint().ir_value(loc=loc, ip=ip)],
"multimem.red.relaxed.gpu.global.add.u32 [$0], 1;",
"l",
has_side_effects=True,
asm_dialect=0,
loc=loc,
ip=ip,
)
def spin_lock_multimem_arrive(lock_ptr: Pointer, loc=None, ip=None) -> None:
"""
arrive a spin lock when the lock_ptr is a multimem address.
"""
multimem_red_add_gpu_relaxed(lock_ptr, loc=loc, ip=ip)
def sm_wise_inter_gpu_multimem_barrier(
barrier: Pointer, barrier_mc: Pointer, num_ranks, loc=None, ip=None
) -> None:
"""
barrier for inter-gpu sm-wise
"""
bidx, bidy, bidz = cute.arch.block_idx()
bdimx, bdimy, _ = cute.arch.grid_dim()
pid = bidx + bidy * bdimx + bidz * bdimx * bdimy
multimem_red_add_sys_release(barrier_mc + pid, loc=loc, ip=ip)
cute.arch.fence_proxy(cute.arch.ProxyKind.alias)
spin_lock_wait(
barrier + pid, num_ranks, mem_order="acquire", mem_scope="sys", loc=loc, ip=ip
)
@dsl_user_op
def multimem_ld_reduce_base(
mc_ptr: Pointer,
*,
ptx_string: str = "",
loc=None,
ip=None,
) -> Tuple[Int32, Int32, Int32, Int32]:
# ld reduce 8xf16 elts
mc_ptr_int = mc_ptr.toint(loc=loc, ip=ip).ir_value(loc=loc, ip=ip)
return_struct = llvm.inline_asm(
ir.Type.parse("!llvm.struct<(i32,i32,i32,i32)>"),
[mc_ptr_int],
ptx_string,
"=r,=r,=r,=r,l",
has_side_effects=True,
asm_dialect=0,
loc=loc,
ip=ip,
)
return_regs = [llvm.extractvalue(T.i32(), return_struct, [i]) for i in range(4)]
return return_regs[0], return_regs[1], return_regs[2], return_regs[3]
multimem_ld_reduce_8xf16 = partial(
multimem_ld_reduce_base,
ptx_string="multimem.ld_reduce.sys.relaxed.global.add.acc::f32.v4.f16x2 {$0, $1, $2, $3}, [$4];",
)
multimem_ld_reduce_4xf32 = partial(
multimem_ld_reduce_base,
ptx_string="multimem.ld_reduce.sys.relaxed.global.add.v4.f32 {$0, $1, $2, $3}, [$4];",
)
multimem_ld_reduce_8xbf16 = partial(
multimem_ld_reduce_base,
ptx_string="multimem.ld_reduce.sys.relaxed.global.add.acc::f32.v4.bf16x2 {$0, $1, $2, $3}, [$4];",
)
multimem_ld_reduce_16xe4m3 = partial(
multimem_ld_reduce_base,
ptx_string="multimem.ld_reduce.sys.relaxed.global.add.acc::f16.v4.e4m3x4 {$0, $1, $2, $3}, [$4];",
)
multimem_ld_reduce_16xe5m2 = partial(
multimem_ld_reduce_base,
ptx_string="multimem.ld_reduce.sys.relaxed.global.add.acc::f16.v4.e5m2x4 {$0, $1, $2, $3}, [$4];",
)
@dsl_user_op
def multimem_st_4xb32(
mc_ptr: Pointer,
x: Int32,
y: Int32,
z: Int32,
w: Int32,
*,
loc=None,
ip=None,
) -> None:
# st 4x32 bits of data
mc_ptr_int = mc_ptr.toint(loc=loc, ip=ip).ir_value(loc=loc, ip=ip)
llvm.inline_asm(
T.i32(),
[mc_ptr_int, x, y, z, w],
"multimem.st.sys.relaxed.global.v4.f32 [$1], {$2, $3, $4, $5};",
"=r,l,r,r,r,r",
has_side_effects=True,
asm_dialect=0,
loc=loc,
ip=ip,
)

View File

@@ -14,10 +14,12 @@ import inspect
import cutlass.cute as cute
from cutlass.cute.arch import get_dyn_smem, get_dyn_smem_size
from cutlass.cutlass_dsl import CutlassBaseDSL, Int8, Numeric, NumericMeta, dsl_user_op
from cutlass.cutlass_dsl import CuTeDSL, Int8, Numeric, NumericMeta, dsl_user_op
SMEM_CAPACITY_MAP = {
"sm_120": (100 - 1) * 1024,
"sm_103": (228 - 1) * 1024,
"sm_100": (228 - 1) * 1024,
"sm_90": (228 - 1) * 1024,
"sm_80": (164 - 1) * 1024,
@@ -71,7 +73,7 @@ class SmemAllocator:
"""
@staticmethod
def capacity_in_bytes(compute_capability: str) -> int:
def capacity_in_bytes(compute_capability: Optional[str] = None) -> int:
"""Get the shared memory capacity in bytes for a given compute capability.
Returns the maximum shared memory capacity in bytes available for the specified
@@ -83,6 +85,9 @@ class SmemAllocator:
:rtype: int
:raises ValueError: If the compute capability is not supported
"""
if compute_capability is None:
arch = CuTeDSL._get_dsl().get_arch_enum()
compute_capability = f"sm_{arch.major}{arch.minor}"
if compute_capability not in SMEM_CAPACITY_MAP:
raise ValueError(f"Unsupported compute capability: {compute_capability}")
return SMEM_CAPACITY_MAP[compute_capability]
@@ -101,7 +106,7 @@ class SmemAllocator:
"""
self._base = get_dyn_smem(Int8, alignment=1024, loc=loc, ip=ip)
self._allocated_bytes = 0
CutlassBaseDSL.track_smem_allocator(self, lambda cls: cls._allocated_bytes)
CuTeDSL.track_smem_allocator(self, lambda cls: cls._allocated_bytes)
@overload
def allocate(

View File

@@ -133,7 +133,7 @@ def get_option_registry():
this._option_registry = OptionRegistry(device_cc())
return this._option_registry
this.__version__ = '4.3.1'
this.__version__ = '4.3.2'
from cutlass_cppgen.backend import create_memory_pool
from cutlass_cppgen.emit.pytorch import pytorch

View File

@@ -51,7 +51,7 @@ setup_pycute.perform_setup()
setup(
name='cutlass_cppgen',
version='4.3.1',
version='4.3.2',
description='CUTLASS Pythonic Interface',
package_dir={'': '.'},
packages=[

View File

@@ -36,7 +36,7 @@ from setuptools import setup
def perform_setup():
setup(
name='pycute',
version='4.3.1',
version='4.3.2',
description='Python implementation of CuTe',
packages=['pycute'],
)