mirror of
https://github.com/NVIDIA/cutlass.git
synced 2026-06-29 10:57:06 +00:00
v4.3.2 update. (#2840)
This commit is contained in:
@@ -22,6 +22,7 @@ import time
|
||||
from pathlib import Path
|
||||
import hashlib
|
||||
from functools import lru_cache
|
||||
import tempfile
|
||||
|
||||
from .utils.logger import log
|
||||
from .jit_executor import JitCompiledFunction
|
||||
@@ -46,15 +47,23 @@ def get_current_user():
|
||||
|
||||
|
||||
# default_generated_ir_path is the path to the cache directory.
|
||||
# It is set to /tmp/{user}/cutlass_python_cache/ by default.
|
||||
# If the user is not found, the default path is used or /tmp/cutlass_python_cache/ is used.
|
||||
try:
|
||||
default_generated_ir_path = f"/tmp/{get_current_user()}/cutlass_python_cache/"
|
||||
except Exception as e:
|
||||
# If all else fails, provide a default fallback path
|
||||
default_generated_ir_path = "/tmp/cutlass_python_cache/"
|
||||
print(f"Could not determine user, using default path. Error: {e}")
|
||||
# If `CUTE_DSL_CACHE_DIR` is set, it is used as the cache directory.
|
||||
# Otherwise, it is set to a directory controled by TMPDIR defaulting
|
||||
# to /tmp/${USER}/cutlass_python_cache.
|
||||
if not (default_generated_ir_path := os.getenv("CUTE_DSL_CACHE_DIR", None)):
|
||||
|
||||
tmp_dir = Path(os.environ.get("TMPDIR", tempfile.gettempdir()))
|
||||
|
||||
def get_reusable_temp_dir(name):
|
||||
path = tmp_dir / f"{get_current_user()}/{name}"
|
||||
path.mkdir(parents=True, exist_ok=True)
|
||||
return str(path)
|
||||
|
||||
try:
|
||||
default_generated_ir_path = get_reusable_temp_dir("cutlass_python_cache")
|
||||
except Exception as e:
|
||||
default_generated_ir_path = str(tmp_dir / "cutlass_python_cache")
|
||||
print(f"Could not determine user, using default path. Error: {e}")
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def get_default_file_dump_root():
|
||||
@@ -223,6 +232,8 @@ def dump_cache_to_path(
|
||||
:type bytecode_writer: callable, optional
|
||||
"""
|
||||
log().info("JIT cache : dumping [%s] items=[%s]", dsl_name, len(jit_cache))
|
||||
if not path:
|
||||
path = default_generated_ir_path
|
||||
os.makedirs(path, exist_ok=True)
|
||||
try:
|
||||
for idx, [key, value] in enumerate(jit_cache.items()):
|
||||
|
||||
@@ -372,10 +372,10 @@ class BaseDSL:
|
||||
|
||||
atexit.register(restore_excepthook, origin_excepthook)
|
||||
|
||||
def dump_cache(self):
|
||||
def dump_cache(self, path=None):
|
||||
if not self.envar.disable_file_caching:
|
||||
dump_cache_to_path(
|
||||
self.name, self.jit_cache, self.envar.file_caching_capacity
|
||||
self.name, self.jit_cache, self.envar.file_caching_capacity, path=path
|
||||
)
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
|
||||
@@ -296,6 +296,7 @@ class EnvironmentVarManager(LogEnvironmentManager):
|
||||
- [DSL_NAME]_FILTER_STACKTRACE: Filter internal stacktrace (default: True)
|
||||
File options:
|
||||
- [DSL_NAME]_DUMP_DIR: Directory to dump the generated files (default: current working directory)
|
||||
- [DSL_NAME]_CACHE_DIR: Cache directory (default: /tmp/{dsl_name}_python_cache_{tmpfile_suffix})
|
||||
- [DSL_NAME]_KEEP_IR: Save generated IR in a file (default: False)
|
||||
- [DSL_NAME]_KEEP_PTX: Save generated PTX in a file (default: False)
|
||||
- [DSL_NAME]_KEEP_CUBIN: Save generated CUBIN in a file (default: False)
|
||||
@@ -333,6 +334,7 @@ class EnvironmentVarManager(LogEnvironmentManager):
|
||||
|
||||
# File options
|
||||
self.keep_ir = get_bool_env_var(f"{prefix}_KEEP_IR", False)
|
||||
self.cache_dir = get_str_env_var(f"{prefix}_CACHE_DIR", None)
|
||||
# Other options
|
||||
self.dryrun = get_bool_env_var(f"{prefix}_DRYRUN", False)
|
||||
self.arch = get_str_env_var(f"{prefix}_ARCH", detect_gpu_arch(prefix))
|
||||
|
||||
@@ -192,6 +192,7 @@ class Tensor(Param):
|
||||
dtype: Union[str, "tvm_ffi.dtype"],
|
||||
*,
|
||||
device_type: Optional[str] = None,
|
||||
device_id: Optional[Var] = None,
|
||||
strides: Optional[Sequence[Var]] = None,
|
||||
map_tensor_dtype_f4x2_to_f4: bool = False,
|
||||
data_alignment: Optional[int] = None,
|
||||
@@ -229,7 +230,10 @@ class Tensor(Param):
|
||||
example_device = tvm_ffi.device(device_type, 0)
|
||||
self.dlpack_device_type = example_device.dlpack_device_type()
|
||||
self.device_type_name = example_device.type
|
||||
self.device_id = Var(name + ".device_id", tvm_ffi.dtype("int32"))
|
||||
if device_id is None:
|
||||
self.device_id = Var(name + ".device.index", tvm_ffi.dtype("int32"))
|
||||
else:
|
||||
self.device_id = device_id
|
||||
self.map_tensor_dtype_f4x2_to_f4 = map_tensor_dtype_f4x2_to_f4
|
||||
|
||||
|
||||
|
||||
@@ -818,6 +818,7 @@ class TVMFFIFunctionBuilder(TVMFFIBuilder):
|
||||
_fn_call_context: str
|
||||
matched_var_binding: dict[spec.Var, ir.Value]
|
||||
matched_var_source: dict[spec.Var, ir.Value]
|
||||
matched_var_arg_field_name: dict[spec.Var, str]
|
||||
|
||||
def __init__(self, module: ir.Module) -> None:
|
||||
super().__init__()
|
||||
@@ -826,6 +827,7 @@ class TVMFFIFunctionBuilder(TVMFFIBuilder):
|
||||
self._fn_call_context: str = ""
|
||||
self.matched_var_binding = {}
|
||||
self.matched_var_source = {}
|
||||
self.matched_var_arg_field_name = {}
|
||||
|
||||
def find_or_declare_extern_func(
|
||||
self, name: str, params: Sequence[ir.Type], ret: ir.Type
|
||||
@@ -897,6 +899,7 @@ class TVMFFIFunctionBuilder(TVMFFIBuilder):
|
||||
*arg_context.get(),
|
||||
self._fn_call_context,
|
||||
],
|
||||
arg_context.get_field_name(""),
|
||||
)
|
||||
|
||||
def decode_param_float(
|
||||
@@ -1000,6 +1003,7 @@ class TVMFFIFunctionBuilder(TVMFFIBuilder):
|
||||
result = result_block.arguments[0]
|
||||
self.matched_var_binding[param] = result
|
||||
self.matched_var_source[param] = v_float64
|
||||
self.matched_var_arg_field_name[param] = arg_context.get_field_name("")
|
||||
|
||||
return result_block
|
||||
|
||||
@@ -1054,6 +1058,7 @@ class TVMFFIFunctionBuilder(TVMFFIBuilder):
|
||||
# For opaque handles, we store the pointer directly
|
||||
self.matched_var_binding[param] = v_ptr
|
||||
self.matched_var_source[param] = v_ptr
|
||||
self.matched_var_arg_field_name[param] = arg_context.get_field_name("")
|
||||
|
||||
return current_block
|
||||
|
||||
@@ -1191,8 +1196,10 @@ class TVMFFIFunctionBuilder(TVMFFIBuilder):
|
||||
var: Union[spec.Var, int],
|
||||
value: ir.Value,
|
||||
error_msg_context: list[str],
|
||||
arg_field_name: str,
|
||||
*,
|
||||
skip_check_predicate: Optional[ir.Value] = None,
|
||||
skip_cast_and_check: bool = False,
|
||||
) -> ir.Block:
|
||||
"""Set or check the matched var binding."""
|
||||
error_kind = "ValueError"
|
||||
@@ -1202,33 +1209,48 @@ class TVMFFIFunctionBuilder(TVMFFIBuilder):
|
||||
if isinstance(var, spec.Var):
|
||||
# if var contains llvm_value and is not populated, populate it
|
||||
if var not in self.matched_var_binding:
|
||||
current_block = self.check_int_value_dtype_bound(
|
||||
current_block, value, var.dtype, error_msg_context
|
||||
)
|
||||
# check divisibility if specified
|
||||
if var.divisibility is not None:
|
||||
current_block = self.check_int_value_divisibility(
|
||||
current_block, value, var.divisibility, error_msg_context,
|
||||
skip_check_predicate=skip_check_predicate,
|
||||
)
|
||||
# store the source value with parameter info
|
||||
with ir.InsertionPoint(current_block):
|
||||
self.matched_var_source[var] = value
|
||||
self.matched_var_binding[var] = self.downcast_i64_to_lower_bits(
|
||||
value, var.dtype
|
||||
if not skip_cast_and_check:
|
||||
current_block = self.check_int_value_dtype_bound(
|
||||
current_block, value, var.dtype, error_msg_context
|
||||
)
|
||||
# check divisibility if specified
|
||||
if var.divisibility is not None:
|
||||
current_block = self.check_int_value_divisibility(
|
||||
current_block, value, var.divisibility, error_msg_context,
|
||||
skip_check_predicate=skip_check_predicate,
|
||||
)
|
||||
# store the source value with parameter info
|
||||
with ir.InsertionPoint(current_block):
|
||||
target_value = self.downcast_i64_to_lower_bits(
|
||||
value, var.dtype
|
||||
)
|
||||
else:
|
||||
target_value = value
|
||||
# store the source value
|
||||
self.matched_var_source[var] = value
|
||||
# store the target value (casted to target dtype aleady)
|
||||
self.matched_var_binding[var] = target_value
|
||||
# store arg_field_name
|
||||
self.matched_var_arg_field_name[var] = arg_field_name
|
||||
return current_block
|
||||
# otherwise, it appears more than once, we need to check if the value matches
|
||||
expected_value = self.matched_var_source[var]
|
||||
prev_arg_field_name = self.matched_var_arg_field_name[var]
|
||||
error_msg_mismatch = [
|
||||
error_prefix_mismatch,
|
||||
*error_msg_context,
|
||||
", symbolic constraint violated"
|
||||
f", expected to match {prev_arg_field_name}",
|
||||
]
|
||||
else:
|
||||
assert isinstance(var, int)
|
||||
with ir.InsertionPoint(current_block):
|
||||
expected_value = self.i64(var)
|
||||
if not skip_cast_and_check:
|
||||
expected_value = self.i64(var)
|
||||
else:
|
||||
expected_value = self.downcast_i64_to_lower_bits(
|
||||
self.i64(var), var.dtype
|
||||
)
|
||||
|
||||
error_msg_mismatch = [
|
||||
error_prefix_mismatch,
|
||||
*error_msg_context,
|
||||
@@ -1261,6 +1283,7 @@ class TVMFFIFunctionBuilder(TVMFFIBuilder):
|
||||
) -> ir.Block:
|
||||
"""Load the shape value from the argument or match the shape value from the parameter."""
|
||||
field_name = arg_context.get_field_name(field_suffix)
|
||||
arg_field_name = f"{field_name}[{shape_index}]"
|
||||
error_msg = [
|
||||
field_name,
|
||||
f"[{shape_index}] ",
|
||||
@@ -1268,7 +1291,8 @@ class TVMFFIFunctionBuilder(TVMFFIBuilder):
|
||||
self._fn_call_context,
|
||||
]
|
||||
return self.set_or_check_matched_var_binding(
|
||||
current_block, var, value, error_msg, skip_check_predicate=skip_check_predicate
|
||||
current_block, var, value, error_msg, arg_field_name,
|
||||
skip_check_predicate=skip_check_predicate
|
||||
)
|
||||
|
||||
def decode_param_shape_from_ffi_array(
|
||||
@@ -1553,8 +1577,22 @@ class TVMFFIFunctionBuilder(TVMFFIBuilder):
|
||||
# store the matched values, these do not need constraint checks
|
||||
self.matched_var_binding[param.data] = data
|
||||
self.matched_var_source[param.data] = param.data
|
||||
self.matched_var_binding[param.device_id] = device_id
|
||||
self.matched_var_source[param.device_id] = param.device_id
|
||||
self.matched_var_arg_field_name[param.data] = arg_context.get_field_name(".data")
|
||||
|
||||
# check device_id constraint if user specifies a device_id variable
|
||||
current_block = self.set_or_check_matched_var_binding(
|
||||
current_block,
|
||||
param.device_id,
|
||||
device_id,
|
||||
[
|
||||
"device index ",
|
||||
*arg_context.get(),
|
||||
self._fn_call_context,
|
||||
],
|
||||
arg_context.get_field_name(".device.index"),
|
||||
skip_cast_and_check=True,
|
||||
)
|
||||
|
||||
# check ndim
|
||||
expected_ndim = len(param.shape)
|
||||
# Break error message into reusable parts for better string deduplication
|
||||
@@ -1683,7 +1721,8 @@ class TVMFFIFunctionBuilder(TVMFFIBuilder):
|
||||
"""Decode the stream parameter at the given index."""
|
||||
# stream is decoded as opaque handle
|
||||
return self.decode_param_opaque_handle(
|
||||
current_block, param.var, args, arg_index, arg_context
|
||||
current_block, param.var, args, arg_index, arg_context,
|
||||
allow_int_as_ptr=True
|
||||
)
|
||||
|
||||
def decode_param_data_pointer(
|
||||
@@ -1873,6 +1912,7 @@ class TVMFFIFunctionBuilder(TVMFFIBuilder):
|
||||
)
|
||||
self.matched_var_binding[param.var] = env_stream
|
||||
self.matched_var_source[param.var] = env_stream
|
||||
self.matched_var_arg_field_name[param.var] = param.name
|
||||
|
||||
return current_block
|
||||
|
||||
|
||||
@@ -115,7 +115,9 @@ class ConverterContext:
|
||||
def __init__(self):
|
||||
self.num_dyn_shape_vars = 0
|
||||
self.num_dyn_stride_vars = 0
|
||||
self.num_device_id_vars = 0
|
||||
self.sym_int_id_mapping = {}
|
||||
self.vdevice_to_device_id_mapping = {}
|
||||
|
||||
def alloc_shape_name(self) -> str:
|
||||
"""Allocate a new dynamic shape variable name."""
|
||||
@@ -143,6 +145,25 @@ class ConverterContext:
|
||||
self.sym_int_id_mapping[sym_int_id] = var
|
||||
return var
|
||||
|
||||
def alloc_or_reuse_device_id(self, device_type: str, vdevice_id: int) -> Optional[spec.Var]:
|
||||
"""Allocate or reuse a device_id variable for a given virtual device.
|
||||
|
||||
This function returns None for CPU tensors.
|
||||
"""
|
||||
# Don't allocate device_id for CPU tensors
|
||||
if device_type == "cpu":
|
||||
return None
|
||||
|
||||
vdevice_key = (device_type, vdevice_id)
|
||||
if vdevice_key in self.vdevice_to_device_id_mapping:
|
||||
return self.vdevice_to_device_id_mapping[vdevice_key]
|
||||
|
||||
name = f"device_id{self.num_device_id_vars}"
|
||||
self.num_device_id_vars += 1
|
||||
device_id_var = spec.Var(name, "int32")
|
||||
self.vdevice_to_device_id_mapping[vdevice_key] = device_id_var
|
||||
return device_id_var
|
||||
|
||||
|
||||
def _convert_single_arg(
|
||||
arg,
|
||||
@@ -209,17 +230,28 @@ def _convert_single_arg(
|
||||
if hasattr(arg, "_tvm_ffi_tensor"):
|
||||
tvm_ffi_tensor = arg._tvm_ffi_tensor
|
||||
dtype = tvm_ffi_tensor.dtype
|
||||
device_type = tvm_ffi_tensor.device.type
|
||||
|
||||
# Allocate device_id (returns None for CPU tensors)
|
||||
vdevice_id = tvm_ffi_tensor.device.index
|
||||
device_id = ctx.alloc_or_reuse_device_id(device_type, vdevice_id)
|
||||
|
||||
tvm_ffi_cute_tensor = spec.Tensor(
|
||||
arg_name,
|
||||
shapes,
|
||||
arg._tvm_ffi_tensor.dtype,
|
||||
strides=strides,
|
||||
data_alignment=arg._assumed_align,
|
||||
device_type=tvm_ffi_tensor.device.type
|
||||
device_type=device_type,
|
||||
device_id=device_id
|
||||
)
|
||||
else:
|
||||
# for FakeTensor, strictly follow the shape and stride from the cute tensor
|
||||
device_type = "cuda" if _is_gpu_memspace(arg.memspace) else "cpu"
|
||||
# Allocate device_id (returns None for CPU tensors)
|
||||
vdevice_id = 0 # For now, use vdevice_id = 0 for all GPU tensors
|
||||
device_id = ctx.alloc_or_reuse_device_id(device_type, vdevice_id)
|
||||
|
||||
tvm_ffi_cute_tensor = spec.Tensor(
|
||||
arg_name,
|
||||
shapes,
|
||||
@@ -227,6 +259,7 @@ def _convert_single_arg(
|
||||
strides=strides,
|
||||
data_alignment=arg._assumed_align,
|
||||
device_type=device_type,
|
||||
device_id=device_id
|
||||
)
|
||||
if arg.element_type == Float4E2M1FN:
|
||||
tvm_ffi_cute_tensor = spec.create_map_tensor_dtype_f4x2_to_f4_spec(
|
||||
|
||||
@@ -515,6 +515,7 @@ class _FakeTensor(Tensor):
|
||||
when the dimension is dynamic.
|
||||
:type use_32bit_stride: bool, optional
|
||||
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, dtype, shape, *, stride, memspace=None, assumed_align=None):
|
||||
|
||||
@@ -55,6 +55,8 @@ class CudaDialectJitModule:
|
||||
for library in self.cuda_library:
|
||||
cuda_runtime.cudaLibraryUnload(library)
|
||||
self.cuda_library.clear()
|
||||
except Exception as e:
|
||||
pass
|
||||
finally:
|
||||
self._unloaded = True
|
||||
|
||||
|
||||
@@ -1,208 +0,0 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
|
||||
#
|
||||
# Use of this software is governed by the terms and conditions of the
|
||||
# NVIDIA End User License Agreement (EULA), available at:
|
||||
# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html
|
||||
#
|
||||
# Any use, reproduction, disclosure, or distribution of this software
|
||||
# and related documentation outside the scope permitted by the EULA
|
||||
# is strictly prohibited.
|
||||
|
||||
from functools import partial
|
||||
from typing import Tuple
|
||||
|
||||
import cutlass.cute as cute
|
||||
from cutlass.cutlass_dsl import T, dsl_user_op
|
||||
|
||||
from cutlass._mlir import ir
|
||||
from cutlass._mlir.dialects import llvm, nvvm
|
||||
from cutlass._mlir.dialects.nvvm import (
|
||||
MemOrderKind,
|
||||
MemScopeKind,
|
||||
AtomicOpKind,
|
||||
)
|
||||
from cutlass.cute.typing import Pointer, Int32
|
||||
|
||||
|
||||
@dsl_user_op
|
||||
def atomicAdd(dst_ptr: Pointer, val: Int32, loc=None, ip=None) -> Int32:
|
||||
return nvvm.atomicrmw(
|
||||
T.i32(),
|
||||
AtomicOpKind.ADD,
|
||||
dst_ptr.llvm_ptr,
|
||||
val.ir_value(loc=loc, ip=ip),
|
||||
mem_order=MemOrderKind.RELAXED,
|
||||
syncscope=MemScopeKind.SYS,
|
||||
loc=loc,
|
||||
ip=ip,
|
||||
)
|
||||
|
||||
|
||||
@cute.jit
|
||||
def ld_bypass(input_tensor: cute.Tensor):
|
||||
fragment = cute.make_rmem_tensor(input_tensor.layout, input_tensor.element_type)
|
||||
copy_atom_load = cute.make_copy_atom(
|
||||
cute.nvgpu.CopyUniversalOp(),
|
||||
input_tensor.element_type,
|
||||
memory_order=cute.nvgpu.common.MemoryOrder.VOLATILE,
|
||||
memory_scope=cute.nvgpu.common.MemoryScope.SYS,
|
||||
)
|
||||
cute.copy_atom_call(copy_atom_load, input_tensor, fragment)
|
||||
vals = fragment.load()
|
||||
return vals
|
||||
|
||||
|
||||
@cute.jit
|
||||
def spin_lock_wait(
|
||||
lock_ptr: Pointer,
|
||||
expect_count: Int32,
|
||||
mem_order: str = "relaxed",
|
||||
mem_scope: str = "gpu",
|
||||
loc=None,
|
||||
ip=None,
|
||||
) -> None:
|
||||
"""
|
||||
wait on a spin lock until the expected count is reached.
|
||||
"""
|
||||
res = 0
|
||||
while res != expect_count:
|
||||
res = nvvm.atomicrmw(
|
||||
T.i32(),
|
||||
AtomicOpKind.CAS,
|
||||
lock_ptr.llvm_ptr,
|
||||
Int32(0).ir_value(loc=loc, ip=ip),
|
||||
b=Int32(expect_count).ir_value(loc=loc, ip=ip),
|
||||
mem_order=(
|
||||
MemOrderKind.ACQUIRE if mem_order == "acquire" else MemOrderKind.RELAXED
|
||||
),
|
||||
syncscope=MemScopeKind.GPU if mem_scope == "gpu" else MemScopeKind.SYS,
|
||||
)
|
||||
|
||||
|
||||
@dsl_user_op
|
||||
def multimem_red_add_sys_release(mc_ptr: Pointer, loc=None, ip=None) -> None:
|
||||
"""
|
||||
add 1 to the multimem address
|
||||
"""
|
||||
llvm.inline_asm(
|
||||
None,
|
||||
[mc_ptr.toint().ir_value(loc=loc, ip=ip)],
|
||||
"multimem.red.release.sys.global.add.u32 [$0], 1;",
|
||||
"l",
|
||||
has_side_effects=True,
|
||||
asm_dialect=0,
|
||||
loc=loc,
|
||||
ip=ip,
|
||||
)
|
||||
|
||||
|
||||
@dsl_user_op
|
||||
def multimem_red_add_gpu_relaxed(mc_ptr: Pointer, loc=None, ip=None) -> None:
|
||||
"""
|
||||
add 1 to the multimem address
|
||||
"""
|
||||
llvm.inline_asm(
|
||||
None,
|
||||
[mc_ptr.toint().ir_value(loc=loc, ip=ip)],
|
||||
"multimem.red.relaxed.gpu.global.add.u32 [$0], 1;",
|
||||
"l",
|
||||
has_side_effects=True,
|
||||
asm_dialect=0,
|
||||
loc=loc,
|
||||
ip=ip,
|
||||
)
|
||||
|
||||
|
||||
def spin_lock_multimem_arrive(lock_ptr: Pointer, loc=None, ip=None) -> None:
|
||||
"""
|
||||
arrive a spin lock when the lock_ptr is a multimem address.
|
||||
"""
|
||||
multimem_red_add_gpu_relaxed(lock_ptr, loc=loc, ip=ip)
|
||||
|
||||
|
||||
def sm_wise_inter_gpu_multimem_barrier(
|
||||
barrier: Pointer, barrier_mc: Pointer, num_ranks, loc=None, ip=None
|
||||
) -> None:
|
||||
"""
|
||||
barrier for inter-gpu sm-wise
|
||||
"""
|
||||
bidx, bidy, bidz = cute.arch.block_idx()
|
||||
bdimx, bdimy, _ = cute.arch.grid_dim()
|
||||
pid = bidx + bidy * bdimx + bidz * bdimx * bdimy
|
||||
multimem_red_add_sys_release(barrier_mc + pid, loc=loc, ip=ip)
|
||||
cute.arch.fence_proxy(cute.arch.ProxyKind.alias)
|
||||
spin_lock_wait(
|
||||
barrier + pid, num_ranks, mem_order="acquire", mem_scope="sys", loc=loc, ip=ip
|
||||
)
|
||||
|
||||
|
||||
@dsl_user_op
|
||||
def multimem_ld_reduce_base(
|
||||
mc_ptr: Pointer,
|
||||
*,
|
||||
ptx_string: str = "",
|
||||
loc=None,
|
||||
ip=None,
|
||||
) -> Tuple[Int32, Int32, Int32, Int32]:
|
||||
# ld reduce 8xf16 elts
|
||||
mc_ptr_int = mc_ptr.toint(loc=loc, ip=ip).ir_value(loc=loc, ip=ip)
|
||||
return_struct = llvm.inline_asm(
|
||||
ir.Type.parse("!llvm.struct<(i32,i32,i32,i32)>"),
|
||||
[mc_ptr_int],
|
||||
ptx_string,
|
||||
"=r,=r,=r,=r,l",
|
||||
has_side_effects=True,
|
||||
asm_dialect=0,
|
||||
loc=loc,
|
||||
ip=ip,
|
||||
)
|
||||
return_regs = [llvm.extractvalue(T.i32(), return_struct, [i]) for i in range(4)]
|
||||
return return_regs[0], return_regs[1], return_regs[2], return_regs[3]
|
||||
|
||||
|
||||
multimem_ld_reduce_8xf16 = partial(
|
||||
multimem_ld_reduce_base,
|
||||
ptx_string="multimem.ld_reduce.sys.relaxed.global.add.acc::f32.v4.f16x2 {$0, $1, $2, $3}, [$4];",
|
||||
)
|
||||
multimem_ld_reduce_4xf32 = partial(
|
||||
multimem_ld_reduce_base,
|
||||
ptx_string="multimem.ld_reduce.sys.relaxed.global.add.v4.f32 {$0, $1, $2, $3}, [$4];",
|
||||
)
|
||||
multimem_ld_reduce_8xbf16 = partial(
|
||||
multimem_ld_reduce_base,
|
||||
ptx_string="multimem.ld_reduce.sys.relaxed.global.add.acc::f32.v4.bf16x2 {$0, $1, $2, $3}, [$4];",
|
||||
)
|
||||
multimem_ld_reduce_16xe4m3 = partial(
|
||||
multimem_ld_reduce_base,
|
||||
ptx_string="multimem.ld_reduce.sys.relaxed.global.add.acc::f16.v4.e4m3x4 {$0, $1, $2, $3}, [$4];",
|
||||
)
|
||||
multimem_ld_reduce_16xe5m2 = partial(
|
||||
multimem_ld_reduce_base,
|
||||
ptx_string="multimem.ld_reduce.sys.relaxed.global.add.acc::f16.v4.e5m2x4 {$0, $1, $2, $3}, [$4];",
|
||||
)
|
||||
|
||||
|
||||
@dsl_user_op
|
||||
def multimem_st_4xb32(
|
||||
mc_ptr: Pointer,
|
||||
x: Int32,
|
||||
y: Int32,
|
||||
z: Int32,
|
||||
w: Int32,
|
||||
*,
|
||||
loc=None,
|
||||
ip=None,
|
||||
) -> None:
|
||||
# st 4x32 bits of data
|
||||
mc_ptr_int = mc_ptr.toint(loc=loc, ip=ip).ir_value(loc=loc, ip=ip)
|
||||
llvm.inline_asm(
|
||||
T.i32(),
|
||||
[mc_ptr_int, x, y, z, w],
|
||||
"multimem.st.sys.relaxed.global.v4.f32 [$1], {$2, $3, $4, $5};",
|
||||
"=r,l,r,r,r,r",
|
||||
has_side_effects=True,
|
||||
asm_dialect=0,
|
||||
loc=loc,
|
||||
ip=ip,
|
||||
)
|
||||
@@ -14,10 +14,12 @@ import inspect
|
||||
|
||||
import cutlass.cute as cute
|
||||
from cutlass.cute.arch import get_dyn_smem, get_dyn_smem_size
|
||||
from cutlass.cutlass_dsl import CutlassBaseDSL, Int8, Numeric, NumericMeta, dsl_user_op
|
||||
from cutlass.cutlass_dsl import CuTeDSL, Int8, Numeric, NumericMeta, dsl_user_op
|
||||
|
||||
|
||||
SMEM_CAPACITY_MAP = {
|
||||
"sm_120": (100 - 1) * 1024,
|
||||
"sm_103": (228 - 1) * 1024,
|
||||
"sm_100": (228 - 1) * 1024,
|
||||
"sm_90": (228 - 1) * 1024,
|
||||
"sm_80": (164 - 1) * 1024,
|
||||
@@ -71,7 +73,7 @@ class SmemAllocator:
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def capacity_in_bytes(compute_capability: str) -> int:
|
||||
def capacity_in_bytes(compute_capability: Optional[str] = None) -> int:
|
||||
"""Get the shared memory capacity in bytes for a given compute capability.
|
||||
|
||||
Returns the maximum shared memory capacity in bytes available for the specified
|
||||
@@ -83,6 +85,9 @@ class SmemAllocator:
|
||||
:rtype: int
|
||||
:raises ValueError: If the compute capability is not supported
|
||||
"""
|
||||
if compute_capability is None:
|
||||
arch = CuTeDSL._get_dsl().get_arch_enum()
|
||||
compute_capability = f"sm_{arch.major}{arch.minor}"
|
||||
if compute_capability not in SMEM_CAPACITY_MAP:
|
||||
raise ValueError(f"Unsupported compute capability: {compute_capability}")
|
||||
return SMEM_CAPACITY_MAP[compute_capability]
|
||||
@@ -101,7 +106,7 @@ class SmemAllocator:
|
||||
"""
|
||||
self._base = get_dyn_smem(Int8, alignment=1024, loc=loc, ip=ip)
|
||||
self._allocated_bytes = 0
|
||||
CutlassBaseDSL.track_smem_allocator(self, lambda cls: cls._allocated_bytes)
|
||||
CuTeDSL.track_smem_allocator(self, lambda cls: cls._allocated_bytes)
|
||||
|
||||
@overload
|
||||
def allocate(
|
||||
|
||||
@@ -133,7 +133,7 @@ def get_option_registry():
|
||||
this._option_registry = OptionRegistry(device_cc())
|
||||
return this._option_registry
|
||||
|
||||
this.__version__ = '4.3.1'
|
||||
this.__version__ = '4.3.2'
|
||||
|
||||
from cutlass_cppgen.backend import create_memory_pool
|
||||
from cutlass_cppgen.emit.pytorch import pytorch
|
||||
|
||||
@@ -51,7 +51,7 @@ setup_pycute.perform_setup()
|
||||
|
||||
setup(
|
||||
name='cutlass_cppgen',
|
||||
version='4.3.1',
|
||||
version='4.3.2',
|
||||
description='CUTLASS Pythonic Interface',
|
||||
package_dir={'': '.'},
|
||||
packages=[
|
||||
|
||||
@@ -36,7 +36,7 @@ from setuptools import setup
|
||||
def perform_setup():
|
||||
setup(
|
||||
name='pycute',
|
||||
version='4.3.1',
|
||||
version='4.3.2',
|
||||
description='Python implementation of CuTe',
|
||||
packages=['pycute'],
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user