v4.0 update. (#2371)

This commit is contained in:
Junkai-Wu
2025-06-06 14:39:20 +08:00
committed by GitHub
parent 2e2af190bd
commit 8bdbfca682
254 changed files with 29751 additions and 1980 deletions

View File

@@ -578,7 +578,7 @@ class ArithValue(ir.Value):
# Unary operators
def __invert__(self, *, loc=None, ip=None) -> "ArithValue":
return arith.xori(self, arith.const(self.type, -1))
return arith.xori(self, arith.constant(self.type, -1))
# Bitwise operations
@_dispatch_to_rhs_r_op

View File

@@ -95,7 +95,7 @@ class Executor:
unroll=bool,
unroll_full=int,
):
log().info("start [%s] stop [%s] step [%s]", start, stop, step)
log().debug("start [%s] stop [%s] step [%s]", start, stop, step)
return self._loop_execute_range_dynamic(
func,
start,
@@ -117,7 +117,7 @@ class Executor:
used_args: list,
iter_args: list,
):
log().info("start [%s] stop [%s] step [%s]", start, stop, step)
log().debug("start [%s] stop [%s] step [%s]", start, stop, step)
loop_results = iter_args
log().debug("iter_args [%s]", iter_args)
for i in range(start, stop, step):
@@ -374,7 +374,7 @@ def loop_selector(
unroll_full=False,
constexpr=None,
):
log().info(
log().debug(
"start [%s] stop [%s] step [%s] used_args [%s] iter_args [%s] unroll [%s] unroll_full [%s] constexpr [%s]",
start,
stop,
@@ -415,7 +415,7 @@ def loop_selector(
def if_selector(pred, used_args=[], yield_args=[]):
log().info("pred [%s] used_args [%s] yield_args [%s]", pred, used_args, yield_args)
log().debug("pred [%s] used_args [%s] yield_args [%s]", pred, used_args, yield_args)
# Handle Numeric types here?
from .typing import Numeric

View File

@@ -248,39 +248,55 @@ class DSLPreprocessor(ast.NodeTransformer):
# Step 3. Return the transformed tree
return combined_body
def check_early_exit(self, tree):
def check_early_exit(self, tree, kind):
"""
Checks if a given region or scope in the provided Python code has early exits.
"""
class EarlyExitChecker(ast.NodeVisitor):
def __init__(self):
def __init__(self, kind):
self.has_early_exit = False
self.early_exit_node = None
self.early_exit_type = None
self.kind = kind
self.loop_nest_level = 0
# Early exit is not allowed in any level of dynamic control flow
def visit_Return(self, node):
self.has_early_exit = True
self.early_exit_node = node
self.early_exit_type = "return"
def visit_Break(self, node):
self.has_early_exit = True
self.early_exit_node = node
self.early_exit_type = "break"
def visit_Continue(self, node):
self.has_early_exit = True
self.early_exit_node = node
self.early_exit_type = "continue"
def visit_Raise(self, node):
self.has_early_exit = True
self.early_exit_node = node
self.early_exit_type = "raise"
checker = EarlyExitChecker()
checker.visit(tree)
def visit_Break(self, node):
# For break/continue in inner loops, we don't consider it as early exit
if self.loop_nest_level == 0 and self.kind != "if":
self.has_early_exit = True
self.early_exit_node = node
self.early_exit_type = "break"
def visit_Continue(self, node):
if self.loop_nest_level == 0 and self.kind != "if":
self.has_early_exit = True
self.early_exit_node = node
self.early_exit_type = "continue"
def visit_For(self, node):
self.loop_nest_level += 1
self.generic_visit(node)
self.loop_nest_level -= 1
def visit_While(self, node):
self.loop_nest_level += 1
self.generic_visit(node)
self.loop_nest_level -= 1
checker = EarlyExitChecker(kind)
checker.generic_visit(tree)
if not checker.has_early_exit:
return
raise DSLAstPreprocessorError(
@@ -591,7 +607,7 @@ class DSLPreprocessor(ast.NodeTransformer):
if self.is_supported_range_call(node):
constexpr_val = self.get_loop_constexpr(node)
# Check for early exit and raise exception
self.check_early_exit(node)
self.check_early_exit(node, "for")
start, stop, step = self.extract_range_args(node.iter)
unroll, unroll_full = self.extract_unroll_args(node.iter)
used_args, iter_args, flat_args = self.analyze_region_variables(
@@ -659,37 +675,42 @@ class DSLPreprocessor(ast.NodeTransformer):
snippet=ast.unparse(node),
)
test = ast.BoolOp(
op=ast.And(),
values=[
ast.Compare(
left=ast.Call(
func=ast.Name(id="type", ctx=ast.Load()),
args=[node.values[0]],
keywords=[],
def short_circuit_eval(value, short_circuit_value):
return ast.BoolOp(
op=ast.And(),
values=[
ast.Compare(
left=ast.Call(
func=ast.Name(id="type", ctx=ast.Load()),
args=[value],
keywords=[],
),
ops=[ast.Eq()],
comparators=[ast.Name(id="bool", ctx=ast.Load())],
),
ops=[ast.Eq()],
comparators=[ast.Name(id="bool", ctx=ast.Load())],
),
ast.Compare(
left=node.values[0],
ops=[ast.Eq()],
comparators=[short_circuit_value],
),
],
)
return ast.copy_location(
ast.IfExp(
ast.Compare(
left=value,
ops=[ast.Eq()],
comparators=[short_circuit_value],
),
],
)
lhs = node.values[0]
for i in range(1, len(node.values)):
test = short_circuit_eval(lhs, short_circuit_value)
lhs = ast.IfExp(
test=test,
body=node.values[0],
body=lhs,
orelse=ast.Call(
func=helper_func,
args=node.values,
args=[lhs, node.values[i]],
keywords=[],
),
),
node,
)
)
return ast.copy_location(lhs, node)
def visit_UnaryOp(self, node):
# Visit child nodes first
@@ -916,7 +937,7 @@ class DSLPreprocessor(ast.NodeTransformer):
return node
# Check for early exit and raise exception
self.check_early_exit(node)
self.check_early_exit(node, "while")
used_args, yield_args, flat_args = self.analyze_region_variables(
node, active_symbols
@@ -1021,7 +1042,7 @@ class DSLPreprocessor(ast.NodeTransformer):
return node
# Check for early exit and raise exception
self.check_early_exit(node)
self.check_early_exit(node, "if")
used_args, yield_args, flat_args = self.analyze_region_variables(
node, active_symbols

View File

@@ -566,7 +566,9 @@ class BaseDSL:
log().debug("Processing [%d] Argument [%s : %s]", i, arg_name, arg_spec)
# Implicit cast to NumericMeta
if isinstance(arg_spec, t.NumericMeta):
if isinstance(arg_spec, t.NumericMeta) and not isinstance(
arg, arg_spec
):
arg = t.cast(arg, arg_spec)
ir_arg, iv_block_args = (
@@ -589,15 +591,17 @@ class BaseDSL:
self.log_additions(ir_arg)
ir_args.extend(ir_arg)
return ir_args
return ir_args, iv_block_args
fop_args = list(fop.regions[0].blocks[0].arguments)
ir_args = gen_exec_args(args, args_spec.args, args_spec.annotations, fop_args)
ir_kwargs = gen_exec_args(
ir_args, iv_block_args = gen_exec_args(
args, args_spec.args, args_spec.annotations, fop_args
)
ir_kwargs, _ = gen_exec_args(
[kwargs[arg] for arg in args_spec.kwonlyargs],
args_spec.kwonlyargs,
args_spec.annotations,
fop_args[len(ir_args) :],
fop_args[iv_block_args:],
)
ir_kwargs = {k: v for k, v in zip(args_spec.kwonlyargs, ir_kwargs)}
@@ -716,8 +720,10 @@ class BaseDSL:
assert len(args) == len(args_spec.args) and len(kwargs) == len(
args_spec.kwonlyargs
), f"Input args {len(args)=} and kwargs {len(kwargs)=} must match arg_spec.args "
f"{len(args_spec.args)=} and arg_spec.kwonlyargs {len(args_spec.kwonlyargs)=}"
), (
f"Input args {len(args)=} and kwargs {len(kwargs)=} must match arg_spec.args "
f"{len(args_spec.args)=} and arg_spec.kwonlyargs {len(args_spec.kwonlyargs)=}"
)
jit_arg_types, jit_arg_attrs, jit_exec_args = [], [], []
default_attr = ir.DictAttr.get({})
@@ -729,7 +735,7 @@ class BaseDSL:
log().debug("Processing [%d] Argument [%s : %s]", i, arg_name, spec_ty)
# Implicitly convert into Numeric type if possible
if isinstance(spec_ty, t.NumericMeta):
if isinstance(spec_ty, t.NumericMeta) and not isinstance(arg, spec_ty):
arg = t.cast(arg, spec_ty)
# Type safety check

View File

@@ -141,33 +141,60 @@ class JitExecutor:
to get rid of mlir context.
"""
# Process positional arguments with defaults
rectified_args = list(args)
if args_spec.defaults and len(args) < len(args_spec.args):
rectified_args.extend(args_spec.defaults[len(args) - len(args_spec.args) :])
for k, v in kwargs.items():
if k in args_spec.args:
idx = args_spec.args.index(k)
if idx < len(rectified_args):
rectified_args[idx] = v
else:
rectified_args.append(v)
# Process keyword arguments
rectified_kwargs = {k: v for k, v in kwargs.items() if k not in args_spec.args}
if args_spec.kwonlydefaults and len(rectified_kwargs) < len(
args_spec.kwonlyargs
):
rectified_kwargs.update(args_spec.kwonlydefaults)
# args/kwargs must match arg_specs
# No canonicalization of args/kwargs to avoid extra latency
if len(args) != len(args_spec.args) or len(kwargs) != len(args_spec.kwonlyargs):
if len(rectified_args) != len(args_spec.args) or len(rectified_kwargs) != len(
args_spec.kwonlyargs
):
raise DSLRuntimeError(
"input args/kwargs length does not match runtime function signature!",
context={
"input args length": len(args),
"input kwargs length": len(kwargs),
"input args length": len(rectified_args),
"input kwargs length": len(rectified_kwargs),
"function signature args length": len(args_spec.args),
"function signature kwonlyargs length": len(args_spec.kwonlyargs),
},
)
exe_args = []
input_args = [*args, *kwargs.values()]
input_arg_names = [*args_spec.args, *args_spec.kwonlyargs]
for i, arg in enumerate(input_args):
arg_type = args_spec.annotations.get(input_arg_names[i], None)
input_args = rectified_args + list(rectified_kwargs.values())
input_arg_names = args_spec.args + args_spec.kwonlyargs
for arg, arg_name in zip(input_args, input_arg_names):
# short-cut for args already converted
if hasattr(arg, "__c_pointers__"):
exe_args.extend(arg.__c_pointers__())
continue
arg_type = args_spec.annotations.get(arg_name, None)
# Implicit cast to NumericMeta
if isinstance(arg_type, t.NumericMeta):
arg = t.cast(arg, arg_type)
else:
# If not any known type, try registered adapter to do the conversion
adapter = JitArgAdapterRegistry.get_registered_adapter(type(arg))
if adapter:
arg = adapter(arg)
# If not any known type, try registered adapter to do the conversion
adapter = JitArgAdapterRegistry.get_registered_adapter(type(arg))
adapted_arg = adapter(arg) if adapter else arg
exe_args.extend(get_c_pointers(adapted_arg))
exe_args.extend(get_c_pointers(arg))
return exe_args

View File

@@ -457,7 +457,7 @@ class StreamAdapter:
def __init__(self, arg):
self._arg = arg
self._c_pointer = ctypes.cast(self._arg.getPtr(), ctypes.c_void_p)
self._c_pointer = self._arg.getPtr()
def __new_from_mlir_values__(self, values):
assert len(values) == 1

View File

@@ -629,7 +629,7 @@ def _binary_op_type_promote(a, b, promote_bool: bool = False):
b_type = b.dtype
# Early return for same types (except when they're bools that need promotion)
if a_type == b_type and not (promote_bool and a_type.width == 1):
if a_type == b_type and not (promote_bool and a_type is Boolean):
return a, b, a_type
# Handle floating point promotions
@@ -1315,10 +1315,7 @@ class Integer(Numeric, metaclass=IntegerMeta, mlir_type=T.i32, is_abstract=True)
def __invert__(self, *, loc=None, ip=None):
res_type = type(self)
# Create a constant of -1 (all bits set to 1) of the same type as value
all_ones = arith.constant(res_type.mlir_type, -1)
# XOR with -1 gives us bitwise NOT
return res_type(arith.xori(self.ir_value(), all_ones, loc=loc, ip=ip))
return res_type(self.ir_value(loc=loc, ip=ip).__invert__(loc=loc, ip=ip))
def __lshift__(self, other, *, loc=None, ip=None):
return _binary_op(operator.lshift)(self, other, loc=loc, ip=ip)
@@ -1457,18 +1454,14 @@ class Boolean(Integer, metaclass=IntegerMeta, width=1, signed=True, mlir_type=T.
- Converted using Python's bool() function
- Example: Boolean(1) -> True, Boolean(0) -> False
2. Boolean:
- Direct value assignment
- Example: Boolean(Boolean(True)) -> True
2. Numeric:
- Uses the Numeric.value to construct Boolean recursively
3. Numeric:
- Uses the __dsl_bool__ method of the Numeric type
4. MLIR Value with IntegerType:
3. MLIR Value with IntegerType:
- If width is 1: Direct assignment
- Otherwise: Compares with 0 using arith.cmpi
5. MLIR Value with FloatType:
4. MLIR Value with FloatType:
- Compares with 0.0 using arith.cmpf
- Uses unordered comparison to handle NaN values
"""
@@ -1479,19 +1472,35 @@ class Boolean(Integer, metaclass=IntegerMeta, width=1, signed=True, mlir_type=T.
value = None
if isinstance(a, (bool, int, float)):
value = bool(a)
elif isinstance(a, Boolean):
value = a.value
elif isinstance(a, Numeric):
value = a.__dsl_bool__(loc=loc, ip=ip)
Boolean.__init__(self, a.value, loc=loc, ip=ip)
return
elif isinstance(a, ArithValue):
if a.type == T.bool():
value = a
else:
value = a != arith_helper.const(0, a.type)
value = a != arith_helper.const(0, a.type, loc=loc, ip=ip)
if value is None:
raise DSLRuntimeError(f"Cannot convert {a} to Boolean")
super().__init__(value, loc=loc, ip=ip)
self._value_int8 = None
def ir_value_int8(self, *, loc=None, ip=None):
"""
Returns int8 ir value of Boolean.
When we need to store Boolean tensor element, use ir_value_int8().
:param loc: Source location information, defaults to None
:type loc: Optional[Location], optional
:param ip: Insertion point for MLIR operations, defaults to None
:type ip: Optional[InsertionPoint], optional
:return: The int8 value of this Boolean
:rtype: ir.Value
"""
if self._value_int8 is not None:
return self._value_int8
self._value_int8 = Int8(self.value, loc=loc, ip=ip).ir_value()
return self._value_int8
def __neg__(self, *, loc=None, ip=None):
"""Negation operator is not supported for boolean type.

View File

@@ -37,6 +37,7 @@ from cutlass.cutlass_dsl import (
)
from cutlass._mlir import ir
from cutlass._mlir.dialects._ods_common import get_op_result_or_op_results
from cutlass._mlir.dialects import cute as _cute_ir
from cutlass._mlir.dialects.cute import (
ScaledBasis as _ScaledBasis,
@@ -962,6 +963,9 @@ class _Pointer(Pointer):
# Cut off the MLIR type's string for making pretty_str more concise
return self.type.__str__()[6:]
def __get_mlir_types__(self):
return [self.value.type]
def __extract_mlir_values__(self):
return [self.value]
@@ -979,7 +983,7 @@ class _Pointer(Pointer):
@property
@lru_cache_ir()
def value_type(self) -> Type[Numeric]:
def dtype(self) -> Type[Numeric]:
return Numeric.from_mlir_type(self.value.type.value_type)
@property
@@ -993,7 +997,7 @@ class _Pointer(Pointer):
@property
@lru_cache_ir()
def memspace(self) -> AddressSpace:
return self.type.address_space
return AddressSpace(self.type.address_space)
# Make it behave as if it inherited from ir.Value
@property
@@ -1015,7 +1019,7 @@ class _Pointer(Pointer):
:return: The LLVM pointer representation
:rtype: ir.Value
"""
llvm_ptr_ty = llvm.PointerType.get(self.type.address_space)
llvm_ptr_ty = llvm.PointerType.get(self.memspace.value)
return builtin.unrealized_conversion_cast(
[llvm_ptr_ty], [self.value], loc=loc, ip=ip
)
@@ -1034,10 +1038,7 @@ class _Pointer(Pointer):
@dsl_user_op
def toint(self, *, loc=None, ip=None):
if self.type.address_space in (
_cute_ir.AddressSpace.gmem,
_cute_ir.AddressSpace.generic,
):
if self.memspace in (AddressSpace.gmem, AddressSpace.generic):
res_type = Int64
else:
res_type = Int32
@@ -1067,25 +1068,26 @@ class _Pointer(Pointer):
raise ValueError("Alignment must be a power of 2")
assert isinstance(self.type, _cute_ir.PtrType)
if self.type.address_space is AddressSpace.tmem:
if self.memspace is AddressSpace.tmem:
raise ValueError("aligning a TMEM pointer is not supported")
if min_align <= self.alignment:
return self
else:
# Convert pointer to integer
address_int = self.toint(loc=loc, ip=ip)
# Align the address
aligned_address = (address_int + min_align - 1) & ~(min_align - 1)
# Create and return the aligned pointer
return make_ptr(
Numeric.from_mlir_type(self.type.value_type),
aligned_address,
self.type.address_space,
assumed_align=min_align,
loc=loc,
ip=ip,
)
dtype = Numeric.from_mlir_type(self.type.value_type)
# Convert pointer to integer
address_int = self.toint(loc=loc, ip=ip)
# Align the address
aligned_address = (address_int + min_align - 1) & ~(min_align - 1)
return make_ptr(
dtype,
aligned_address,
self.memspace,
assumed_align=min_align,
loc=loc,
ip=ip,
)
@ir.register_value_caster(_cute_ir.MemRefType.get_static_typeid(), replace=True)
@@ -1138,8 +1140,34 @@ class _Tensor(Tensor):
self._dtype = dtype
if isinstance(value, ir.Value):
self.value = value
elif isinstance(value, _Tensor):
self.value = value.value
else:
raise TypeError(f"Expected ir.Value, got {type(value)}")
raise TypeError(f"Expected ir.Value or core._Tensor, got {type(value)}")
# Set iterator
iter_val = _cute_ir.get_iter(self.value)
if isinstance(iter_val, Pointer):
self._iterator = iter_val
elif isinstance(iter_val.type, _cute_ir.IntTupleType):
self._iterator = _unpack_x_tuple(iter_val)
elif isinstance(iter_val, ir.Value):
# Example: SMEM descriptor iterator, not well supported today
self._iterator = iter_val
else:
raise TypeError(f"unsupported iterator type, got {type(iter_val)}")
# Set dtype
if self._dtype is None:
if is_int_tuple(self.iterator):
self._dtype = IntTuple
elif isinstance(self.iterator, Pointer):
self._dtype = self.iterator.value_type
elif isinstance(self.type, _cute_nvgpu_ir.SmemDescViewType):
# SmemDescViewType do not need dtype
self._dtype = None
else:
raise TypeError(f"unsupported iterator type, got {type(self.iterator)}")
def __str__(self):
return f"tensor<{pretty_str(self.iterator)} o {pretty_str(self.layout)}>"
@@ -1157,7 +1185,7 @@ class _Tensor(Tensor):
), f"Expected _Tensor or ir.Value, but got {type(values[0])}"
return _Tensor(
values[0] if isinstance(values[0], ir.Value) else values[0].value,
self._dtype,
dtype=self.element_type,
)
# Cheat to let `Type(_Tensor())` to return cute.Tensor
@@ -1252,9 +1280,6 @@ class _Tensor(Tensor):
return self.element_type(data_val)
def _cvt_to_dest(self, data: Union["TensorSSA", Numeric], *, loc=None, ip=None):
if data.dtype is self.element_type:
return data.ir_value(loc=loc, ip=ip)
orig_dtype = data.dtype
# Implicit upcast to wider type
if (
@@ -1269,11 +1294,11 @@ class _Tensor(Tensor):
f"to Tensor with element type {self.element_type}"
)
val = data.ir_value(loc=loc, ip=ip)
if isinstance(data.dtype, (Int8, Boolean)) and (self.element_type is Boolean):
zero = Int8(0).ir_value(loc=loc, ip=ip)
val = arith.cmpi(arith.CmpIPredicate.ne, val, zero, loc=loc, ip=ip)
if data.dtype is Boolean and self.element_type is Boolean:
# Boolean Numeric and Boolean TensorSSA both hold i1 value, but we need int8 value store to memory
val = data.ir_value_int8()
else:
val = data.ir_value()
return val
@dsl_user_op
@@ -1340,7 +1365,7 @@ class _Tensor(Tensor):
# Implicit upcast to wider type
val = self._cvt_to_dest(data, loc=loc, ip=ip)
if val.type != self.element_type.mlir_type:
if val.type != self.type.value_type:
raise ValueError(
f"type mismatch, store {val.type} to {self.element_type}"
)
@@ -1365,16 +1390,7 @@ class _Tensor(Tensor):
@property
def iterator(self) -> Union[Pointer, IntTuple]:
res = _cute_ir.get_iter(self.value)
if isinstance(res, Pointer):
return res
elif isinstance(res.type, _cute_ir.IntTupleType):
return _unpack_x_tuple(res)
elif isinstance(res, ir.Value):
# Example: SMEM descriptor iterator, not well supported today
return res
else:
raise TypeError(f"unsupported iterator type, got {type(res)}")
return self._iterator
@property
def layout(self) -> Layout:
@@ -1405,12 +1421,7 @@ class _Tensor(Tensor):
@property
@lru_cache_ir()
def element_type(self) -> Union[Type[Numeric], Type[IntTuple]]:
if is_integer(self.iterator) or isinstance(self.iterator, tuple):
return IntTuple
elif isinstance(self.iterator, Pointer):
return self.iterator.value_type
else:
raise TypeError(f"unsupported iterator type, got {type(self.iterator)}")
return self._dtype
@property
@lru_cache_ir()
@@ -1443,7 +1454,14 @@ class _Tensor(Tensor):
self._check_can_load_store()
res_vect = _cute_ir.memref_load_vec(self.value, row_major=True, loc=loc, ip=ip)
if self.element_type is Boolean:
assert (
res_vect.type.element_type == T.i8()
), f"Boolean tensor must be stored as i8 in memory, but got {res_vect.type.element_type}"
zeros = full_like(self, 0, Int8, loc=loc, ip=ip)
res_vect = arith.cmpi(
arith.CmpIPredicate.ne, res_vect, zeros, loc=loc, ip=ip
)
return TensorSSA(res_vect, self.shape, self.element_type)
@dsl_user_op
@@ -1532,9 +1550,7 @@ class _Tensor(Tensor):
self[None] = full(self.shape, fill_value=value, dtype=dst_type, loc=loc, ip=ip)
def _check_can_load_store(self):
if not isinstance(
self.type, _cute_ir.MemRefType
) or not self.type.address_space in (
if not isinstance(self.type, _cute_ir.MemRefType) or not self.memspace in (
AddressSpace.rmem,
AddressSpace.smem,
AddressSpace.gmem,
@@ -1734,10 +1750,6 @@ def printf(*args, loc=None, ip=None) -> None:
arg0 = arg.value if isinstance(arg, Numeric) else arg
if isinstance(arg0, ir.Value):
if isinstance(arg0.type, ir.FloatType) and (arg0.type != T.f32()):
raise TypeError(
f"cute.printf only supports 32-bit floating-point type, but got {arg0.type}"
)
return arg0
elif isinstance(arg0, bool):
return const(arg0, Boolean)
@@ -2212,11 +2224,13 @@ def group_modes(input, begin: int, end: int = -1, *, loc=None, ip=None):
shape = make_shape(2, 3, 4, 5)
grouped_shape = group_modes(shape, 0, 2) # Shape ((2, 3), 4, 5)
"""
if depth(input) == 0:
if depth(input) == 0 and is_integer(input):
return (input,)
if isinstance(input, tuple):
return (*input[:begin], (input[begin:end]), *input[end:])
return _cute_ir.group_modes(input.value, begin, end, loc=loc, ip=ip)
return _cute_ir.group_modes(
input.value if isinstance(input, Tensor) else input, begin, end, loc=loc, ip=ip
)
@overload
@@ -2315,10 +2329,13 @@ def slice_(src, coord: Coord, *, loc=None, ip=None):
else:
return ()
res_type = None
if isinstance(src, Tensor):
res_type = src.element_type
src = src.value
coord_val = _pack_coord(coord, loc=loc, ip=ip)
return _cute_ir.slice(input=src, coord=coord_val, loc=loc, ip=ip)
res = _cute_ir.slice(input=src, coord=coord_val, loc=loc, ip=ip)
return _Tensor(res, dtype=res_type) if isinstance(res, _Tensor) else res
@overload
@@ -2751,7 +2768,8 @@ def filter_zeros(input, *, target_profile=None, loc=None, ip=None):
"""Filter out zeros from a layout or tensor.
This function removes zero-stride dimensions from a layout or tensor.
See Section 3.3 in the CuTe Whitepaper for more details on layout operations.
Refer to https://github.com/NVIDIA/cutlass/blob/main/media/docs/cpp/cute/02_layout_algebra.md
for more layout algebra operations.
:param input: The input layout or tensor to filter
:type input: Layout or Tensor
@@ -2913,7 +2931,8 @@ def size(
Computes the size (number of elements) in the domain of a layout or tensor.
For layouts, this corresponds to the shape of the coordinate space.
See Section 3.2 in the CuTe Whitepaper for more details on layout domains.
See https://github.com/NVIDIA/cutlass/blob/main/media/docs/cpp/cute/01_layout.md
for more details on layout domains.
:param a: The input object whose size to compute
:type a: IntTuple, Shape, Layout, ComposedLayout or Tensor
@@ -3177,7 +3196,7 @@ def make_composed_layout(
) -> ComposedLayout:
"""Create a composed layout by composing an inner transformation with an outer layout.
As described in the CuTe whitepaper, a composed layout applies a sequence of transformations
A composed layout applies a sequence of transformations
to coordinates. The composition is defined as (inner ∘ offset ∘ outer), where the operations
are applied from right to left.
@@ -3416,12 +3435,7 @@ def recast_ptr(
value_type = ptr.type.value_type if dtype is None else dtype
swizzle = swizzle_.type.attribute if swizzle_ is not None else None
res_ty = _cute_ir.PtrType.get(
value_type,
AddressSpace(ptr.type.address_space),
ptr.alignment,
swizzle,
)
res_ty = _cute_ir.PtrType.get(value_type, ptr.memspace, ptr.alignment, swizzle)
return _cute_ir.recast_iter(res_ty, ptr.value, loc=loc, ip=ip)
@@ -3438,8 +3452,15 @@ def make_ptr(
if dtype is None or not isinstance(dtype, NumericMeta):
raise TypeError(f"expects dtype to be a type of Numeric, but got {dtype}")
if not isinstance(mem_space, AddressSpace):
raise TypeError(f"expects mem_space to be an AddressSpace, but got {mem_space}")
if isinstance(value, ir.Value) and llvm.PointerType.isinstance(value.type):
value = llvm.ptrtoint(T.i64(), value)
if not is_integer(value):
raise TypeError(f"expects integer value, but got {type(value)}")
value = Int32(value) if mem_space == AddressSpace.tmem else Int64(value)
bytes_per_elt = max(1, dtype.width // 8)
if assumed_align is None:
@@ -3450,13 +3471,11 @@ def make_ptr(
f"{bytes_per_elt=} is not a multiple of {assumed_align=} and vice versa."
)
value = Int32(value) if mem_space == AddressSpace.tmem else Int64(value)
aligned_ty = _cute_ir.ConstrainedIntType.get(assumed_align, type(value).width)
aligned_intptr = _cute_ir.assume(aligned_ty, value.ir_value(), loc=loc, ip=ip)
ptr_ty = _cute_ir.PtrType.get(
T.i8() if dtype is None else dtype.mlir_type, mem_space, assumed_align
)
data_ty = T.i8() if dtype is None else dtype.mlir_type
ptr_ty = _cute_ir.PtrType.get(data_ty, mem_space, assumed_align)
return _cute_ir.inttoptr(ptr_ty, aligned_intptr, loc=loc, ip=ip)
@@ -3582,7 +3601,7 @@ def make_fragment(
) -> Tensor:
if not issubclass(dtype, Numeric):
raise TypeError(f"value_type must be a type of Numeric, but got {type(dtype)}")
elem_ty = dtype.mlir_type
elem_ty = dtype.mlir_type if dtype is not Boolean else T.i8()
# Alignment for register memory is useless(?), pick-up large enough number
# to allow .128 (> 16B) load store
@@ -3691,16 +3710,12 @@ def make_fragment_like(src, dtype=None, *, loc=None, ip=None):
)
return make_fragment(new_layout, dtype, loc=loc, ip=ip)
else:
if dtype is None:
ty = src.element_type.mlir_type
else:
ty = dtype.mlir_type
dtype = src.element_type if dtype is None else dtype
ty = dtype.mlir_type if dtype is not Boolean else T.i8()
new_tensor = _cute_ir.make_fragment_like(
src.value, elem_type=ty, loc=loc, ip=ip
)
return _Tensor(
new_tensor.value, dtype if dtype is not None else src.element_type
)
return _Tensor(new_tensor.value, dtype)
else:
raise TypeError(
f"src must be a Layout or ComposedLayout or tensor, got {type(src)}"
@@ -3958,11 +3973,14 @@ def logical_divide(target: Tensor, tiler: Tiler, *, loc=None, ip=None) -> Tensor
@dsl_user_op
def logical_divide(target, tiler: Tiler, *, loc=None, ip=None):
res_type = None
if isinstance(target, _Tensor):
res_type = target.element_type
target = target.value
if isinstance(tiler, tuple):
tiler = _pack_tile(tiler, loc=loc, ip=ip)
return _cute_ir.logical_divide(input=target, tiler=tiler, loc=loc, ip=ip)
res = _cute_ir.logical_divide(input=target, tiler=tiler, loc=loc, ip=ip)
return _Tensor(res, dtype=res_type) if isinstance(res, _Tensor) else res
@overload
@@ -3973,11 +3991,14 @@ def zipped_divide(target: Tensor, tiler: Tiler, *, loc=None, ip=None) -> Tensor:
@dsl_user_op
def zipped_divide(target, tiler: Tiler, *, loc=None, ip=None):
res_type = None
if isinstance(target, _Tensor):
res_type = target.element_type
target = target.value
if isinstance(tiler, tuple):
tiler = _pack_tile(tiler, loc=loc, ip=ip)
return _cute_ir.zipped_divide(input=target, tiler=tiler, loc=loc, ip=ip)
res = _cute_ir.zipped_divide(input=target, tiler=tiler, loc=loc, ip=ip)
return _Tensor(res, dtype=res_type) if isinstance(res, _Tensor) else res
@overload
@@ -3988,11 +4009,14 @@ def tiled_divide(target: Tensor, tiler: Tiler, *, loc=None, ip=None) -> Tensor:
@dsl_user_op
def tiled_divide(target, tiler: Tiler, *, loc=None, ip=None):
res_type = None
if isinstance(target, _Tensor):
res_type = target.element_type
target = target.value
if isinstance(tiler, tuple):
tiler = _pack_tile(tiler, loc=loc, ip=ip)
return _cute_ir.tiled_divide(input=target, tiler=tiler, loc=loc, ip=ip)
res = _cute_ir.tiled_divide(input=target, tiler=tiler, loc=loc, ip=ip)
return _Tensor(res, dtype=res_type) if isinstance(res, _Tensor) else res
@overload
@@ -4003,11 +4027,14 @@ def flat_divide(target: Tensor, tiler: Tiler, *, loc=None, ip=None) -> Tensor: .
@dsl_user_op
def flat_divide(target, tiler: Tiler, *, loc=None, ip=None):
res_type = None
if isinstance(target, _Tensor):
res_type = target.element_type
target = target.value
if isinstance(tiler, tuple):
tiler = _pack_tile(tiler, loc=loc, ip=ip)
return _cute_ir.flat_divide(input=target, tiler=tiler, loc=loc, ip=ip)
res = _cute_ir.flat_divide(input=target, tiler=tiler, loc=loc, ip=ip)
return _Tensor(res, dtype=res_type) if isinstance(res, _Tensor) else res
#
@@ -4075,14 +4102,22 @@ def tile_to_shape(
def local_partition(
target: Tensor,
tiler: Union[Layout, Shape],
index,
index: Union[int, Numeric],
proj: XTuple = 1,
*,
loc=None,
ip=None,
) -> Tensor:
if isinstance(index, cutlass_arith.ArithValue):
index_val = index
else:
index_val = index.ir_value()
if index_val.type.width > 32:
raise NotImplementedError(
f"Index value should be 32-bit or smaller integer type, but got {index_val.type}"
)
return _cute_ir.local_partition(
input=target.value, tiler=dice(tiler, proj), index=index, loc=loc, ip=ip
input=target.value, tiler=dice(tiler, proj), index=index_val, loc=loc, ip=ip
)
@@ -4332,6 +4367,8 @@ class MmaAtom(Atom):
def make_fragment_A(self, input, *, loc=None, ip=None):
# input could be memref/shape/layout for tmem based fragment
if isinstance(input, _Tensor):
if self.op is not None:
self.op._verify_fragment_A(input, loc=loc, ip=ip)
input = input.value
if isinstance(input, tuple):
input = _pack_shape(input, loc=loc, ip=ip)
@@ -4343,9 +4380,12 @@ class MmaAtom(Atom):
ip=ip,
)
@dsl_user_op
def make_fragment_B(self, input, *, loc=None, ip=None):
if isinstance(input, _Tensor):
if self.op is not None:
self.op._verify_fragment_B(input, loc=loc, ip=ip)
input = input.value
return _cute_ir.mma_make_fragment(
_cute_ir.MmaOperand.B,
@@ -5193,7 +5233,7 @@ def copy(
src: Tensor,
dst: Tensor,
*,
pred: Tensor = None,
pred: Optional[Tensor] = None,
loc=None,
ip=None,
**kwargs,
@@ -5334,7 +5374,7 @@ class TensorSSA(cutlass_arith.ArithValue):
other = as_numeric(other)
# Promote types
lhs, rhs, res_type = _binary_op_type_promote(self, other, True)
lhs, rhs, res_type = _binary_op_type_promote(self, other)
# Promote scalar to vector
if not isinstance(rhs, TensorSSA):
@@ -5827,6 +5867,28 @@ class TensorSSA(cutlass_arith.ArithValue):
def ir_value(self, *, loc=None, ip=None):
return self
def ir_value_int8(self, *, loc=None, ip=None):
"""
Returns int8 ir value of Boolean tensor.
When we need to store Boolean tensor ssa, use ir_value_int8().
:param loc: Source location information, defaults to None
:type loc: Optional[Location], optional
:param ip: Insertion point for MLIR operations, defaults to None
:type ip: Optional[InsertionPoint], optional
:return: The int8 value of this Boolean
:rtype: ir.Value
"""
assert (
self.element_type is Boolean
), f"Only boolean type needs to be converted to int8, got {self.element_type}"
if not hasattr(self, "_value_int8"):
self._value_int8 = arith.extsi(
T.vector(self.type.shape[0], T.i8()), self, loc=loc, ip=ip
)
return self._value_int8
def reduce(self, op, init_val, reduction_profile: Coord, *, loc=None, ip=None):
"""
Perform reduce on selected modes with given predefined reduction op.

View File

@@ -84,6 +84,11 @@ class MmaUniversalOp(core.MmaOp):
)
return MmaUniversalTrait(_cute_ir.atom(atom_ty, loc=loc, ip=ip))
def _verify_fragment_A(self, input, *, loc=None, ip=None):
pass
def _verify_fragment_B(self, input, *, loc=None, ip=None):
pass
class MmaUniversalTrait(core.Trait):
pass

View File

@@ -20,7 +20,7 @@ import cutlass._mlir.dialects.cute_nvgpu as _cute_nvgpu_ir
from cutlass._mlir import ir
from ..common import OpError
from ...core import MmaOp, Trait, _pack_shape, rank, depth
from ...core import MmaOp, Trait, _pack_shape, rank, depth, _Tensor
from ...typing import (
Shape,
Float8E5M2,
@@ -34,6 +34,7 @@ from ...typing import (
Uint8,
Int32,
Numeric,
AddressSpace,
)
@@ -212,6 +213,30 @@ class MmaOp(MmaOp):
+ f"\n Instruction shape MNK = {self.shape_mnk}"
)
def _verify_fragment_A(self, input: _Tensor, *, loc=None, ip=None):
if input.memspace == AddressSpace.smem and isinstance(
input.layout.type, _cute_ir.ComposedLayoutType
):
raise OpError(
self,
f"Expected affine layout for {self._make_trait()}'s operand A, "
f"but got composed layout instead: {input.layout}"
f"\nPlease use recast_ptr(ptr, {input.layout.inner}, element_type) operation to move swizzle to the ptr",
)
return True
def _verify_fragment_B(self, input: _Tensor, *, loc=None, ip=None):
if input.memspace == AddressSpace.smem and isinstance(
input.layout.type, _cute_ir.ComposedLayoutType
):
raise OpError(
self,
f"Expected affine layout for {self._make_trait()}'s operand B, "
f"but got composed layout instead: {input.layout}"
f"\nPlease use recast_ptr(ptr, {input.layout.inner}, element_type) operation to move swizzle to the ptr",
)
return True
class MmaTrait(Trait):
admissible_fields = [Field.ACCUMULATE, Field.NEGATE_A, Field.NEGATE_B]

View File

@@ -16,8 +16,8 @@ import cutlass._mlir.dialects.cute as _cute_ir
import cutlass._mlir.dialects.cute_nvgpu as _cute_nvgpu_ir
from ..common import OpError
from ...core import MmaOp, Trait, _pack_shape
from ...typing import Shape, Float16, BFloat16, Float32, Numeric
from ...core import MmaOp, Trait, _pack_shape, _Tensor
from ...typing import Shape, Float16, BFloat16, Float32, Numeric, AddressSpace
@dataclass(frozen=True)
@@ -73,6 +73,11 @@ class MmaF16BF16Op(MmaOp):
+ f"\n Instruction shape MNK = {self.shape_mnk}"
)
def _verify_fragment_A(self, input: _Tensor, *, loc=None, ip=None):
pass
def _verify_fragment_B(self, input: _Tensor, *, loc=None, ip=None):
pass
class MmaF16BF16Trait(Trait):
pass

View File

@@ -20,7 +20,7 @@ import cutlass._mlir.dialects.cute_nvgpu as _cute_nvgpu_ir
from cutlass._mlir import ir
from ..common import OpError
from ...core import MmaOp, Trait, _pack_shape, rank, depth
from ...core import MmaOp, Trait, _pack_shape, rank, depth, _Tensor
from ...typing import (
Shape,
Float16,
@@ -30,6 +30,7 @@ from ...typing import (
Float8E5M2,
Float8E4M3FN,
Numeric,
AddressSpace,
)
@@ -167,6 +168,30 @@ class MmaOp(MmaOp):
+ f"\n Instruction shape MNK = {self.shape_mnk}"
)
def _verify_fragment_A(self, input: _Tensor, *, loc=None, ip=None):
if input.memspace == AddressSpace.smem and isinstance(
input.layout.type, _cute_ir.ComposedLayoutType
):
raise OpError(
self,
f"Expected affine layout for {self._make_trait()}'s operand A, "
f"but got composed layout instead: {input.layout}"
f"\nPlease use recast_ptr(ptr, {input.layout.inner}, element_type) operation to move swizzle to the ptr",
)
return True
def _verify_fragment_B(self, input: _Tensor, *, loc=None, ip=None):
if input.memspace == AddressSpace.smem and isinstance(
input.layout.type, _cute_ir.ComposedLayoutType
):
raise OpError(
self,
f"Expected affine layout for {self._make_trait()}'s operand B, "
f"but got composed layout instead: {input.layout}"
f"\nPlease use recast_ptr(ptr, {input.layout.inner}, element_type) operation to move swizzle to the ptr",
)
return True
class MmaTrait(Trait):
admissible_fields = [Field.ACCUMULATE]

View File

@@ -124,7 +124,7 @@ class _Pointer(Pointer):
)
@property
def element_type(self) -> Type[Numeric]:
def dtype(self) -> Type[Numeric]:
return self._dtype
@property

View File

@@ -35,20 +35,7 @@ from inspect import isclass
def assert_(cond, msg=None):
if isinstance(cond, ir.Value):
if ir.VectorType.isinstance(cond.type):
assert (
cond.type.element_type == T.bool()
), f"only expects vector type with boolean elements, but got {cond.type}"
cond_val = vector.multi_reduction(
vector.CombiningKind.AND, cond, const(True), range(cond.type.rank)
)
else:
cond_val = cond
else:
cond_val = const(cond, t.Boolean)
cf.assert_(cond_val, msg if msg else "")
cf.assert_(t.Boolean(cond).ir_value(), msg if msg else "")
def _maybe_recast_tensor_from_f4(src: core.Tensor, tv_layout: core.Layout):

View File

@@ -56,14 +56,23 @@ XTuple = Union[IntTuple, Shape, Stride, Coord, Tile]
Tiler = Union[Shape, Layout, Tile]
class Pointer:
class Pointer(ABC):
"""
Abstract base class for CuTe jit function and runtime _Pointer
"""
def __extract_mlir_values__(self):
# Doesn't matter just return a value
return [self]
@property
def value_type(self) -> Type[Numeric]:
return self.dtype
@property
def dtype(self) -> Type[Numeric]: ...
def __get_mlir_types__(self) -> List[ir.Type]: ...
def __extract_mlir_values__(self) -> List[ir.Value]: ...
def __new_from_mlir_values__(self, values) -> "Pointer": ...
class Tensor(ABC):
@@ -144,10 +153,13 @@ class Tensor(ABC):
def store(self, data: "TensorSSA", *, loc=None, ip=None): ...
def mark_layout_dynamic(self, leading_dim: int|None = None) -> "Tensor": ...
def mark_layout_dynamic(self, leading_dim: int | None = None) -> "Tensor": ...
def mark_compact_shape_dynamic(
self, mode: int, stride_order: tuple[int, ...]|None = None, divisibility: int = 1
self,
mode: int,
stride_order: tuple[int, ...] | None = None,
divisibility: int = 1,
) -> "Tensor": ...
@abstractmethod

View File

@@ -30,6 +30,7 @@ class Agent(enum.Enum):
"""
Agent indicates what is participating in the pipeline synchronization.
"""
# Arbitrary grouping of N threads
Thread = enum.auto()
# Same as AsyncThread, but includes all threads in the block
@@ -42,6 +43,7 @@ class CooperativeGroup:
"""
CooperativeGroup contains size and alignment restrictions for an Agent.
"""
def __init__(self, agent: Agent, size: int = 1, alignment: int = 1):
if agent is Agent.Thread:
assert size > 0
@@ -76,6 +78,7 @@ class _PipelineOp(enum.Enum):
"""
PipelineOp assigns an operation to an agent corresponding to a specific hardware feature.
"""
# async-threads
AsyncThread = enum.auto()
# Blackwell (SM100a) MMA instruction
@@ -140,12 +143,8 @@ class MbarrierArray(SyncObjectArray):
"Error: Mbarrier tx count must be greater than 0 for TMA ops."
)
# Using a tensor to store mbarrier i64 ptrs
self.mbarrier_array = cute.make_fragment(cute.make_layout(num_stages), Int64)
for i in range(num_stages):
self.mbarrier_array[i] = _cute_ir.ptrtoint(
T.i64(), (self.barrier_storage + i).value
)
# Store mbarrier base pointer
self.mbarrier_base = self.barrier_storage
# Mbarrier initialization in constructor
self.mbarrier_init()
@@ -155,10 +154,11 @@ class MbarrierArray(SyncObjectArray):
"""
Initializes an array of mbarriers using warp 0.
"""
def then_body():
for index in range(self.num_stages):
cute.arch.mbarrier_init_arrive_cnt(
_mbarrier_i64_to_ptr(self.mbarrier_array[index]), self.arrive_count
self.get_barrier(index), self.arrive_count
)
warp_idx = cute.arch.warp_idx()
@@ -166,7 +166,12 @@ class MbarrierArray(SyncObjectArray):
if_generate(warp_idx == 0, then_body)
def arrive(self, index: int, dst: int):
def arrive(
self,
index: int,
dst: int,
cta_group: Optional[cute.nvgpu.tcgen05.CtaGroup] = None,
):
"""
Select the arrive corresponding to this MbarrierArray's PipelineOp
:param index: Index of the mbarrier in the array to arrive on
@@ -175,55 +180,53 @@ class MbarrierArray(SyncObjectArray):
- For TCGen05Mma, dst serves as a multicast mask (e.g., 0b1011 allows arrive signal to be multicast to CTAs in the cluster with rank = 0, 1, and 3).
- For AsyncThread, dst serves as a destination cta rank (e.g., 3 means threads will arrive on the mbarrier with rank = 3 in the cluster).
:type dst: int | None
:param cta_group: CTA group for TCGen05Mma, defaults to None for other op types
:type cta_group: cute.nvgpu.tcgen05.CtaGroup, optional
"""
if self.op_type is _PipelineOp.AsyncThread:
self.arrive_mbarrier(index, dst)
elif self.op_type is _PipelineOp.TCGen05Mma:
self.arrive_tcgen05mma(index, dst)
assert (
cta_group is not None
), "Error: CTA group must be provided for TCGen05Mma."
self.arrive_tcgen05mma(index, dst, cta_group)
elif self.op_type in [_PipelineOp.TmaLoad]:
self.arrive_and_expect_tx(index, self.tx_count)
else:
print(_get_pipeline_op(self.op_type))
assert False, "Error: MbarrierArray is not supported for this PipelineOp."
assert False, f"Error: MbarrierArray is not supported for PipelineOp: {_get_pipeline_op(self.op_type)}."
def arrive_mbarrier(self, index: int, dst_rank: int):
if dst_rank is None:
cute.arch.mbarrier_arrive(_mbarrier_i64_to_ptr(self.mbarrier_array[index]))
cute.arch.mbarrier_arrive(self.get_barrier(index))
else:
cute.arch.mbarrier_arrive(
_mbarrier_i64_to_ptr(self.mbarrier_array[index]), dst_rank
)
cute.arch.mbarrier_arrive(self.get_barrier(index), dst_rank)
def arrive_tcgen05mma(self, index: int, mask: int):
def arrive_tcgen05mma(
self, index: int, mask: int, cta_group: cute.nvgpu.tcgen05.CtaGroup
):
if mask is None:
with cute.arch.elect_one():
cute.nvgpu.tcgen05.commit(
_mbarrier_i64_to_ptr(self.mbarrier_array[index])
)
cute.nvgpu.tcgen05.commit(self.get_barrier(index))
else:
with cute.arch.elect_one():
cute.nvgpu.tcgen05.commit(
_mbarrier_i64_to_ptr(self.mbarrier_array[index]),
self.get_barrier(index),
mask,
cute.nvgpu.tcgen05.CtaGroup.TWO,
cta_group,
)
def arrive_and_expect_tx(self, index: int, tx_count: int):
with cute.arch.elect_one():
cute.arch.mbarrier_init_tx_bytes(
_mbarrier_i64_to_ptr(self.mbarrier_array[index]), tx_count
)
cute.arch.mbarrier_init_tx_bytes(self.get_barrier(index), tx_count)
def try_wait(self, index: int, phase: int):
return cute.arch.mbarrier_try_wait(
_mbarrier_i64_to_ptr(self.mbarrier_array[index]), phase
)
return cute.arch.mbarrier_try_wait(self.get_barrier(index), phase)
def wait(self, index: int, phase: int):
cute.arch.mbarrier_wait(_mbarrier_i64_to_ptr(self.mbarrier_array[index]), phase)
cute.arch.mbarrier_wait(self.get_barrier(index), phase)
def get_barrier(self, index: int) -> cute.Pointer:
return _mbarrier_i64_to_ptr(self.mbarrier_array[index])
return self.mbarrier_base + index
class TmaStoreFence(SyncObjectArray):
@@ -390,6 +393,7 @@ class PipelineAsync:
PipelineAsync is a generic pipeline class where both the producer and consumer are
AsyncThreads. It also serves as a base class for specialized pipeline classes.
"""
sync_object_array_full: SyncObjectArray
sync_object_array_empty: SyncObjectArray
num_stages: Int32
@@ -522,6 +526,7 @@ class PipelineTmaAsync(PipelineAsync):
"""
PipelineTmaAsync is used for TMA producers and AsyncThread consumers (e.g. Hopper mainloops).
"""
is_signalling_thread: bool
@staticmethod
@@ -628,7 +633,6 @@ class PipelineTmaAsync(PipelineAsync):
)
self.sync_object_array_full.arrive(state.index, self.producer_mask)
def producer_commit(self, state: PipelineState):
"""
TMA producer commit is a NOP. The transaction barrier signals the commit upon completion of the TMA.
@@ -646,12 +650,15 @@ class PipelineTmaAsync(PipelineAsync):
),
)
@dataclass(frozen=True)
class PipelineTmaUmma(PipelineAsync):
"""
PipelineTmaUmma is used for TMA producers and UMMA consumers (e.g. Blackwell mainloops).
"""
is_leader_cta: bool
cta_group: cute.nvgpu.tcgen05.CtaGroup
@staticmethod
def _compute_mcast_arrival_mask(cta_layout_vmnk: cute.Layout):
@@ -748,6 +755,12 @@ class PipelineTmaUmma(PipelineAsync):
producer_mask = PipelineTmaUmma._compute_mcast_arrival_mask(cta_layout_vmnk)
is_leader_cta = PipelineTmaUmma._compute_is_leader_cta(cta_layout_vmnk)
cta_group = (
cute.nvgpu.tcgen05.CtaGroup.ONE
if cta_layout_vmnk is None or cute.size(cta_layout_vmnk, mode=[0]) == 1
else cute.nvgpu.tcgen05.CtaGroup.TWO
)
consumer_mask = producer_mask
pipeline_init_wait(cta_layout_vmnk)
@@ -759,6 +772,15 @@ class PipelineTmaUmma(PipelineAsync):
producer_mask,
consumer_mask,
is_leader_cta,
cta_group,
)
def consumer_release(self, state: PipelineState):
"""
UMMA consumer release buffer empty, cta_group needs to be provided.
"""
self.sync_object_array_empty.arrive(
state.index, self.consumer_mask, self.cta_group
)
def producer_acquire(
@@ -789,6 +811,8 @@ class PipelineUmmaAsync(PipelineAsync):
PipelineTmaUmma is used for UMMA producers and AsyncThread consumers (e.g. Blackwell accumulator pipelines).
"""
cta_group: cute.nvgpu.tcgen05.CtaGroup
@staticmethod
def _compute_tmem_sync_mask(cta_layout_vmnk: cute.Layout):
"""
@@ -858,6 +882,12 @@ class PipelineUmmaAsync(PipelineAsync):
else:
consumer_mask = PipelineUmmaAsync._compute_peer_cta_rank()
cta_group = (
cute.nvgpu.tcgen05.CtaGroup.ONE
if cta_layout_vmnk is None or cute.size(cta_layout_vmnk, mode=[0]) == 1
else cute.nvgpu.tcgen05.CtaGroup.TWO
)
pipeline_init_wait(cta_layout_vmnk)
return PipelineUmmaAsync(
@@ -866,6 +896,15 @@ class PipelineUmmaAsync(PipelineAsync):
num_stages,
producer_mask,
consumer_mask,
cta_group,
)
def producer_commit(self, state: PipelineState):
"""
UMMA producer commit buffer full, cta_group needs to be provided.
"""
self.sync_object_array_full.arrive(
state.index, self.producer_mask, self.cta_group
)
def producer_tail(self, state: PipelineState):

View File

@@ -185,7 +185,7 @@ class SmemAllocator:
and isinstance(layout.inner, cute.Swizzle)
) and (swizzle is not None):
raise TypeError(
f"iterator swizzle with swizzle layout is currently not supported"
f"Invalid tensor type: cannot be both iterator swizzle (PDSL) and swizzle layout(PISL) at the same time."
)
if isinstance(layout, int):

View File

@@ -527,13 +527,13 @@ def pack_from_irvalue(
"""
Packs MLIR values into a list of mixed values.
"""
log().info("===--- Values Pack (%d)", len(ir_values))
log().debug("===--- Values Pack (%d)", len(ir_values))
for idx, packed in enumerate(ir_values):
log().info("[%d]: will-packed: %s", idx, ir_values)
log().debug("[%d]: will-packed: %s", idx, ir_values)
for idx, unpacked in indices.items():
log().info("[%d]: indices: %s", idx, unpacked)
log().debug("[%d]: indices: %s", idx, unpacked)
for idx, c in enumerate(class_types):
log().info("[%d]: obj-types: %s", idx, type(c))
log().debug("[%d]: obj-types: %s", idx, type(c))
mixed_values = [None] * len(indices)
for idx, (start, length) in sorted(indices.items()):
@@ -552,10 +552,10 @@ def pack_from_irvalue(
except DSLRuntimeError as e:
mixed_values[idx] = chunk[0]
log().info("------------------ ")
log().debug("------------------ ")
for idx, packed in enumerate(mixed_values):
log().info("[%d]: packed: %s", idx, packed)
log().info("------------------ ")
log().debug("[%d]: packed: %s", idx, packed)
log().debug("------------------ ")
return mixed_values
@@ -571,9 +571,9 @@ def unpack_to_irvalue(
class_types = []
current_offset = 0
log().info("===--- Values UNPack (%d)", len(mixed_values))
log().debug("===--- Values UNPack (%d)", len(mixed_values))
for idx, packed in enumerate(mixed_values):
log().info("[%d]: will-unpacked: [type:%s] %s", idx, type(packed), packed)
log().debug("[%d]: will-unpacked: [type:%s] %s", idx, type(packed), packed)
for idx, item in enumerate(mixed_values):
class_types.append(item)
try:
@@ -612,16 +612,16 @@ def unpack_to_irvalue(
),
) from e
log().info("------------------ ")
log().debug("------------------ ")
for idx, unpacked in enumerate(unpacked_values):
log().info("[%d]: unpacked values: %s", idx, unpacked)
log().debug("[%d]: unpacked values: %s", idx, unpacked)
for idx, unpacked in enumerate(ir_values):
log().info("[%d]: unpacked ir_values: %s", idx, unpacked)
log().debug("[%d]: unpacked ir_values: %s", idx, unpacked)
for idx, unpacked in indices.items():
log().info("[%d]: indices: %s", idx, unpacked)
log().debug("[%d]: indices: %s", idx, unpacked)
for idx, unpacked in enumerate(class_types):
log().info("[%d]: initial-class-types: %s", idx, unpacked)
log().info("------------------ ")
log().debug("[%d]: initial-class-types: %s", idx, unpacked)
log().debug("------------------ ")
return ir_values, unpacked_values, indices, class_types
@@ -1302,7 +1302,6 @@ class WhileLoopContext:
def __exit__(self, exc_type, exc_value, traceback):
self.ipoint_op.__exit__(exc_type, exc_value, traceback)
return True
@property
def results(self):

View File

@@ -331,11 +331,7 @@ def _if_execute_dynamic(
# Assume final result types match the dynamic yields
result_types = [arg.type for arg in dyn_yield_ops]
pred_ = t.as_numeric(pred)
if not isinstance(pred_, Boolean):
# Convert to Boolean through comparison
pred_ = pred_ == True
pred_ = Boolean(pred)
try:
if_op = scf.IfOp(