mirror of
https://github.com/NVIDIA/cutlass.git
synced 2026-04-19 22:38:56 +00:00
v4.0 update. (#2371)
This commit is contained in:
@@ -578,7 +578,7 @@ class ArithValue(ir.Value):
|
||||
|
||||
# Unary operators
|
||||
def __invert__(self, *, loc=None, ip=None) -> "ArithValue":
|
||||
return arith.xori(self, arith.const(self.type, -1))
|
||||
return arith.xori(self, arith.constant(self.type, -1))
|
||||
|
||||
# Bitwise operations
|
||||
@_dispatch_to_rhs_r_op
|
||||
|
||||
@@ -95,7 +95,7 @@ class Executor:
|
||||
unroll=bool,
|
||||
unroll_full=int,
|
||||
):
|
||||
log().info("start [%s] stop [%s] step [%s]", start, stop, step)
|
||||
log().debug("start [%s] stop [%s] step [%s]", start, stop, step)
|
||||
return self._loop_execute_range_dynamic(
|
||||
func,
|
||||
start,
|
||||
@@ -117,7 +117,7 @@ class Executor:
|
||||
used_args: list,
|
||||
iter_args: list,
|
||||
):
|
||||
log().info("start [%s] stop [%s] step [%s]", start, stop, step)
|
||||
log().debug("start [%s] stop [%s] step [%s]", start, stop, step)
|
||||
loop_results = iter_args
|
||||
log().debug("iter_args [%s]", iter_args)
|
||||
for i in range(start, stop, step):
|
||||
@@ -374,7 +374,7 @@ def loop_selector(
|
||||
unroll_full=False,
|
||||
constexpr=None,
|
||||
):
|
||||
log().info(
|
||||
log().debug(
|
||||
"start [%s] stop [%s] step [%s] used_args [%s] iter_args [%s] unroll [%s] unroll_full [%s] constexpr [%s]",
|
||||
start,
|
||||
stop,
|
||||
@@ -415,7 +415,7 @@ def loop_selector(
|
||||
|
||||
|
||||
def if_selector(pred, used_args=[], yield_args=[]):
|
||||
log().info("pred [%s] used_args [%s] yield_args [%s]", pred, used_args, yield_args)
|
||||
log().debug("pred [%s] used_args [%s] yield_args [%s]", pred, used_args, yield_args)
|
||||
# Handle Numeric types here?
|
||||
|
||||
from .typing import Numeric
|
||||
|
||||
@@ -248,39 +248,55 @@ class DSLPreprocessor(ast.NodeTransformer):
|
||||
# Step 3. Return the transformed tree
|
||||
return combined_body
|
||||
|
||||
def check_early_exit(self, tree):
|
||||
def check_early_exit(self, tree, kind):
|
||||
"""
|
||||
Checks if a given region or scope in the provided Python code has early exits.
|
||||
"""
|
||||
|
||||
class EarlyExitChecker(ast.NodeVisitor):
|
||||
def __init__(self):
|
||||
def __init__(self, kind):
|
||||
self.has_early_exit = False
|
||||
self.early_exit_node = None
|
||||
self.early_exit_type = None
|
||||
self.kind = kind
|
||||
self.loop_nest_level = 0
|
||||
|
||||
# Early exit is not allowed in any level of dynamic control flow
|
||||
def visit_Return(self, node):
|
||||
self.has_early_exit = True
|
||||
self.early_exit_node = node
|
||||
self.early_exit_type = "return"
|
||||
|
||||
def visit_Break(self, node):
|
||||
self.has_early_exit = True
|
||||
self.early_exit_node = node
|
||||
self.early_exit_type = "break"
|
||||
|
||||
def visit_Continue(self, node):
|
||||
self.has_early_exit = True
|
||||
self.early_exit_node = node
|
||||
self.early_exit_type = "continue"
|
||||
|
||||
def visit_Raise(self, node):
|
||||
self.has_early_exit = True
|
||||
self.early_exit_node = node
|
||||
self.early_exit_type = "raise"
|
||||
|
||||
checker = EarlyExitChecker()
|
||||
checker.visit(tree)
|
||||
def visit_Break(self, node):
|
||||
# For break/continue in inner loops, we don't consider it as early exit
|
||||
if self.loop_nest_level == 0 and self.kind != "if":
|
||||
self.has_early_exit = True
|
||||
self.early_exit_node = node
|
||||
self.early_exit_type = "break"
|
||||
|
||||
def visit_Continue(self, node):
|
||||
if self.loop_nest_level == 0 and self.kind != "if":
|
||||
self.has_early_exit = True
|
||||
self.early_exit_node = node
|
||||
self.early_exit_type = "continue"
|
||||
|
||||
def visit_For(self, node):
|
||||
self.loop_nest_level += 1
|
||||
self.generic_visit(node)
|
||||
self.loop_nest_level -= 1
|
||||
|
||||
def visit_While(self, node):
|
||||
self.loop_nest_level += 1
|
||||
self.generic_visit(node)
|
||||
self.loop_nest_level -= 1
|
||||
|
||||
checker = EarlyExitChecker(kind)
|
||||
checker.generic_visit(tree)
|
||||
if not checker.has_early_exit:
|
||||
return
|
||||
raise DSLAstPreprocessorError(
|
||||
@@ -591,7 +607,7 @@ class DSLPreprocessor(ast.NodeTransformer):
|
||||
if self.is_supported_range_call(node):
|
||||
constexpr_val = self.get_loop_constexpr(node)
|
||||
# Check for early exit and raise exception
|
||||
self.check_early_exit(node)
|
||||
self.check_early_exit(node, "for")
|
||||
start, stop, step = self.extract_range_args(node.iter)
|
||||
unroll, unroll_full = self.extract_unroll_args(node.iter)
|
||||
used_args, iter_args, flat_args = self.analyze_region_variables(
|
||||
@@ -659,37 +675,42 @@ class DSLPreprocessor(ast.NodeTransformer):
|
||||
snippet=ast.unparse(node),
|
||||
)
|
||||
|
||||
test = ast.BoolOp(
|
||||
op=ast.And(),
|
||||
values=[
|
||||
ast.Compare(
|
||||
left=ast.Call(
|
||||
func=ast.Name(id="type", ctx=ast.Load()),
|
||||
args=[node.values[0]],
|
||||
keywords=[],
|
||||
def short_circuit_eval(value, short_circuit_value):
|
||||
return ast.BoolOp(
|
||||
op=ast.And(),
|
||||
values=[
|
||||
ast.Compare(
|
||||
left=ast.Call(
|
||||
func=ast.Name(id="type", ctx=ast.Load()),
|
||||
args=[value],
|
||||
keywords=[],
|
||||
),
|
||||
ops=[ast.Eq()],
|
||||
comparators=[ast.Name(id="bool", ctx=ast.Load())],
|
||||
),
|
||||
ops=[ast.Eq()],
|
||||
comparators=[ast.Name(id="bool", ctx=ast.Load())],
|
||||
),
|
||||
ast.Compare(
|
||||
left=node.values[0],
|
||||
ops=[ast.Eq()],
|
||||
comparators=[short_circuit_value],
|
||||
),
|
||||
],
|
||||
)
|
||||
return ast.copy_location(
|
||||
ast.IfExp(
|
||||
ast.Compare(
|
||||
left=value,
|
||||
ops=[ast.Eq()],
|
||||
comparators=[short_circuit_value],
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
lhs = node.values[0]
|
||||
|
||||
for i in range(1, len(node.values)):
|
||||
test = short_circuit_eval(lhs, short_circuit_value)
|
||||
lhs = ast.IfExp(
|
||||
test=test,
|
||||
body=node.values[0],
|
||||
body=lhs,
|
||||
orelse=ast.Call(
|
||||
func=helper_func,
|
||||
args=node.values,
|
||||
args=[lhs, node.values[i]],
|
||||
keywords=[],
|
||||
),
|
||||
),
|
||||
node,
|
||||
)
|
||||
)
|
||||
|
||||
return ast.copy_location(lhs, node)
|
||||
|
||||
def visit_UnaryOp(self, node):
|
||||
# Visit child nodes first
|
||||
@@ -916,7 +937,7 @@ class DSLPreprocessor(ast.NodeTransformer):
|
||||
return node
|
||||
|
||||
# Check for early exit and raise exception
|
||||
self.check_early_exit(node)
|
||||
self.check_early_exit(node, "while")
|
||||
|
||||
used_args, yield_args, flat_args = self.analyze_region_variables(
|
||||
node, active_symbols
|
||||
@@ -1021,7 +1042,7 @@ class DSLPreprocessor(ast.NodeTransformer):
|
||||
return node
|
||||
|
||||
# Check for early exit and raise exception
|
||||
self.check_early_exit(node)
|
||||
self.check_early_exit(node, "if")
|
||||
|
||||
used_args, yield_args, flat_args = self.analyze_region_variables(
|
||||
node, active_symbols
|
||||
|
||||
@@ -566,7 +566,9 @@ class BaseDSL:
|
||||
log().debug("Processing [%d] Argument [%s : %s]", i, arg_name, arg_spec)
|
||||
|
||||
# Implicit cast to NumericMeta
|
||||
if isinstance(arg_spec, t.NumericMeta):
|
||||
if isinstance(arg_spec, t.NumericMeta) and not isinstance(
|
||||
arg, arg_spec
|
||||
):
|
||||
arg = t.cast(arg, arg_spec)
|
||||
|
||||
ir_arg, iv_block_args = (
|
||||
@@ -589,15 +591,17 @@ class BaseDSL:
|
||||
self.log_additions(ir_arg)
|
||||
ir_args.extend(ir_arg)
|
||||
|
||||
return ir_args
|
||||
return ir_args, iv_block_args
|
||||
|
||||
fop_args = list(fop.regions[0].blocks[0].arguments)
|
||||
ir_args = gen_exec_args(args, args_spec.args, args_spec.annotations, fop_args)
|
||||
ir_kwargs = gen_exec_args(
|
||||
ir_args, iv_block_args = gen_exec_args(
|
||||
args, args_spec.args, args_spec.annotations, fop_args
|
||||
)
|
||||
ir_kwargs, _ = gen_exec_args(
|
||||
[kwargs[arg] for arg in args_spec.kwonlyargs],
|
||||
args_spec.kwonlyargs,
|
||||
args_spec.annotations,
|
||||
fop_args[len(ir_args) :],
|
||||
fop_args[iv_block_args:],
|
||||
)
|
||||
ir_kwargs = {k: v for k, v in zip(args_spec.kwonlyargs, ir_kwargs)}
|
||||
|
||||
@@ -716,8 +720,10 @@ class BaseDSL:
|
||||
|
||||
assert len(args) == len(args_spec.args) and len(kwargs) == len(
|
||||
args_spec.kwonlyargs
|
||||
), f"Input args {len(args)=} and kwargs {len(kwargs)=} must match arg_spec.args "
|
||||
f"{len(args_spec.args)=} and arg_spec.kwonlyargs {len(args_spec.kwonlyargs)=}"
|
||||
), (
|
||||
f"Input args {len(args)=} and kwargs {len(kwargs)=} must match arg_spec.args "
|
||||
f"{len(args_spec.args)=} and arg_spec.kwonlyargs {len(args_spec.kwonlyargs)=}"
|
||||
)
|
||||
|
||||
jit_arg_types, jit_arg_attrs, jit_exec_args = [], [], []
|
||||
default_attr = ir.DictAttr.get({})
|
||||
@@ -729,7 +735,7 @@ class BaseDSL:
|
||||
log().debug("Processing [%d] Argument [%s : %s]", i, arg_name, spec_ty)
|
||||
|
||||
# Implicitly convert into Numeric type if possible
|
||||
if isinstance(spec_ty, t.NumericMeta):
|
||||
if isinstance(spec_ty, t.NumericMeta) and not isinstance(arg, spec_ty):
|
||||
arg = t.cast(arg, spec_ty)
|
||||
|
||||
# Type safety check
|
||||
|
||||
@@ -141,33 +141,60 @@ class JitExecutor:
|
||||
to get rid of mlir context.
|
||||
"""
|
||||
|
||||
# Process positional arguments with defaults
|
||||
rectified_args = list(args)
|
||||
if args_spec.defaults and len(args) < len(args_spec.args):
|
||||
rectified_args.extend(args_spec.defaults[len(args) - len(args_spec.args) :])
|
||||
for k, v in kwargs.items():
|
||||
if k in args_spec.args:
|
||||
idx = args_spec.args.index(k)
|
||||
if idx < len(rectified_args):
|
||||
rectified_args[idx] = v
|
||||
else:
|
||||
rectified_args.append(v)
|
||||
|
||||
# Process keyword arguments
|
||||
rectified_kwargs = {k: v for k, v in kwargs.items() if k not in args_spec.args}
|
||||
if args_spec.kwonlydefaults and len(rectified_kwargs) < len(
|
||||
args_spec.kwonlyargs
|
||||
):
|
||||
rectified_kwargs.update(args_spec.kwonlydefaults)
|
||||
|
||||
# args/kwargs must match arg_specs
|
||||
# No canonicalization of args/kwargs to avoid extra latency
|
||||
if len(args) != len(args_spec.args) or len(kwargs) != len(args_spec.kwonlyargs):
|
||||
if len(rectified_args) != len(args_spec.args) or len(rectified_kwargs) != len(
|
||||
args_spec.kwonlyargs
|
||||
):
|
||||
raise DSLRuntimeError(
|
||||
"input args/kwargs length does not match runtime function signature!",
|
||||
context={
|
||||
"input args length": len(args),
|
||||
"input kwargs length": len(kwargs),
|
||||
"input args length": len(rectified_args),
|
||||
"input kwargs length": len(rectified_kwargs),
|
||||
"function signature args length": len(args_spec.args),
|
||||
"function signature kwonlyargs length": len(args_spec.kwonlyargs),
|
||||
},
|
||||
)
|
||||
|
||||
exe_args = []
|
||||
input_args = [*args, *kwargs.values()]
|
||||
input_arg_names = [*args_spec.args, *args_spec.kwonlyargs]
|
||||
for i, arg in enumerate(input_args):
|
||||
arg_type = args_spec.annotations.get(input_arg_names[i], None)
|
||||
input_args = rectified_args + list(rectified_kwargs.values())
|
||||
input_arg_names = args_spec.args + args_spec.kwonlyargs
|
||||
for arg, arg_name in zip(input_args, input_arg_names):
|
||||
# short-cut for args already converted
|
||||
if hasattr(arg, "__c_pointers__"):
|
||||
exe_args.extend(arg.__c_pointers__())
|
||||
continue
|
||||
|
||||
arg_type = args_spec.annotations.get(arg_name, None)
|
||||
|
||||
# Implicit cast to NumericMeta
|
||||
if isinstance(arg_type, t.NumericMeta):
|
||||
arg = t.cast(arg, arg_type)
|
||||
else:
|
||||
# If not any known type, try registered adapter to do the conversion
|
||||
adapter = JitArgAdapterRegistry.get_registered_adapter(type(arg))
|
||||
if adapter:
|
||||
arg = adapter(arg)
|
||||
|
||||
# If not any known type, try registered adapter to do the conversion
|
||||
adapter = JitArgAdapterRegistry.get_registered_adapter(type(arg))
|
||||
adapted_arg = adapter(arg) if adapter else arg
|
||||
exe_args.extend(get_c_pointers(adapted_arg))
|
||||
exe_args.extend(get_c_pointers(arg))
|
||||
|
||||
return exe_args
|
||||
|
||||
|
||||
@@ -457,7 +457,7 @@ class StreamAdapter:
|
||||
|
||||
def __init__(self, arg):
|
||||
self._arg = arg
|
||||
self._c_pointer = ctypes.cast(self._arg.getPtr(), ctypes.c_void_p)
|
||||
self._c_pointer = self._arg.getPtr()
|
||||
|
||||
def __new_from_mlir_values__(self, values):
|
||||
assert len(values) == 1
|
||||
|
||||
@@ -629,7 +629,7 @@ def _binary_op_type_promote(a, b, promote_bool: bool = False):
|
||||
b_type = b.dtype
|
||||
|
||||
# Early return for same types (except when they're bools that need promotion)
|
||||
if a_type == b_type and not (promote_bool and a_type.width == 1):
|
||||
if a_type == b_type and not (promote_bool and a_type is Boolean):
|
||||
return a, b, a_type
|
||||
|
||||
# Handle floating point promotions
|
||||
@@ -1315,10 +1315,7 @@ class Integer(Numeric, metaclass=IntegerMeta, mlir_type=T.i32, is_abstract=True)
|
||||
|
||||
def __invert__(self, *, loc=None, ip=None):
|
||||
res_type = type(self)
|
||||
# Create a constant of -1 (all bits set to 1) of the same type as value
|
||||
all_ones = arith.constant(res_type.mlir_type, -1)
|
||||
# XOR with -1 gives us bitwise NOT
|
||||
return res_type(arith.xori(self.ir_value(), all_ones, loc=loc, ip=ip))
|
||||
return res_type(self.ir_value(loc=loc, ip=ip).__invert__(loc=loc, ip=ip))
|
||||
|
||||
def __lshift__(self, other, *, loc=None, ip=None):
|
||||
return _binary_op(operator.lshift)(self, other, loc=loc, ip=ip)
|
||||
@@ -1457,18 +1454,14 @@ class Boolean(Integer, metaclass=IntegerMeta, width=1, signed=True, mlir_type=T.
|
||||
- Converted using Python's bool() function
|
||||
- Example: Boolean(1) -> True, Boolean(0) -> False
|
||||
|
||||
2. Boolean:
|
||||
- Direct value assignment
|
||||
- Example: Boolean(Boolean(True)) -> True
|
||||
2. Numeric:
|
||||
- Uses the Numeric.value to construct Boolean recursively
|
||||
|
||||
3. Numeric:
|
||||
- Uses the __dsl_bool__ method of the Numeric type
|
||||
|
||||
4. MLIR Value with IntegerType:
|
||||
3. MLIR Value with IntegerType:
|
||||
- If width is 1: Direct assignment
|
||||
- Otherwise: Compares with 0 using arith.cmpi
|
||||
|
||||
5. MLIR Value with FloatType:
|
||||
4. MLIR Value with FloatType:
|
||||
- Compares with 0.0 using arith.cmpf
|
||||
- Uses unordered comparison to handle NaN values
|
||||
"""
|
||||
@@ -1479,19 +1472,35 @@ class Boolean(Integer, metaclass=IntegerMeta, width=1, signed=True, mlir_type=T.
|
||||
value = None
|
||||
if isinstance(a, (bool, int, float)):
|
||||
value = bool(a)
|
||||
elif isinstance(a, Boolean):
|
||||
value = a.value
|
||||
elif isinstance(a, Numeric):
|
||||
value = a.__dsl_bool__(loc=loc, ip=ip)
|
||||
Boolean.__init__(self, a.value, loc=loc, ip=ip)
|
||||
return
|
||||
elif isinstance(a, ArithValue):
|
||||
if a.type == T.bool():
|
||||
value = a
|
||||
else:
|
||||
value = a != arith_helper.const(0, a.type)
|
||||
|
||||
value = a != arith_helper.const(0, a.type, loc=loc, ip=ip)
|
||||
if value is None:
|
||||
raise DSLRuntimeError(f"Cannot convert {a} to Boolean")
|
||||
super().__init__(value, loc=loc, ip=ip)
|
||||
self._value_int8 = None
|
||||
|
||||
def ir_value_int8(self, *, loc=None, ip=None):
|
||||
"""
|
||||
Returns int8 ir value of Boolean.
|
||||
When we need to store Boolean tensor element, use ir_value_int8().
|
||||
|
||||
:param loc: Source location information, defaults to None
|
||||
:type loc: Optional[Location], optional
|
||||
:param ip: Insertion point for MLIR operations, defaults to None
|
||||
:type ip: Optional[InsertionPoint], optional
|
||||
:return: The int8 value of this Boolean
|
||||
:rtype: ir.Value
|
||||
"""
|
||||
if self._value_int8 is not None:
|
||||
return self._value_int8
|
||||
self._value_int8 = Int8(self.value, loc=loc, ip=ip).ir_value()
|
||||
return self._value_int8
|
||||
|
||||
def __neg__(self, *, loc=None, ip=None):
|
||||
"""Negation operator is not supported for boolean type.
|
||||
|
||||
@@ -37,6 +37,7 @@ from cutlass.cutlass_dsl import (
|
||||
)
|
||||
|
||||
from cutlass._mlir import ir
|
||||
from cutlass._mlir.dialects._ods_common import get_op_result_or_op_results
|
||||
from cutlass._mlir.dialects import cute as _cute_ir
|
||||
from cutlass._mlir.dialects.cute import (
|
||||
ScaledBasis as _ScaledBasis,
|
||||
@@ -962,6 +963,9 @@ class _Pointer(Pointer):
|
||||
# Cut off the MLIR type's string for making pretty_str more concise
|
||||
return self.type.__str__()[6:]
|
||||
|
||||
def __get_mlir_types__(self):
|
||||
return [self.value.type]
|
||||
|
||||
def __extract_mlir_values__(self):
|
||||
return [self.value]
|
||||
|
||||
@@ -979,7 +983,7 @@ class _Pointer(Pointer):
|
||||
|
||||
@property
|
||||
@lru_cache_ir()
|
||||
def value_type(self) -> Type[Numeric]:
|
||||
def dtype(self) -> Type[Numeric]:
|
||||
return Numeric.from_mlir_type(self.value.type.value_type)
|
||||
|
||||
@property
|
||||
@@ -993,7 +997,7 @@ class _Pointer(Pointer):
|
||||
@property
|
||||
@lru_cache_ir()
|
||||
def memspace(self) -> AddressSpace:
|
||||
return self.type.address_space
|
||||
return AddressSpace(self.type.address_space)
|
||||
|
||||
# Make it behave as if it inherited from ir.Value
|
||||
@property
|
||||
@@ -1015,7 +1019,7 @@ class _Pointer(Pointer):
|
||||
:return: The LLVM pointer representation
|
||||
:rtype: ir.Value
|
||||
"""
|
||||
llvm_ptr_ty = llvm.PointerType.get(self.type.address_space)
|
||||
llvm_ptr_ty = llvm.PointerType.get(self.memspace.value)
|
||||
return builtin.unrealized_conversion_cast(
|
||||
[llvm_ptr_ty], [self.value], loc=loc, ip=ip
|
||||
)
|
||||
@@ -1034,10 +1038,7 @@ class _Pointer(Pointer):
|
||||
|
||||
@dsl_user_op
|
||||
def toint(self, *, loc=None, ip=None):
|
||||
if self.type.address_space in (
|
||||
_cute_ir.AddressSpace.gmem,
|
||||
_cute_ir.AddressSpace.generic,
|
||||
):
|
||||
if self.memspace in (AddressSpace.gmem, AddressSpace.generic):
|
||||
res_type = Int64
|
||||
else:
|
||||
res_type = Int32
|
||||
@@ -1067,25 +1068,26 @@ class _Pointer(Pointer):
|
||||
raise ValueError("Alignment must be a power of 2")
|
||||
|
||||
assert isinstance(self.type, _cute_ir.PtrType)
|
||||
if self.type.address_space is AddressSpace.tmem:
|
||||
if self.memspace is AddressSpace.tmem:
|
||||
raise ValueError("aligning a TMEM pointer is not supported")
|
||||
|
||||
if min_align <= self.alignment:
|
||||
return self
|
||||
else:
|
||||
# Convert pointer to integer
|
||||
address_int = self.toint(loc=loc, ip=ip)
|
||||
# Align the address
|
||||
aligned_address = (address_int + min_align - 1) & ~(min_align - 1)
|
||||
# Create and return the aligned pointer
|
||||
return make_ptr(
|
||||
Numeric.from_mlir_type(self.type.value_type),
|
||||
aligned_address,
|
||||
self.type.address_space,
|
||||
assumed_align=min_align,
|
||||
loc=loc,
|
||||
ip=ip,
|
||||
)
|
||||
|
||||
dtype = Numeric.from_mlir_type(self.type.value_type)
|
||||
# Convert pointer to integer
|
||||
address_int = self.toint(loc=loc, ip=ip)
|
||||
# Align the address
|
||||
aligned_address = (address_int + min_align - 1) & ~(min_align - 1)
|
||||
|
||||
return make_ptr(
|
||||
dtype,
|
||||
aligned_address,
|
||||
self.memspace,
|
||||
assumed_align=min_align,
|
||||
loc=loc,
|
||||
ip=ip,
|
||||
)
|
||||
|
||||
|
||||
@ir.register_value_caster(_cute_ir.MemRefType.get_static_typeid(), replace=True)
|
||||
@@ -1138,8 +1140,34 @@ class _Tensor(Tensor):
|
||||
self._dtype = dtype
|
||||
if isinstance(value, ir.Value):
|
||||
self.value = value
|
||||
elif isinstance(value, _Tensor):
|
||||
self.value = value.value
|
||||
else:
|
||||
raise TypeError(f"Expected ir.Value, got {type(value)}")
|
||||
raise TypeError(f"Expected ir.Value or core._Tensor, got {type(value)}")
|
||||
|
||||
# Set iterator
|
||||
iter_val = _cute_ir.get_iter(self.value)
|
||||
if isinstance(iter_val, Pointer):
|
||||
self._iterator = iter_val
|
||||
elif isinstance(iter_val.type, _cute_ir.IntTupleType):
|
||||
self._iterator = _unpack_x_tuple(iter_val)
|
||||
elif isinstance(iter_val, ir.Value):
|
||||
# Example: SMEM descriptor iterator, not well supported today
|
||||
self._iterator = iter_val
|
||||
else:
|
||||
raise TypeError(f"unsupported iterator type, got {type(iter_val)}")
|
||||
|
||||
# Set dtype
|
||||
if self._dtype is None:
|
||||
if is_int_tuple(self.iterator):
|
||||
self._dtype = IntTuple
|
||||
elif isinstance(self.iterator, Pointer):
|
||||
self._dtype = self.iterator.value_type
|
||||
elif isinstance(self.type, _cute_nvgpu_ir.SmemDescViewType):
|
||||
# SmemDescViewType do not need dtype
|
||||
self._dtype = None
|
||||
else:
|
||||
raise TypeError(f"unsupported iterator type, got {type(self.iterator)}")
|
||||
|
||||
def __str__(self):
|
||||
return f"tensor<{pretty_str(self.iterator)} o {pretty_str(self.layout)}>"
|
||||
@@ -1157,7 +1185,7 @@ class _Tensor(Tensor):
|
||||
), f"Expected _Tensor or ir.Value, but got {type(values[0])}"
|
||||
return _Tensor(
|
||||
values[0] if isinstance(values[0], ir.Value) else values[0].value,
|
||||
self._dtype,
|
||||
dtype=self.element_type,
|
||||
)
|
||||
|
||||
# Cheat to let `Type(_Tensor())` to return cute.Tensor
|
||||
@@ -1252,9 +1280,6 @@ class _Tensor(Tensor):
|
||||
return self.element_type(data_val)
|
||||
|
||||
def _cvt_to_dest(self, data: Union["TensorSSA", Numeric], *, loc=None, ip=None):
|
||||
if data.dtype is self.element_type:
|
||||
return data.ir_value(loc=loc, ip=ip)
|
||||
|
||||
orig_dtype = data.dtype
|
||||
# Implicit upcast to wider type
|
||||
if (
|
||||
@@ -1269,11 +1294,11 @@ class _Tensor(Tensor):
|
||||
f"to Tensor with element type {self.element_type}"
|
||||
)
|
||||
|
||||
val = data.ir_value(loc=loc, ip=ip)
|
||||
if isinstance(data.dtype, (Int8, Boolean)) and (self.element_type is Boolean):
|
||||
zero = Int8(0).ir_value(loc=loc, ip=ip)
|
||||
val = arith.cmpi(arith.CmpIPredicate.ne, val, zero, loc=loc, ip=ip)
|
||||
|
||||
if data.dtype is Boolean and self.element_type is Boolean:
|
||||
# Boolean Numeric and Boolean TensorSSA both hold i1 value, but we need int8 value store to memory
|
||||
val = data.ir_value_int8()
|
||||
else:
|
||||
val = data.ir_value()
|
||||
return val
|
||||
|
||||
@dsl_user_op
|
||||
@@ -1340,7 +1365,7 @@ class _Tensor(Tensor):
|
||||
|
||||
# Implicit upcast to wider type
|
||||
val = self._cvt_to_dest(data, loc=loc, ip=ip)
|
||||
if val.type != self.element_type.mlir_type:
|
||||
if val.type != self.type.value_type:
|
||||
raise ValueError(
|
||||
f"type mismatch, store {val.type} to {self.element_type}"
|
||||
)
|
||||
@@ -1365,16 +1390,7 @@ class _Tensor(Tensor):
|
||||
|
||||
@property
|
||||
def iterator(self) -> Union[Pointer, IntTuple]:
|
||||
res = _cute_ir.get_iter(self.value)
|
||||
if isinstance(res, Pointer):
|
||||
return res
|
||||
elif isinstance(res.type, _cute_ir.IntTupleType):
|
||||
return _unpack_x_tuple(res)
|
||||
elif isinstance(res, ir.Value):
|
||||
# Example: SMEM descriptor iterator, not well supported today
|
||||
return res
|
||||
else:
|
||||
raise TypeError(f"unsupported iterator type, got {type(res)}")
|
||||
return self._iterator
|
||||
|
||||
@property
|
||||
def layout(self) -> Layout:
|
||||
@@ -1405,12 +1421,7 @@ class _Tensor(Tensor):
|
||||
@property
|
||||
@lru_cache_ir()
|
||||
def element_type(self) -> Union[Type[Numeric], Type[IntTuple]]:
|
||||
if is_integer(self.iterator) or isinstance(self.iterator, tuple):
|
||||
return IntTuple
|
||||
elif isinstance(self.iterator, Pointer):
|
||||
return self.iterator.value_type
|
||||
else:
|
||||
raise TypeError(f"unsupported iterator type, got {type(self.iterator)}")
|
||||
return self._dtype
|
||||
|
||||
@property
|
||||
@lru_cache_ir()
|
||||
@@ -1443,7 +1454,14 @@ class _Tensor(Tensor):
|
||||
self._check_can_load_store()
|
||||
|
||||
res_vect = _cute_ir.memref_load_vec(self.value, row_major=True, loc=loc, ip=ip)
|
||||
|
||||
if self.element_type is Boolean:
|
||||
assert (
|
||||
res_vect.type.element_type == T.i8()
|
||||
), f"Boolean tensor must be stored as i8 in memory, but got {res_vect.type.element_type}"
|
||||
zeros = full_like(self, 0, Int8, loc=loc, ip=ip)
|
||||
res_vect = arith.cmpi(
|
||||
arith.CmpIPredicate.ne, res_vect, zeros, loc=loc, ip=ip
|
||||
)
|
||||
return TensorSSA(res_vect, self.shape, self.element_type)
|
||||
|
||||
@dsl_user_op
|
||||
@@ -1532,9 +1550,7 @@ class _Tensor(Tensor):
|
||||
self[None] = full(self.shape, fill_value=value, dtype=dst_type, loc=loc, ip=ip)
|
||||
|
||||
def _check_can_load_store(self):
|
||||
if not isinstance(
|
||||
self.type, _cute_ir.MemRefType
|
||||
) or not self.type.address_space in (
|
||||
if not isinstance(self.type, _cute_ir.MemRefType) or not self.memspace in (
|
||||
AddressSpace.rmem,
|
||||
AddressSpace.smem,
|
||||
AddressSpace.gmem,
|
||||
@@ -1734,10 +1750,6 @@ def printf(*args, loc=None, ip=None) -> None:
|
||||
arg0 = arg.value if isinstance(arg, Numeric) else arg
|
||||
|
||||
if isinstance(arg0, ir.Value):
|
||||
if isinstance(arg0.type, ir.FloatType) and (arg0.type != T.f32()):
|
||||
raise TypeError(
|
||||
f"cute.printf only supports 32-bit floating-point type, but got {arg0.type}"
|
||||
)
|
||||
return arg0
|
||||
elif isinstance(arg0, bool):
|
||||
return const(arg0, Boolean)
|
||||
@@ -2212,11 +2224,13 @@ def group_modes(input, begin: int, end: int = -1, *, loc=None, ip=None):
|
||||
shape = make_shape(2, 3, 4, 5)
|
||||
grouped_shape = group_modes(shape, 0, 2) # Shape ((2, 3), 4, 5)
|
||||
"""
|
||||
if depth(input) == 0:
|
||||
if depth(input) == 0 and is_integer(input):
|
||||
return (input,)
|
||||
if isinstance(input, tuple):
|
||||
return (*input[:begin], (input[begin:end]), *input[end:])
|
||||
return _cute_ir.group_modes(input.value, begin, end, loc=loc, ip=ip)
|
||||
return _cute_ir.group_modes(
|
||||
input.value if isinstance(input, Tensor) else input, begin, end, loc=loc, ip=ip
|
||||
)
|
||||
|
||||
|
||||
@overload
|
||||
@@ -2315,10 +2329,13 @@ def slice_(src, coord: Coord, *, loc=None, ip=None):
|
||||
else:
|
||||
return ()
|
||||
|
||||
res_type = None
|
||||
if isinstance(src, Tensor):
|
||||
res_type = src.element_type
|
||||
src = src.value
|
||||
coord_val = _pack_coord(coord, loc=loc, ip=ip)
|
||||
return _cute_ir.slice(input=src, coord=coord_val, loc=loc, ip=ip)
|
||||
res = _cute_ir.slice(input=src, coord=coord_val, loc=loc, ip=ip)
|
||||
return _Tensor(res, dtype=res_type) if isinstance(res, _Tensor) else res
|
||||
|
||||
|
||||
@overload
|
||||
@@ -2751,7 +2768,8 @@ def filter_zeros(input, *, target_profile=None, loc=None, ip=None):
|
||||
"""Filter out zeros from a layout or tensor.
|
||||
|
||||
This function removes zero-stride dimensions from a layout or tensor.
|
||||
See Section 3.3 in the CuTe Whitepaper for more details on layout operations.
|
||||
Refer to https://github.com/NVIDIA/cutlass/blob/main/media/docs/cpp/cute/02_layout_algebra.md
|
||||
for more layout algebra operations.
|
||||
|
||||
:param input: The input layout or tensor to filter
|
||||
:type input: Layout or Tensor
|
||||
@@ -2913,7 +2931,8 @@ def size(
|
||||
|
||||
Computes the size (number of elements) in the domain of a layout or tensor.
|
||||
For layouts, this corresponds to the shape of the coordinate space.
|
||||
See Section 3.2 in the CuTe Whitepaper for more details on layout domains.
|
||||
See https://github.com/NVIDIA/cutlass/blob/main/media/docs/cpp/cute/01_layout.md
|
||||
for more details on layout domains.
|
||||
|
||||
:param a: The input object whose size to compute
|
||||
:type a: IntTuple, Shape, Layout, ComposedLayout or Tensor
|
||||
@@ -3177,7 +3196,7 @@ def make_composed_layout(
|
||||
) -> ComposedLayout:
|
||||
"""Create a composed layout by composing an inner transformation with an outer layout.
|
||||
|
||||
As described in the CuTe whitepaper, a composed layout applies a sequence of transformations
|
||||
A composed layout applies a sequence of transformations
|
||||
to coordinates. The composition is defined as (inner ∘ offset ∘ outer), where the operations
|
||||
are applied from right to left.
|
||||
|
||||
@@ -3416,12 +3435,7 @@ def recast_ptr(
|
||||
|
||||
value_type = ptr.type.value_type if dtype is None else dtype
|
||||
swizzle = swizzle_.type.attribute if swizzle_ is not None else None
|
||||
res_ty = _cute_ir.PtrType.get(
|
||||
value_type,
|
||||
AddressSpace(ptr.type.address_space),
|
||||
ptr.alignment,
|
||||
swizzle,
|
||||
)
|
||||
res_ty = _cute_ir.PtrType.get(value_type, ptr.memspace, ptr.alignment, swizzle)
|
||||
return _cute_ir.recast_iter(res_ty, ptr.value, loc=loc, ip=ip)
|
||||
|
||||
|
||||
@@ -3438,8 +3452,15 @@ def make_ptr(
|
||||
if dtype is None or not isinstance(dtype, NumericMeta):
|
||||
raise TypeError(f"expects dtype to be a type of Numeric, but got {dtype}")
|
||||
|
||||
if not isinstance(mem_space, AddressSpace):
|
||||
raise TypeError(f"expects mem_space to be an AddressSpace, but got {mem_space}")
|
||||
|
||||
if isinstance(value, ir.Value) and llvm.PointerType.isinstance(value.type):
|
||||
value = llvm.ptrtoint(T.i64(), value)
|
||||
|
||||
if not is_integer(value):
|
||||
raise TypeError(f"expects integer value, but got {type(value)}")
|
||||
value = Int32(value) if mem_space == AddressSpace.tmem else Int64(value)
|
||||
|
||||
bytes_per_elt = max(1, dtype.width // 8)
|
||||
if assumed_align is None:
|
||||
@@ -3450,13 +3471,11 @@ def make_ptr(
|
||||
f"{bytes_per_elt=} is not a multiple of {assumed_align=} and vice versa."
|
||||
)
|
||||
|
||||
value = Int32(value) if mem_space == AddressSpace.tmem else Int64(value)
|
||||
aligned_ty = _cute_ir.ConstrainedIntType.get(assumed_align, type(value).width)
|
||||
aligned_intptr = _cute_ir.assume(aligned_ty, value.ir_value(), loc=loc, ip=ip)
|
||||
|
||||
ptr_ty = _cute_ir.PtrType.get(
|
||||
T.i8() if dtype is None else dtype.mlir_type, mem_space, assumed_align
|
||||
)
|
||||
data_ty = T.i8() if dtype is None else dtype.mlir_type
|
||||
ptr_ty = _cute_ir.PtrType.get(data_ty, mem_space, assumed_align)
|
||||
return _cute_ir.inttoptr(ptr_ty, aligned_intptr, loc=loc, ip=ip)
|
||||
|
||||
|
||||
@@ -3582,7 +3601,7 @@ def make_fragment(
|
||||
) -> Tensor:
|
||||
if not issubclass(dtype, Numeric):
|
||||
raise TypeError(f"value_type must be a type of Numeric, but got {type(dtype)}")
|
||||
elem_ty = dtype.mlir_type
|
||||
elem_ty = dtype.mlir_type if dtype is not Boolean else T.i8()
|
||||
|
||||
# Alignment for register memory is useless(?), pick-up large enough number
|
||||
# to allow .128 (> 16B) load store
|
||||
@@ -3691,16 +3710,12 @@ def make_fragment_like(src, dtype=None, *, loc=None, ip=None):
|
||||
)
|
||||
return make_fragment(new_layout, dtype, loc=loc, ip=ip)
|
||||
else:
|
||||
if dtype is None:
|
||||
ty = src.element_type.mlir_type
|
||||
else:
|
||||
ty = dtype.mlir_type
|
||||
dtype = src.element_type if dtype is None else dtype
|
||||
ty = dtype.mlir_type if dtype is not Boolean else T.i8()
|
||||
new_tensor = _cute_ir.make_fragment_like(
|
||||
src.value, elem_type=ty, loc=loc, ip=ip
|
||||
)
|
||||
return _Tensor(
|
||||
new_tensor.value, dtype if dtype is not None else src.element_type
|
||||
)
|
||||
return _Tensor(new_tensor.value, dtype)
|
||||
else:
|
||||
raise TypeError(
|
||||
f"src must be a Layout or ComposedLayout or tensor, got {type(src)}"
|
||||
@@ -3958,11 +3973,14 @@ def logical_divide(target: Tensor, tiler: Tiler, *, loc=None, ip=None) -> Tensor
|
||||
|
||||
@dsl_user_op
|
||||
def logical_divide(target, tiler: Tiler, *, loc=None, ip=None):
|
||||
res_type = None
|
||||
if isinstance(target, _Tensor):
|
||||
res_type = target.element_type
|
||||
target = target.value
|
||||
if isinstance(tiler, tuple):
|
||||
tiler = _pack_tile(tiler, loc=loc, ip=ip)
|
||||
return _cute_ir.logical_divide(input=target, tiler=tiler, loc=loc, ip=ip)
|
||||
res = _cute_ir.logical_divide(input=target, tiler=tiler, loc=loc, ip=ip)
|
||||
return _Tensor(res, dtype=res_type) if isinstance(res, _Tensor) else res
|
||||
|
||||
|
||||
@overload
|
||||
@@ -3973,11 +3991,14 @@ def zipped_divide(target: Tensor, tiler: Tiler, *, loc=None, ip=None) -> Tensor:
|
||||
|
||||
@dsl_user_op
|
||||
def zipped_divide(target, tiler: Tiler, *, loc=None, ip=None):
|
||||
res_type = None
|
||||
if isinstance(target, _Tensor):
|
||||
res_type = target.element_type
|
||||
target = target.value
|
||||
if isinstance(tiler, tuple):
|
||||
tiler = _pack_tile(tiler, loc=loc, ip=ip)
|
||||
return _cute_ir.zipped_divide(input=target, tiler=tiler, loc=loc, ip=ip)
|
||||
res = _cute_ir.zipped_divide(input=target, tiler=tiler, loc=loc, ip=ip)
|
||||
return _Tensor(res, dtype=res_type) if isinstance(res, _Tensor) else res
|
||||
|
||||
|
||||
@overload
|
||||
@@ -3988,11 +4009,14 @@ def tiled_divide(target: Tensor, tiler: Tiler, *, loc=None, ip=None) -> Tensor:
|
||||
|
||||
@dsl_user_op
|
||||
def tiled_divide(target, tiler: Tiler, *, loc=None, ip=None):
|
||||
res_type = None
|
||||
if isinstance(target, _Tensor):
|
||||
res_type = target.element_type
|
||||
target = target.value
|
||||
if isinstance(tiler, tuple):
|
||||
tiler = _pack_tile(tiler, loc=loc, ip=ip)
|
||||
return _cute_ir.tiled_divide(input=target, tiler=tiler, loc=loc, ip=ip)
|
||||
res = _cute_ir.tiled_divide(input=target, tiler=tiler, loc=loc, ip=ip)
|
||||
return _Tensor(res, dtype=res_type) if isinstance(res, _Tensor) else res
|
||||
|
||||
|
||||
@overload
|
||||
@@ -4003,11 +4027,14 @@ def flat_divide(target: Tensor, tiler: Tiler, *, loc=None, ip=None) -> Tensor: .
|
||||
|
||||
@dsl_user_op
|
||||
def flat_divide(target, tiler: Tiler, *, loc=None, ip=None):
|
||||
res_type = None
|
||||
if isinstance(target, _Tensor):
|
||||
res_type = target.element_type
|
||||
target = target.value
|
||||
if isinstance(tiler, tuple):
|
||||
tiler = _pack_tile(tiler, loc=loc, ip=ip)
|
||||
return _cute_ir.flat_divide(input=target, tiler=tiler, loc=loc, ip=ip)
|
||||
res = _cute_ir.flat_divide(input=target, tiler=tiler, loc=loc, ip=ip)
|
||||
return _Tensor(res, dtype=res_type) if isinstance(res, _Tensor) else res
|
||||
|
||||
|
||||
#
|
||||
@@ -4075,14 +4102,22 @@ def tile_to_shape(
|
||||
def local_partition(
|
||||
target: Tensor,
|
||||
tiler: Union[Layout, Shape],
|
||||
index,
|
||||
index: Union[int, Numeric],
|
||||
proj: XTuple = 1,
|
||||
*,
|
||||
loc=None,
|
||||
ip=None,
|
||||
) -> Tensor:
|
||||
if isinstance(index, cutlass_arith.ArithValue):
|
||||
index_val = index
|
||||
else:
|
||||
index_val = index.ir_value()
|
||||
if index_val.type.width > 32:
|
||||
raise NotImplementedError(
|
||||
f"Index value should be 32-bit or smaller integer type, but got {index_val.type}"
|
||||
)
|
||||
return _cute_ir.local_partition(
|
||||
input=target.value, tiler=dice(tiler, proj), index=index, loc=loc, ip=ip
|
||||
input=target.value, tiler=dice(tiler, proj), index=index_val, loc=loc, ip=ip
|
||||
)
|
||||
|
||||
|
||||
@@ -4332,6 +4367,8 @@ class MmaAtom(Atom):
|
||||
def make_fragment_A(self, input, *, loc=None, ip=None):
|
||||
# input could be memref/shape/layout for tmem based fragment
|
||||
if isinstance(input, _Tensor):
|
||||
if self.op is not None:
|
||||
self.op._verify_fragment_A(input, loc=loc, ip=ip)
|
||||
input = input.value
|
||||
if isinstance(input, tuple):
|
||||
input = _pack_shape(input, loc=loc, ip=ip)
|
||||
@@ -4343,9 +4380,12 @@ class MmaAtom(Atom):
|
||||
ip=ip,
|
||||
)
|
||||
|
||||
|
||||
@dsl_user_op
|
||||
def make_fragment_B(self, input, *, loc=None, ip=None):
|
||||
if isinstance(input, _Tensor):
|
||||
if self.op is not None:
|
||||
self.op._verify_fragment_B(input, loc=loc, ip=ip)
|
||||
input = input.value
|
||||
return _cute_ir.mma_make_fragment(
|
||||
_cute_ir.MmaOperand.B,
|
||||
@@ -5193,7 +5233,7 @@ def copy(
|
||||
src: Tensor,
|
||||
dst: Tensor,
|
||||
*,
|
||||
pred: Tensor = None,
|
||||
pred: Optional[Tensor] = None,
|
||||
loc=None,
|
||||
ip=None,
|
||||
**kwargs,
|
||||
@@ -5334,7 +5374,7 @@ class TensorSSA(cutlass_arith.ArithValue):
|
||||
other = as_numeric(other)
|
||||
|
||||
# Promote types
|
||||
lhs, rhs, res_type = _binary_op_type_promote(self, other, True)
|
||||
lhs, rhs, res_type = _binary_op_type_promote(self, other)
|
||||
|
||||
# Promote scalar to vector
|
||||
if not isinstance(rhs, TensorSSA):
|
||||
@@ -5827,6 +5867,28 @@ class TensorSSA(cutlass_arith.ArithValue):
|
||||
def ir_value(self, *, loc=None, ip=None):
|
||||
return self
|
||||
|
||||
def ir_value_int8(self, *, loc=None, ip=None):
|
||||
"""
|
||||
Returns int8 ir value of Boolean tensor.
|
||||
When we need to store Boolean tensor ssa, use ir_value_int8().
|
||||
|
||||
:param loc: Source location information, defaults to None
|
||||
:type loc: Optional[Location], optional
|
||||
:param ip: Insertion point for MLIR operations, defaults to None
|
||||
:type ip: Optional[InsertionPoint], optional
|
||||
:return: The int8 value of this Boolean
|
||||
:rtype: ir.Value
|
||||
"""
|
||||
assert (
|
||||
self.element_type is Boolean
|
||||
), f"Only boolean type needs to be converted to int8, got {self.element_type}"
|
||||
|
||||
if not hasattr(self, "_value_int8"):
|
||||
self._value_int8 = arith.extsi(
|
||||
T.vector(self.type.shape[0], T.i8()), self, loc=loc, ip=ip
|
||||
)
|
||||
return self._value_int8
|
||||
|
||||
def reduce(self, op, init_val, reduction_profile: Coord, *, loc=None, ip=None):
|
||||
"""
|
||||
Perform reduce on selected modes with given predefined reduction op.
|
||||
|
||||
@@ -84,6 +84,11 @@ class MmaUniversalOp(core.MmaOp):
|
||||
)
|
||||
return MmaUniversalTrait(_cute_ir.atom(atom_ty, loc=loc, ip=ip))
|
||||
|
||||
def _verify_fragment_A(self, input, *, loc=None, ip=None):
|
||||
pass
|
||||
|
||||
def _verify_fragment_B(self, input, *, loc=None, ip=None):
|
||||
pass
|
||||
|
||||
class MmaUniversalTrait(core.Trait):
|
||||
pass
|
||||
|
||||
@@ -20,7 +20,7 @@ import cutlass._mlir.dialects.cute_nvgpu as _cute_nvgpu_ir
|
||||
from cutlass._mlir import ir
|
||||
|
||||
from ..common import OpError
|
||||
from ...core import MmaOp, Trait, _pack_shape, rank, depth
|
||||
from ...core import MmaOp, Trait, _pack_shape, rank, depth, _Tensor
|
||||
from ...typing import (
|
||||
Shape,
|
||||
Float8E5M2,
|
||||
@@ -34,6 +34,7 @@ from ...typing import (
|
||||
Uint8,
|
||||
Int32,
|
||||
Numeric,
|
||||
AddressSpace,
|
||||
)
|
||||
|
||||
|
||||
@@ -212,6 +213,30 @@ class MmaOp(MmaOp):
|
||||
+ f"\n Instruction shape MNK = {self.shape_mnk}"
|
||||
)
|
||||
|
||||
def _verify_fragment_A(self, input: _Tensor, *, loc=None, ip=None):
|
||||
if input.memspace == AddressSpace.smem and isinstance(
|
||||
input.layout.type, _cute_ir.ComposedLayoutType
|
||||
):
|
||||
raise OpError(
|
||||
self,
|
||||
f"Expected affine layout for {self._make_trait()}'s operand A, "
|
||||
f"but got composed layout instead: {input.layout}"
|
||||
f"\nPlease use recast_ptr(ptr, {input.layout.inner}, element_type) operation to move swizzle to the ptr",
|
||||
)
|
||||
return True
|
||||
|
||||
def _verify_fragment_B(self, input: _Tensor, *, loc=None, ip=None):
|
||||
if input.memspace == AddressSpace.smem and isinstance(
|
||||
input.layout.type, _cute_ir.ComposedLayoutType
|
||||
):
|
||||
raise OpError(
|
||||
self,
|
||||
f"Expected affine layout for {self._make_trait()}'s operand B, "
|
||||
f"but got composed layout instead: {input.layout}"
|
||||
f"\nPlease use recast_ptr(ptr, {input.layout.inner}, element_type) operation to move swizzle to the ptr",
|
||||
)
|
||||
return True
|
||||
|
||||
|
||||
class MmaTrait(Trait):
|
||||
admissible_fields = [Field.ACCUMULATE, Field.NEGATE_A, Field.NEGATE_B]
|
||||
|
||||
@@ -16,8 +16,8 @@ import cutlass._mlir.dialects.cute as _cute_ir
|
||||
import cutlass._mlir.dialects.cute_nvgpu as _cute_nvgpu_ir
|
||||
|
||||
from ..common import OpError
|
||||
from ...core import MmaOp, Trait, _pack_shape
|
||||
from ...typing import Shape, Float16, BFloat16, Float32, Numeric
|
||||
from ...core import MmaOp, Trait, _pack_shape, _Tensor
|
||||
from ...typing import Shape, Float16, BFloat16, Float32, Numeric, AddressSpace
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
@@ -73,6 +73,11 @@ class MmaF16BF16Op(MmaOp):
|
||||
+ f"\n Instruction shape MNK = {self.shape_mnk}"
|
||||
)
|
||||
|
||||
def _verify_fragment_A(self, input: _Tensor, *, loc=None, ip=None):
|
||||
pass
|
||||
|
||||
def _verify_fragment_B(self, input: _Tensor, *, loc=None, ip=None):
|
||||
pass
|
||||
|
||||
class MmaF16BF16Trait(Trait):
|
||||
pass
|
||||
|
||||
@@ -20,7 +20,7 @@ import cutlass._mlir.dialects.cute_nvgpu as _cute_nvgpu_ir
|
||||
from cutlass._mlir import ir
|
||||
|
||||
from ..common import OpError
|
||||
from ...core import MmaOp, Trait, _pack_shape, rank, depth
|
||||
from ...core import MmaOp, Trait, _pack_shape, rank, depth, _Tensor
|
||||
from ...typing import (
|
||||
Shape,
|
||||
Float16,
|
||||
@@ -30,6 +30,7 @@ from ...typing import (
|
||||
Float8E5M2,
|
||||
Float8E4M3FN,
|
||||
Numeric,
|
||||
AddressSpace,
|
||||
)
|
||||
|
||||
|
||||
@@ -167,6 +168,30 @@ class MmaOp(MmaOp):
|
||||
+ f"\n Instruction shape MNK = {self.shape_mnk}"
|
||||
)
|
||||
|
||||
def _verify_fragment_A(self, input: _Tensor, *, loc=None, ip=None):
|
||||
if input.memspace == AddressSpace.smem and isinstance(
|
||||
input.layout.type, _cute_ir.ComposedLayoutType
|
||||
):
|
||||
raise OpError(
|
||||
self,
|
||||
f"Expected affine layout for {self._make_trait()}'s operand A, "
|
||||
f"but got composed layout instead: {input.layout}"
|
||||
f"\nPlease use recast_ptr(ptr, {input.layout.inner}, element_type) operation to move swizzle to the ptr",
|
||||
)
|
||||
return True
|
||||
|
||||
def _verify_fragment_B(self, input: _Tensor, *, loc=None, ip=None):
|
||||
if input.memspace == AddressSpace.smem and isinstance(
|
||||
input.layout.type, _cute_ir.ComposedLayoutType
|
||||
):
|
||||
raise OpError(
|
||||
self,
|
||||
f"Expected affine layout for {self._make_trait()}'s operand B, "
|
||||
f"but got composed layout instead: {input.layout}"
|
||||
f"\nPlease use recast_ptr(ptr, {input.layout.inner}, element_type) operation to move swizzle to the ptr",
|
||||
)
|
||||
return True
|
||||
|
||||
|
||||
class MmaTrait(Trait):
|
||||
admissible_fields = [Field.ACCUMULATE]
|
||||
|
||||
@@ -124,7 +124,7 @@ class _Pointer(Pointer):
|
||||
)
|
||||
|
||||
@property
|
||||
def element_type(self) -> Type[Numeric]:
|
||||
def dtype(self) -> Type[Numeric]:
|
||||
return self._dtype
|
||||
|
||||
@property
|
||||
|
||||
@@ -35,20 +35,7 @@ from inspect import isclass
|
||||
|
||||
|
||||
def assert_(cond, msg=None):
|
||||
if isinstance(cond, ir.Value):
|
||||
if ir.VectorType.isinstance(cond.type):
|
||||
assert (
|
||||
cond.type.element_type == T.bool()
|
||||
), f"only expects vector type with boolean elements, but got {cond.type}"
|
||||
cond_val = vector.multi_reduction(
|
||||
vector.CombiningKind.AND, cond, const(True), range(cond.type.rank)
|
||||
)
|
||||
else:
|
||||
cond_val = cond
|
||||
else:
|
||||
cond_val = const(cond, t.Boolean)
|
||||
|
||||
cf.assert_(cond_val, msg if msg else "")
|
||||
cf.assert_(t.Boolean(cond).ir_value(), msg if msg else "")
|
||||
|
||||
|
||||
def _maybe_recast_tensor_from_f4(src: core.Tensor, tv_layout: core.Layout):
|
||||
|
||||
@@ -56,14 +56,23 @@ XTuple = Union[IntTuple, Shape, Stride, Coord, Tile]
|
||||
Tiler = Union[Shape, Layout, Tile]
|
||||
|
||||
|
||||
class Pointer:
|
||||
class Pointer(ABC):
|
||||
"""
|
||||
Abstract base class for CuTe jit function and runtime _Pointer
|
||||
"""
|
||||
|
||||
def __extract_mlir_values__(self):
|
||||
# Doesn't matter just return a value
|
||||
return [self]
|
||||
@property
|
||||
def value_type(self) -> Type[Numeric]:
|
||||
return self.dtype
|
||||
|
||||
@property
|
||||
def dtype(self) -> Type[Numeric]: ...
|
||||
|
||||
def __get_mlir_types__(self) -> List[ir.Type]: ...
|
||||
|
||||
def __extract_mlir_values__(self) -> List[ir.Value]: ...
|
||||
|
||||
def __new_from_mlir_values__(self, values) -> "Pointer": ...
|
||||
|
||||
|
||||
class Tensor(ABC):
|
||||
@@ -144,10 +153,13 @@ class Tensor(ABC):
|
||||
|
||||
def store(self, data: "TensorSSA", *, loc=None, ip=None): ...
|
||||
|
||||
def mark_layout_dynamic(self, leading_dim: int|None = None) -> "Tensor": ...
|
||||
def mark_layout_dynamic(self, leading_dim: int | None = None) -> "Tensor": ...
|
||||
|
||||
def mark_compact_shape_dynamic(
|
||||
self, mode: int, stride_order: tuple[int, ...]|None = None, divisibility: int = 1
|
||||
self,
|
||||
mode: int,
|
||||
stride_order: tuple[int, ...] | None = None,
|
||||
divisibility: int = 1,
|
||||
) -> "Tensor": ...
|
||||
|
||||
@abstractmethod
|
||||
|
||||
@@ -30,6 +30,7 @@ class Agent(enum.Enum):
|
||||
"""
|
||||
Agent indicates what is participating in the pipeline synchronization.
|
||||
"""
|
||||
|
||||
# Arbitrary grouping of N threads
|
||||
Thread = enum.auto()
|
||||
# Same as AsyncThread, but includes all threads in the block
|
||||
@@ -42,6 +43,7 @@ class CooperativeGroup:
|
||||
"""
|
||||
CooperativeGroup contains size and alignment restrictions for an Agent.
|
||||
"""
|
||||
|
||||
def __init__(self, agent: Agent, size: int = 1, alignment: int = 1):
|
||||
if agent is Agent.Thread:
|
||||
assert size > 0
|
||||
@@ -76,6 +78,7 @@ class _PipelineOp(enum.Enum):
|
||||
"""
|
||||
PipelineOp assigns an operation to an agent corresponding to a specific hardware feature.
|
||||
"""
|
||||
|
||||
# async-threads
|
||||
AsyncThread = enum.auto()
|
||||
# Blackwell (SM100a) MMA instruction
|
||||
@@ -140,12 +143,8 @@ class MbarrierArray(SyncObjectArray):
|
||||
"Error: Mbarrier tx count must be greater than 0 for TMA ops."
|
||||
)
|
||||
|
||||
# Using a tensor to store mbarrier i64 ptrs
|
||||
self.mbarrier_array = cute.make_fragment(cute.make_layout(num_stages), Int64)
|
||||
for i in range(num_stages):
|
||||
self.mbarrier_array[i] = _cute_ir.ptrtoint(
|
||||
T.i64(), (self.barrier_storage + i).value
|
||||
)
|
||||
# Store mbarrier base pointer
|
||||
self.mbarrier_base = self.barrier_storage
|
||||
|
||||
# Mbarrier initialization in constructor
|
||||
self.mbarrier_init()
|
||||
@@ -155,10 +154,11 @@ class MbarrierArray(SyncObjectArray):
|
||||
"""
|
||||
Initializes an array of mbarriers using warp 0.
|
||||
"""
|
||||
|
||||
def then_body():
|
||||
for index in range(self.num_stages):
|
||||
cute.arch.mbarrier_init_arrive_cnt(
|
||||
_mbarrier_i64_to_ptr(self.mbarrier_array[index]), self.arrive_count
|
||||
self.get_barrier(index), self.arrive_count
|
||||
)
|
||||
|
||||
warp_idx = cute.arch.warp_idx()
|
||||
@@ -166,7 +166,12 @@ class MbarrierArray(SyncObjectArray):
|
||||
|
||||
if_generate(warp_idx == 0, then_body)
|
||||
|
||||
def arrive(self, index: int, dst: int):
|
||||
def arrive(
|
||||
self,
|
||||
index: int,
|
||||
dst: int,
|
||||
cta_group: Optional[cute.nvgpu.tcgen05.CtaGroup] = None,
|
||||
):
|
||||
"""
|
||||
Select the arrive corresponding to this MbarrierArray's PipelineOp
|
||||
:param index: Index of the mbarrier in the array to arrive on
|
||||
@@ -175,55 +180,53 @@ class MbarrierArray(SyncObjectArray):
|
||||
- For TCGen05Mma, dst serves as a multicast mask (e.g., 0b1011 allows arrive signal to be multicast to CTAs in the cluster with rank = 0, 1, and 3).
|
||||
- For AsyncThread, dst serves as a destination cta rank (e.g., 3 means threads will arrive on the mbarrier with rank = 3 in the cluster).
|
||||
:type dst: int | None
|
||||
:param cta_group: CTA group for TCGen05Mma, defaults to None for other op types
|
||||
:type cta_group: cute.nvgpu.tcgen05.CtaGroup, optional
|
||||
"""
|
||||
if self.op_type is _PipelineOp.AsyncThread:
|
||||
self.arrive_mbarrier(index, dst)
|
||||
elif self.op_type is _PipelineOp.TCGen05Mma:
|
||||
self.arrive_tcgen05mma(index, dst)
|
||||
assert (
|
||||
cta_group is not None
|
||||
), "Error: CTA group must be provided for TCGen05Mma."
|
||||
self.arrive_tcgen05mma(index, dst, cta_group)
|
||||
elif self.op_type in [_PipelineOp.TmaLoad]:
|
||||
self.arrive_and_expect_tx(index, self.tx_count)
|
||||
else:
|
||||
print(_get_pipeline_op(self.op_type))
|
||||
assert False, "Error: MbarrierArray is not supported for this PipelineOp."
|
||||
assert False, f"Error: MbarrierArray is not supported for PipelineOp: {_get_pipeline_op(self.op_type)}."
|
||||
|
||||
def arrive_mbarrier(self, index: int, dst_rank: int):
|
||||
if dst_rank is None:
|
||||
cute.arch.mbarrier_arrive(_mbarrier_i64_to_ptr(self.mbarrier_array[index]))
|
||||
cute.arch.mbarrier_arrive(self.get_barrier(index))
|
||||
else:
|
||||
cute.arch.mbarrier_arrive(
|
||||
_mbarrier_i64_to_ptr(self.mbarrier_array[index]), dst_rank
|
||||
)
|
||||
cute.arch.mbarrier_arrive(self.get_barrier(index), dst_rank)
|
||||
|
||||
def arrive_tcgen05mma(self, index: int, mask: int):
|
||||
def arrive_tcgen05mma(
|
||||
self, index: int, mask: int, cta_group: cute.nvgpu.tcgen05.CtaGroup
|
||||
):
|
||||
if mask is None:
|
||||
with cute.arch.elect_one():
|
||||
cute.nvgpu.tcgen05.commit(
|
||||
_mbarrier_i64_to_ptr(self.mbarrier_array[index])
|
||||
)
|
||||
cute.nvgpu.tcgen05.commit(self.get_barrier(index))
|
||||
else:
|
||||
with cute.arch.elect_one():
|
||||
cute.nvgpu.tcgen05.commit(
|
||||
_mbarrier_i64_to_ptr(self.mbarrier_array[index]),
|
||||
self.get_barrier(index),
|
||||
mask,
|
||||
cute.nvgpu.tcgen05.CtaGroup.TWO,
|
||||
cta_group,
|
||||
)
|
||||
|
||||
def arrive_and_expect_tx(self, index: int, tx_count: int):
|
||||
with cute.arch.elect_one():
|
||||
cute.arch.mbarrier_init_tx_bytes(
|
||||
_mbarrier_i64_to_ptr(self.mbarrier_array[index]), tx_count
|
||||
)
|
||||
cute.arch.mbarrier_init_tx_bytes(self.get_barrier(index), tx_count)
|
||||
|
||||
def try_wait(self, index: int, phase: int):
|
||||
return cute.arch.mbarrier_try_wait(
|
||||
_mbarrier_i64_to_ptr(self.mbarrier_array[index]), phase
|
||||
)
|
||||
return cute.arch.mbarrier_try_wait(self.get_barrier(index), phase)
|
||||
|
||||
def wait(self, index: int, phase: int):
|
||||
cute.arch.mbarrier_wait(_mbarrier_i64_to_ptr(self.mbarrier_array[index]), phase)
|
||||
cute.arch.mbarrier_wait(self.get_barrier(index), phase)
|
||||
|
||||
def get_barrier(self, index: int) -> cute.Pointer:
|
||||
return _mbarrier_i64_to_ptr(self.mbarrier_array[index])
|
||||
return self.mbarrier_base + index
|
||||
|
||||
|
||||
class TmaStoreFence(SyncObjectArray):
|
||||
@@ -390,6 +393,7 @@ class PipelineAsync:
|
||||
PipelineAsync is a generic pipeline class where both the producer and consumer are
|
||||
AsyncThreads. It also serves as a base class for specialized pipeline classes.
|
||||
"""
|
||||
|
||||
sync_object_array_full: SyncObjectArray
|
||||
sync_object_array_empty: SyncObjectArray
|
||||
num_stages: Int32
|
||||
@@ -522,6 +526,7 @@ class PipelineTmaAsync(PipelineAsync):
|
||||
"""
|
||||
PipelineTmaAsync is used for TMA producers and AsyncThread consumers (e.g. Hopper mainloops).
|
||||
"""
|
||||
|
||||
is_signalling_thread: bool
|
||||
|
||||
@staticmethod
|
||||
@@ -628,7 +633,6 @@ class PipelineTmaAsync(PipelineAsync):
|
||||
)
|
||||
self.sync_object_array_full.arrive(state.index, self.producer_mask)
|
||||
|
||||
|
||||
def producer_commit(self, state: PipelineState):
|
||||
"""
|
||||
TMA producer commit is a NOP. The transaction barrier signals the commit upon completion of the TMA.
|
||||
@@ -646,12 +650,15 @@ class PipelineTmaAsync(PipelineAsync):
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class PipelineTmaUmma(PipelineAsync):
|
||||
"""
|
||||
PipelineTmaUmma is used for TMA producers and UMMA consumers (e.g. Blackwell mainloops).
|
||||
"""
|
||||
|
||||
is_leader_cta: bool
|
||||
cta_group: cute.nvgpu.tcgen05.CtaGroup
|
||||
|
||||
@staticmethod
|
||||
def _compute_mcast_arrival_mask(cta_layout_vmnk: cute.Layout):
|
||||
@@ -748,6 +755,12 @@ class PipelineTmaUmma(PipelineAsync):
|
||||
producer_mask = PipelineTmaUmma._compute_mcast_arrival_mask(cta_layout_vmnk)
|
||||
is_leader_cta = PipelineTmaUmma._compute_is_leader_cta(cta_layout_vmnk)
|
||||
|
||||
cta_group = (
|
||||
cute.nvgpu.tcgen05.CtaGroup.ONE
|
||||
if cta_layout_vmnk is None or cute.size(cta_layout_vmnk, mode=[0]) == 1
|
||||
else cute.nvgpu.tcgen05.CtaGroup.TWO
|
||||
)
|
||||
|
||||
consumer_mask = producer_mask
|
||||
|
||||
pipeline_init_wait(cta_layout_vmnk)
|
||||
@@ -759,6 +772,15 @@ class PipelineTmaUmma(PipelineAsync):
|
||||
producer_mask,
|
||||
consumer_mask,
|
||||
is_leader_cta,
|
||||
cta_group,
|
||||
)
|
||||
|
||||
def consumer_release(self, state: PipelineState):
|
||||
"""
|
||||
UMMA consumer release buffer empty, cta_group needs to be provided.
|
||||
"""
|
||||
self.sync_object_array_empty.arrive(
|
||||
state.index, self.consumer_mask, self.cta_group
|
||||
)
|
||||
|
||||
def producer_acquire(
|
||||
@@ -789,6 +811,8 @@ class PipelineUmmaAsync(PipelineAsync):
|
||||
PipelineTmaUmma is used for UMMA producers and AsyncThread consumers (e.g. Blackwell accumulator pipelines).
|
||||
"""
|
||||
|
||||
cta_group: cute.nvgpu.tcgen05.CtaGroup
|
||||
|
||||
@staticmethod
|
||||
def _compute_tmem_sync_mask(cta_layout_vmnk: cute.Layout):
|
||||
"""
|
||||
@@ -858,6 +882,12 @@ class PipelineUmmaAsync(PipelineAsync):
|
||||
else:
|
||||
consumer_mask = PipelineUmmaAsync._compute_peer_cta_rank()
|
||||
|
||||
cta_group = (
|
||||
cute.nvgpu.tcgen05.CtaGroup.ONE
|
||||
if cta_layout_vmnk is None or cute.size(cta_layout_vmnk, mode=[0]) == 1
|
||||
else cute.nvgpu.tcgen05.CtaGroup.TWO
|
||||
)
|
||||
|
||||
pipeline_init_wait(cta_layout_vmnk)
|
||||
|
||||
return PipelineUmmaAsync(
|
||||
@@ -866,6 +896,15 @@ class PipelineUmmaAsync(PipelineAsync):
|
||||
num_stages,
|
||||
producer_mask,
|
||||
consumer_mask,
|
||||
cta_group,
|
||||
)
|
||||
|
||||
def producer_commit(self, state: PipelineState):
|
||||
"""
|
||||
UMMA producer commit buffer full, cta_group needs to be provided.
|
||||
"""
|
||||
self.sync_object_array_full.arrive(
|
||||
state.index, self.producer_mask, self.cta_group
|
||||
)
|
||||
|
||||
def producer_tail(self, state: PipelineState):
|
||||
|
||||
@@ -185,7 +185,7 @@ class SmemAllocator:
|
||||
and isinstance(layout.inner, cute.Swizzle)
|
||||
) and (swizzle is not None):
|
||||
raise TypeError(
|
||||
f"iterator swizzle with swizzle layout is currently not supported"
|
||||
f"Invalid tensor type: cannot be both iterator swizzle (PDSL) and swizzle layout(PISL) at the same time."
|
||||
)
|
||||
|
||||
if isinstance(layout, int):
|
||||
|
||||
@@ -527,13 +527,13 @@ def pack_from_irvalue(
|
||||
"""
|
||||
Packs MLIR values into a list of mixed values.
|
||||
"""
|
||||
log().info("===--- Values Pack (%d)", len(ir_values))
|
||||
log().debug("===--- Values Pack (%d)", len(ir_values))
|
||||
for idx, packed in enumerate(ir_values):
|
||||
log().info("[%d]: will-packed: %s", idx, ir_values)
|
||||
log().debug("[%d]: will-packed: %s", idx, ir_values)
|
||||
for idx, unpacked in indices.items():
|
||||
log().info("[%d]: indices: %s", idx, unpacked)
|
||||
log().debug("[%d]: indices: %s", idx, unpacked)
|
||||
for idx, c in enumerate(class_types):
|
||||
log().info("[%d]: obj-types: %s", idx, type(c))
|
||||
log().debug("[%d]: obj-types: %s", idx, type(c))
|
||||
|
||||
mixed_values = [None] * len(indices)
|
||||
for idx, (start, length) in sorted(indices.items()):
|
||||
@@ -552,10 +552,10 @@ def pack_from_irvalue(
|
||||
except DSLRuntimeError as e:
|
||||
mixed_values[idx] = chunk[0]
|
||||
|
||||
log().info("------------------ ")
|
||||
log().debug("------------------ ")
|
||||
for idx, packed in enumerate(mixed_values):
|
||||
log().info("[%d]: packed: %s", idx, packed)
|
||||
log().info("------------------ ")
|
||||
log().debug("[%d]: packed: %s", idx, packed)
|
||||
log().debug("------------------ ")
|
||||
return mixed_values
|
||||
|
||||
|
||||
@@ -571,9 +571,9 @@ def unpack_to_irvalue(
|
||||
class_types = []
|
||||
current_offset = 0
|
||||
|
||||
log().info("===--- Values UNPack (%d)", len(mixed_values))
|
||||
log().debug("===--- Values UNPack (%d)", len(mixed_values))
|
||||
for idx, packed in enumerate(mixed_values):
|
||||
log().info("[%d]: will-unpacked: [type:%s] %s", idx, type(packed), packed)
|
||||
log().debug("[%d]: will-unpacked: [type:%s] %s", idx, type(packed), packed)
|
||||
for idx, item in enumerate(mixed_values):
|
||||
class_types.append(item)
|
||||
try:
|
||||
@@ -612,16 +612,16 @@ def unpack_to_irvalue(
|
||||
),
|
||||
) from e
|
||||
|
||||
log().info("------------------ ")
|
||||
log().debug("------------------ ")
|
||||
for idx, unpacked in enumerate(unpacked_values):
|
||||
log().info("[%d]: unpacked values: %s", idx, unpacked)
|
||||
log().debug("[%d]: unpacked values: %s", idx, unpacked)
|
||||
for idx, unpacked in enumerate(ir_values):
|
||||
log().info("[%d]: unpacked ir_values: %s", idx, unpacked)
|
||||
log().debug("[%d]: unpacked ir_values: %s", idx, unpacked)
|
||||
for idx, unpacked in indices.items():
|
||||
log().info("[%d]: indices: %s", idx, unpacked)
|
||||
log().debug("[%d]: indices: %s", idx, unpacked)
|
||||
for idx, unpacked in enumerate(class_types):
|
||||
log().info("[%d]: initial-class-types: %s", idx, unpacked)
|
||||
log().info("------------------ ")
|
||||
log().debug("[%d]: initial-class-types: %s", idx, unpacked)
|
||||
log().debug("------------------ ")
|
||||
|
||||
return ir_values, unpacked_values, indices, class_types
|
||||
|
||||
@@ -1302,7 +1302,6 @@ class WhileLoopContext:
|
||||
|
||||
def __exit__(self, exc_type, exc_value, traceback):
|
||||
self.ipoint_op.__exit__(exc_type, exc_value, traceback)
|
||||
return True
|
||||
|
||||
@property
|
||||
def results(self):
|
||||
|
||||
@@ -331,11 +331,7 @@ def _if_execute_dynamic(
|
||||
# Assume final result types match the dynamic yields
|
||||
result_types = [arg.type for arg in dyn_yield_ops]
|
||||
|
||||
pred_ = t.as_numeric(pred)
|
||||
|
||||
if not isinstance(pred_, Boolean):
|
||||
# Convert to Boolean through comparison
|
||||
pred_ = pred_ == True
|
||||
pred_ = Boolean(pred)
|
||||
|
||||
try:
|
||||
if_op = scf.IfOp(
|
||||
|
||||
Reference in New Issue
Block a user