v4.2 tag release. (#2638)

This commit is contained in:
Junkai-Wu
2025-09-16 00:21:53 +08:00
committed by GitHub
parent 56f0718a97
commit 6a35b4d22f
161 changed files with 14056 additions and 3793 deletions

View File

@@ -207,7 +207,7 @@ def int_to_int(a, dst_elem_type, *, loc=None, ip=None):
if dst_width == src_width:
return a
elif src_signed and not dst_signed:
elif src_signed != False and not dst_signed:
# Signed -> Unsigned
if dst_width > src_width:
return arith.extui(dst_mlir_type, a, loc=loc, ip=ip)
@@ -216,7 +216,7 @@ def int_to_int(a, dst_elem_type, *, loc=None, ip=None):
elif src_signed == dst_signed:
# Same signedness
if dst_width > src_width:
if src_signed and src_width > 1:
if src_signed != False and src_width > 1:
return arith.extsi(dst_mlir_type, a, loc=loc, ip=ip)
else:
return arith.extui(dst_mlir_type, a, loc=loc, ip=ip)
@@ -479,7 +479,7 @@ class ArithValue(ir.Value):
if self.is_float:
q = arith.divf(self, other, loc=loc, ip=ip)
return math.floor(q, loc=loc, ip=ip)
elif self.signed:
elif self.signed != False:
return arith.floordivsi(self, other, loc=loc, ip=ip)
else:
return arith.divui(self, other, loc=loc, ip=ip)
@@ -489,7 +489,7 @@ class ArithValue(ir.Value):
def __mod__(self, other, *, loc=None, ip=None) -> "ArithValue":
if self.is_float:
return arith.remf(self, other, loc=loc, ip=ip)
elif self.signed:
elif self.signed != False:
return arith.remsi(self, other, loc=loc, ip=ip)
else:
return arith.remui(self, other, loc=loc, ip=ip)
@@ -524,7 +524,7 @@ class ArithValue(ir.Value):
def __lt__(self, other, *, loc=None, ip=None) -> "ArithValue":
if self.is_float:
return arith.cmpf(arith.CmpFPredicate.OLT, self, other, loc=loc, ip=ip)
elif self.signed:
elif self.signed != False:
return arith.cmpi(arith.CmpIPredicate.slt, self, other, loc=loc, ip=ip)
else:
return arith.cmpi(arith.CmpIPredicate.ult, self, other, loc=loc, ip=ip)
@@ -534,7 +534,7 @@ class ArithValue(ir.Value):
def __le__(self, other, *, loc=None, ip=None) -> "ArithValue":
if self.is_float:
return arith.cmpf(arith.CmpFPredicate.OLE, self, other, loc=loc, ip=ip)
elif self.signed:
elif self.signed != False:
return arith.cmpi(arith.CmpIPredicate.sle, self, other, loc=loc, ip=ip)
else:
return arith.cmpi(arith.CmpIPredicate.ule, self, other, loc=loc, ip=ip)
@@ -561,7 +561,7 @@ class ArithValue(ir.Value):
def __gt__(self, other, *, loc=None, ip=None) -> "ArithValue":
if self.is_float:
return arith.cmpf(arith.CmpFPredicate.OGT, self, other, loc=loc, ip=ip)
elif self.signed:
elif self.signed != False:
return arith.cmpi(arith.CmpIPredicate.sgt, self, other, loc=loc, ip=ip)
else:
return arith.cmpi(arith.CmpIPredicate.ugt, self, other, loc=loc, ip=ip)
@@ -571,7 +571,7 @@ class ArithValue(ir.Value):
def __ge__(self, other, *, loc=None, ip=None) -> "ArithValue":
if self.is_float:
return arith.cmpf(arith.CmpFPredicate.OGE, self, other, loc=loc, ip=ip)
elif self.signed:
elif self.signed != False:
return arith.cmpi(arith.CmpIPredicate.sge, self, other, loc=loc, ip=ip)
else:
return arith.cmpi(arith.CmpIPredicate.uge, self, other, loc=loc, ip=ip)
@@ -599,7 +599,7 @@ class ArithValue(ir.Value):
@_dispatch_to_rhs_r_op
@_binary_op
def __rshift__(self, other, *, loc=None, ip=None) -> "ArithValue":
if self.signed:
if self.signed != False:
return arith.shrsi(self, other, loc=loc, ip=ip)
else:
return arith.shrui(self, other, loc=loc, ip=ip)
@@ -633,7 +633,7 @@ class ArithValue(ir.Value):
return super().__hash__()
def __str__(self):
return super().__str__().replace(ir.Value.__name__, ArithValue.__name__)
return "?"
def __repr__(self):
return self.__str__()
@@ -657,7 +657,7 @@ def _min(lhs, rhs, *, loc=None, ip=None):
rhs = arith.constant(lhs.type, rhs, loc=loc, ip=ip)
if arith._is_integer_like_type(lhs.type):
if lhs.signed:
if lhs.signed != False:
return arith.minsi(lhs, rhs, loc=loc, ip=ip)
else:
return arith.minui(lhs, rhs, loc=loc, ip=ip)
@@ -683,7 +683,7 @@ def _max(lhs, rhs, *, loc=None, ip=None):
rhs = arith.constant(lhs.type, rhs, loc=loc, ip=ip)
if arith._is_integer_like_type(lhs.type):
if lhs.signed:
if lhs.signed != False:
return arith.maxsi(lhs, rhs, loc=loc, ip=ip)
else:
return arith.maxui(lhs, rhs, loc=loc, ip=ip)

View File

@@ -17,12 +17,16 @@ The preprocessor read through python's ast and changes the input code.
from typing import Callable, Iterator, Optional, overload
from typing_extensions import deprecated
import warnings
import inspect
from types import BuiltinFunctionType
from functools import lru_cache
from .utils.logger import log
from .common import *
from ._mlir_helpers.arith import ArithValue
class Executor:
"""
The Executor class handles dynamic and compile-time (constexpr) execution
@@ -45,9 +49,11 @@ class Executor:
self._compare_executor = None
self._any_executor = None
self._all_executor = None
self._builtin_redirector = None
def set_functions(
self,
*,
is_dynamic_expression: Callable,
loop_execute_range_dynamic: Callable,
if_dynamic: Callable,
@@ -55,6 +61,7 @@ class Executor:
compare_executor: Callable,
any_executor: Callable = None,
all_executor: Callable = None,
builtin_redirector: Callable = None,
):
self._is_dynamic_expression = is_dynamic_expression
self._loop_execute_range_dynamic = loop_execute_range_dynamic
@@ -63,6 +70,7 @@ class Executor:
self._compare_executor = compare_executor
self._any_executor = any_executor
self._all_executor = all_executor
self._builtin_redirector = builtin_redirector
@staticmethod
def convert_to_list(x):
@@ -90,42 +98,18 @@ class Executor:
return res[0]
return res
@staticmethod
def for_constexpr(
func: Callable,
start: int,
stop: int,
step: int,
used_args: list,
iter_args: list,
):
log().debug("start [%s] stop [%s] step [%s]", start, stop, step)
loop_results = iter_args
log().debug("iter_args [%s]", iter_args)
for i in range(start, stop, step):
log().debug("i [%s] iter_args [%s]", i, iter_args)
loop_results = func(i, *used_args, *loop_results)
log().debug("loop_results [%s]", loop_results)
if loop_results is None:
loop_results = []
if not isinstance(loop_results, list):
loop_results = [loop_results]
log().debug("done loop_results [%s]", loop_results)
return Executor.converge_ret_val(loop_results)
def for_execute(
self,
func,
start,
stop,
step,
used_args=[],
iter_args=[],
iter_arg_names=[],
write_args=[],
full_write_args_count=0,
write_args_names=[],
unroll=-1,
unroll_full=False,
pipelining=None,
prefetch_stages=None,
):
assert (
self._loop_execute_range_dynamic
@@ -137,12 +121,12 @@ class Executor:
start,
stop,
step,
used_args,
iter_args,
iter_arg_names,
write_args,
full_write_args_count,
write_args_names,
unroll,
unroll_full,
pipelining,
prefetch_stages,
)
def if_execute(
@@ -150,15 +134,20 @@ class Executor:
pred,
then_block: Callable,
else_block: Optional[Callable] = None,
used_args=[],
yield_args=[],
yield_arg_names=[],
write_args=[],
full_write_args_count=0,
write_args_names=[],
):
assert self._if_dynamic, "Functions must be set before execution."
# MLIR generation
return self._if_dynamic(
pred, then_block, else_block, used_args, yield_args, yield_arg_names
pred,
then_block,
else_block,
write_args,
full_write_args_count,
write_args_names,
)
def while_execute(
@@ -166,9 +155,9 @@ class Executor:
pred,
while_before_block: Callable,
while_after_block: Callable,
used_args=[],
yield_args=[],
yield_arg_names=[],
write_args=[],
full_write_args_count=0,
write_args_names=[],
):
assert self._while_dynamic, "Functions must be set before execution."
@@ -176,9 +165,9 @@ class Executor:
return self._while_dynamic(
while_before_block,
while_after_block,
used_args,
yield_args,
yield_arg_names,
write_args,
full_write_args_count,
write_args_names,
)
@@ -194,23 +183,24 @@ def loop_selector(
stop,
step,
*,
used_args=[],
iter_args=[],
iter_arg_names=[],
write_args=[],
full_write_args_count=0,
write_args_names=[],
unroll=-1,
unroll_full=False,
pipelining=None,
prefetch_stages=None,
):
log().debug(
"start [%s] stop [%s] step [%s] used_args [%s] iter_args [%s] unroll [%s] unroll_full [%s] pipelining [%s]",
"start [%s] stop [%s] step [%s] write_args [%s] full_write_args_count [%s] write_args_names [%s] unroll [%s] unroll_full [%s] prefetch_stages [%s]",
start,
stop,
step,
used_args,
iter_args,
write_args,
full_write_args_count,
write_args_names,
unroll,
unroll_full,
pipelining,
prefetch_stages,
)
from .typing import Integer, Numeric
@@ -230,19 +220,19 @@ def loop_selector(
start,
stop,
step,
used_args,
iter_args,
iter_arg_names,
write_args,
full_write_args_count,
write_args_names,
unroll,
unroll_full,
pipelining,
prefetch_stages,
)
return ir_loop
def if_selector(pred, used_args=[], yield_args=[]):
log().debug("pred [%s] used_args [%s] yield_args [%s]", pred, used_args, yield_args)
def if_selector(pred, write_args=[]):
log().debug("pred [%s] write_args [%s]", pred, write_args)
# Handle Numeric types here?
from .typing import Numeric
@@ -251,14 +241,14 @@ def if_selector(pred, used_args=[], yield_args=[]):
pred = pred.value
def ir_loop(func):
return func(pred, *used_args, *yield_args)
return func(pred, *write_args)
return ir_loop
def while_selector(pred, used_args=[], yield_args=[]):
def while_selector(pred, write_args=[]):
def ir_while_loop(func):
return func(pred, *used_args, *yield_args)
return func(pred, *write_args)
return ir_while_loop
@@ -267,17 +257,17 @@ def while_executor(
pred,
while_before_block: Callable,
while_after_block: Callable,
used_args=[],
yield_args=[],
yield_arg_names=[],
write_args=[],
full_write_args_count=0,
write_args_names=[],
):
return executor.while_execute(
pred,
while_before_block,
while_after_block,
used_args,
yield_args,
yield_arg_names,
write_args,
full_write_args_count,
write_args_names,
)
@@ -285,12 +275,17 @@ def if_executor(
pred,
then_block: Callable,
else_block: Optional[Callable] = None,
used_args=[],
yield_args=[],
yield_arg_names=[],
write_args=[],
full_write_args_count=0,
write_args_names=[],
):
return executor.if_execute(
pred, then_block, else_block, used_args, yield_args, yield_arg_names
pred,
then_block,
else_block,
write_args,
full_write_args_count,
write_args_names,
)
@@ -313,14 +308,17 @@ class range:
- unroll: Number of iterations to unroll (0 or 1 = no unrolling)
- unroll_full: Whether to fully unroll the loop
- pipelining: Compiler generated pipeline configuration
- prefetch_stages: Number of prefetch stages to generate
"""
@overload
def __new__(cls, stop, unroll=0, unroll_full=False, pipelining=None):
def __new__(cls, stop, unroll=0, unroll_full=False, prefetch_stages=None):
pass
@overload
def __new__(cls, start, stop, step, unroll=0, unroll_full=False, pipelining=None):
def __new__(
cls, start, stop, step, unroll=0, unroll_full=False, prefetch_stages=None
):
pass
def __new__(cls, *args, **kwargs):
@@ -340,6 +338,7 @@ def range_dynamic(*args, **kwargs):
def range_constexpr(*args):
raise DSLRuntimeError("range_constexpr should be preprocessed by preprocessor.")
# =============================================================================
# If expressions
# =============================================================================
@@ -405,7 +404,7 @@ def assert_executor(test, msg=None):
else:
raise DSLRuntimeError(
"Only constexpr (Python Value) is allowed here, but got non-constexpr (IR Values) expression.",
suggestion = "Please replace with runtime assert."
suggestion="Please replace with runtime assert.",
)
@@ -413,10 +412,11 @@ def bool_cast(value):
if executor._is_dynamic_expression(value):
raise DSLRuntimeError(
"Only constexpr (Python Value) is allowed here, but got non-constexpr (IR Values) expression.",
suggestion = "Please explicitly convert to boolean with expressions like comparision."
suggestion="Please explicitly convert to boolean with expressions like comparision.",
)
return bool(value)
def compare_executor(left, comparators, ops):
"""
Executes comparison operations with a left operand and a list of comparators.
@@ -470,6 +470,19 @@ def all_executor(iterable):
# =============================================================================
# Control flow checks
# =============================================================================
class DSLOptimizationWarning(Warning):
"""
This warning is used to warn the user about the optimization related issues in DSL.
"""
def __init__(self, message):
self.message = message
super().__init__()
def __str__(self):
return self.message
def range_value_check(*args):
"""
Ensure all `range_constexpr` bounds are compile-time constants (Python ints).
@@ -495,7 +508,7 @@ def range_value_check(*args):
if range_length >= 64:
warnings.warn(
f"This static loop has {range_length} iterations, which may be very slow to compile, consider using `cutlass.range(..., unroll_full=True)` instead.",
category=UserWarning,
category=DSLOptimizationWarning,
stacklevel=2,
)
@@ -519,7 +532,50 @@ def range_perf_warning(filename, lineno, *args):
"This loop is no longer unrolled and may cause performance regression. "
"Use `range(..., unroll_full=True)` for full unrolling, or switch to `range_constexpr` when bounds are compile-time constants."
),
category=UserWarning,
category=DSLOptimizationWarning,
filename=filename,
lineno=lineno,
)
@lru_cache(maxsize=1)
def _get_self_module():
"""
This function is used to get the owning module of this function.
"""
return inspect.getmodule(_get_self_module)
def cf_symbol_check(symbol):
"""
Check if the symbol is control flow symbol from current module.
"""
failed = False
name = symbol.__name__
self_module = _get_self_module()
if inspect.ismodule(symbol):
name = "range"
if not self_module.__name__.startswith(symbol.__name__):
failed = True
else:
owning_module = inspect.getmodule(symbol)
if owning_module != self_module:
failed = True
if failed:
raise DSLRuntimeError(
f"Incorrect {symbol.__name__} is used.",
suggestion=f"Please avoid overriding `{symbol.__name__}` from DSL package.",
)
def redirect_builtin_function(fcn):
"""
This function is used to redirect built-in function call
to the function defined in DSL package.
"""
# Only redirect if it's a built-in
if isinstance(fcn, BuiltinFunctionType) and executor._builtin_redirector:
return executor._builtin_redirector(fcn)
return fcn

File diff suppressed because it is too large Load Diff

View File

@@ -139,8 +139,7 @@ def dump_cache_to_path(
dsl_name, jit_cache, cache_limit, path=default_generated_ir_path
):
log().info("JIT cache : dumping [%s] items=[%s]", dsl_name, len(jit_cache))
if not os.path.exists(path):
os.makedirs(path)
os.makedirs(path, exist_ok=True)
original_path = os.getcwd()
try:
os.chdir(path)

View File

@@ -205,6 +205,8 @@ class CompileOptions:
self._parser.add_argument(
"--enable-device-assertions", action="store_true", default=False
)
self._parser.add_argument("--link-libraries", type=str, default="")
try:
self._options = self._parser.parse_args(options.split())
except SystemExit as e:

View File

@@ -32,13 +32,14 @@ import hashlib
from functools import lru_cache, wraps
from collections import namedtuple
from abc import ABC, abstractmethod
from typing import Any, Union, Tuple, get_origin, get_args
from types import FunctionType
from typing import Any, Union, Tuple, get_origin, get_args, List
from types import FunctionType, SimpleNamespace
import warnings
from . import typing as t
from .env_manager import EnvironmentVarManager
from .compiler import CompileOptions
from .ast_helpers import DSLOptimizationWarning
# =============================================================================
# CUDA Python
@@ -56,7 +57,7 @@ from .utils.timer import timer
from .utils.logger import setup_log, log
from .utils.stacktrace import filter_exception, walk_to_top_module, filter_stackframe
from .runtime.jit_arg_adapters import is_argument_constexpr, JitArgAdapterRegistry
from .runtime.tensor_descriptor import TensorDescriptor
from .ast_preprocessor import DSLPreprocessor
from .common import *
from .typing import (
@@ -73,12 +74,6 @@ from .._mlir import runtime as rt
from .._mlir.extras import types as T
from .._mlir.dialects import arith, math, func
# =============================================================================
# cutlass.dlpack_runtime
# =============================================================================
from .runtime.dlpack_runtime import dlpack_to_tensor_desc, mark_layout_dynamic
# =============================================================================
# Global Variables
# =============================================================================
@@ -177,6 +172,7 @@ def is_dynamic_expression(value):
return True
return False
def extract_mlir_values(obj):
"""
Given the `obj`, recursively go through it to extract all contained IR values as list of MLIR values
@@ -186,6 +182,10 @@ def extract_mlir_values(obj):
res = obj.__extract_mlir_values__()
elif isinstance(obj, (tuple, list)):
res = sum((extract_mlir_values(x) for x in obj), [])
elif isinstance(obj, SimpleNamespace):
res = []
for k, v in obj.__dict__.items():
res.extend(extract_mlir_values(v))
# Can't call is_dynamic_expression as _is_dynamic_expression depends on extract_mlir_values
elif isinstance(obj, set):
raise DSLRuntimeError(
@@ -215,6 +215,13 @@ def new_from_mlir_values(obj, values):
values = values[n_items:]
obj_ty = type(obj)
return obj_ty(res)
elif isinstance(obj, SimpleNamespace):
res = SimpleNamespace()
for k, v in obj.__dict__.items():
n_items = len(get_mlir_types(v))
res.__dict__[k] = new_from_mlir_values(v, values[:n_items])
values = values[n_items:]
return res
elif isinstance(obj, set):
raise DSLRuntimeError(
"Sets are not supported in new_from_mlir_values to ensure order preservation",
@@ -249,8 +256,6 @@ class DSLCallable:
Methods:
__call__(*args, **kwargs): Calls the wrapped function and clears it.
get_arg_spec(): Returns the argument specification of the function.
get_signature(): Returns the signature of the function.
"""
def __init__(self, func):
@@ -266,23 +271,23 @@ class DSLCallable:
assert self.func is not None, "DSLCallable is already called"
return self.func
@property
def __signature__(self):
return inspect.signature(self.__func__)
@property
def __name__(self):
return self.__func__.__name__
def get_arg_spec(self):
return inspect.getfullargspec(self.__func__)
def get_signature(self):
return inspect.signature(self.__func__)
class BaseDSL:
gpu_module = None
def __init__(
self,
*,
name: str,
dsl_package_name: List[str],
compiler_provider: Any,
pass_sm_arch_name: str,
device_compilation_only=False,
@@ -293,6 +298,7 @@ class BaseDSL:
Parameters:
- name (str): Name of DSL, used for environment variables and logging.
- package_name (str): Name of the package, used for the preprocessor.
- compiler_provider (MLIR dialect): Provider for compiler.
- pass_sm_arch_name (str): The keyword name of the SM.
- device_compilation_only (bool) : Only device code, and call it via cuda driver
@@ -330,6 +336,9 @@ class BaseDSL:
self.device_jit_decorator_name = f"@{BaseDSL.kernel.__name__}"
# set warning
if not self.envar.enable_optimization_warnings:
# By default, optimization warnings are disabled
warnings.filterwarnings("ignore", category=DSLOptimizationWarning)
if self.envar.warnings_as_errors:
warnings.filterwarnings("error")
if self.envar.warnings_ignore:
@@ -355,7 +364,7 @@ class BaseDSL:
self.compile_options = CompileOptions()
if preprocess:
self.preprocessor = DSLPreprocessor()
self.preprocessor = DSLPreprocessor(dsl_package_name)
log().info(f"Initializing {name} DSL")
log().debug(f"Logger initialized for {self.name}")
@@ -656,7 +665,7 @@ class BaseDSL:
return ir_args, ir_kwargs
@abstractmethod
def _generate_mlir_type_for_tensor_descriptor(self, tensor: TensorDescriptor):
def _generate_mlir_type_for_tensor_descriptor(self, tensor):
"""
Generate MLIR type for the tensor descriptor.
"""
@@ -671,13 +680,6 @@ class BaseDSL:
"""
pass
@abstractmethod
def _get_module_globals(self):
"""
Get the module's globals.
"""
pass
def _get_globals(self):
"""
Combines global and local variables from the current context and the
@@ -690,43 +692,21 @@ class BaseDSL:
AST preprocessor generates a new python code, so the resulting globals
dictionary is used to execute the python code.
"""
all_globals = self._get_module_globals().copy()
all_globals = {}
if self.frame:
all_globals.update(self.frame.f_globals)
all_globals.update(self.frame.f_locals)
return all_globals
@abstractmethod
def _is_tensor_descriptor(self, maybe_tensor_descriptor) -> bool:
return isinstance(
maybe_tensor_descriptor, TensorDescriptor
) or TensorDescriptor.can_transformed_to_dlpack(maybe_tensor_descriptor)
pass
@abstractmethod
def _handle_tensor_descriptor(
self, maybe_tensor, arg_name: str, need_gpu_memory: bool
) -> TensorDescriptor:
if self._is_tensor_descriptor(maybe_tensor):
tensor = (
maybe_tensor
if isinstance(maybe_tensor, TensorDescriptor)
else TensorDescriptor(maybe_tensor)
)
if need_gpu_memory and not tensor.is_in_device:
log().info(
"FAIL name=[%s] tensor=[%s] in_gpu=[%s]",
arg_name,
tensor,
tensor.is_in_device,
)
raise DSLRuntimeError(
f'Tensor "{arg_name}" is tensor "{tensor}" '
"is not in the GPU memory. "
)
return tensor
raise DSLRuntimeError(
f"Argument {arg_name} could not be transformed into a TensorDescriptor."
)
) -> Any:
pass
def _validate_arg(self, arg, arg_index, arg_name, arg_spec):
"""
@@ -882,10 +862,11 @@ class BaseDSL:
cluster: list = None
grid: list = field(default_factory=lambda: [1, 1, 1])
block: list = field(default_factory=lambda: [1, 1, 1])
smem: int = 0
smem: int = None
async_deps: list = field(default_factory=list)
has_cluster: bool = False
min_blocks_per_mp: int = 0
auto_smem: bool = False
def __post_init__(self):
if len(self.grid) != 3:
@@ -893,6 +874,10 @@ class BaseDSL:
if len(self.block) != 3:
raise DSLRuntimeError(f"Expect 3d block!")
if self.smem is None:
self.smem = 0
self.auto_smem = True
self.has_cluster = self.cluster is not None
if self.cluster is None:
self.cluster = [None, None, None]
@@ -1116,8 +1101,6 @@ class BaseDSL:
try:
result = funcBody(*ir_args, **ir_kwargs)
func.ReturnOp([])
except DSLAstPreprocessorError as pp_error:
raise pp_error
except NameError as name_error:
raise DSLRuntimeError(
f"💥💥💥 Error during runtime code generation for function `{funcBody.__name__}` 💥💥💥",
@@ -1127,11 +1110,6 @@ class BaseDSL:
except DSLRuntimeError as dsl_error:
# Throw it's already a DSL error
raise dsl_error
except Exception as general_e:
# Transform internal error to a DSL error
raise DSLRuntimeError(
f"💥💥💥 Error during runtime code generation for function `{funcBody.__name__}` 💥💥💥"
) from general_e
return module, result
# Build IR module
@@ -1328,10 +1306,8 @@ class BaseDSL:
raise DSLRuntimeError("Function body is not set.")
# Pass the actual function object to inspect.signature to get the signature.
if isinstance(self.funcBody, DSLCallable):
sig = self.funcBody.get_signature()
else:
sig = inspect.signature(self.funcBody)
sig = inspect.signature(self.funcBody)
function_name = self.funcBody.__name__
bound_args = self._get_function_bound_args(sig, function_name, *args, **kwargs)
@@ -1382,10 +1358,7 @@ class BaseDSL:
# Check the number of arguments
sig = self._check_arg_count(*args, **kwargs)
if isinstance(funcBody, DSLCallable):
args_spec = funcBody.get_arg_spec()
else:
args_spec = inspect.getfullargspec(funcBody)
args_spec = inspect.getfullargspec(funcBody)
# Canonicalize the input arguments
canonicalized_args, canonicalized_kwargs = self._canonicalize_args(
@@ -1447,7 +1420,7 @@ class BaseDSL:
return cuda_helpers.stream_create()
def _execute_cuda(
self, fname_cubin, kernel_name, grid_size, block_size, stream=None
self, fname_cubin, kernel_name, grid_size, block_size, smem_size, stream=None
):
"""
Executes a specified CUDA kernel from a cubin file, handling module loading,
@@ -1471,7 +1444,7 @@ class BaseDSL:
grid_size,
block_size,
stream,
smem_size=16000,
smem_size=smem_size,
kernel_args=self.exe_args,
)
@@ -1480,7 +1453,13 @@ class BaseDSL:
cuda_helpers.stream_sync(stream)
def _execute_by_cuda_driver(
self, kernel_generator, generate_cubin, grid_size, block_size, stream=None
self,
kernel_generator,
generate_cubin,
grid_size,
block_size,
smem_size,
stream=None,
):
"""
This function builds IR and execute the module using cuda driver.
@@ -1511,10 +1490,9 @@ class BaseDSL:
fname_cubin = generate_cubin(module, kernel_name)
# Execute a cuda kernel from cubin
if block_size is None:
# The TileIR driver should set this automatically.
block_size = self.block_size
self._execute_cuda(fname_cubin, kernel_name, grid_size, block_size, stream)
self._execute_cuda(
fname_cubin, kernel_name, grid_size, block_size, smem_size, stream
)
return ret
@@ -1587,10 +1565,7 @@ class BaseDSL:
kernelGenHelper = dkwargs.get("kernelGenHelper", None)
kernel_name = funcBody.__name__
if isinstance(funcBody, DSLCallable):
args_spec = funcBody.get_arg_spec()
else:
args_spec = inspect.getfullargspec(funcBody)
args_spec = inspect.getfullargspec(funcBody)
self.funcBody = funcBody
# Give each kernel a unique name. (The same kernel may be

View File

@@ -58,6 +58,11 @@ def get_int_env_var(var_name, default_value=0):
return int(value) if value and value.isdigit() else default_value
@lru_cache(maxsize=None)
def has_env_var(var_name):
return os.getenv(var_name) is not None
def detect_gpu_arch(prefix):
"""
Attempts to detect the machine's GPU architecture.
@@ -256,6 +261,7 @@ class EnvironmentVarManager:
- [DSL_NAME]_ARCH: GPU architecture (default: "sm_100")
- [DSL_NAME]_WARNINGS_AS_ERRORS: Enable warnings as error (default: False)
- [DSL_NAME]_WARNINGS_IGNORE: Ignore warnings (default: False)
- [DSL_NAME]_ENABLE_OPTIMIZATION_WARNINGS: Enable warnings of optimization warnings (default: False)
- [DSL_NAME]_JIT_TIME_PROFILING: Whether or not to profile the IR generation/compilation/execution time (default: False)
- [DSL_NAME]_DISABLE_FILE_CACHING: Disable file caching (default: False)
- [DSL_NAME]_FILE_CACHING_CAPACITY: Limits the number of the cache save/load files (default: 1000)
@@ -267,7 +273,6 @@ class EnvironmentVarManager:
self.prefix = prefix # change if needed
# Printing options
self.log_to_console = get_bool_env_var(f"{prefix}_LOG_TO_CONSOLE", False)
self.print_after_preprocessor = get_bool_env_var(
f"{prefix}_PRINT_AFTER_PREPROCESSOR", False
)
@@ -275,15 +280,29 @@ class EnvironmentVarManager:
self.filterStacktrace = get_bool_env_var(f"{prefix}_FILTER_STACKTRACE", True)
# File options
self.keepIR = get_bool_env_var(f"{prefix}_KEEP_IR", False)
# Logging options
self.log_to_console = get_bool_env_var(f"{prefix}_LOG_TO_CONSOLE", False)
self.log_to_file = get_bool_env_var(f"{prefix}_LOG_TO_FILE", False)
# Other options
if (
has_env_var(f"{prefix}_LOG_LEVEL")
and not self.log_to_console
and not self.log_to_file
):
log().warning(
f"Log level was set, but neither logging to file ({prefix}_LOG_TO_FILE) nor logging to console ({prefix}_LOG_TO_CONSOLE) is enabled!"
)
self.log_level = get_int_env_var(f"{prefix}_LOG_LEVEL", 1)
# Other options
self.dryrun = get_bool_env_var(f"{prefix}_DRYRUN", False)
self.arch = get_str_env_var(f"{prefix}_ARCH", detect_gpu_arch(prefix))
self.warnings_as_errors = get_bool_env_var(
f"{prefix}_WARNINGS_AS_ERRORS", False
)
self.warnings_ignore = get_bool_env_var(f"{prefix}_WARNINGS_IGNORE", False)
self.enable_optimization_warnings = get_bool_env_var(
f"{prefix}_ENABLE_OPTIMIZATION_WARNINGS", False
)
self.jitTimeProfiling = get_bool_env_var(f"{prefix}_JIT_TIME_PROFILING", False)
self.disable_file_caching = get_bool_env_var(
f"{prefix}_DISABLE_FILE_CACHING", False

View File

@@ -14,16 +14,12 @@ This module provides a runtime utility functions that are needed for
the DSL.
"""
from . import device_tensor
from . import dlpack_types
from . import cuda
from . import tensor_descriptor
from . import jit_arg_adapters
__all__ = [
"device_tensor",
"dlpack_types",
"cuda",
"tensor_descriptor",
"jit_arg_adapters",
]

View File

@@ -309,7 +309,7 @@ def get_kernel_function(module, kernel_name):
return kernel
def launch_kernel(kernel, grid_dims, block_dims, stream, smem_size=0, kernel_args=None):
def launch_kernel(kernel, grid_dims, block_dims, stream, smem_size, kernel_args=None):
"""
Launches the CUDA kernel.
"""

View File

@@ -183,6 +183,13 @@ class TensorDescriptor:
"""
return self.device_type == _dpack.DLDeviceType.kDLGPU
@staticmethod
def is_compatible(maybe_tensor_descriptor) -> bool:
"""Check if the object is a TensorDescriptor or can be converted to one."""
return isinstance(
maybe_tensor_descriptor, TensorDescriptor
) or TensorDescriptor.can_transformed_to_dlpack(maybe_tensor_descriptor)
def from_tensor(tensor) -> TensorDescriptor:
"""Create a TensorDescriptor from a tensor object."""
@@ -192,10 +199,3 @@ def from_tensor(tensor) -> TensorDescriptor:
def to_tensor(tensor_descriptor: TensorDescriptor):
"""Return tensor object from tensor descriptor."""
return tensor_descriptor.tensor
def is_tensor_descriptor(maybe_tensor_descriptor) -> bool:
"""Check if the object is a TensorDescriptor."""
return isinstance(
maybe_tensor_descriptor, TensorDescriptor
) or TensorDescriptor.can_transformed_to_dlpack(maybe_tensor_descriptor)

View File

@@ -126,6 +126,7 @@ from .core import (
basic_copy_if,
autovec_copy,
copy,
copy_atom_call,
gemm,
# Wrapper classes
ComposedLayout,
@@ -290,6 +291,7 @@ __all__ = [
"basic_copy_if",
"autovec_copy",
"copy",
"copy_atom_call",
"gemm",
# Tensor creation
"full",

View File

@@ -315,3 +315,35 @@ def mbarrier_arrive(
loc=loc,
ip=ip,
)
@dsl_user_op
def cp_async_mbarrier_arrive_noinc(mbar_ptr: Pointer, *, loc=None, ip=None) -> None:
"""
Arrives on an mbarrier for async load **without incrementing** the arrival count
(`cp.async.mbarrier.arrive.shared ..., noinc=1`).
Used in the warp-specialized kernel when the non-TMA load warp(producer) is not the same
as the math/epilogue warp(consumer).
:param mbar_ptr: A pointer to the mbarrier in SMEM
:type mbar_ptr: Pointer
"""
arch = CuTeDSL._get_dsl().envar.arch
check_value_in(
arch,
[
"sm_90",
"sm_90a",
"sm_100a",
"sm_100f",
],
"arch",
)
mbar_llvm_ptr = mbar_ptr.llvm_ptr
nvvm.cp_async_mbarrier_arrive_shared(
mbar_llvm_ptr,
noinc=True,
loc=loc,
ip=ip,
)

View File

@@ -11,6 +11,7 @@
from functools import partial
from typing import Optional, Tuple, Union, Callable
from typing_extensions import deprecated
from cutlass.cutlass_dsl import T, dsl_user_op
@@ -642,6 +643,9 @@ def rcp_approx(a: Union[float, Float32], *, loc=None, ip=None):
@dsl_user_op
@deprecated(
"cute.arch.exp2 is deprecated, use cute.math.exp2 with `fastmath=True` instead"
)
def exp2(a: Union[float, Float32], *, loc=None, ip=None) -> Float32:
return Float32(
llvm.inline_asm(
@@ -656,15 +660,19 @@ def exp2(a: Union[float, Float32], *, loc=None, ip=None) -> Float32:
)
# TODO: add `fastmath` flag for this op
@dsl_user_op
@deprecated(
"cute.arch.exp is deprecated, use cute.math.exp with `fastmath=True` instead"
)
def exp(a: Union[float, Float32], *, loc=None, ip=None) -> Float32:
LOG2_E = 1.4426950408889634
return exp2(a * LOG2_E, loc=loc, ip=ip)
# TODO: add `fastmath` flag for this op
@dsl_user_op
@deprecated(
"cute.arch.exp_packed_f32x2 is deprecated, use cute.arch.mul_packed_f32x2 and cute.math.exp2 with `fastmath=True` instead"
)
def exp_packed_f32x2(
a: Tuple[Float32, Float32], *, loc=None, ip=None
) -> Tuple[Float32, Float32]:

View File

@@ -31,7 +31,6 @@ from typing import (
Optional,
)
from enum import Enum, auto
from typing_extensions import deprecated
from cutlass.cutlass_dsl import (
const,
@@ -1662,7 +1661,9 @@ class _Tensor(Tensor):
@dsl_user_op
def print_tensor(tensor: Tensor, *, verbose: bool = False, loc=None, ip=None):
def print_tensor(
tensor: Union[Tensor, "TensorSSA"], *, verbose: bool = False, loc=None, ip=None
):
"""Print content of the tensor in human readable format.
Outputs the tensor data in a structured format showing both metadata
@@ -1693,6 +1694,11 @@ def print_tensor(tensor: Tensor, *, verbose: bool = False, loc=None, ip=None):
[ 0.9159, 0.7577, 0.6918, 0.0754, 0.0591],
[ 0.6551, 0.1626, 0.1189, 0.0292, 0.8655]])
"""
if isinstance(tensor, TensorSSA):
tmp = make_fragment(tensor.shape, tensor.dtype)
tmp.store(tensor)
tensor = tmp
if not isinstance(tensor.type, _cute_ir.MemRefType):
raise NotImplementedError(
f"printing {tensor} is not supported because it doesn't support trivial dereferencing. "
@@ -1769,7 +1775,7 @@ def is_static(x: Union[ir.Type, ir.Value, XTuple]) -> bool:
return False
elif is_dynamic_expression(x):
return _cute_ir.is_static(x.type)
elif isinstance(x, int) or x is None:
elif isinstance(x, (bool, int, float)) or x is None:
return True
elif isinstance(x, ScaledBasis):
return x.is_static()
@@ -2241,7 +2247,7 @@ def is_weakly_congruent(
* X is a non-tuple value, OR
* X and Y are both tuples of the same rank AND all corresponding elements are weakly congruent.
Weak congruence allows scalar values to match with tuples, making it useful
Weak congruence allows scalar values to match with tuples, making it useful
for determining whether an object has a hierarchical structure "up to" another.
:param a: First object to compare
@@ -2921,33 +2927,46 @@ def flatten_to_tuple(a: Union[IntTuple, Coord, Shape, Stride]) -> tuple:
return tuple(chain.from_iterable(tuple(flatten_to_tuple(x) for x in a)))
def flatten(a: Union[IntTuple, Coord, Shape, Stride, Layout, Tensor]) -> tuple:
@overload
def flatten(a: Union[IntTuple, Coord, Shape, Stride]) -> IntTuple: ...
@overload
def flatten(a: Tensor) -> Tensor: ...
@overload
def flatten(a: Layout) -> Layout: ...
def flatten(a):
"""Flattens a CuTe data structure into a simpler form.
For tuples, this function flattens the structure into a single-level tuple.
For non-tuple types, it returns the input unchanged.
For layouts, it returns a new layout with flattened shape and stride.
For tensors, it returns a new tensor with flattened layout.
For other types, it returns the input unchanged.
:param a: The structure to flatten
:type a: Union[IntTuple, Coord, Shape, Stride, Layout, Tensor]
:return: The flattened structure
:rtype: Union[tuple, Any]
:raises NotImplementedError: If input is a Layout or Tensor
**Examples:**
.. code-block:: python
flatten((1, 2, 3)) # Returns (1, 2, 3)
flatten(((1, 2), (3, 4))) # Returns (1, 2, 3, 4)
flatten(5) # Returns 5
"""
if isinstance(a, (Layout, Tensor)):
raise NotImplementedError("flatten layout and tensor is not supported")
flatten((1, 2, 3)) # Returns (1, 2, 3)
flatten(((1, 2), (3, 4))) # Returns (1, 2, 3, 4)
flatten(5) # Returns 5
flatten(Layout(shape, stride)) # Returns Layout(flatten(shape), flatten(stride))
flatten(Tensor(layout)) # Returns Tensor(flatten(layout))
if not isinstance(a, tuple):
return a
else:
"""
if isinstance(a, Tensor):
return make_tensor(a.iterator, flatten(a.layout))
elif isinstance(a, Layout):
return make_layout(flatten(a.shape), stride=flatten(a.stride))
elif isinstance(a, tuple):
return flatten_to_tuple(a)
else:
return a
def unflatten(
@@ -4120,14 +4139,14 @@ def complement(
@dsl_user_op
def right_inverse(input: Layout, *, loc=None, ip=None) -> Layout:
if not isinstance(input, Layout):
raise TypeError(f"expects input of type Layout, but got {type(Layout)}")
raise TypeError(f"expects input of type Layout, but got {type(input)}")
return _cute_ir.right_inverse(input=input, loc=loc, ip=ip)
@dsl_user_op
def left_inverse(input: Layout, *, loc=None, ip=None) -> Layout:
if not isinstance(input, Layout):
raise TypeError(f"expects input of type Layout, but got {type(Layout)}")
raise TypeError(f"expects input of type Layout, but got {type(input)}")
return _cute_ir.left_inverse(input=input, loc=loc, ip=ip)
@@ -5156,7 +5175,6 @@ def _make_tiled_copy(atom, layout_tv, tiler_mn, *, loc=None, ip=None):
return TiledCopy(atom.op, trait)
@deprecated("Use make_tiled_copy_tv instead")
def make_tiled_copy(atom, layout_tv, tiler_mn, *, loc=None, ip=None):
"""Create a tiled type given a TV partitioner and tiler.
@@ -5434,6 +5452,14 @@ def gemm(
For MMA Atoms that require single-threaded execution, the gemm op automatically handles thread
election internally. Manual thread selection is not required in such cases.
Following dispatch rules are supported:
- Dispatch [1]: (V) x (V) => (V) => (V,1,1) x (V,1,1) => (V,1,1)
- Dispatch [2]: (M) x (N) => (M,N) => (1,M,1) x (1,N,1) => (1,M,N)
- Dispatch [3]: (M,K) x (N,K) => (M,N) => (1,M,K) x (1,N,K) => (1,M,N)
- Dispatch [4]: (V,M) x (V,N) => (V,M,N) => (V,M,1) x (V,N,1) => (V,M,N)
- Dispatch [5]: (V,M,K) x (V,N,K) => (V,M,N)
:param atom: MMA atom
:type atom: MmaAtom
:param d: Destination tensor
@@ -5454,6 +5480,27 @@ def gemm(
:rtype: None
"""
a_rank = rank(a.shape)
b_rank = rank(b.shape)
c_rank = rank(c.shape)
d_rank = rank(d.shape)
if a_rank != b_rank:
raise ValueError("`a` and `b` must have the same rank")
if c_rank != d_rank:
raise ValueError("`c` and `d` must have the same rank")
if a_rank == 1:
if c_rank > 2:
raise ValueError("`c` must have rank <= 2 when `a` has rank 1")
elif a_rank == 2:
if c_rank not in (2, 3):
raise ValueError("`c` must have rank 2 or 3 when `a` has rank 2")
elif a_rank == 3:
if c_rank != 3:
raise ValueError("`c` must have rank 3 when `a` has rank 3")
value = atom._unpack(loc=loc, ip=ip, **kwargs)
return _cute_ir.gemm(value, d.value, a.value, b.value, c.value, loc=loc, ip=ip)
@@ -5645,6 +5692,76 @@ def copy(
return _cute_ir.copy(value, src.value, dst.value, pred=pred, loc=loc, ip=ip)
@dsl_user_op
def copy_atom_call(
atom: CopyAtom,
src: Tensor,
dst: Tensor,
*,
pred: Optional[Tensor] = None,
loc=None,
ip=None,
**kwargs,
) -> None:
"""
Execute a single copy atom operation.
The copy_atom_call operation executes a copy atom with the given operands.
Following src/dst layout of atom are valid:
* ((atom_v))
* (atom_v)
Note: The format ((atom_v, rest_v)) is NOT valid for copy_atom_call since it would
require multiple atom operations, which contradicts the definition of a single copy atom call.
Examples:
.. code-block:: python
# Call a copy atom operation
cute.copy_atom_call(copy_atom, src_tensor, dst_tensor)
An additional predication tensor can be provided. If the partitioned tensors have the following
logical profile ``((ATOM_V,ATOM_REST),REST_M,...)``, the predication tensor must have a profile
consistent with ``(ATOM_REST,REST_M,...)``.
"""
if isinstance(src.type, _cute_ir.MemRefType) and isinstance(
dst.type, _cute_ir.MemRefType
):
if src.element_type.width != dst.element_type.width:
raise TypeError(
"`copy_atom_call` currently only supports equal source and destination "
"element type bit width"
)
value = atom._unpack(loc=loc, ip=ip, **kwargs)
if isinstance(pred, Tensor):
pred = pred.value
return _cute_ir.copy_atom_call(
value, src.value, dst.value, pred=pred, loc=loc, ip=ip
)
def prefetch(atom: CopyAtom, src: Tensor, *, loc=None, ip=None) -> None:
"""
The Prefetch algorithm.
The "prefetch" expects source tensors to be partitioned according to the provided Copy Atom.
Prefetch is used for loading tensors from global memory to L2.
Prefetch accepts Copy Atom but not all are allowed. Currently, only support for tma load tensor prefetch.
.. code-block:: python
cute.prefetch(tma_atom, src)
For Copy Atoms that require single-threaded execution, the copy op automatically handles thread
election internally. Manual thread selection is not required in such cases.
"""
dummy_tma_bar_ptr = make_ptr(Int64, 0, AddressSpace.smem, loc=loc, ip=ip)
value = atom._unpack(loc=loc, ip=ip, tma_bar_ptr=dummy_tma_bar_ptr)
return _cute_ir.prefetch(value, src.value, loc=loc, ip=ip)
####################################################################################################
#
# TensorSSA class (experimental)
@@ -5657,6 +5774,11 @@ class ReductionOp(Enum):
MUL = auto()
MAX = auto()
MIN = auto()
INC = auto()
DEC = auto()
AND = auto()
OR = auto()
XOR = auto()
def __str__(self):
return self.name.lower()
@@ -5697,6 +5819,7 @@ class TensorSSA(cutlass_arith.ArithValue):
self._shape = shape
self._dtype = dtype
self._layout = None
@property
def dtype(self) -> Type[Numeric]:
@@ -5776,13 +5899,26 @@ class TensorSSA(cutlass_arith.ArithValue):
):
res_type = Boolean
if lhs.shape != rhs.shape:
raise ValueError(
f"lhs and rhs must have the same shape type, but got {lhs.shape} and {rhs.shape}"
)
assert isinstance(rhs, TensorSSA), f"rhs must be TensorSSA but got {rhs}"
if not isinstance(rhs, TensorSSA):
raise TypeError(f"rhs must be TensorSSA but got {rhs}")
def _broadcast(s, t):
if s == 1:
return t
elif t == 1:
return s
elif s == t:
return s
else:
raise ValueError(f"cannot broadcast {s} and {t}")
max_rank = max(rank(lhs.shape), rank(rhs.shape))
lhs_shape = append(lhs.shape, 1, up_to_rank=max_rank)
rhs_shape = append(rhs.shape, 1, up_to_rank=max_rank)
res_shape = transform_leaf(_broadcast, lhs_shape, rhs_shape)
# broadcast to the same shape
lhs = lhs.broadcast_to(res_shape)
rhs = rhs.broadcast_to(res_shape)
if (
op in (operator.add, operator.sub)
@@ -5807,6 +5943,38 @@ class TensorSSA(cutlass_arith.ArithValue):
return res
def broadcast_to(self, target_shape: Shape, *, loc=None, ip=None) -> "TensorSSA":
"""
Broadcast the tensor to the target shape.
"""
# pad source shape to the same rank
shape = append(self.shape, 1, up_to_rank=rank(target_shape))
if shape == target_shape:
return self
def _check_broadcast(s, t):
if s != t and s != 1:
raise ValueError(
f"src_shape and target_shape must be the same when src_shape is not 1, but got {s} and {t}"
)
transform_leaf(_check_broadcast, shape, target_shape)
# reshape to flatten N-D vector
flat_shp = flatten_to_tuple(shape)
temp_ty = ir.VectorType.get(list(flat_shp), self.dtype.mlir_type)
temp_vect = vector.shape_cast(temp_ty, self, loc=loc, ip=ip)
# broadcast to result N-D vector
flat_tgt_shp = flatten_to_tuple(target_shape)
temp_tgt_ty = ir.VectorType.get(list(flat_tgt_shp), self.dtype.mlir_type)
temp_tgt_vect = vector.broadcast(temp_tgt_ty, temp_vect, loc=loc, ip=ip)
res_1d_ty = ir.VectorType.get([size(target_shape)], self.dtype.mlir_type) # type: ignore
res_1d_vect = vector.shape_cast(res_1d_ty, temp_tgt_vect, loc=loc, ip=ip)
return TensorSSA(res_1d_vect, target_shape, self.dtype)
def __pow__(self, other, *, loc=None, ip=None) -> "TensorSSA":
"""
Returns the results of tensor^other.
@@ -6093,6 +6261,16 @@ class TensorSSA(cutlass_arith.ArithValue):
"""
return self._apply_op(operator.and_, other, flip=True, loc=loc, ip=ip)
def __neg__(self, *, loc=None, ip=None) -> "TensorSSA":
"""
Returns the negation of the tensor.
:return: The element-wise negation of the tensor
:rtype: TensorSSA
"""
return self._apply_op(operator.sub, 0, flip=True, loc=loc, ip=ip)
def _flatten_shape_and_coord(self, crd, *, loc=None, ip=None):
# Coalesce and flatten source layout at terminal of coordinate
# (N_0,(N_1,...), ...) -> (N_0,N_1,N_2,...)
@@ -6158,17 +6336,13 @@ class TensorSSA(cutlass_arith.ArithValue):
if crd is None:
return self
if not has_underscore(crd) or depth(crd) == 0:
idx = crd2idx(crd, make_layout(self._shape))
if is_static(idx):
res = vector.extract(
self, dynamic_position=[], static_position=[idx], loc=loc, ip=ip
)
else:
res = vector.extract(
self, dynamic_position=[crd], static_position=[], loc=loc, ip=ip
)
return self.dtype(res)
if not has_underscore(crd):
if self._layout is None:
self._layout = make_layout(self._shape, loc=loc, ip=ip)
idx = crd2idx(crd, self._layout, loc=loc, ip=ip)
idx_val = as_numeric(idx).ir_value(loc=loc, ip=ip)
res_val = vector.extractelement(self, position=idx_val, loc=loc, ip=ip)
return self.dtype(res_val)
if not is_static(crd):
raise ValueError("dynamic coordinate is not supported")
@@ -6274,7 +6448,7 @@ class TensorSSA(cutlass_arith.ArithValue):
:type op: operator
:param init_val: The initial value for the reduction
:type init_val: numeric
:param reduction_profile: Specifies which dimensions to reduce. Dimensions marked with '_' are kept.
:param reduction_profile: Specifies which dimensions to reduce. Dimensions marked with `None` are kept.
:type reduction_profile: Coord
:return: The reduced tensor
@@ -6289,9 +6463,9 @@ class TensorSSA(cutlass_arith.ArithValue):
reduce(f32 o (4, 5))
=> f32
reduce(f32 o (4, (5, 4)), reduction_profile=(_, 1))
reduce(f32 o (4, (5, 4)), reduction_profile=(None, 1))
=> f32 o (4,)
reduce(f32 o (4, (5, 4)), reduction_profile=(_, (_, 1)))
reduce(f32 o (4, (5, 4)), reduction_profile=(None, (None, 1)))
=> f32 o (4, (5,))
"""
# short-cut to no-op
@@ -6354,21 +6528,6 @@ class TensorSSA(cutlass_arith.ArithValue):
return self._build_result(res_vect, res_shp, loc=loc, ip=ip)
def _get_attr_for_type(ty, value):
if isinstance(ty, ir.IntegerType):
return ir.IntegerAttr.get(ty, value.to(int))
elif isinstance(ty, ir.FloatType):
return ir.FloatAttr.get(ty, value.to(float))
else:
raise TypeError(f"unsupported type: {ty}")
def _splat(res_ty, fill_value):
elem_attr = _get_attr_for_type(res_ty.element_type, fill_value)
vect_attr = ir.DenseElementsAttr.get_splat(res_ty, elem_attr)
return arith.constant(res_ty, vect_attr)
@dsl_user_op
def full(shape, fill_value, dtype: Type[Numeric], *, loc=None, ip=None) -> TensorSSA:
"""
@@ -6389,9 +6548,14 @@ def full(shape, fill_value, dtype: Type[Numeric], *, loc=None, ip=None) -> Tenso
if isinstance(fill_value, (ir.Value, int, float, bool)):
fill_value = dtype(fill_value)
elif isinstance(fill_value, Numeric):
fill_value = fill_value.to(dtype, loc=loc, ip=ip)
else:
raise ValueError(f"Expected fill_value be numeric type, but got {fill_value}")
res_mlir_type = T.vector(size, dtype.mlir_type)
return TensorSSA(_splat(res_mlir_type, fill_value), shape, dtype)
res_ty = T.vector(size, dtype.mlir_type)
res_val = vector.splat(res_ty, fill_value.ir_value(loc=loc, ip=ip), loc=loc, ip=ip)
return TensorSSA(res_val, shape, dtype)
def full_like(
@@ -6547,7 +6711,7 @@ class struct:
**Usage:**
.. code-block::
.. code-block:: python
# Supports base_dsl scalar int/float elements, array and nested struct:
@cute.struct
@@ -6661,7 +6825,8 @@ class struct:
Initializes a new memory range.
:param dtype: The data type.
:param size: The size of the memory range in bytes.
:param size: Size of the memory range in bytes. A size of **0** is accepted, but in that
case the range can only be used for its address (e.g. as a partition marker).
:param base: The base address of the memory range.
"""
self._dtype = dtype
@@ -6673,9 +6838,9 @@ class struct:
Returns start pointer to the data in this memory range.
:return: A pointer to the start of the memory range.
:raises AssertionError: If the size of the memory range is not greater than zero.
:raises AssertionError: If the size of the memory range is negative.
"""
assert self._size > 0
assert self._size >= 0
return recast_ptr(self._base, dtype=self._dtype)
def get_tensor(self, layout, swizzle=None, dtype=None):
@@ -6716,31 +6881,48 @@ class struct:
:param v: The object to align. Must be a struct, MemRange, or a scalar type.
:param align: The alignment value to set.
:return: A copy of the object with the specified alignment.
:raises TypeError: If the object is not a struct, MemRange, or a scalar type.
:ivar _dtype: The data type to be aligned.
:ivar _align: The alignment of the data type.
"""
_dtype = None
_align = None
def __new__(cls, name, bases, dct):
return super().__new__(cls, name, bases, dct)
def __getitem__(cls, params) -> Any:
if len(params) == 2:
obj, align = params
dtype, align = params
assert align > 0
else:
raise TypeError("Invalid struct.Align Arguments")
# make a copy of type and mark alignment
if struct._is_scalar_type(obj) or isinstance(
obj, (struct, struct._MemRangeMeta)
if not struct._is_scalar_type(dtype) and not isinstance(
dtype, (struct, struct._MemRangeMeta)
):
new_obj = py_copy.copy(obj)
setattr(new_obj, "_struct_alignment_", align)
return new_obj
else:
raise TypeError(
"align only can be applied to sturct/MemRange/base_dsl scalar"
"align only can be applied to struct/MemRange/base_dsl scalar"
)
# Create new class with alignment
new_cls = type(
f"struct.Align[{dtype.__name__}, {align}]",
(struct.Align,),
{"_dtype": dtype, "_align": align},
)
return new_cls
@property
def dtype(cls):
return cls._dtype
@property
def align(cls):
return cls._align
class Align(metaclass=_AlignMeta):
"""
Aligns the given type by `Align[T, alignment]`.
@@ -6768,6 +6950,7 @@ class struct:
:raises TypeError: If the struct is empty.
"""
self._cls = cls
self.__name__ = f"struct::{cls.__name__}"
# Get the class annotations
self._annotations = cls.__annotations__
# Create a dictionary to store the offsets
@@ -6780,12 +6963,10 @@ class struct:
raise TypeError("Empty struct is not supported!")
for name, object in self._annotations.items():
# get alignment of object
def alignof(object, default: int = 1):
return getattr(object, "_struct_alignment_", default)
# alignment for the next offset
def align_offset(offset, align):
return (offset + (align - 1)) & ~(align - 1)
sub_align = 1
if isinstance(object, struct._AlignMeta):
sub_align = object.align
object = object.dtype
# switch addition order to support dynamic size
def add_offset(val):
@@ -6793,35 +6974,37 @@ class struct:
# size of scalar
if struct._is_scalar_type(object):
dtype_size = object.width // 8
sub_align = alignof(object, dtype_size)
offset = align_offset(offset, sub_align)
dtype_size = max(1, object.width // 8)
sub_align = max(dtype_size, sub_align)
offset = self.align_offset(offset, sub_align)
self._offsets[name] = offset
offset = add_offset(dtype_size)
# size of array is size_in_bytes, alignment is elem_size
elif isinstance(object, struct._MemRangeMeta):
if object.size == 0:
continue # skip empty array
sub_align = alignof(object, max(1, object.elem_width // 8))
offset = align_offset(offset, sub_align)
# Allow empty array as a free marker-only struct member.
# Use max(sub_align, ) because we might have in the future some
# object.elem_width less than 8, such as fp4, bit and others,
# and align_offset() does not support an alignment of 0.
sub_align = max(object.elem_width // 8, sub_align)
offset = self.align_offset(offset, sub_align)
self._offsets[name] = offset
offset = add_offset(object.size_in_bytes)
# size of struct
elif isinstance(object, struct):
sub_align = max(object.__alignof__(), alignof(object))
offset = align_offset(offset, sub_align)
sub_align = max(object.__alignof__(), sub_align)
offset = self.align_offset(offset, sub_align)
self._offsets[name] = offset
offset = add_offset(object.__sizeof__())
else:
raise TypeError(
f"Struct element only support sturct/array/base_dsl scalar, "
f"Struct element only support struct/array/base_dsl scalar, "
f"but got {object}"
)
# Total aligment determined by the strictest requirement
alignment = max(alignment, sub_align)
# Total size determined by alignment
self._align_of = alignment
self._size_of = align_offset(offset, alignment)
self._size_of = self.align_offset(offset, alignment)
# create the __init__ method for decorated struct
def __call__(self, base: Any) -> None:
@@ -6840,6 +7023,8 @@ class struct:
setattr(cls, "_base", base)
for name, off in self._offsets.items():
obj = self._annotations[name]
if isinstance(obj, struct._AlignMeta):
obj = obj.dtype
if struct._is_scalar_type(obj):
new_obj = recast_ptr(base + off, dtype=obj)
setattr(cls, name, new_obj)
@@ -6851,7 +7036,7 @@ class struct:
setattr(cls, name, new_obj)
else:
raise TypeError(
f"Struct element only support sturct/array/base_dsl scalar, "
f"Struct element only support struct/array/base_dsl scalar, "
f"but got {obj}"
)
return cls
@@ -6872,3 +7057,14 @@ class struct:
# get alignment
def __alignof__(self) -> int:
return self._align_of
# util func for aligning offset
@staticmethod
def align_offset(offset, align):
"""
Return the round-up offset up to the next multiple of align.
"""
assert align > 0 and not (
align & (align - 1)
), "align should be a strictly positive power of 2."
return (offset + (align - 1)) & ~(align - 1)

View File

@@ -10,16 +10,53 @@
# is strictly prohibited.
from .core import TensorSSA
from .typing import Numeric
from cutlass._mlir.dialects import math, arith
from typing import Callable, Union
def acos(a: TensorSSA) -> TensorSSA:
def _math_op(func: Callable, fastmath: bool, *args, **kwargs):
"""Dispatch the function to either a TensorSSA or a Numeric(Float).
:param func: The function to dispatch
:param args: The input tensor or scalar
:param kwargs: The input tensor or scalar
"""
arg_type = type(args[0])
for arg in args:
if not isinstance(arg, TensorSSA) and (
not isinstance(arg, Numeric) or not type(arg).is_float
):
raise TypeError(
f"Expected a TensorSSA or Numeric(Float), but got {type(arg)}"
)
if not isinstance(arg, arg_type):
raise TypeError(
f"Expected all inputs to be of type {arg_type}, but got {type(arg)}"
)
fastmath_flag = arith.FastMathFlags.fast if fastmath else arith.FastMathFlags.none
if isinstance(args[0], TensorSSA):
return TensorSSA(
func(*args, fastmath=fastmath_flag), args[0].shape, args[0].dtype
)
else:
args = [a.ir_value() for a in args]
return func(*args, fastmath=fastmath_flag)
def acos(
a: Union[TensorSSA, Numeric], fastmath: bool = False
) -> Union[TensorSSA, Numeric]:
"""Compute element-wise arc cosine of the input tensor.
:param a: Input tensor
:type a: TensorSSA
:type a: Union[TensorSSA, Numeric]
:param fastmath: Enable fast math optimizations, defaults to False
:type fastmath: bool, optional
:return: Tensor containing the arc cosine of each element in input tensor
:rtype: TensorSSA
:rtype: Union[TensorSSA, Numeric]
Example:
@@ -29,16 +66,20 @@ def acos(a: TensorSSA) -> TensorSSA:
y = x.load() # Load values
z = acos(y) # Compute arc cosine
"""
return TensorSSA(math.acos(a, fastmath=arith.FastMathFlags.none), a.shape, a.dtype)
return _math_op(math.acos, fastmath, a)
def asin(a: TensorSSA) -> TensorSSA:
def asin(
a: Union[TensorSSA, Numeric], fastmath: bool = False
) -> Union[TensorSSA, Numeric]:
"""Compute element-wise arc sine of the input tensor.
:param a: Input tensor
:type a: TensorSSA
:type a: Union[TensorSSA, Numeric]
:param fastmath: Enable fast math optimizations, defaults to False
:type fastmath: bool, optional
:return: Tensor containing the arc sine of each element in input tensor
:rtype: TensorSSA
:rtype: Union[TensorSSA, Numeric]
Example:
@@ -48,18 +89,20 @@ def asin(a: TensorSSA) -> TensorSSA:
y = x.load() # Load values
z = asin(y) # Compute arc sine
"""
return TensorSSA(math.asin(a, fastmath=arith.FastMathFlags.none), a.shape, a.dtype)
return _math_op(math.asin, fastmath, a)
def atan(a: TensorSSA, fastmath: bool = False) -> TensorSSA:
def atan(
a: Union[TensorSSA, Numeric], fastmath: bool = False
) -> Union[TensorSSA, Numeric]:
"""Compute element-wise arc tangent of the input tensor.
:param a: Input tensor
:type a: TensorSSA
:type a: Union[TensorSSA, Numeric]
:param fastmath: Enable fast math optimizations, defaults to False
:type fastmath: bool, optional
:return: Tensor containing the arc tangent of each element in input tensor
:rtype: TensorSSA
:rtype: Union[TensorSSA, Numeric]
Example:
@@ -70,23 +113,25 @@ def atan(a: TensorSSA, fastmath: bool = False) -> TensorSSA:
z = atan(y) # Compute arc tangent
"""
raise NotImplementedError("atan is not implemented")
return TensorSSA(math.atan(a, fastmath=arith.FastMathFlags.none), a.shape, a.dtype)
return _math_op(math.atan, fastmath, a)
def atan2(a: TensorSSA, b: TensorSSA, fastmath: bool = False) -> TensorSSA:
def atan2(
a: Union[TensorSSA, Numeric], b: Union[TensorSSA, Numeric], fastmath: bool = False
) -> Union[TensorSSA, Numeric]:
"""Compute element-wise arc tangent of two tensors.
Computes atan2(a, b) element-wise. The function atan2(a, b) is the angle in radians
between the positive x-axis and the point given by the coordinates (b, a).
:param a: First input tensor (y-coordinates)
:type a: TensorSSA
:type a: Union[TensorSSA, Numeric]
:param b: Second input tensor (x-coordinates)
:type b: TensorSSA
:type b: Union[TensorSSA, Numeric]
:param fastmath: Enable fast math optimizations, defaults to False
:type fastmath: bool, optional
:return: Tensor containing the arc tangent of a/b element-wise
:rtype: TensorSSA
:rtype: Union[TensorSSA, Numeric]
Example:
@@ -96,20 +141,20 @@ def atan2(a: TensorSSA, b: TensorSSA, fastmath: bool = False) -> TensorSSA:
x = cute.make_fragment(ptr2, layout).load() # x coordinates
theta = atan2(y, x) # Compute angles
"""
return TensorSSA(
math.atan2(a, b, fastmath=arith.FastMathFlags.none), a.shape, a.dtype
)
return _math_op(math.atan2, fastmath, a, b)
def cos(a: TensorSSA, fastmath: bool = False) -> TensorSSA:
def cos(
a: Union[TensorSSA, Numeric], fastmath: bool = False
) -> Union[TensorSSA, Numeric]:
"""Compute element-wise cosine of the input tensor.
:param a: Input tensor (in radians)
:type a: TensorSSA
:type a: Union[TensorSSA, Numeric]
:param fastmath: Enable fast math optimizations, defaults to False
:type fastmath: bool, optional
:return: Tensor containing the cosine of each element
:rtype: TensorSSA
:rtype: Union[TensorSSA, Numeric]
Example:
@@ -119,21 +164,23 @@ def cos(a: TensorSSA, fastmath: bool = False) -> TensorSSA:
y = x.load() # Load values
z = cos(y) # Compute cosine
"""
return TensorSSA(math.cos(a, fastmath=arith.FastMathFlags.none), a.shape, a.dtype)
return _math_op(math.cos, fastmath, a)
def erf(a: TensorSSA, fastmath: bool = False) -> TensorSSA:
def erf(
a: Union[TensorSSA, Numeric], fastmath: bool = False
) -> Union[TensorSSA, Numeric]:
"""Compute element-wise error function of the input tensor.
The error function is defined as:
erf(x) = 2/√π ∫[0 to x] exp(-t²) dt
:param a: Input tensor
:type a: TensorSSA
:type a: Union[TensorSSA, Numeric]
:param fastmath: Enable fast math optimizations, defaults to False
:type fastmath: bool, optional
:return: Tensor containing the error function value for each element
:rtype: TensorSSA
:rtype: Union[TensorSSA, Numeric]
Example:
@@ -143,18 +190,43 @@ def erf(a: TensorSSA, fastmath: bool = False) -> TensorSSA:
y = x.load() # Load values
z = erf(y) # Compute error function
"""
return TensorSSA(math.erf(a, fastmath=arith.FastMathFlags.none), a.shape, a.dtype)
return _math_op(math.erf, fastmath, a)
def exp2(a: TensorSSA, fastmath: bool = False) -> TensorSSA:
def exp(
a: Union[TensorSSA, Numeric], fastmath: bool = False
) -> Union[TensorSSA, Numeric]:
"""Compute element-wise exponential of the input tensor.
:param a: Input tensor
:type a: Union[TensorSSA, Numeric]
:param fastmath: Enable fast math optimizations, defaults to False
:type fastmath: bool, optional
:return: Tensor containing the exponential of each element
:rtype: Union[TensorSSA, Numeric]
Example:
.. code-block::
x = cute.make_fragment(layout) # Create tensor
y = x.load() # Load values
z = exp(y) # Compute exponential
"""
return _math_op(math.exp, fastmath, a)
def exp2(
a: Union[TensorSSA, Numeric], fastmath: bool = False
) -> Union[TensorSSA, Numeric]:
"""Compute element-wise base-2 exponential of the input tensor.
:param a: Input tensor
:type a: TensorSSA
:type a: Union[TensorSSA, Numeric]
:param fastmath: Enable fast math optimizations, defaults to False
:type fastmath: bool, optional
:return: Tensor containing 2 raised to the power of each element
:rtype: TensorSSA
:rtype: Union[TensorSSA, Numeric]
Example:
@@ -164,18 +236,20 @@ def exp2(a: TensorSSA, fastmath: bool = False) -> TensorSSA:
y = x.load() # Load values
z = exp2(y) # Compute 2^x
"""
return TensorSSA(math.exp2(a, fastmath=arith.FastMathFlags.none), a.shape, a.dtype)
return _math_op(math.exp2, fastmath, a)
def log(a: TensorSSA, fastmath: bool = False) -> TensorSSA:
def log(
a: Union[TensorSSA, Numeric], fastmath: bool = False
) -> Union[TensorSSA, Numeric]:
"""Compute element-wise natural logarithm of the input tensor.
:param a: Input tensor
:type a: TensorSSA
:type a: Union[TensorSSA, Numeric]
:param fastmath: Enable fast math optimizations, defaults to False
:type fastmath: bool, optional
:return: Tensor containing the natural logarithm of each element
:rtype: TensorSSA
:rtype: Union[TensorSSA, Numeric]
Example:
@@ -185,18 +259,20 @@ def log(a: TensorSSA, fastmath: bool = False) -> TensorSSA:
y = x.load() # Load values
z = log(y) # Compute natural logarithm
"""
return TensorSSA(math.log(a, fastmath=arith.FastMathFlags.none), a.shape, a.dtype)
return _math_op(math.log, fastmath, a)
def log2(a: TensorSSA, fastmath: bool = False) -> TensorSSA:
def log2(
a: Union[TensorSSA, Numeric], fastmath: bool = False
) -> Union[TensorSSA, Numeric]:
"""Compute element-wise base-2 logarithm of the input tensor.
:param a: Input tensor
:type a: TensorSSA
:type a: Union[TensorSSA, Numeric]
:param fastmath: Enable fast math optimizations, defaults to False
:type fastmath: bool, optional
:return: Tensor containing the base-2 logarithm of each element
:rtype: TensorSSA
:rtype: Union[TensorSSA, Numeric]
Example:
@@ -206,18 +282,20 @@ def log2(a: TensorSSA, fastmath: bool = False) -> TensorSSA:
y = x.load() # Load values
z = log2(y) # Compute log base 2
"""
return TensorSSA(math.log2(a, fastmath=arith.FastMathFlags.none), a.shape, a.dtype)
return _math_op(math.log2, fastmath, a)
def log10(a: TensorSSA, fastmath: bool = False) -> TensorSSA:
def log10(
a: Union[TensorSSA, Numeric], fastmath: bool = False
) -> Union[TensorSSA, Numeric]:
"""Compute element-wise base-10 logarithm of the input tensor.
:param a: Input tensor
:type a: TensorSSA
:type a: Union[TensorSSA, Numeric]
:param fastmath: Enable fast math optimizations, defaults to False
:type fastmath: bool, optional
:return: Tensor containing the base-10 logarithm of each element
:rtype: TensorSSA
:rtype: Union[TensorSSA, Numeric]
Example:
@@ -227,20 +305,22 @@ def log10(a: TensorSSA, fastmath: bool = False) -> TensorSSA:
y = x.load() # Load values
z = log10(y) # Compute log base 10
"""
return TensorSSA(math.log10(a, fastmath=arith.FastMathFlags.none), a.shape, a.dtype)
return _math_op(math.log10, fastmath, a)
def rsqrt(a: TensorSSA, fastmath: bool = False) -> TensorSSA:
def rsqrt(
a: Union[TensorSSA, Numeric], fastmath: bool = False
) -> Union[TensorSSA, Numeric]:
"""Compute element-wise reciprocal square root of the input tensor.
Computes 1/√x element-wise.
:param a: Input tensor
:type a: TensorSSA
:type a: Union[TensorSSA, Numeric]
:param fastmath: Enable fast math optimizations, defaults to False
:type fastmath: bool, optional
:return: Tensor containing the reciprocal square root of each element
:rtype: TensorSSA
:rtype: Union[TensorSSA, Numeric]
Example:
@@ -250,18 +330,20 @@ def rsqrt(a: TensorSSA, fastmath: bool = False) -> TensorSSA:
y = x.load() # Load values
z = rsqrt(y) # Compute 1/√x
"""
return TensorSSA(math.rsqrt(a, fastmath=arith.FastMathFlags.none), a.shape, a.dtype)
return _math_op(math.rsqrt, fastmath, a)
def sin(a: TensorSSA, fastmath: bool = False) -> TensorSSA:
def sin(
a: Union[TensorSSA, Numeric], fastmath: bool = False
) -> Union[TensorSSA, Numeric]:
"""Compute element-wise sine of the input tensor.
:param a: Input tensor (in radians)
:type a: TensorSSA
:type a: Union[TensorSSA, Numeric]
:param fastmath: Enable fast math optimizations, defaults to False
:type fastmath: bool, optional
:return: Tensor containing the sine of each element
:rtype: TensorSSA
:rtype: Union[TensorSSA, Numeric]
Example:
@@ -271,18 +353,20 @@ def sin(a: TensorSSA, fastmath: bool = False) -> TensorSSA:
y = x.load() # Load values
z = sin(y) # Compute sine
"""
return TensorSSA(math.sin(a, fastmath=arith.FastMathFlags.none), a.shape, a.dtype)
return _math_op(math.sin, fastmath, a)
def sqrt(a: TensorSSA, fastmath: bool = False) -> TensorSSA:
def sqrt(
a: Union[TensorSSA, Numeric], fastmath: bool = False
) -> Union[TensorSSA, Numeric]:
"""Compute element-wise square root of the input tensor.
:param a: Input tensor
:type a: TensorSSA
:type a: Union[TensorSSA, Numeric]
:param fastmath: Enable fast math optimizations, defaults to False
:type fastmath: bool, optional
:return: Tensor containing the square root of each element
:rtype: TensorSSA
:rtype: Union[TensorSSA, Numeric]
Example:
@@ -292,16 +376,20 @@ def sqrt(a: TensorSSA, fastmath: bool = False) -> TensorSSA:
y = x.load() # Load values
z = sqrt(y) # Compute square root
"""
return TensorSSA(math.sqrt(a, fastmath=arith.FastMathFlags.none), a.shape, a.dtype)
return _math_op(math.sqrt, fastmath, a)
def tan(a: TensorSSA) -> TensorSSA:
def tan(
a: Union[TensorSSA, Numeric], fastmath: bool = False
) -> Union[TensorSSA, Numeric]:
"""Compute element-wise tangent of the input tensor.
:param a: Input tensor (in radians)
:type a: TensorSSA
:type a: Union[TensorSSA, Numeric]
:param fastmath: Enable fast math optimizations, defaults to False
:type fastmath: bool, optional
:return: Tensor containing the tangent of each element
:rtype: TensorSSA
:rtype: Union[TensorSSA, Numeric]
Example:
@@ -311,18 +399,20 @@ def tan(a: TensorSSA) -> TensorSSA:
y = x.load() # Load values
z = tan(y) # Compute tangent
"""
return TensorSSA(math.tan(a, fastmath=arith.FastMathFlags.none), a.shape, a.dtype)
return _math_op(math.tan, fastmath, a)
def tanh(a: TensorSSA, fastmath: bool = False) -> TensorSSA:
def tanh(
a: Union[TensorSSA, Numeric], fastmath: bool = False
) -> Union[TensorSSA, Numeric]:
"""Compute element-wise hyperbolic tangent of the input tensor.
:param a: Input tensor
:type a: TensorSSA
:type a: Union[TensorSSA, Numeric]
:param fastmath: Enable fast math optimizations, defaults to False
:type fastmath: bool, optional
:return: Tensor containing the hyperbolic tangent of each element
:rtype: TensorSSA
:rtype: Union[TensorSSA, Numeric]
Example:
@@ -332,7 +422,7 @@ def tanh(a: TensorSSA, fastmath: bool = False) -> TensorSSA:
y = x.load() # Load values
z = tanh(y) # Compute hyperbolic tangent
"""
return TensorSSA(math.tanh(a, fastmath=arith.FastMathFlags.none), a.shape, a.dtype)
return _math_op(math.tanh, fastmath, a)
__all__ = [
@@ -342,6 +432,7 @@ __all__ = [
"atan2",
"cos",
"erf",
"exp",
"exp2",
"log",
"log10",

View File

@@ -8,7 +8,7 @@
# Any use, reproduction, disclosure, or distribution of this software
# and related documentation outside the scope permitted by the EULA
# is strictly prohibited.
import enum
from dataclasses import dataclass
from typing import Type, Optional
@@ -101,6 +101,42 @@ class MmaUniversalTrait(core.Trait):
####################################################################################################
class MemoryOrder(enum.Enum):
WEAK = _cute_ir.MemOrderKind.WEAK
RELAXED = _cute_ir.MemOrderKind.RELAXED
ACQUIRE = _cute_ir.MemOrderKind.ACQUIRE
RELEASE = _cute_ir.MemOrderKind.RELEASE
ACQ_REL = _cute_ir.MemOrderKind.ACQ_REL
SC = _cute_ir.MemOrderKind.SC
MMIO = _cute_ir.MemOrderKind.MMIO
CONSTANT = _cute_ir.MemOrderKind.CONSTANT
VOLATILE = _cute_ir.MemOrderKind.VOLATILE
def __str__(self) -> str:
return f"{self.__class__.__name__}.{self.name}"
def __repr__(self) -> str:
return f"<{self.__class__.__name__}.{self.name}>"
def _to_ir(self) -> _cute_ir.MemOrderKind:
return self.value
class MemoryScope(enum.Enum):
CTA = _cute_ir.MemScopeKind.CTA
CLUSTER = _cute_ir.MemScopeKind.CLUSTER
GPU = _cute_ir.MemScopeKind.GPU
SYS = _cute_ir.MemScopeKind.SYS
def __str__(self) -> str:
return f"{self.__class__.__name__}.{self.name}"
def __repr__(self) -> str:
return f"<{self.__class__.__name__}.{self.name}>"
def _to_ir(self) -> _cute_ir.MemScopeKind:
return self.value
@dataclass(frozen=True)
class CopyUniversalOp(core.CopyOp):
"""
@@ -133,13 +169,18 @@ class CopyUniversalOp(core.CopyOp):
**kwargs,
) -> "CopyUniversalTrait":
num_bits_per_copy = kwargs.get("num_bits_per_copy", 0)
memory_order = kwargs.get("memory_order", MemoryOrder.WEAK)
memory_scope = kwargs.get("memory_scope", MemoryScope.CTA)
if not isinstance(num_bits_per_copy, int) or (num_bits_per_copy < 0):
raise ValueError(
"expects a 'num_bits_per_copy' kw argument of type int that is non-negative "
f"when creating a copy Atom for {self.__class__.__name__}"
)
ty = _cute_nvgpu_ir.CopyAtomSIMTSyncCopyType.get(
copy_internal_type.mlir_type, num_bits_per_copy
copy_internal_type.mlir_type,
num_bits_per_copy,
memory_order._to_ir(),
memory_scope._to_ir(),
)
return CopyUniversalTrait(_cute_ir.atom(ty, loc=loc, ip=ip))

View File

@@ -23,6 +23,7 @@ __all__ = [
"CopyBulkTensorTileG2SOp",
"CopyBulkTensorTileG2SMulticastOp",
"CopyBulkTensorTileS2GOp",
"CopyReduceBulkTensorTileS2GOp",
#
# helpers.py
#

View File

@@ -19,7 +19,7 @@ import cutlass._mlir.dialects.cute as _cute_ir
import cutlass._mlir.dialects.cute_nvgpu as _cute_nvgpu_ir
from cutlass._mlir import ir
from ...core import CopyOp, Trait
from ...core import CopyOp, Trait, ReductionOp
from ...typing import Int16, Pointer, Integer, Numeric
from ..common import OpError
from ..tcgen05.mma import CtaGroup
@@ -80,6 +80,12 @@ class CopyG2SOp(CopyOp):
**kwargs,
) -> "CopyG2STrait":
num_bits_per_copy = kwargs.get("num_bits_per_copy", None)
# Verify that the user provided enum values
if not isinstance(self.cache_mode, LoadCacheMode):
raise OpError(
self,
"expects the 'cache_mode' Op parameter to be a LoadCacheMode instance",
)
if not isinstance(num_bits_per_copy, int) or (num_bits_per_copy <= 0):
raise ValueError(
"expects a 'num_bits_per_copy' kw argument of type int that is positive "
@@ -330,7 +336,7 @@ class CopyBulkTensorTileG2SMulticastNonExecTrait(Trait):
@dataclass(frozen=True)
class CopyBulkTensorTileS2GOp(CopyOp):
"""
Bulk tensor asynchrnous SMEM to GMEM Copy Operation using the TMA unit.
Bulk tensor asynchronous SMEM to GMEM Copy Operation using the TMA unit.
See the `PTX documentation <https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-cp-async-bulk-tensor>`__.
This Operation uses TMA in the ``.tile`` mode.
@@ -379,3 +385,87 @@ class CopyBulkTensorTileS2GTrait(Trait):
exec_value, attr, tma_desc_ptr.value, loc=loc, ip=ip
)
return exec_value
@dataclass(frozen=True)
class CopyReduceBulkTensorTileS2GOp(CopyOp):
"""
Bulk tensor asynchronous SMEM to GMEM Reduction Operation using the TMA unit.
See the `PTX documentation <https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-cp-reduce-async-bulk>`__.
This Operation uses TMA in the ``.tile`` mode.
"""
reduction_kind: ReductionOp = ReductionOp.ADD
admissible_archs = [
"sm_90",
"sm_90a",
"sm_100a",
"sm_100f",
]
def __post__init__(self):
# Arch verification
arch = CuTeDSL.__get_dsl().envar.arch
if arch not in self.admissible_archs:
raise OpError(
self,
f"expects arch to be one of {self.admissible_archs}, but got {arch}",
suggestion="Ensure env CUTE_DSL_ARCH matches your GPU architecture",
)
def __str__(self) -> str:
return "cp.async SMEM -> GMEM bulk tensor reduction Operation"
def _make_trait(
self, copy_internal_type: Type[Numeric], *, loc=None, ip=None, **kwargs
) -> "CopyReduceBulkTensorTileS2GTrait":
raise NotImplementedError(
"Use cpasync.make_tiled_tma_atom to obtain a copy Atom for TMA"
)
def _to_ir(self) -> _cute_nvgpu_ir.ReductionKind:
if self.reduction_kind == ReductionOp.ADD:
return _cute_nvgpu_ir.ReductionKind.ADD
elif self.reduction_kind == ReductionOp.MIN:
return _cute_nvgpu_ir.ReductionKind.MIN
elif self.reduction_kind == ReductionOp.MAX:
return _cute_nvgpu_ir.ReductionKind.MAX
elif self.reduction_kind == ReductionOp.INC:
return _cute_nvgpu_ir.ReductionKind.INC
elif self.reduction_kind == ReductionOp.DEC:
return _cute_nvgpu_ir.ReductionKind.DEC
elif self.reduction_kind == ReductionOp.AND:
return _cute_nvgpu_ir.ReductionKind.AND
elif self.reduction_kind == ReductionOp.OR:
return _cute_nvgpu_ir.ReductionKind.OR
elif self.reduction_kind == ReductionOp.XOR:
return _cute_nvgpu_ir.ReductionKind.XOR
else:
assert False, "unrecognized self.reduction_kind"
class CopyReduceBulkTensorTileS2GTrait(Trait):
def unpack(self, *, loc=None, ip=None, tma_desc_ptr: Optional[Pointer] = None):
"""
Custom implementation of unpack for non-executable TMAs.
"""
exec_value = _cute_nvgpu_ir.atom_make_exec_tma(self.value, loc=loc, ip=ip)
if isinstance(tma_desc_ptr, Pointer):
attr_str = (
f"#cute_nvgpu.atom_copy_field_tmareduce<{TMA_DESC_PTR_FIELD_NAME}>"
)
attr = ir.Attribute.parse(attr_str)
exec_value = _cute_nvgpu_ir.atom_set_value(
exec_value, attr, tma_desc_ptr.value, loc=loc, ip=ip
)
return exec_value
__all__ = [
"LoadCacheMode",
"CopyG2SOp",
"CopyBulkTensorTileG2SOp",
"CopyBulkTensorTileG2SMulticastOp",
"CopyBulkTensorTileS2GOp",
"CopyReduceBulkTensorTileS2GOp",
]

View File

@@ -22,9 +22,11 @@ from .copy import (
CopyBulkTensorTileG2SOp,
CopyBulkTensorTileG2SMulticastOp,
CopyBulkTensorTileS2GOp,
CopyReduceBulkTensorTileS2GOp,
CopyBulkTensorTileG2SNonExecTrait,
CopyBulkTensorTileG2SMulticastNonExecTrait,
CopyBulkTensorTileS2GTrait,
CopyReduceBulkTensorTileS2GTrait,
)
@@ -34,6 +36,7 @@ def make_tiled_tma_atom(
CopyBulkTensorTileG2SOp,
CopyBulkTensorTileG2SMulticastOp,
CopyBulkTensorTileS2GOp,
CopyReduceBulkTensorTileS2GOp,
],
gmem_tensor: Tensor,
smem_layout: Union[Layout, core.ComposedLayout],
@@ -67,7 +70,7 @@ def make_tiled_tma_atom(
similarly to any other CuTe tensors using the algebra.
:param op: The Copy Operation to construct an Atom for
:type op: Union[CopyBulkTensorTileG2SOp, CopyBulkTensorTileG2SMulticastOp, CopyBulkTensorTileS2GOp]
:type op: Union[CopyBulkTensorTileG2SOp, CopyBulkTensorTileG2SMulticastOp, CopyBulkTensorTileS2GOp, CopyReduceBulkTensorTileS2GOp]
:param gmem_tensor: The GMEM tensor involved in the Copy
:type gmem_tensor: Tensor
:param smem_layout: The SMEM layout to construct the Copy Atom for
@@ -141,6 +144,17 @@ def make_tiled_tma_atom(
ip=ip,
)
return core.CopyAtom(op, CopyBulkTensorTileS2GTrait(res[0])), res[1]
elif isinstance(op, CopyReduceBulkTensorTileS2GOp):
res = _cute_nvgpu_ir.atom_make_non_exec_tiled_tma_reduce(
gmem_tensor.value,
smem_layout,
cta_v_map,
op._to_ir(),
internal_type=internal_type,
loc=loc,
ip=ip,
)
return core.CopyAtom(op, CopyReduceBulkTensorTileS2GTrait(res[0])), res[1]
else:
raise ValueError(f"expects a bulk tensor (TMA) Copy Op, but got {op}")

View File

@@ -21,7 +21,7 @@ from cutlass._mlir import ir
import cutlass._mlir.dialects.cute as _cute_ir
from cutlass.base_dsl.dsl import is_dynamic_expression
from cutlass.cutlass_dsl import TensorFormat, JitArgAdapterRegistry
from cutlass.cutlass_dsl import JitArgAdapterRegistry
# Local modules imports
from .typing import (
@@ -82,42 +82,36 @@ class _Pointer(Pointer):
self._dtype = dtype
self._addr_space = mem_space
is_in_device = mem_space == _cute_ir.AddressSpace.gmem
if assumed_align is None:
if is_in_device:
self._assumed_align = 32
else:
self._assumed_align = dtype.width // 8
self._assumed_align = dtype.width // 8
else:
self._assumed_align = assumed_align
class PtrDescriptor(ctypes.Structure):
"""A ctype descriptor for CuTe memref ptr"""
_fields_ = [("ptr", ctypes.c_void_p)]
def __str__(self):
return f"0x{self.ptr:016x}"
self._desc = PtrDescriptor(int(self._pointer))
self._c_pointer = ctypes.cast(ctypes.pointer(self._desc), ctypes.c_void_p)
self._c_pointer = None
assert (
self._desc.ptr % self._assumed_align == 0
int(self._pointer) % self._assumed_align == 0
), f"pointer must be {self._assumed_align} bytes aligned"
def size_in_bytes(self) -> int:
self._desc = ctypes.c_void_p(int(self._pointer))
return ctypes.sizeof(self._desc)
def __get_mlir_types__(self):
return [self.mlir_type]
def __c_pointers__(self):
if self._c_pointer is None:
self._desc = ctypes.c_void_p(int(self._pointer))
self._c_pointer = ctypes.addressof(self._desc)
return [self._c_pointer]
def __new_from_mlir_values__(self, values):
assert len(values) == 1
return values[0]
def __extract_mlir_values__(self):
return [self._c_pointer]
# Move mlir Type out of __init__ to decouple with mlir Context
@property
def mlir_type(self) -> ir.Type:
@@ -145,7 +139,7 @@ class _Pointer(Pointer):
return False
def __str__(self) -> str:
return f"Ptr<0x{self._desc.ptr:016x}@{self._addr_space}>"
return f"Ptr<0x{int(self._pointer):016x}@{self._addr_space}>"
def __repr__(self):
return self.__str__()

View File

@@ -31,9 +31,12 @@ from .helpers import (
from .sm90 import (
PipelineAsync,
PipelineCpAsync,
PipelineTmaAsync,
PipelineTmaMultiConsumersAsync,
PipelineTmaStore,
PipelineProducer,
PipelineConsumer,
)
from .sm100 import (
@@ -53,10 +56,13 @@ __all__ = [
"PipelineUserType",
"PipelineState",
"PipelineAsync",
"PipelineCpAsync",
"PipelineTmaAsync",
"PipelineTmaUmma",
"PipelineTmaMultiConsumersAsync",
"PipelineAsyncUmma",
"PipelineUmmaAsync",
"PipelineTmaStore",
"PipelineProducer",
"PipelineConsumer",
]

View File

@@ -89,6 +89,8 @@ class PipelineOp(enum.Enum):
TmaStore = enum.auto()
# Composite of multiple PipelineOps
Composite = enum.auto()
# Async load without TMA
AsyncLoad = enum.auto()
def _get_pipeline_op(type_str):
@@ -226,6 +228,8 @@ class MbarrierArray(SyncObject):
self.arrive_tcgen05mma(index, dst, cta_group)
elif self.op_type in [PipelineOp.TmaLoad]:
self.arrive_and_expect_tx(index, self.tx_count)
elif self.op_type is PipelineOp.AsyncLoad:
self.arrive_cp_async_mbarrier(index)
else:
assert (
False
@@ -237,6 +241,9 @@ class MbarrierArray(SyncObject):
else:
cute.arch.mbarrier_arrive(self.get_barrier(index), dst_rank)
def arrive_cp_async_mbarrier(self, index: int):
cute.arch.cp_async_mbarrier_arrive_noinc(self.get_barrier(index))
def arrive_tcgen05mma(
self, index: int, mask: Optional[int], cta_group: cute.nvgpu.tcgen05.CtaGroup
) -> None:

View File

@@ -19,6 +19,7 @@ import cutlass.cute as cute
from cutlass.cutlass_dsl import Boolean, if_generate
from cutlass.pipeline import (
Agent,
CooperativeGroup,
PipelineOp,
PipelineState,
@@ -106,9 +107,9 @@ class PipelineTmaUmma(PipelineAsync):
:type barrier_storage: cute.Pointer
:param num_stages: Number of buffer stages for this pipeline
:type num_stages: Int32
:param producer_group: CooperativeGroup for the producer agent
:param producer_group: `CooperativeGroup` for the producer agent
:type producer_group: CooperativeGroup
:param consumer_group: CooperativeGroup for the consumer agent
:param consumer_group: `CooperativeGroup` for the consumer agent
:type consumer_group: CooperativeGroup
:param tx_count: Number of bytes expected to be written to the transaction barrier for one stage
:type tx_count: int
@@ -258,9 +259,9 @@ class PipelineAsyncUmma(PipelineAsync):
:type barrier_storage: cute.Pointer
:param num_stages: Number of buffer stages for this pipeline
:type num_stages: Int32
:param producer_group: CooperativeGroup for the producer agent
:param producer_group: `CooperativeGroup` for the producer agent
:type producer_group: CooperativeGroup
:param consumer_group: CooperativeGroup for the consumer agent
:param consumer_group: `CooperativeGroup` for the consumer agent
:type consumer_group: CooperativeGroup
:param cta_layout_vmnk: Layout of the cluster shape
:type cta_layout_vmnk: cute.Layout | None
@@ -368,9 +369,9 @@ class PipelineUmmaAsync(PipelineAsync):
:type barrier_storage: cute.Pointer
:param num_stages: Number of buffer stages for this pipeline
:type num_stages: Int32
:param producer_group: CooperativeGroup for the producer agent
:param producer_group: `CooperativeGroup` for the producer agent
:type producer_group: CooperativeGroup
:param consumer_group: CooperativeGroup for the consumer agent
:param consumer_group: `CooperativeGroup` for the consumer agent
:type consumer_group: CooperativeGroup
:param cta_layout_vmnk: Layout of the cluster shape
:type cta_layout_vmnk: cute.Layout | None

View File

@@ -10,15 +10,18 @@
# is strictly prohibited.
import enum
from typing import Type, Tuple
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Optional, Union
import warnings
import cutlass
import cutlass.cute as cute
from cutlass.cutlass_dsl import Boolean, Int32, if_generate
from cutlass.pipeline import (
Agent,
CooperativeGroup,
PipelineOp,
SyncObject,
@@ -91,6 +94,30 @@ class PipelineAsync:
- D: Data ready (producer has written data to buffer)
- R: Consumer reading (consumer is consuming data from buffer)
**Example:**
.. code-block:: python
# Create pipeline with 5 stages
pipeline = PipelineAsync.create(
num_stages=5, # number of pipeline stages
producer_group=producer_warp,
consumer_group=consumer_warp
barrier_storage=smem_ptr, # smem pointer for array of mbarriers in shared memory
)
producer, consumer = pipeline.make_participants()
# Producer side
for i in range(num_iterations):
handle = producer.acquire_and_advance() # Wait for buffer to be empty & Move index to next stage
# Write data to pipeline buffer
handle.commit() # Signal buffer is full
# Consumer side
for i in range(num_iterations):
handle = consumer.wait_and_advance() # Wait for buffer to be full & Move index to next stage
# Read data from pipeline buffer
handle.release() # Signal buffer is empty
"""
sync_object_full: SyncObject
@@ -114,6 +141,7 @@ class PipelineAsync:
PipelineOp.TmaLoad,
PipelineOp.TCGen05Mma,
PipelineOp.Composite,
PipelineOp.AsyncLoad,
]:
return MbarrierArray(
barrier_storage=barrier_storage,
@@ -232,6 +260,74 @@ class PipelineAsync:
state.advance()
self.producer_acquire(state)
# Util methods to manage produer and consumer
def make_producer(self):
state = make_pipeline_state(PipelineUserType.Producer, self.num_stages)
return PipelineProducer(self, state, self.sync_object_full.cg)
def make_consumer(self):
state = make_pipeline_state(PipelineUserType.Consumer, self.num_stages)
return PipelineConsumer(self, state, self.sync_object_empty.cg)
def make_participants(self):
return self.make_producer(), self.make_consumer()
@dataclass(frozen=True)
class PipelineCpAsync(PipelineAsync):
"""
PipelineCpAsync is used for CpAsync producers and AsyncThread consumers (e.g. Hopper non-TMA mainloops).
"""
@staticmethod
def create(
barrier_storage: cute.Pointer,
num_stages: Int32,
producer_group: CooperativeGroup,
consumer_group: CooperativeGroup,
producer_mask: Int32 = None,
consumer_mask: Int32 = None,
):
"""
This helper function computes any necessary attributes and returns an instance of PipelineAsync.
:param barrier_storage: Pointer to the smem address for this pipeline's mbarriers
:type barrier_storage: cute.Pointer
:param num_stages: Number of buffer stages for this pipeline
:type num_stages: Int32
:param producer_group: CooperativeGroup for the producer agent
:type producer_group: CooperativeGroup
:param consumer_group: CooperativeGroup for the consumer agent
:type consumer_group: CooperativeGroup
:param producer_mask: Mask for signaling arrives for the producer agent
:type producer_mask: Int32 | None
:param consumer_mask: Mask for signaling arrives for the consumer agent
:type consumer_mask: Int32 | None
"""
producer_type = PipelineOp.AsyncLoad
consumer_type = PipelineOp.AsyncThread
producer = (producer_type, producer_group)
consumer = (consumer_type, consumer_group)
sync_object_array_full = PipelineCpAsync._make_sync_object(
barrier_storage.align(min_align=8), num_stages, producer
)
sync_object_array_empty = PipelineCpAsync._make_sync_object(
barrier_storage.align(min_align=8) + num_stages, num_stages, consumer
)
pipeline_init_wait()
return PipelineCpAsync(
sync_object_array_full,
sync_object_array_empty,
num_stages,
producer_mask,
consumer_mask,
)
@dataclass(frozen=True)
class PipelineTmaAsync(PipelineAsync):
"""
@@ -294,9 +390,9 @@ class PipelineTmaAsync(PipelineAsync):
:type barrier_storage: cute.Pointer
:param num_stages: Number of buffer stages for this pipeline
:type num_stages: Int32
:param producer_group: CooperativeGroup for the producer agent
:param producer_group: `CooperativeGroup` for the producer agent
:type producer_group: CooperativeGroup
:param consumer_group: CooperativeGroup for the consumer agent
:param consumer_group: `CooperativeGroup` for the consumer agent
:type consumer_group: CooperativeGroup
:param tx_count: Number of bytes expected to be written to the transaction barrier for one stage
:type tx_count: int
@@ -404,11 +500,11 @@ class PipelineTmaMultiConsumersAsync(PipelineAsync):
:type barrier_storage: cute.Pointer
:param num_stages: Number of buffer stages for this pipeline
:type num_stages: Int32
:param producer_group: CooperativeGroup for the producer agent
:param producer_group: `CooperativeGroup` for the producer agent
:type producer_group: CooperativeGroup
:param consumer_group_umma: CooperativeGroup for the UMMA consumer agent
:param consumer_group_umma: `CooperativeGroup` for the UMMA consumer agent
:type consumer_group_umma: CooperativeGroup
:param consumer_group_async: CooperativeGroup for the AsyncThread consumer agent
:param consumer_group_async: `CooperativeGroup` for the AsyncThread consumer agent
:type consumer_group_async: CooperativeGroup
:param tx_count: Number of bytes expected to be written to the transaction barrier for one stage
:type tx_count: int
@@ -529,9 +625,10 @@ class PipelineTmaStore(PipelineAsync):
This helper function computes any necessary attributes and returns an instance of PipelineTmaStore.
:param num_stages: Number of buffer stages for this pipeline
:type num_stages: Int32
:param producer_group: CooperativeGroup for the producer agent
:param producer_group: `CooperativeGroup` for the producer agent
:type producer_group: CooperativeGroup
"""
producer_type = PipelineOp.TmaStore
producer = (producer_type, producer_group)
@@ -556,3 +653,333 @@ class PipelineTmaStore(PipelineAsync):
self.sync_object_full.tail()
#################################################################
# Utilities to help user of pipeline to simplify the workflow
#################################################################
class ImmutableResourceHandle:
__origin: PipelineAsync
__immutable_state: PipelineState
def __init__(self, origin: PipelineAsync, immutable_state: PipelineState):
self.__origin = origin
self.__immutable_state = immutable_state
@property
def index(self):
"""Get the index of the current pipeline stage."""
return self.__immutable_state.index
@property
def count(self):
"""Get the count of how many handles this producer has committed.
This is useful for tracking the number of blocks that have been loaded from gmem.
"""
return self.__immutable_state.count
def get_origin(self):
"""Get the original pipeline this resource handle belongs to."""
return self.__origin
def __extract_mlir_values__(self):
"""Extract MLIR values from the current state.
:return: List of MLIR values representing the current state
:rtype: list
"""
# TODO: need to handle pipeline as well
return self.__immutable_state.__extract_mlir_values__()
def __new_from_mlir_values__(self, values):
"""Create a new Producer instance from MLIR values.
:param values: MLIR values to initialize the state
:type values: Any
:return: New Producer instance with state initialized from values
:rtype: Producer
"""
return self.__class__(
self.__origin, self.__immutable_state.__new_from_mlir_values__(values)
)
class PipelineProducer:
"""A class representing a producer in an asynchronous pipeline.
The Producer class manages the producer side of an asynchronous pipeline, handling
synchronization and state management for producing data. It provides methods for
acquiring, committing, and advancing through pipeline stages.
:ivar __pipeline: The asynchronous pipeline this producer belongs to
:type __pipeline: PipelineAsync
:ivar __state: The current state of the producer in the pipeline
:type __state: PipelineState
:ivar __group: The cooperative group this producer operates in
:type __group: CooperativeGroup
**Examples:**
.. code-block:: python
pipeline = PipelineAsync.create(...)
producer = pipeline.create_producer(producer_group, stages)
for i in range(iterations):
handle = producer.acquire_and_advance() # Wait for buffer to be empty
# Produce data
producer.commit(handle) # Signal data is ready
# An alternative way to do this is:
# handle.commit() # Signal data is ready
"""
__pipeline: PipelineAsync
__state: PipelineState
__group: CooperativeGroup
class ImmutableResourceHandle(ImmutableResourceHandle):
@property
def barrier(self):
"""Get the barrier pointer for the current pipeline stage.
:return: Pointer to the barrier for the current stage
:rtype: cute.Pointer
"""
return self.get_origin().producer_get_barrier(
self._ImmutableResourceHandle__immutable_state
)
def commit(self):
"""Signal that data production is complete for the current stage.
This allows consumers to start processing the data.
"""
self.get_origin().producer_commit(
self._ImmutableResourceHandle__immutable_state
)
def __init__(self, pipeline, state, group: CooperativeGroup):
"""Initialize a new Producer instance.
:param pipeline: The pipeline this producer belongs to
:type pipeline: PipelineAsync
:param state: Initial pipeline state
:type state: PipelineState
:param group: The cooperative group for synchronization
:type group: CooperativeGroup
"""
self.__pipeline = pipeline
self.__state = state
self.__group = group
def acquire(
self,
try_acquire_token: Optional[Boolean] = None,
) -> ImmutableResourceHandle:
"""Wait for the current buffer to be empty before producing data.
This is a blocking operation.
:param try_acquire_token: Optional token to try to acquire the buffer
:type try_acquire_token: Optional[Boolean]
:return: A handle to the producer for committing the data
:rtype: ImmutableResourceHandle
"""
self.__pipeline.producer_acquire(self.__state, try_acquire_token)
handle = PipelineProducer.ImmutableResourceHandle(
self.__pipeline, self.__state.clone()
)
return handle
def advance(self):
"""Move to the next pipeline stage."""
self.__state.advance()
def acquire_and_advance(
self, try_acquire_token: Optional[Boolean] = None
) -> ImmutableResourceHandle:
"""Wait for the current buffer to be empty before producing data.
Then advance to the next stage.
This is a blocking operation.
:param try_acquire_token: Optional token to try to acquire the buffer
:type try_acquire_token: Optional[Boolean]
:return: A handle to the producer for committing the data
:rtype: ImmutableResourceHandle
"""
handle = self.acquire(try_acquire_token)
self.advance()
return handle
def try_acquire(self) -> Boolean:
"""Try to acquire the current buffer without blocking.
:return: True if acquisition was successful, False otherwise
:rtype: Boolean
"""
return self.__pipeline.producer_try_acquire(self.__state)
def commit(self, handle: Optional[ImmutableResourceHandle] = None):
"""Signal that data production is complete for the current stage.
This allows consumers to start processing the data.
"""
if handle is not None:
assert (
handle.get_origin() is self
), "ResourceHandle does not belong to this PipelineProducer instance"
handle.commit()
else:
self.__pipeline.producer_commit(self.__state)
def tail(self):
"""Ensure all used buffers are properly synchronized before producer exit.
This should be called before the producer finishes to avoid dangling signals.
"""
self.__pipeline.producer_tail(self.__state)
def __extract_mlir_values__(self):
"""Extract MLIR values from the current state.
:return: List of MLIR values representing the current state
:rtype: list
"""
# TODO: need to handle pipeline as well
return self.__state.__extract_mlir_values__()
def __new_from_mlir_values__(self, values):
"""Create a new Producer instance from MLIR values.
:param values: MLIR values to initialize the state
:type values: Any
:return: New Producer instance with state initialized from values
:rtype: Producer
"""
return PipelineProducer(
self.__pipeline, self.__state.__new_from_mlir_values__(values), self.__group
)
class PipelineConsumer:
"""A class representing a consumer in an asynchronous pipeline.
The Consumer class manages the consumer side of an asynchronous pipeline, handling
synchronization and state management for consuming data. It provides methods for
waiting, releasing, and advancing through pipeline stages.
:ivar __pipeline: The asynchronous pipeline this consumer belongs to
:type __pipeline: PipelineAsync
:ivar __state: The current state of the consumer in the pipeline
:type __state: PipelineState
:ivar __group: The cooperative group this consumer operates in
:type __group: CooperativeGroup
**Examples:**
.. code-block:: python
pipeline = PipelineAsync.create(...)
consumer = pipeline.create_consumer(consumer_group, stages)
for i in range(iterations):
handle = consumer.wait_and_advance() # Wait for data to be ready
# Consume data
consumer.release(handle) # Signal buffer is empty
# An alternative way to do this is:
# handle.release() # Signal buffer is empty
"""
__pipeline: PipelineAsync
__state: PipelineState
__group: CooperativeGroup
class ImmutableResourceHandle(ImmutableResourceHandle):
def release(self):
"""Signal that data production is complete for the current stage.
This allows consumers to start processing the data.
"""
self.get_origin().consumer_release(
self._ImmutableResourceHandle__immutable_state
)
def __init__(self, pipeline, state: PipelineState, group: CooperativeGroup):
"""Initialize a new Consumer instance.
:param pipeline: The pipeline this consumer belongs to
:type pipeline: PipelineAsync
:param state: Initial pipeline state
:type state: PipelineState
:param group: The cooperative group for synchronization
:type group: CooperativeGroup
"""
self.__pipeline = pipeline
self.__group = group
self.__state = state
def wait(self, try_wait_token: Optional[Boolean] = None) -> ImmutableResourceHandle:
"""Wait for data to be ready in the current buffer.
This is a blocking operation.
:param try_wait_token: Optional token to try to wait for the buffer
:type try_wait_token: Optional[Boolean]
:return: A handle to the consumer for releasing the data
:rtype: PipelineConsumerHandle
"""
self.__pipeline.consumer_wait(self.__state, try_wait_token)
handle = PipelineConsumer.ImmutableResourceHandle(
self.__pipeline, self.__state.clone()
)
return handle
def advance(self):
"""Move to the next pipeline stage."""
self.__state.advance()
def wait_and_advance(
self, try_wait_token: Optional[Boolean] = None
) -> ImmutableResourceHandle:
"""Wait for data to be ready in the current buffer.
Then advance to the next stage.
This is a blocking operation.
:param try_wait_token: Optional token to try to wait for the buffer
:type try_wait_token: Optional[Boolean]
:return: A handle to the consumer for releasing the data
:rtype: PipelineConsumerHandle
"""
handle = self.wait(try_wait_token)
self.advance()
return handle
def try_wait(self) -> Boolean:
"""Try to check if data is ready without blocking.
:return: True if data is ready, False otherwise
:rtype: Boolean
"""
return self.__pipeline.consumer_try_wait(self.__state)
def release(self, handle: Optional[ImmutableResourceHandle] = None):
"""Signal that data consumption is complete for the current stage.
This allows producers to start producing new data.
"""
if handle is not None:
assert (
handle.get_origin() is self
), "ResourceHandle does not belong to this PipelineConsumer instance"
handle.release()
else:
self.__pipeline.consumer_release(self.__state)
def __extract_mlir_values__(self):
"""Extract MLIR values from the current state.
:return: List of MLIR values representing the current state
:rtype: list
"""
return self.__state.__extract_mlir_values__()
def __new_from_mlir_values__(self, values):
"""Create a new Consumer instance from MLIR values.
:param values: MLIR values to initialize the state
:type values: Any
:return: New Consumer instance with state initialized from values
:rtype: Consumer
"""
# TODO: need to call pipeline.__new_from_mlir_values__ recursively
return PipelineConsumer(
self.__pipeline, self.__state.__new_from_mlir_values__(values), self.__group
)

View File

@@ -9,6 +9,8 @@
# and related documentation outside the scope permitted by the EULA
# is strictly prohibited.
import ctypes
from math import prod
from dataclasses import dataclass
from enum import Enum
from typing import Optional, Type, Union
@@ -54,6 +56,25 @@ def dtype(ty: Type[Numeric]):
return torch_dtype
def as_tensor(pointer, shape, torch_type):
"""Convert a pointer to a torch tensor"""
if torch_type.itemsize == 1:
cytype = ctypes.c_uint8
elif torch_type.itemsize == 2:
cytype = ctypes.c_uint16
elif torch_type.itemsize == 4:
cytype = ctypes.c_uint32
elif torch_type.itemsize == 8:
cytype = ctypes.c_uint64
else:
raise ValueError(f"Unsupported torch dtype: {torch_type}")
cpointer = ctypes.cast(pointer, ctypes.POINTER(cytype))
arr = (cpointer._type_ * prod(shape)).from_address(
ctypes.addressof(cpointer.contents)
)
return torch.frombuffer(arr, dtype=torch_type).view(*shape)
@dataclass
class ScalarInitConfig:
"""Configuration for scalar initialization"""
@@ -128,7 +149,7 @@ def create_and_permute_torch_tensor(
if not isinstance(init_config, GaussianInitConfig):
raise ValueError("init_config must be GaussianInitConfig()")
f32_torch_tensor = init_torch_tensor.normal_(init_config.mean, init_config.std)
f32_torch_tensor = f32_torch_tensor * (1 << init_config.scale)
f32_torch_tensor = f32_torch_tensor * init_config.scale
else:
raise ValueError(f"Invalid init type: {init_type}")

View File

@@ -64,6 +64,18 @@ from .smem_capacity import (
get_smem_capacity_in_bytes,
)
from .distributed_helpers import (
spin_lock_wait,
spin_lock_multimem_arrive,
multimem_ld_reduce_8xf16,
multimem_ld_reduce_4xf32,
multimem_ld_reduce_8xbf16,
multimem_ld_reduce_16xe4m3,
multimem_ld_reduce_16xe5m2,
multimem_st_4xb32,
sm_wise_inter_gpu_multimem_barrier,
)
__all__ = [
"get_smem_capacity_in_bytes",
"SmemAllocator",

View File

@@ -0,0 +1,179 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
#
# Use of this software is governed by the terms and conditions of the
# NVIDIA End User License Agreement (EULA), available at:
# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html
#
# Any use, reproduction, disclosure, or distribution of this software
# and related documentation outside the scope permitted by the EULA
# is strictly prohibited.
from functools import partial
from typing import Tuple
import cutlass.cute as cute
from cutlass.cutlass_dsl import T, dsl_user_op, while_generate
from cutlass._mlir import ir
from cutlass._mlir.dialects import arith, llvm, nvvm, scf
from cutlass._mlir.dialects.nvvm import (
MemOrderKind,
MemScopeKind,
AtomicOpKind,
)
from cutlass.cute.typing import Pointer, Int32, Boolean
@dsl_user_op
def atomicAdd(dst_ptr: Pointer, val: Int32, loc=None, ip=None) -> Int32:
return nvvm.atomicrmw(
T.i32(),
AtomicOpKind.ADD,
dst_ptr.llvm_ptr,
val.ir_value(loc=loc, ip=ip),
mem_order=MemOrderKind.RELAXED,
syncscope=MemScopeKind.SYS,
loc=loc,
ip=ip,
)
@cute.jit
def ld_bypass(input_tensor: cute.Tensor):
fragment = cute.make_fragment(input_tensor.layout, input_tensor.element_type)
copy_atom_load = cute.make_copy_atom(
cute.nvgpu.CopyUniversalOp(),
input_tensor.element_type,
memory_order=cute.nvgpu.common.MemoryOrder.VOLATILE,
memory_scope=cute.nvgpu.common.MemoryScope.SYS,
)
cute.copy(copy_atom_load, input_tensor, fragment)
vals = fragment.load()
return vals
@cute.jit
def spin_lock_wait(lock_ptr: Pointer, expect_count: Int32, mem_order : str = "relaxed", mem_scope : str = "gpu", loc=None, ip=None) -> None:
"""
wait on a spin lock until the expected count is reached.
"""
res = 0
while res != expect_count:
res = nvvm.atomicrmw(
T.i32(),
AtomicOpKind.CAS,
lock_ptr.llvm_ptr,
Int32(0).ir_value(loc=loc, ip=ip),
b=Int32(expect_count).ir_value(loc=loc, ip=ip),
mem_order=MemOrderKind.ACQUIRE if mem_order == "acquire" else MemOrderKind.RELAXED,
syncscope=MemScopeKind.GPU if mem_scope == "gpu" else MemScopeKind.SYS
)
@dsl_user_op
def multimem_red_add_sys_release(mc_ptr: Pointer, loc=None, ip=None) -> None:
"""
add 1 to the multimem address
"""
llvm.inline_asm(
None,
[mc_ptr.toint().ir_value()],
"multimem.red.release.sys.global.add.u32 [$0], 1;",
"l",
has_side_effects=True,
asm_dialect=0,
loc=loc,
ip=ip,
)
@dsl_user_op
def multimem_red_add_gpu_relaxed(mc_ptr: Pointer, loc=None, ip=None) -> None:
"""
add 1 to the multimem address
"""
llvm.inline_asm(
None,
[mc_ptr.toint().ir_value()],
"multimem.red.relaxed.gpu.global.add.u32 [$0], 1;",
"l",
has_side_effects=True,
asm_dialect=0,
loc=loc,
ip=ip,
)
def spin_lock_multimem_arrive(lock_ptr: Pointer, loc=None, ip=None) -> None:
"""
arrive a spin lock when the lock_ptr is a multimem address.
"""
multimem_red_add_gpu_relaxed(lock_ptr, loc=loc, ip=ip)
def sm_wise_inter_gpu_multimem_barrier(barrier : Pointer, barrier_mc : Pointer, num_ranks, loc=None, ip=None) -> None :
"""
barrier for inter-gpu sm-wise
"""
bidx, bidy, bidz = cute.arch.block_idx()
bdimx, bdimy, _ = cute.arch.grid_dim()
pid = bidx + bidy * bdimx + bidz * bdimx * bdimy
multimem_red_add_sys_release(barrier_mc + pid, loc=loc, ip=ip)
cute.arch.fence_proxy(cute.arch.ProxyKind.alias)
spin_lock_wait(barrier + pid, num_ranks, mem_order="acquire", mem_scope="sys", loc=loc, ip=ip)
@dsl_user_op
def multimem_ld_reduce_base(
mc_ptr: Pointer,
*,
ptx_string: str = "",
loc=None,
ip=None,
) -> Tuple[Int32, Int32, Int32, Int32]:
# ld reduce 8xf16 elts
mc_ptr_int = mc_ptr.toint(loc=loc, ip=ip).ir_value()
return_struct = llvm.inline_asm(
ir.Type.parse("!llvm.struct<(i32,i32,i32,i32)>"),
[mc_ptr_int],
ptx_string,
"=r,=r,=r,=r,l",
has_side_effects=True,
asm_dialect=0,
loc=loc,
ip=ip,
)
return_regs = [llvm.extractvalue(T.i32(), return_struct, [i]) for i in range(4)]
return return_regs[0], return_regs[1], return_regs[2], return_regs[3]
multimem_ld_reduce_8xf16 = partial(multimem_ld_reduce_base, ptx_string="multimem.ld_reduce.sys.relaxed.global.add.acc::f32.v4.f16x2 {$0, $1, $2, $3}, [$4];")
multimem_ld_reduce_4xf32 = partial(multimem_ld_reduce_base, ptx_string="multimem.ld_reduce.sys.relaxed.global.add.v4.f32 {$0, $1, $2, $3}, [$4];")
multimem_ld_reduce_8xbf16 = partial(multimem_ld_reduce_base, ptx_string="multimem.ld_reduce.sys.relaxed.global.add.acc::f32.v4.bf16x2 {$0, $1, $2, $3}, [$4];")
multimem_ld_reduce_16xe4m3 = partial(multimem_ld_reduce_base, ptx_string="multimem.ld_reduce.sys.relaxed.global.add.acc::f16.v4.e4m3x4 {$0, $1, $2, $3}, [$4];")
multimem_ld_reduce_16xe5m2 = partial(multimem_ld_reduce_base, ptx_string="multimem.ld_reduce.sys.relaxed.global.add.acc::f16.v4.e5m2x4 {$0, $1, $2, $3}, [$4];")
@dsl_user_op
def multimem_st_4xb32(
mc_ptr: Pointer,
x: Int32,
y: Int32,
z: Int32,
w: Int32,
*,
loc=None,
ip=None,
) -> None:
# st 4x32 bits of data
mc_ptr_int = mc_ptr.toint(loc=loc, ip=ip).ir_value()
llvm.inline_asm(
T.i32(),
[mc_ptr_int, x, y, z, w],
"multimem.st.sys.relaxed.global.v4.f32 [$1], {$2, $3, $4, $5};",
"=r,l,r,r,r,r",
has_side_effects=True,
asm_dialect=0,
loc=loc,
ip=ip,
)

View File

@@ -34,18 +34,6 @@ class LayoutEnum(Enum):
else warpgroup.OperandMajorMode.MN
)
def is_k_major_a(self):
return self == LayoutEnum.ROW_MAJOR
def is_m_major_a(self):
return self == LayoutEnum.COL_MAJOR
def is_k_major_b(self):
return self == LayoutEnum.COL_MAJOR
def is_n_major_b(self):
return self == LayoutEnum.ROW_MAJOR
def is_n_major_c(self):
return self == LayoutEnum.ROW_MAJOR

View File

@@ -11,7 +11,7 @@
from typing import Type, Union, overload
from cutlass.cutlass_dsl import Int8, Numeric, NumericMeta
from cutlass.cutlass_dsl import Int8, Numeric, NumericMeta, CutlassBaseDSL
import cutlass.cute as cute
from cutlass.cute.arch import get_dyn_smem, get_dyn_smem_size
@@ -40,14 +40,17 @@ class SmemAllocator:
"""
self._base = get_dyn_smem(Int8, alignment=1024)
self._allocated_bytes = 0
CutlassBaseDSL.track_smem_allocator(self, lambda cls: cls._allocated_bytes)
@overload
def allocate(self, size_or_type: int, byte_alignment: int): ...
def allocate(self, size_or_type: int, byte_alignment: int) -> cute.Pointer: ...
@overload
def allocate(self, size_or_type: cute.struct, byte_alignment: int): ...
def allocate(
self, size_or_type: cute.struct, byte_alignment: int
) -> cute.Pointer: ...
def allocate(self, size_or_type, byte_alignment: int = 1) -> int:
def allocate(self, size_or_type, byte_alignment: int = 1) -> cute.Pointer:
"""Allocate a block of memory with specified size and alignment.
This method adjusts the base pointer to ensure proper alignment and updates

View File

@@ -382,3 +382,5 @@ class StaticPersistentTileScheduler:
@property
def num_tiles_executed(self) -> Int32:
return self._num_tiles_executed

View File

@@ -17,6 +17,7 @@ from ..base_dsl.ast_helpers import (
if_executor,
while_selector,
while_executor,
range,
range_constexpr,
range_dynamic,
const_expr,
@@ -28,6 +29,8 @@ from ..base_dsl.ast_helpers import (
all_executor,
range_value_check,
range_perf_warning,
cf_symbol_check,
redirect_builtin_function,
)
from ..base_dsl import *
@@ -38,5 +41,4 @@ from ..base_dsl._mlir_helpers.op import dsl_user_op
from ..base_dsl.runtime import *
from ..base_dsl.runtime import cuda as cuda_helpers
from ..base_dsl.compiler import compile
from ..base_dsl.runtime.dlpack_runtime import *
from ..base_dsl.runtime.jit_arg_adapters import *

View File

@@ -15,12 +15,14 @@ regarding to that dialect.
"""
# Local module imports
from typing import Callable, Union, Type, List, Union, Sequence, ForwardRef
from inspect import isclass
from itertools import chain
from types import GenericAlias, SimpleNamespace, UnionType
from typing import Callable, Union, Type, List, Union, Sequence, ForwardRef, Any
import functools
import pkgutil
from dataclasses import is_dataclass
from dataclasses import is_dataclass, fields
from collections.abc import Sequence
import builtins
from ..base_dsl import *
from ..base_dsl import compiler
@@ -51,20 +53,15 @@ from ..base_dsl.ast_helpers import (
while_selector,
while_executor,
assert_executor,
const_expr,
dynamic_expr,
bool_cast,
compare_executor,
any_executor,
all_executor,
range_value_check,
range_perf_warning,
)
from ..base_dsl.runtime.dlpack_runtime import (
get_cute_tensor_c_pointer,
get_tensor_desc_shape_all,
get_tensor_desc_stride_all,
get_tensor_desc_element_type,
get_tensor_desc_is_in_device,
get_tensor_desc_assumed_align,
cf_symbol_check,
)
from .cutlass_ast_decorators import (
@@ -73,6 +70,16 @@ from .cutlass_ast_decorators import (
_while_execute_dynamic,
)
from .tree_utils import (
is_constexpr_field,
tree_flatten,
tree_unflatten,
PyTreeDef,
is_frozen_dataclass,
DSLTreeFlattenError,
)
from ..base_dsl.runtime.jit_arg_adapters import JitArgAdapterRegistry
# =============================================================================
# Cutlass DSL Base Abstract Class
@@ -125,6 +132,46 @@ def is_cute_algebra_type(arg_spec):
return False
def _get_c_pointers_cutlass(obj):
"""
This is an extended version of `get_c_pointers` that supports dataclasses, SimpleNamespace, and dict.
"""
if hasattr(obj, "__c_pointers__"):
return obj.__c_pointers__()
elif isinstance(obj, (tuple, list)):
return list(chain.from_iterable(_get_c_pointers_cutlass(x) for x in obj))
elif isinstance(obj, SimpleNamespace):
return list(
chain.from_iterable(
_get_c_pointers_cutlass(x) for x in obj.__dict__.values()
)
)
elif isinstance(obj, dict):
return list(
chain.from_iterable(_get_c_pointers_cutlass(x) for x in obj.values())
)
elif is_dataclass(obj):
return list(
chain.from_iterable(
_get_c_pointers_cutlass(getattr(obj, f.name))
for f in fields(obj)
if not is_constexpr_field(f)
)
)
elif isinstance(obj, set):
raise DSLRuntimeError(
"Sets are not supported in get_c_pointers to ensure order preservation",
context="The DSL attempted to generate JIT function argument(s) for an argument of type set but failed.",
suggestion="Consider using a list or tuple instead",
)
else:
# Try get adapter
adapter = JitArgAdapterRegistry.get_registered_adapter(type(obj))
if adapter is not None:
return _get_c_pointers_cutlass(adapter(obj))
return []
class CutlassBaseDSL(BaseDSL):
"""This abstract class provides a DSL for Cutlass."""
@@ -137,16 +184,25 @@ class CutlassBaseDSL(BaseDSL):
preprocess: bool = False,
):
super().__init__(
name,
compiler_provider,
pass_sm_arch_name,
device_compilation_only,
preprocess,
name=name,
dsl_package_name=["cutlass"],
compiler_provider=compiler_provider,
pass_sm_arch_name=pass_sm_arch_name,
device_compilation_only=device_compilation_only,
preprocess=preprocess,
)
self._smem_usage_tracker: tuple = None
# this method is not useful for cutlass_dsl, so we only provide a dummy implementation.
def _is_tensor_descriptor(self, maybe_tensor_descriptor) -> bool:
return False
# this method is not useful for cutlass_dsl, so we only provide a dummy implementation.
def _handle_tensor_descriptor(
self, maybe_tensor, arg_name: str, need_gpu_memory: bool
) -> Any:
return False
def _build_gpu_module(self, attrs):
self.gpu_module = gpu.GPUModuleOp(ir.StringAttr.get("kernels"))
with ir.InsertionPoint(self.gpu_module.bodyRegion.blocks.append(*[])):
@@ -229,8 +285,43 @@ class CutlassBaseDSL(BaseDSL):
return version_hash
@staticmethod
def track_smem_allocator(allocator, callback):
"""
Tracks shared memory usage for kernel functions.
Find and set allocator to its parent dsl object.
"""
frame = inspect.currentframe().f_back
while frame:
obj = frame.f_locals.get("self", None)
if obj and isinstance(obj, CutlassBaseDSL):
obj._set_smem_tracking(allocator, callback)
return
frame = frame.f_back
warnings.warn("Cannot find parent dsl for allocator!", UserWarning)
def _set_smem_tracking(self, allocator, callback):
# Registers an allocator and callback for current dsl
self._smem_usage_tracker = (allocator, callback)
def _reset_smem_tracking(self):
# Clear an allocator and callback for current dsl
self._smem_usage_tracker = None
def _get_smem_usage(self) -> int:
# Treat final allocated bytes of allocator as smem usage
if not self._smem_usage_tracker:
return 0
allocator, callback = self._smem_usage_tracker
return callback(allocator)
def _kernel_helper(self, funcBody, *args, **kwargs):
class _CutlassIrKernelGenHelper(BaseDSL._KernelGenHelper):
def __init__(self, dsl: CutlassBaseDSL):
super().__init__()
self.dsl = dsl
self.dsl._reset_smem_tracking()
def generate_func_op(self, arg_types, arg_attrs, kernel_name, loc=None):
super().generate_func_op(arg_types, arg_attrs, kernel_name)
self.func_op = func.FuncOp(
@@ -272,6 +363,17 @@ class CutlassBaseDSL(BaseDSL):
if cfg.has_cluster:
cfg.cluster = [to_index(size) for size in cfg.cluster]
smem_usage = self.dsl._get_smem_usage()
if any(not isinstance(x, int) for x in [cfg.smem, smem_usage]):
pass # cannot compare dynamic value inside kernel to launch op in py
elif cfg.auto_smem:
cfg.smem = smem_usage
elif smem_usage > cfg.smem:
warnings.warn(
f"Potential error: specified kernel launch smem bytes "
f"({cfg.smem}) is smaller than kernel usage ({smem_usage})!",
UserWarning,
)
cfg.smem = const(cfg.smem)
if not isinstance(cfg.async_deps, (list, tuple)):
@@ -295,12 +397,13 @@ class CutlassBaseDSL(BaseDSL):
return token if is_async else None
return KernelLauncher(
self, _CutlassIrKernelGenHelper, funcBody, *args, **kwargs
self,
lambda: _CutlassIrKernelGenHelper(self),
funcBody,
*args,
**kwargs,
)
def _get_module_globals(self):
return globals()
def _preprocess_launch_config_args(self, args, kwargs):
"""Helper to preprocess args and kwargs for LaunchConfig"""
if "stream" in kwargs:
@@ -316,7 +419,10 @@ class CutlassBaseDSL(BaseDSL):
Validates if the arg is really of the annotated type.
"""
if is_arg_spec_constexpr(arg_annotation, arg_name, arg_index, None):
if (
is_arg_spec_constexpr(arg_annotation, arg_name, arg_index, None)
or arg_annotation is Any
):
pass
else:
origin = get_origin(arg_annotation)
@@ -329,11 +435,12 @@ class CutlassBaseDSL(BaseDSL):
f"expects argument #{arg_index+1} ({arg_name}) to be Type[{expected_base}], but got {arg}"
)
# Handle Union types and generic types
elif origin is Union:
elif origin is Union or isinstance(arg_annotation, UnionType):
# For Union types, check if arg matches any of the allowed types
allowed_types = get_args(arg_annotation)
if not any(
(isinstance(ty, type) and isinstance(arg, ty))
(ty is Any)
or (isinstance(ty, type) and isinstance(arg, ty))
or (get_origin(ty) is tuple and isinstance(arg, tuple))
for ty in allowed_types
):
@@ -381,6 +488,26 @@ class CutlassBaseDSL(BaseDSL):
jit_exec_arg.extend(get_c_pointers(arg) if is_host else dyn_vals)
else:
jit_exec_arg = jit_arg_type = jit_arg_attr = None
elif not hasattr(arg, "__extract_mlir_values__") and not hasattr(
arg, "__new_from_mlir_values__"
):
# Try tree_flatten
try:
dyn_vals, _ = tree_flatten(arg)
except DSLTreeFlattenError:
# If fails, just return the original arg
return jit_exec_arg, jit_arg_type, jit_arg_attr
if dyn_vals:
jit_arg_type.extend([v.type for v in dyn_vals])
jit_arg_attr.extend([default_attr] * len(dyn_vals))
jit_exec_arg.extend(
_get_c_pointers_cutlass(arg) if is_host else dyn_vals
)
else:
# If tree flatten yields empty list, treat it as a constexpr thing
# Like a dataclass with all fields are constexpr, or an empty tuple or list
jit_exec_arg = jit_arg_type = jit_arg_attr = None
return jit_exec_arg, jit_arg_type, jit_arg_attr
def _generate_execution_arguments_for_known_types(
@@ -396,6 +523,17 @@ class CutlassBaseDSL(BaseDSL):
blk_args = fop_args[iv_block_args : iv_block_args + n_args]
ir_arg.append(new_from_mlir_values(arg, blk_args))
iv_block_args += n_args
elif not hasattr(arg, "__extract_mlir_values__") and not hasattr(
arg, "__new_from_mlir_values__"
):
# Try tree_unflatten
try:
dyn_vals, tree_def = tree_flatten(arg)
block_args = fop_args[iv_block_args : iv_block_args + len(dyn_vals)]
ir_arg.append(tree_unflatten(tree_def, block_args))
iv_block_args += len(dyn_vals)
except DSLTreeFlattenError:
return ir_arg, iv_block_args
return ir_arg, iv_block_args
@@ -458,10 +596,7 @@ class KernelLauncher:
def _check_func_args(self, funcBody, *func_args, **func_kwargs):
# Get function signature
if isinstance(funcBody, DSLCallable):
sig = funcBody.get_signature()
else:
sig = inspect.signature(funcBody)
sig = inspect.signature(funcBody)
# func_args and func_kwargs should match funcBody's signature,
# no extra or missing arguments.
@@ -473,6 +608,12 @@ class KernelLauncher:
cause=e,
)
def smem_usage(self) -> int:
"""
Check smem usage for this kernel, only available after `launch`
"""
return self.dsl._get_smem_usage()
def launch(self, *args, **kwargs):
self.dsl.frame = inspect.currentframe().f_back
self.dsl._preprocess_launch_config_args(args, kwargs)
@@ -497,134 +638,151 @@ class KernelLauncher:
# =============================================================================
# Utils
# =============================================================================
def is_frozen_dataclass(obj_or_cls) -> bool:
def _filter_readonly_frozen_dataclass(
iter_args: List[Any], items_to_filter: List[Any], full_write_args_count: int
) -> List[Any]:
"""
Return True if obj_or_cls is a dataclass (class or instance) declared with frozen=True,
otherwise False.
"""
if not isinstance(obj_or_cls, type):
# If it's an instance, get its class
obj_or_cls = obj_or_cls.__class__
Filter items based on whether corresponding iter_args are frozen dataclasses.
# Must be a dataclass, and __dataclass_params__.frozen must be True
return (
is_dataclass(obj_or_cls)
and getattr(obj_or_cls, "__dataclass_params__", None) is not None
and obj_or_cls.__dataclass_params__.frozen
This function filters items (which can be values or names) based on the same
logic: keep items if they correspond to full-write arguments (index < full_write_args_count)
or if the corresponding iter_arg is not a frozen dataclass.
Args:
iter_args: List of arguments to check for frozen dataclass status
items_to_filter: List of items to filter (values or names)
full_write_args_count: Number of arguments that are always written (not read-only)
Returns:
Filtered list of items
Examples:
# Filter values (original remove_read_only_frozen_dataclass behavior)
filtered_values = _filter_readonly_frozen_dataclass(iter_args, iter_args, full_write_args_count)
# Filter names (original filter_readonly_frozen_dataclass_names behavior)
filtered_names = _filter_readonly_frozen_dataclass(iter_args, iter_args_names, full_write_args_count)
"""
return [
item
for i, item in enumerate(items_to_filter)
if i < full_write_args_count or not is_frozen_dataclass(iter_args[i])
]
def remove_read_only_frozen_dataclass(
iter_args: List[Any], full_write_args_count: int
) -> List[Any]:
"""Filter out frozen dataclass arguments that are not full-write arguments."""
return _filter_readonly_frozen_dataclass(
iter_args, iter_args, full_write_args_count
)
def filter_readonly_frozen_dataclass_names(
iter_args: List[Any], iter_args_names: List[str], full_write_args_count: int
) -> List[str]:
"""Filter names based on whether corresponding iter_args are frozen dataclasses."""
return _filter_readonly_frozen_dataclass(
iter_args, iter_args_names, full_write_args_count
)
def insert_read_only_frozen_dataclass(
iter_args: List[Any], original_iter_args: List[Any], full_write_args_count: int
) -> List[Any]:
"""
Insert read-only frozen dataclass arguments back into the iteration arguments.
This function takes the new iteration arguments and the original arguments,
and preserves frozen dataclass instances from the original arguments while
using the new arguments for non-frozen dataclass instances.
Args:
iter_args: New iteration arguments to use for non-frozen dataclass instances
original_iter_args: Original iteration arguments to preserve frozen dataclass instances from
full_write_args_count: Number of arguments that are always written (not read-only)
Returns:
List of arguments with frozen dataclass instances preserved from original
"""
# Take full-write arguments from new iter_args
full_write_args = (
iter_args[:full_write_args_count] if full_write_args_count > 0 else []
)
# Process remaining arguments: preserve frozen dataclass from original, use new for others
remaining_original = original_iter_args[full_write_args_count:]
remaining_new = iter_args[full_write_args_count:]
def process_remaining_arg(original_arg, new_arg_iter):
"""Process a single remaining argument, preserving frozen dataclass if present"""
return original_arg if is_frozen_dataclass(original_arg) else next(new_arg_iter)
# Use zip to pair original args with new args, then map the processing function
new_arg_iter = iter(remaining_new)
processed_remaining = [
process_remaining_arg(orig_arg, new_arg_iter) for orig_arg in remaining_original
]
return full_write_args + processed_remaining
def unpack_to_irvalue(
mixed_values: List[Any], body_name: str, full_write_args_count: int
) -> Tuple[List[ir.Value], PyTreeDef]:
log().debug("===--- Values UNPack")
for idx, packed in enumerate(mixed_values):
log().debug("[%d]: will-unpacked: [type:%s] %s", idx, type(packed), packed)
try:
unpacked_values, treedef = tree_flatten(
remove_read_only_frozen_dataclass(mixed_values, full_write_args_count)
)
except DSLTreeFlattenError as e:
raise DSLRuntimeError(
f"The '{body_name}' statement encountered a user-defined Python object, which cannot be automatically converted into an dynamic expression.",
context={
e.message: (
f"All expressions within '{body_name}' must be dynamic expressions, "
"mixing Python objects and dynamic expressions is not supported. "
"The DSL failed to convert the Python object into dynamic expressions."
)
},
suggestion=(
f"Please ensure '{e.type_str}' implements the '{DynamicExpression.__name__}' or mark with `dataclass`, "
f"so it can be treated as a valid dynamic expression or mark '{body_name}' as a constant expression if conditions are Python objects."
),
)
log().debug("------------------ ")
for idx, unpacked in enumerate(unpacked_values):
log().debug("[%d]: unpacked values: %s", idx, unpacked)
log().debug("treedef: %s", treedef)
log().debug("------------------ ")
return unpacked_values, treedef
def pack_from_irvalue(
ir_values: List["ir.Value"],
indices: Dict[int, Tuple[int, int]],
class_types: List[Any],
pytree_def: PyTreeDef,
mixed_values: List[Any],
full_write_args_count: int,
) -> List[Any]:
"""
Packs MLIR values into a list of mixed values.
"""
log().debug("===--- Values Pack (%d)", len(ir_values))
for idx, packed in enumerate(ir_values):
log().debug("[%d]: will-packed: %s", idx, ir_values)
for idx, unpacked in indices.items():
log().debug("[%d]: indices: %s", idx, unpacked)
for idx, c in enumerate(class_types):
log().debug("[%d]: obj-types: %s", idx, type(c))
mixed_values = [None] * len(indices)
for idx, (start, length) in sorted(indices.items()):
chunk = ir_values[start : start + length]
obj = class_types[idx]
if is_frozen_dataclass(obj):
mixed_values[idx] = obj
elif not isinstance(obj, type) and hasattr(obj, "__new_from_mlir_values__"):
mixed_values[idx] = obj.__new_from_mlir_values__(chunk)
elif isinstance(chunk, list) and chunk[0] is None:
mixed_values[idx] = class_types[idx]
else:
if len(chunk) == 1:
try:
mixed_values[idx] = t.as_numeric(chunk[0])
except ValueError:
# Suppress the conversion error and try new_from_mlir_values below
pass
if mixed_values[idx] is None:
mixed_values[idx] = new_from_mlir_values(obj, chunk)
log().debug("------------------ ")
for idx, packed in enumerate(mixed_values):
log().debug("[%d]: packed: %s", idx, packed)
log().debug("------------------ ")
return mixed_values
def unpack_to_irvalue(
mixed_values: List[Any], body_name: str
) -> Tuple[List[ir.Value], List[Any], Dict[int, Tuple[int, int]], List[Any]]:
"""
Unpacks mixed values into ir.Value values.
"""
unpacked_values = []
ir_values = []
indices = {}
class_types = []
current_offset = 0
log().debug("===--- Values UNPack (%d)", len(mixed_values))
for idx, packed in enumerate(mixed_values):
log().debug("[%d]: will-unpacked: [type:%s] %s", idx, type(packed), packed)
for idx, item in enumerate(mixed_values):
class_types.append(item)
try:
if is_frozen_dataclass(item):
extracted_vals = [None]
else:
extracted_vals = extract_mlir_values(item)
# it's consexpr (python value), so we create mlir value for it
if extracted_vals == []:
if item is None:
extracted_vals = [None]
else:
dyn_expr = t.as_numeric(item)
extracted_vals = extract_mlir_values(dyn_expr)
ir_values.extend(extracted_vals)
else:
ir_values.extend(extracted_vals)
unpacked_values.extend(extracted_vals)
length = len(extracted_vals)
indices[idx] = (current_offset, length)
current_offset += length
except Exception as e:
raise DSLRuntimeError(
f"The '{body_name}' statement encountered a user-defined Python object, which cannot be automatically converted into an dynamic expression (aka MLIR value).",
context={
item: (
f"All expressions within '{body_name}' must be dynamic expressions, "
"mixing Python objects and dynamic expressions (aka MLIR values) is not supported. "
"The DSL failed to convert the Python object into MLIR values."
)
},
suggestion=(
f"Please ensure '{item}' implements the '{DynamicExpression.__name__}', "
f"so it can be treated as a valid dynamic expression or mark '{body_name}' as a constant expression if conditions are Python objects."
),
) from e
log().debug("------------------ ")
for idx, unpacked in enumerate(unpacked_values):
log().debug("[%d]: unpacked values: %s", idx, unpacked)
for idx, unpacked in enumerate(ir_values):
log().debug("[%d]: unpacked ir_values: %s", idx, unpacked)
for idx, unpacked in indices.items():
log().debug("[%d]: indices: %s", idx, unpacked)
for idx, unpacked in enumerate(class_types):
log().debug("[%d]: initial-class-types: %s", idx, unpacked)
for idx, value in enumerate(ir_values):
log().debug("[%d]: will-packed: %s", idx, value)
log().debug("treedef: %s", pytree_def)
log().debug("------------------ ")
return ir_values, unpacked_values, indices, class_types
unflattened = tree_unflatten(pytree_def, ir_values)
return insert_read_only_frozen_dataclass(
unflattened, mixed_values, full_write_args_count
)
def to_index(value):
@@ -1015,8 +1173,8 @@ def any_(iterable):
def select_(cond, if_value, else_value):
def _as_scalar(value):
if const_expr(isinstance(value, list)):
if const_expr(len(value) == 1):
if isinstance(value, list):
if len(value) == 1:
return value[0]
else:
raise DSLRuntimeError(
@@ -1024,16 +1182,16 @@ def select_(cond, if_value, else_value):
)
return value
if const_expr(not is_dynamic_expression(cond)):
if not is_dynamic_expression(cond):
raise DSLRuntimeError("Conditional expression must be dynamic")
# Extract MLIR values
cond = extract_mlir_values(cond)
if const_expr(is_dynamic_expression(if_value)):
if is_dynamic_expression(if_value):
if_value = extract_mlir_values(if_value)
else:
if_value = const(if_value)
if const_expr(is_dynamic_expression(else_value)):
if is_dynamic_expression(else_value):
else_value = extract_mlir_values(else_value)
else:
else_value = const(else_value)
@@ -1089,7 +1247,7 @@ def for_generate(
iter_args: Optional[Sequence[ir.Value]] = None,
*,
unroll: LoopUnroll = None,
pipelining=None,
prefetch_stages=None,
loc=None,
ip=None,
):
@@ -1127,8 +1285,8 @@ def for_generate(
if unroll is not None:
for_op.attributes["loop_annotation"] = unroll
if pipelining is not None:
for_op.attributes["cutlass.pipelining"] = _createI32Attr(pipelining)
if prefetch_stages is not None:
for_op.attributes["cutlass.pipelining"] = _createI32Attr(prefetch_stages)
iv = for_op.induction_variable
new_results = new_from_mlir_values(iter_args, for_op.results)
@@ -1155,11 +1313,11 @@ def not_(lhs: Union[ir.Value, bool], *, loc=None, ip=None):
"""
res = None
# Handle Python bool first to prevent infinite recursion
if const_expr(type(lhs) == bool):
if type(lhs) == bool:
res = lhs ^ True
elif const_expr(hasattr(lhs, "__dsl_not__")):
elif hasattr(lhs, "__dsl_not__"):
res = lhs.__dsl_not__(loc=loc, ip=ip)
elif const_expr(is_dynamic_expression(lhs)):
elif is_dynamic_expression(lhs):
# If lhs is MLIR value, compute not using xor
res = arith.XOrIOp(lhs, const(1, lhs.type)).result
else:
@@ -1338,29 +1496,59 @@ def equal(lhs, rhs):
return lhs == rhs
def in_(lhs, rhs, op):
def not_equal(lhs, rhs):
if not is_dynamic_expression(lhs) and not is_dynamic_expression(rhs):
return lhs != rhs
# Both sequence
if isinstance(lhs, Sequence) and isinstance(rhs, Sequence):
# Short-circuit for unequal length
if len(lhs) != len(rhs):
return True
return any_(not_equal(l, r) for l, r in zip(lhs, rhs))
if hasattr(lhs, "__ne__"):
return lhs != rhs
elif hasattr(rhs, "__ne__"):
return rhs != lhs
else:
return not_(equal(lhs, rhs))
def in_(lhs, rhs):
if not is_dynamic_expression(lhs) and not is_dynamic_expression(rhs):
return lhs in rhs
if not isinstance(rhs, Sequence):
raise DSLRuntimeError(
f"'{op}' not supported between instances of {type(lhs)} and {type(rhs)}"
f"'in' not supported between instances of {type(lhs)} and {type(rhs)}"
)
return any_(equal(lhs, r) for r in rhs)
def _lt_gt(lhs, rhs, op):
def native_lt_gt(lhs, rhs, op):
if op == "<":
return lhs < rhs
elif op == ">":
return lhs > rhs
else:
raise DSLRuntimeError(f"Unsupported comparison operator: {op}")
def _lte_gte(lhs, rhs, op):
def native_lte_gte(lhs, rhs, op):
match op:
case "<":
return lhs < rhs
case "<=":
if hasattr(lhs, "__le__"):
return lhs <= rhs
else:
return not_(lhs > rhs)
case ">":
return lhs > rhs
case ">=":
if hasattr(lhs, "__ge__"):
return lhs >= rhs
else:
return not_(lhs < rhs)
case _:
raise DSLRuntimeError(f"Unsupported comparison operator: {op}")
if not is_dynamic_expression(lhs) and not is_dynamic_expression(rhs):
return native_lt_gt(lhs, rhs, op)
return native_lte_gte(lhs, rhs, op)
# Both sequence, comparisons other than == and != do not allow mixing different types of sequences
if (
@@ -1375,7 +1563,7 @@ def _lt_gt(lhs, rhs, op):
is_equal = equal(l, r)
mask.append(not_(or_(is_equal, unequal_found)))
unequal_found = not_(is_equal)
comp_results.append(_lt_gt(l, r, op))
comp_results.append(_lte_gte(l, r, op))
result = any_(and_(r, m) for r, m in zip(comp_results, mask))
@@ -1383,62 +1571,126 @@ def _lt_gt(lhs, rhs, op):
# Ref https://docs.python.org/3/tutorial/datastructures.html#comparing-sequences-and-other-types
# If one sequence is an initial sub-sequence of the other, the shorter sequence is the smaller (lesser) one
has_valid_mask = any_(mask)
if op == "<":
length_result = len(lhs) < len(rhs)
elif op == ">":
length_result = len(lhs) > len(rhs)
match op:
case "<":
length_result = len(lhs) < len(rhs)
case ">":
length_result = len(lhs) > len(rhs)
case "<=":
length_result = len(lhs) <= len(rhs)
case ">=":
length_result = len(lhs) >= len(rhs)
if type(has_valid_mask) == bool:
return result if has_valid_mask else length_result
else:
return select_(has_valid_mask, result, length_result)
else:
return result
if op in {"<=", ">="}:
# If no unequal, return True
return select_(unequal_found, result, True)
else:
return result
else:
return native_lt_gt(lhs, rhs, op)
return native_lte_gte(lhs, rhs, op)
def greater_than(lhs, rhs):
return _lt_gt(lhs, rhs, ">")
return _lte_gte(lhs, rhs, ">")
def greater_equal(lhs, rhs):
return _lte_gte(lhs, rhs, ">=")
def less_than(lhs, rhs):
return _lt_gt(lhs, rhs, "<")
return _lte_gte(lhs, rhs, "<")
def less_equal(lhs, rhs):
return _lte_gte(lhs, rhs, "<=")
def _compare_dispatch(lhs, rhs, op):
"""
Dispatches the comparison operation between lhs and rhs based on the given operator.
:param lhs: The left-hand side operand for the comparison.
:param rhs: The right-hand side operand for the comparison.
:param op: The comparison operator as a string. Supported operators are:
- "is", "is not": Python identity comparisons.
- "in", "not in": Membership tests.
- "==", "!=": Equality and inequality.
- "<", ">", "<=", ">=": Relational comparisons.
:return: The result of the comparison, which may be a boolean or a DSL-specific type.
:raises DSLRuntimeError: If the operator is not supported.
"""
match op:
# 'is' and 'is not' are pure python operators
case "is":
return lhs is rhs
case "is not":
return lhs is not rhs
case "in":
return in_(lhs, rhs)
case "not in":
return not_(in_(lhs, rhs))
case "==":
return equal(lhs, rhs)
case "!=":
return not_equal(lhs, rhs)
case "<":
return less_than(lhs, rhs)
case ">":
return greater_than(lhs, rhs)
case ">=":
return greater_equal(lhs, rhs)
case "<=":
return less_equal(lhs, rhs)
case _:
raise DSLRuntimeError(f"Unsupported comparison operator: {op}")
def _compare_executor(left, comparators, ops):
result = left
# Fast path for single comparison
if len(comparators) == 1:
return _compare_dispatch(left, comparators[0], ops[0])
# Chain comparison, dispatch in a loop
result = True
current = left
for comparator, op in zip(comparators, ops):
# 'is' and 'is not' are pure python operators
if op == "is":
result = result is comparator
elif op == "is not":
result = result is not comparator
elif op in ["in", "not in"]:
result = in_(left, comparator, op)
elif op in ["==", "!="]:
result = equal(left, comparator)
elif op in ["<", ">="]:
result = less_than(left, comparator)
elif op in [">", "<="]:
result = greater_than(left, comparator)
else:
raise DSLRuntimeError(f"Unsupported comparison operator: {op}")
# Invert the result for NotIn, NotEq, GtE, LtE
if op in ["not in", "!=", ">=", "<="]:
result = not_(result)
cmp_result = _compare_dispatch(current, comparator, op)
result = and_(result, cmp_result)
current = comparator
return result
def _builtin_redirector(fcn):
if fcn == builtins.max:
return max
elif fcn == builtins.min:
return min
elif fcn == builtins.any:
return any_
elif fcn == builtins.all:
return all_
else:
raise DSLRuntimeError(f"Unsupported built-in function: {fcn}")
# =============================================================================
# Set the AST decorator
# =============================================================================
# Set the DSL specific functions
executor.set_functions(
is_dynamic_expression,
_loop_execute_range_dynamic,
_if_execute_dynamic,
_while_execute_dynamic,
_compare_executor,
any_,
all_,
is_dynamic_expression=is_dynamic_expression,
loop_execute_range_dynamic=_loop_execute_range_dynamic,
if_dynamic=_if_execute_dynamic,
while_dynamic=_while_execute_dynamic,
compare_executor=_compare_executor,
any_executor=any_,
all_executor=all_,
builtin_redirector=_builtin_redirector,
)

View File

@@ -14,13 +14,22 @@ from types import NoneType
from cutlass._mlir import ir
from cutlass._mlir.dialects import scf, arith
from cutlass._mlir.extras import types as T
from collections.abc import Sequence
from ..base_dsl.dsl import extract_mlir_values, new_from_mlir_values
from ..base_dsl.dsl import is_dynamic_expression
from ..base_dsl.ast_helpers import *
from ..base_dsl.utils.logger import log
from ..base_dsl import typing as t
from ..base_dsl.typing import Int32, Float32, Boolean, Numeric, get_mlir_types
from ..base_dsl.typing import (
Int32,
Float32,
Boolean,
Numeric,
get_mlir_types,
as_numeric,
)
from . import cutlass as cutlass_dsl
from .tree_utils import PyTreeDef, check_tree_equal
# =============================================================================
# AST Helpers
@@ -57,14 +66,6 @@ class ScfGenerator:
def __init__(self):
pass
@staticmethod
def fill_none(ir_values, unpacked_values):
i = 0
for idx, item in enumerate(unpacked_values):
if item is not None:
unpacked_values[idx] = ir_values[i]
i += 1
@staticmethod
def _normalize_region_result_to_list(region_result: Any) -> List[Any]:
"""
@@ -82,34 +83,109 @@ class ScfGenerator:
return region_result_list
@staticmethod
def check_region_result(region_values, ir_values):
for i, (expected_value, actual_value) in enumerate(
zip(ir_values, region_values)
def _check_region_result(original_value, region_value, arg_name, op_type_name):
"""
Validate that a region result maintains the same type as the original value.
This method checks for type consistency between the original value passed to a dynamic
SCF operation (like for, if, while) and the value returned from the operation's region.
Args:
original_value: The value before entering the SCF operation region
region_value: The value returned from the SCF operation region
arg_name: Name of the argument being checked (for error reporting)
op_type_name: Type of SCF operation (e.g., 'for', 'if', 'while') for error reporting
Raises:
DSLRuntimeError: If the region value has a different type than the original value.
The error includes suggestions for using compile-time control flow instead.
Note:
This method performs relaxed type checking that allows inheritance relationships.
For example, a child class can be returned where a parent class was expected.
However, fundamental type changes (like None to non-None, different sequence types,
or different numeric types) are not allowed in dynamic SCF operations.
"""
def get_type_name(value):
if isinstance(value, NoneType):
return "None"
elif isinstance(value, Sequence):
return f"{type(value).__name__}<{len(value)}>"
else:
return type(value).__name__
# Check for type mismatches
type_mismatch = False
old_type_name = None
new_type_name = None
# Handle None type changes
if isinstance(original_value, NoneType) != isinstance(region_value, NoneType):
type_mismatch = True
old_type_name = get_type_name(original_value)
new_type_name = get_type_name(region_value)
# Handle sequence type/length changes
elif isinstance(original_value, Sequence) and isinstance(
region_value, Sequence
):
expected_value_type = get_mlir_types(expected_value)
actual_value_type = get_mlir_types(actual_value)
if expected_value_type != actual_value_type:
return False, i, expected_value_type, actual_value_type
return True, -1, None, None
if type(original_value) != type(region_value) or len(original_value) != len(
region_value
):
type_mismatch = True
old_type_name = get_type_name(original_value)
new_type_name = get_type_name(region_value)
# Handle numeric type changes
elif isinstance(
original_value, (Numeric, ArithValue, ir.Value, int, float, bool)
) or isinstance(
region_value, (Numeric, ArithValue, ir.Value, int, float, bool)
):
try:
original_numeric = as_numeric(original_value)
region_numeric = as_numeric(region_value)
if original_numeric.dtype != region_numeric.dtype:
type_mismatch = True
old_type_name = original_numeric.dtype.__name__
new_type_name = region_numeric.dtype.__name__
except Exception:
pass
# Handle general type changes (relaxed for inheritance)
elif type(original_value) != type(region_value):
old_type = type(original_value)
new_type = type(region_value)
if not (issubclass(old_type, new_type) or issubclass(new_type, old_type)):
type_mismatch = True
old_type_name = old_type.__name__
new_type_name = new_type.__name__
if type_mismatch:
raise DSLRuntimeError(
f"`{arg_name}` is {old_type_name} prior to this `{op_type_name}`, "
f"and update to {new_type_name} inside of this `{op_type_name}` is not supported.",
suggestion=(
f"Please avoid changing type inside a dynamic `{op_type_name}`, "
f"or change to compile-time control flow by marking this `{op_type_name}` with "
f"`{'range_constexpr' if op_type_name == 'for' else 'const_expr'}`."
),
)
def scf_execute_dynamic(
self,
op_type_name: str,
used_args: List[Any],
mix_iter_args: List[Any],
full_write_args_count: int,
mix_iter_arg_names: List[str],
create_op_func: Callable[
[List[ir.Value], Dict[int, Tuple[int, int]], List[Any]], ir.Operation
],
create_op_func: Callable[[List[ir.Value]], ir.Operation],
region_builders: List[
Callable[
[
"ir.Operation",
List["ir.Value"], # block_args
List[Any], # used_args
List["ir.Value"], # dyn_yield_ops
Dict[int, Tuple[int, int]],
PyTreeDef,
List[Any],
int,
],
Any,
]
@@ -119,11 +195,11 @@ class ScfGenerator:
block_term_op_builder: Dict[Callable, Callable] = {},
) -> Any:
# 1) Unpack
ir_values, dyn_unpacked_values, dyn_indices, dyn_class_types = (
cutlass_dsl.unpack_to_irvalue(mix_iter_args, op_type_name)
ir_values, pytree_def = cutlass_dsl.unpack_to_irvalue(
mix_iter_args, op_type_name, full_write_args_count
)
# 2) Create the SCF op
op = create_op_func(ir_values, dyn_indices, dyn_class_types)
op = create_op_func(ir_values)
log().debug("Generated scf.%s \n[%s]", op_type_name, op)
# 3) Build the regions
@@ -135,76 +211,61 @@ class ScfGenerator:
region_result = builder(
op,
block_args,
used_args,
dyn_unpacked_values,
dyn_indices,
dyn_class_types,
ir_values,
pytree_def,
mix_iter_args,
full_write_args_count,
)
# Use custom terminator if provided for this builder, otherwise use default YieldOp
if builder in block_term_op_builder:
# Use the provided terminator generator
block_term_op_builder[builder](region_result)
block_term_op_builder[builder](region_result, full_write_args_count)
else:
# For standard yield op, check result
for arg, result, name in zip(
mix_iter_args,
(
region_result
if isinstance(region_result, list)
else [region_result]
),
mix_iter_arg_names,
):
if isinstance(arg, NoneType) and not isinstance(
result, NoneType
):
raise DSLRuntimeError(
(
f"`{name}` is None prior to this `{op_type_name}`, "
f"and update to non-None value inside of this `{op_type_name}` is not supported."
),
suggestion=(
f"Please make sure `{name}` is not None prior to this `{op_type_name}`, "
f"or mark this `{op_type_name}` with "
f"`{'range' if op_type_name == 'for' else 'const_expr'}`."
),
)
# Normalize region_result
region_result_list = ScfGenerator._normalize_region_result_to_list(
region_result
)
# Default behavior - generate YieldOp
region_values, unpacked_values, _, _ = (
cutlass_dsl.unpack_to_irvalue(region_result_list, op_type_name)
)
is_match, mismatch_idx, expected_type, actual_type = (
ScfGenerator.check_region_result(region_values, ir_values)
)
if not is_match:
# From unpacked index, we need to find the original index
original_idx = -1
for unpacked_idx, (original_idx, length) in dyn_indices.items():
if (
mismatch_idx >= original_idx
and mismatch_idx < original_idx + length
):
original_idx = unpacked_idx
break
raise DSLRuntimeError(
f"`{op_type_name}` expects {expected_type} type for varible `{mix_iter_arg_names[original_idx]}`, but got {actual_type}.",
suggestion=f"Please make sure `{mix_iter_arg_names[original_idx]}` type is not changed inside of `{op_type_name}`.",
# For standard yield op, check result
for arg, result, name in zip(
mix_iter_args,
region_result_list,
mix_iter_arg_names,
):
ScfGenerator._check_region_result(
arg, result, name, op_type_name
)
# Default behavior - generate YieldOp
region_values, yield_pytree_def = cutlass_dsl.unpack_to_irvalue(
region_result_list, op_type_name, full_write_args_count
)
mismatch = check_tree_equal(pytree_def, yield_pytree_def)
if mismatch != -1:
# Get arg name
filterd_arg_names = (
cutlass_dsl.filter_readonly_frozen_dataclass_names(
mix_iter_args, mix_iter_arg_names, full_write_args_count
)
)
raise DSLRuntimeError(
f"`{filterd_arg_names[mismatch]}` is structured different after this `{op_type_name}`.",
suggestion=(
f"Please avoid changing type structure inside a dynamic `{op_type_name}`, "
f"or change to compile-time control flow by marking this `{op_type_name}` with "
f"`{'range_constexpr' if op_type_name == 'for' else 'const_expr'}`."
),
)
scf.YieldOp(region_values)
log().debug("Completed scf.%s \n[%s]", op_type_name, op)
ScfGenerator.fill_none(op.results, unpacked_values)
# 4) Pack final results
final_results = cutlass_dsl.pack_from_irvalue(
unpacked_values, dyn_indices, dyn_class_types
op.results, pytree_def, mix_iter_args, full_write_args_count
)
# 5) Return in a nice pattern
@@ -215,28 +276,32 @@ class ScfGenerator:
return final_results
def _attr_const_check(attr, expected_type, attr_name):
# Use strict type equality to prevent `bool` being accepted where `int` is required.
if is_dynamic_expression(attr) or type(attr) is not expected_type:
raise DSLRuntimeError(
f"loop attribute `{attr_name}` must be a Python value of type `{expected_type.__name__}`, got `{type(attr).__name__}`."
)
def _loop_execute_range_dynamic(
func: Callable,
start: Any,
stop: Any,
step: Any,
used_args: List[Any] = [],
mix_iter_args: List[Any] = [],
full_write_args_count: int = 0,
mix_iter_arg_names: List[str] = [],
unroll: int = -1,
unroll_full: bool = False,
pipelining: int = None,
prefetch_stages: int = None,
):
"""
Example: build an scf.for with optional unroll, using our universal helper.
"""
scf_gen = ScfGenerator()
def create_for_op(
dyn_yield_ops: List[ir.Value],
dyn_indices: Dict[int, Tuple[int, int]],
dyn_class_types: List[Any],
):
def create_for_op(dyn_yield_ops: List[ir.Value]):
for d in dyn_yield_ops:
if not isinstance(d, ir.Value):
raise DSLRuntimeError(
@@ -254,6 +319,10 @@ def _loop_execute_range_dynamic(
stop_ = stop_.ir_value()
step_ = step_.ir_value()
# Attributes must be pure Python value, add a check
_attr_const_check(unroll, int, "unroll")
_attr_const_check(unroll_full, bool, "unroll_full")
# Possibly attach unroll attributes
unroll_attr = None
if unroll_full:
@@ -262,17 +331,18 @@ def _loop_execute_range_dynamic(
unroll_attr = LoopUnroll(count=unroll)
log().debug("Unroll attribute: %s", unroll_attr)
pipelining_attr = None
if pipelining is not None:
if pipelining >= 0:
pipelining_attr = ir.IntegerAttr.get(
ir.IntegerType.get_signless(32), pipelining
prefetch_stages_attr = None
if prefetch_stages is not None:
_attr_const_check(prefetch_stages, int, "prefetch_stages")
if prefetch_stages >= 0:
prefetch_stages_attr = ir.IntegerAttr.get(
ir.IntegerType.get_signless(32), prefetch_stages
)
else:
raise DSLRuntimeError(
f"Pipelining must be non-negative, got {pipelining}"
f"loop attribute `prefetch_stages` must be non-negative, got `{prefetch_stages}`."
)
log().debug("Pipelining attribute: %s", pipelining_attr)
log().debug("prefetch_stages attribute: %s", prefetch_stages_attr)
log().debug(
"Creating scf.ForOp \n\t\tstart=%s: type : %s\n\t\tstop=%s: type : %s\n\t\tstep=%s: type : %s",
@@ -303,47 +373,48 @@ def _loop_execute_range_dynamic(
if unroll_attr is not None:
for_op.attributes["loop_annotation"] = unroll_attr
if pipelining_attr is not None:
for_op.attributes["cutlass.pipelining"] = pipelining_attr
if prefetch_stages_attr is not None:
for_op.attributes["cutlass.pipelining"] = prefetch_stages_attr
return for_op
def for_body_builder(
op, block_args, used_args, dyn_yield_ops, dyn_indices, dyn_class_types
op,
block_args,
_,
pytree_def,
mix_iter_args,
full_write_args_count,
):
# Insert induction variable at the beginning
dyn_yield_ops.insert(0, block_args[0])
ScfGenerator.fill_none(block_args, dyn_yield_ops)
block_args = dyn_yield_ops
# scf.ForOp block_args are typically [induction_var, iter_args...]
# But MLIR also gives you op.induction_variable
iv = t.as_numeric(op.induction_variable)
log().debug(
"For body builder: %s block_args: %s used_args: %s",
"For body builder: %s block_args: %s full_write_args_count: %s",
iv,
block_args,
used_args,
full_write_args_count,
)
if len(block_args) <= 1:
# block_args[1:] are iteration variables
func_args = []
func_args.extend(
cutlass_dsl.pack_from_irvalue(
block_args[1:], pytree_def, mix_iter_args, full_write_args_count
)
)
if not func_args:
# No iteration arguments, or only the induction var
func(iv, *used_args)
func(iv)
return [] # yield nothing
else:
# block_args[1:] are iteration variables
func_args = [*used_args]
func_args.extend(
cutlass_dsl.pack_from_irvalue(
block_args[1:], dyn_indices, dyn_class_types
)
)
updated_func_args = func(iv, *func_args)
return updated_func_args
# Now call the universal SCF executor with a single region builder
return scf_gen.scf_execute_dynamic(
op_type_name="for",
used_args=used_args,
mix_iter_args=mix_iter_args,
full_write_args_count=full_write_args_count,
mix_iter_arg_names=mix_iter_arg_names,
create_op_func=create_for_op,
region_builders=[for_body_builder],
@@ -354,8 +425,8 @@ def _if_execute_dynamic(
pred: "ir.Value",
then_block: Callable,
else_block: Callable = None,
used_args: List[Any] = [],
mix_yield_args: List[Any] = [],
full_write_args_count: int = 0,
mix_yield_arg_names: List[str] = [],
if_constexpr=None, # ignoring for brevity
):
@@ -364,11 +435,7 @@ def _if_execute_dynamic(
"""
scf_gen = ScfGenerator()
def create_if_op(
dyn_yield_ops: List[ir.Value],
dyn_indices: Dict[int, Tuple[int, int]],
dyn_class_types: List[Any],
):
def create_if_op(dyn_yield_ops: List[ir.Value]):
# Assume final result types match the dynamic yields
result_types = [arg.type for arg in dyn_yield_ops]
@@ -387,11 +454,18 @@ def _if_execute_dynamic(
return if_op
def then_builder(
if_op, block_args, used_args, dyn_yield_ops, dyn_indices, dyn_class_types
if_op,
_,
dyn_yield_ops,
pytree_def,
mix_iter_args,
full_write_args_count,
):
flat_args = [*used_args]
flat_args = []
flat_args.extend(
cutlass_dsl.pack_from_irvalue(dyn_yield_ops, dyn_indices, dyn_class_types)
cutlass_dsl.pack_from_irvalue(
dyn_yield_ops, pytree_def, mix_iter_args, full_write_args_count
)
)
return then_block(*flat_args)
@@ -400,12 +474,17 @@ def _if_execute_dynamic(
if else_block is not None:
def else_builder(
if_op, block_args, used_args, dyn_yield_ops, dyn_indices, dyn_class_types
if_op,
_,
dyn_yield_ops,
pytree_def,
mix_iter_args,
full_write_args_count,
):
flat_args = [*used_args]
flat_args = []
flat_args.extend(
cutlass_dsl.pack_from_irvalue(
dyn_yield_ops, dyn_indices, dyn_class_types
dyn_yield_ops, pytree_def, mix_iter_args, full_write_args_count
)
)
return else_block(*flat_args)
@@ -414,8 +493,8 @@ def _if_execute_dynamic(
return scf_gen.scf_execute_dynamic(
op_type_name="if",
used_args=used_args,
mix_iter_args=mix_yield_args,
full_write_args_count=full_write_args_count,
mix_iter_arg_names=mix_yield_arg_names,
create_op_func=create_if_op,
region_builders=region_builders,
@@ -425,9 +504,9 @@ def _if_execute_dynamic(
def _while_execute_dynamic(
while_before_block: Callable,
while_after_block: Callable = None,
used_args=[],
yield_args=[],
yield_arg_names=[],
write_args=[],
full_write_args_count=0,
write_args_names=[],
):
"""
Create and return an SCF WhileOp for dynamic loops.
@@ -436,8 +515,7 @@ def _while_execute_dynamic(
Args:
while_before_block: Function that returns (condition, updated_values)
while_after_block: Function that returns updated values
used_args: Additional arguments used in the loop body
yield_args: Values that are updated in the loop
write_args: Values that are updated in the loop
See create_while_function in ast_preprocessor.py for details on the input structure.
"""
@@ -445,11 +523,7 @@ def _while_execute_dynamic(
while_op_type_name = "while"
scf_gen = ScfGenerator()
def create_while_op(
dyn_yield_ops: List[ir.Value],
dyn_indices: Dict[int, Tuple[int, int]],
dyn_class_types: List[Any],
):
def create_while_op(dyn_yield_ops: List[ir.Value]):
# Create the while operation with the types from yield_args
result_types = [arg.type for arg in dyn_yield_ops]
try:
@@ -468,14 +542,19 @@ def _while_execute_dynamic(
) from e
def before_block_builder(
op, block_args, used_args, dyn_yield_ops, dyn_indices, dyn_class_types
op,
block_args,
_,
pytree_def,
mix_iter_args,
full_write_args_count,
):
# Build the before (condition) block
ScfGenerator.fill_none(block_args, dyn_yield_ops)
block_args = dyn_yield_ops
flat_args = [*used_args]
flat_args = []
flat_args.extend(
cutlass_dsl.pack_from_irvalue(block_args, dyn_indices, dyn_class_types)
cutlass_dsl.pack_from_irvalue(
block_args, pytree_def, mix_iter_args, full_write_args_count
)
)
log().debug("before block args: %s", flat_args)
@@ -493,18 +572,15 @@ def _while_execute_dynamic(
return cond, before_results
def before_block_terminator(cond_and_results):
def before_block_terminator(cond_and_results, full_write_args_count):
# Generate a condition op instead of yield op
cond = cond_and_results[0]
before_result_list = ScfGenerator._normalize_region_result_to_list(
cond_and_results[1]
)
ir_cond_list, _, _, _ = cutlass_dsl.unpack_to_irvalue(
[cond], while_op_type_name
)
ir_cond = ir_cond_list[0]
ir_results_list, _, _, _ = cutlass_dsl.unpack_to_irvalue(
before_result_list, while_op_type_name
ir_cond = as_numeric(cond).ir_value()
ir_results_list, pytree_def = cutlass_dsl.unpack_to_irvalue(
before_result_list, while_op_type_name, full_write_args_count
)
log().debug(
"creating scf.ConditionOp with [%s], [%s]",
@@ -514,14 +590,19 @@ def _while_execute_dynamic(
scf.ConditionOp(ir_cond, ir_results_list)
def after_block_builder(
op, block_args, used_args, dyn_yield_ops, dyn_indices, dyn_class_types
op,
block_args,
_,
pytree_def,
mix_iter_args,
full_write_args_count,
):
# Build the after (body) block
ScfGenerator.fill_none(block_args, dyn_yield_ops)
block_args = dyn_yield_ops
flat_args = [*used_args]
flat_args = []
flat_args.extend(
cutlass_dsl.pack_from_irvalue(block_args, dyn_indices, dyn_class_types)
cutlass_dsl.pack_from_irvalue(
block_args, pytree_def, mix_iter_args, full_write_args_count
)
)
log().debug("after block args: %s", flat_args)
@@ -541,9 +622,9 @@ def _while_execute_dynamic(
# Call the universal SCF executor with two region builders
return scf_gen.scf_execute_dynamic(
op_type_name=while_op_type_name,
used_args=used_args,
mix_iter_args=yield_args,
mix_iter_arg_names=yield_arg_names,
mix_iter_args=write_args,
full_write_args_count=full_write_args_count,
mix_iter_arg_names=write_args_names,
create_op_func=create_while_op,
region_builders=[before_block_builder, after_block_builder],
block_term_op_builder={

View File

@@ -0,0 +1,763 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
#
# Use of this software is governed by the terms and conditions of the
# NVIDIA End User License Agreement (EULA), available at:
# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html
#
# Any use, reproduction, disclosure, or distribution of this software
# and related documentation outside the scope permitted by the EULA
# is strictly prohibited.
from typing import Callable, Any, Iterable, Iterator, NamedTuple, Union, get_origin
import dataclasses
import itertools as it
from types import SimpleNamespace
from ..base_dsl.typing import as_numeric, Numeric, Constexpr
from ..base_dsl._mlir_helpers.arith import ArithValue
from ..base_dsl.common import DSLBaseError
from .._mlir import ir
# =============================================================================
# Tree Utils
# =============================================================================
class DSLTreeFlattenError(DSLBaseError):
"""Exception raised when tree flattening fails due to unsupported types."""
def __init__(self, msg: str, type_str: str):
super().__init__(msg)
self.type_str = type_str
def unzip2(pairs: Iterable[tuple[Any, Any]]) -> tuple[list[Any], list[Any]]:
"""Unzip a sequence of pairs into two lists."""
lst1, lst2 = [], []
for x1, x2 in pairs:
lst1.append(x1)
lst2.append(x2)
return lst1, lst2
def get_fully_qualified_class_name(x: Any) -> str:
"""
Get the fully qualified class name of an object.
Args:
x: Any object
Returns:
str: Fully qualified class name in format 'module.class_name'
Example:
>>> get_fully_qualified_class_name([1, 2, 3])
'builtins.list'
"""
return f"{x.__class__.__module__}.{x.__class__.__qualname__}"
def is_frozen_dataclass(obj_or_cls: Any) -> bool:
"""
Check if an object or class is a frozen dataclass.
Args:
obj_or_cls: Either a dataclass instance or class
Returns:
bool: True if the object/class is a dataclass declared with frozen=True,
False otherwise
Example:
>>> from dataclasses import dataclass
>>> @dataclass(frozen=True)
... class Point:
... x: int
... y: int
>>> is_frozen_dataclass(Point)
True
>>> is_frozen_dataclass(Point(1, 2))
True
"""
cls = obj_or_cls if isinstance(obj_or_cls, type) else obj_or_cls.__class__
return (
dataclasses.is_dataclass(cls)
and getattr(cls, "__dataclass_params__", None) is not None
and cls.__dataclass_params__.frozen
)
def is_dynamic_expression(x: Any) -> bool:
"""
Check if an object implements the DynamicExpression protocol.
Objects implementing this protocol must have both `__extract_mlir_values__`
and `__new_from_mlir_values__` methods.
Args:
x: Any object to check
Returns:
bool: True if the object implements the DynamicExpression protocol,
False otherwise
"""
return all(
hasattr(x, attr)
for attr in ("__extract_mlir_values__", "__new_from_mlir_values__")
)
def is_constexpr_field(field: dataclasses.Field) -> bool:
"""
Check if a field is a constexpr field.
"""
if field.type is Constexpr:
return True
elif get_origin(field.type) is Constexpr:
return True
return False
# =============================================================================
# PyTreeDef
# =============================================================================
class NodeType(NamedTuple):
"""
Represents a node in a pytree structure.
Attributes:
name: String representation of the node type
to_iterable: Function to convert node to iterable form
from_iterable: Function to reconstruct node from iterable form
"""
name: str
to_iterable: Callable
from_iterable: Callable
class PyTreeDef(NamedTuple):
"""
Represents the structure definition of a pytree.
Attributes:
node_type: The type of this node
node_metadata: SimpleNamespace metadata associated with this node
child_treedefs: Tuple of child tree definitions
"""
node_type: NodeType
node_metadata: SimpleNamespace
child_treedefs: tuple["PyTreeDef", ...]
@dataclasses.dataclass(frozen=True)
class Leaf:
"""
Represents a leaf node in a pytree structure.
Attributes:
is_numeric: Whether this leaf contains a `Numeric` value
is_none: Whether this leaf represents None
node_metadata: SimpleNamespace metadata associated with this leaf
ir_type_str: String representation of the IR type
"""
is_numeric: bool = False
is_none: bool = False
node_metadata: SimpleNamespace = None
ir_type_str: str = None
# =============================================================================
# Default to_iterable and from_iterable
# =============================================================================
def extract_dataclass_members(x: Any) -> tuple[list[str], list[Any]]:
"""
Extract non-method, non-function attributes from a dataclass instance.
Args:
x: A dataclass instance
Returns:
tuple: (field_names, field_values) lists
"""
fields = [field.name for field in dataclasses.fields(x)]
# If the dataclass has extra fields, raise an error
for k in x.__dict__.keys():
if k not in fields:
raise DSLTreeFlattenError(
f"`{x}` has extra field `{k}`",
type_str=get_fully_qualified_class_name(x),
)
if not fields:
return [], []
# record constexpr fields
members = []
constexpr_fields = []
for field in dataclasses.fields(x):
if is_constexpr_field(field):
constexpr_fields.append(field.name)
fields.remove(field.name)
v = getattr(x, field.name)
if is_dynamic_expression(v):
raise DSLTreeFlattenError(
f"`{x}` has dynamic expression field `{field.name}` with a Constexpr type annotation `{field.type}`",
type_str=get_fully_qualified_class_name(x),
)
else:
members.append(getattr(x, field.name))
return fields, members, constexpr_fields
def default_dataclass_to_iterable(x: Any) -> tuple[SimpleNamespace, list[Any]]:
"""
Convert a dataclass instance to iterable form for tree flattening.
Extracts all non-method, non-function attributes that don't start with '__'
and returns them along with metadata about the dataclass.
Args:
x: A dataclass instance
Returns:
tuple: (metadata, members) where metadata contains type info and field names,
and members is the list of attribute values
"""
fields, members, constexpr_fields = extract_dataclass_members(x)
metadata = SimpleNamespace(
type_str=get_fully_qualified_class_name(x),
fields=fields,
constexpr_fields=constexpr_fields,
original_obj=x,
)
return metadata, members
def set_dataclass_attributes(
instance: Any,
fields: list[str],
values: Iterable[Any],
constexpr_fields: list[str],
) -> Any:
"""
Set attributes on a dataclass instance.
Args:
instance: The dataclass instance
fields: List of field names
values: Iterable of field values
is_frozen: Whether the dataclass is frozen
Returns:
The instance with attributes set
"""
if not fields:
return instance
kwargs = dict(zip(fields, values))
for field in constexpr_fields:
kwargs[field] = getattr(instance, field)
return dataclasses.replace(instance, **kwargs)
def default_dataclass_from_iterable(
metadata: SimpleNamespace, children: Iterable[Any]
) -> Any:
"""
Reconstruct a dataclass instance from iterable form.
Handles both regular and frozen dataclasses appropriately.
Args:
metadata: Metadata containing type information and field names
children: Iterable of attribute values to reconstruct the instance
Returns:
The reconstructed dataclass instance
"""
instance = metadata.original_obj
new_instance = set_dataclass_attributes(
instance, metadata.fields, children, metadata.constexpr_fields
)
metadata.original_obj = new_instance
return new_instance
def dynamic_expression_to_iterable(x: Any) -> tuple[SimpleNamespace, list[Any]]:
"""
Convert a dynamic expression to iterable form.
Uses the object's `__extract_mlir_values__` method to extract MLIR values.
Args:
x: A dynamic expression object
Returns:
tuple: (metadata, mlir_values) where metadata marks this as a dynamic expression
and mlir_values are the extracted MLIR values
"""
return (
SimpleNamespace(is_dynamic_expression=1, original_obj=x),
x.__extract_mlir_values__(),
)
def dynamic_expression_from_iterable(
metadata: SimpleNamespace, children: Iterable[Any]
) -> Any:
"""
Reconstruct a dynamic expression from iterable form.
Uses the object's `__new_from_mlir_values__` method to reconstruct from MLIR values.
Args:
metadata: Metadata containing the original object
children: Iterable of MLIR values to reconstruct from
Returns:
The reconstructed dynamic expression object
"""
return metadata.original_obj.__new_from_mlir_values__(list(children))
def default_dict_to_iterable(x: Any) -> tuple[SimpleNamespace, list[Any]]:
"""
Convert a dict to iterable form.
"""
if isinstance(x, SimpleNamespace):
keys = list(x.__dict__.keys())
values = list(x.__dict__.values())
else:
keys = list(x.keys())
values = list(x.values())
return (
SimpleNamespace(
type_str=get_fully_qualified_class_name(x), original_obj=x, fields=keys
),
values,
)
def default_dict_from_iterable(
metadata: SimpleNamespace, children: Iterable[Any]
) -> Any:
"""
Reconstruct a dict from iterable form.
"""
instance = metadata.original_obj
fields = metadata.fields
is_simple_namespace = isinstance(instance, SimpleNamespace)
for k, v in zip(fields, children):
if is_simple_namespace:
setattr(instance, k, v)
else:
instance[k] = v
return instance
# =============================================================================
# Register pytree nodes
# =============================================================================
_node_types: dict[type, NodeType] = {}
def register_pytree_node(ty: type, to_iter: Callable, from_iter: Callable) -> NodeType:
"""
Register a new node type for pytree operations.
Args:
ty: The type to register
to_iter: Function to convert instances of this type to iterable form
from_iter: Function to reconstruct instances of this type from iterable form
Returns:
NodeType: The created NodeType instance
"""
nt = NodeType(str(ty), to_iter, from_iter)
_node_types[ty] = nt
return nt
def register_default_node_types() -> None:
"""Register default node types for pytree operations."""
default_registrations = [
(
tuple,
lambda t: (SimpleNamespace(length=len(t)), list(t)),
lambda _, xs: tuple(xs),
),
(
list,
lambda l: (SimpleNamespace(length=len(l)), list(l)),
lambda _, xs: list(xs),
),
(
dict,
default_dict_to_iterable,
default_dict_from_iterable,
),
(
SimpleNamespace,
default_dict_to_iterable,
default_dict_from_iterable,
),
]
for ty, to_iter, from_iter in default_registrations:
register_pytree_node(ty, to_iter, from_iter)
# Initialize default registrations
register_default_node_types()
# =============================================================================
# tree_flatten and tree_unflatten
# =============================================================================
"""
Behavior of tree_flatten and tree_unflatten, for example:
```python
a = (1, 2, 3)
b = MyClass(a=1, b =[1,2,3])
```
yields the following tree:
```python
tree_a = PyTreeDef(type = 'tuple',
metadata = {length = 3},
children = [
Leaf(type = int),
Leaf(type = int),
Leaf(type = int),
],
)
flattened_a = [1, 2, 3]
tree_b = PyTreeDef(type = 'MyClass',
metadata = {fields = ['a','b']},
children = [
PyTreeDef(type = `list`,
metadata = {length = 3},
children = [
Leaf(type=`int`),
Leaf(type=`int`),
Leaf(type=`int`),
],
),
Leaf(type=int),
],
)
flattened_b = [1, 1, 2, 3]
```
Passing the flattened values and PyTreeDef to tree_unflatten to reconstruct the original structure.
``` python
unflattened_a = tree_unflatten(tree_a, flattened_a)
unflattened_b = tree_unflatten(tree_b, flattened_b)
```
yields the following structure:
``` python
unflattened_a = (1, 2, 3)
unflattened_b = MyClass(a=1, b =[1,2,3])
```
unflattened_a should be structurally identical to a, and unflattened_b should be structurally identical to b.
"""
def tree_flatten(x: Any) -> tuple[list[Any], PyTreeDef]:
"""
Flatten a nested structure into a flat list of values and a tree definition.
This function recursively traverses nested data structures (trees) and
flattens them into a linear list of leaf values, while preserving the
structure information in a PyTreeDef.
Args:
x: The nested structure to flatten
Returns:
tuple: (flat_values, treedef) where flat_values is a list of leaf values
and treedef is the tree structure definition
Raises:
DSLTreeFlattenError: If the structure contains unsupported types
Example:
>>> tree_flatten([1, [2, 3], 4])
([1, 2, 3, 4], PyTreeDef(...))
"""
children_iter, treedef = _tree_flatten(x)
return list(children_iter), treedef
def get_registered_node_types_or_insert(x: Any) -> NodeType | None:
"""
Get the registered node type for an object, registering it if necessary.
This function checks if a type is already registered for pytree operations.
If not, it automatically registers the type based on its characteristics:
- Dynamic expressions get registered with dynamic expression handlers
- Dataclasses get registered with default dataclass handlers
Args:
x: The object to get or register a node type for
Returns:
NodeType or None: The registered node type, or None if the type
cannot be registered
"""
node_type = _node_types.get(type(x))
if node_type:
return node_type
elif is_dynamic_expression(x):
# If a class implements DynamicExpression protocol, register it before default dataclass one
return register_pytree_node(
type(x), dynamic_expression_to_iterable, dynamic_expression_from_iterable
)
elif dataclasses.is_dataclass(x):
return register_pytree_node(
type(x), default_dataclass_to_iterable, default_dataclass_from_iterable
)
else:
return None
def create_leaf_for_value(
x: Any,
is_numeric: bool = False,
is_none: bool = False,
node_metadata: SimpleNamespace = None,
ir_type_str: str = None,
) -> Leaf:
"""
Create a Leaf node for a given value.
Args:
x: The value to create a leaf for
is_numeric: Whether this is a numeric value
is_none: Whether this represents None
node_metadata: Optional metadata
ir_type_str: Optional IR type string
Returns:
Leaf: The created leaf node
"""
return Leaf(
is_numeric=is_numeric,
is_none=is_none,
node_metadata=node_metadata,
ir_type_str=ir_type_str or (str(x.type) if hasattr(x, "type") else None),
)
def _tree_flatten(x: Any) -> tuple[Iterable[Any], PyTreeDef | Leaf]:
"""
Internal function to flatten a tree structure.
This is the core implementation of tree flattening that handles different
types of objects including None, ArithValue, ir.Value, Numeric types,
and registered pytree node types.
Args:
x: The object to flatten
Returns:
tuple: (flattened_values, treedef) where flattened_values is an iterable
of leaf values and treedef is the tree structure
Raises:
DSLTreeFlattenError: If the object type is not supported
"""
match x:
case None:
return [], create_leaf_for_value(x, is_none=True)
case ArithValue() if is_dynamic_expression(x):
v = x.__extract_mlir_values__()
return v, create_leaf_for_value(
x,
node_metadata=SimpleNamespace(is_dynamic_expression=1, original_obj=x),
ir_type_str=str(v[0].type),
)
case ArithValue():
return [x], create_leaf_for_value(x, is_numeric=True)
case ir.Value():
return [x], create_leaf_for_value(x)
case Numeric():
v = x.__extract_mlir_values__()
return v, create_leaf_for_value(
x,
node_metadata=SimpleNamespace(is_dynamic_expression=1, original_obj=x),
ir_type_str=str(v[0].type),
)
case _:
node_type = get_registered_node_types_or_insert(x)
if node_type:
node_metadata, children = node_type.to_iterable(x)
children_flat, child_trees = unzip2(map(_tree_flatten, children))
flattened = it.chain.from_iterable(children_flat)
return flattened, PyTreeDef(
node_type, node_metadata, tuple(child_trees)
)
# Try to convert to numeric
try:
nval = as_numeric(x).ir_value()
return [nval], create_leaf_for_value(nval, is_numeric=True)
except Exception:
raise DSLTreeFlattenError(
"Flatten Error", get_fully_qualified_class_name(x)
)
def tree_unflatten(treedef: PyTreeDef, xs: list[Any]) -> Any:
"""
Reconstruct a nested structure from a flat list of values and tree definition.
This is the inverse operation of tree_flatten. It takes the flattened
values and the tree structure definition to reconstruct the original
nested structure.
Args:
treedef: The tree structure definition from tree_flatten
xs: List of flat values to reconstruct from
Returns:
The reconstructed nested structure
Example:
>>> flat_values, treedef = tree_flatten([1, [2, 3], 4])
>>> tree_unflatten(treedef, flat_values)
[1, [2, 3], 4]
"""
return _tree_unflatten(treedef, iter(xs))
def _tree_unflatten(treedef: PyTreeDef | Leaf, xs: Iterator[Any]) -> Any:
"""
Internal function to reconstruct a tree structure.
This is the core implementation of tree unflattening that handles
different types of tree definitions including Leaf nodes and PyTreeDef nodes.
Args:
treedef: The tree structure definition
xs: Iterator of flat values to reconstruct from
Returns:
The reconstructed object
"""
match treedef:
case Leaf(is_none=True):
return None
case Leaf(
node_metadata=metadata
) if metadata and metadata.is_dynamic_expression:
return metadata.original_obj.__new_from_mlir_values__([next(xs)])
case Leaf(is_numeric=True):
return as_numeric(next(xs))
case Leaf():
return next(xs)
case PyTreeDef():
children = (_tree_unflatten(t, xs) for t in treedef.child_treedefs)
return treedef.node_type.from_iterable(treedef.node_metadata, children)
def _check_tree_equal(lhs: Union[PyTreeDef, Leaf], rhs: Union[PyTreeDef, Leaf]) -> bool:
"""
Check if two tree definitions are structurally equal.
This is a helper function for check_tree_equal that recursively compares
tree structures.
Args:
lhs: Left tree definition (PyTreeDef or Leaf)
rhs: Right tree definition (PyTreeDef or Leaf)
Returns:
bool: True if the trees are structurally equal, False otherwise
"""
match (lhs, rhs):
case (Leaf(), Leaf()):
return lhs.is_none == rhs.is_none and lhs.ir_type_str == rhs.ir_type_str
case (PyTreeDef(), PyTreeDef()):
lhs_metadata = lhs.node_metadata
rhs_metadata = rhs.node_metadata
lhs_fields = getattr(lhs_metadata, "fields", [])
rhs_fields = getattr(rhs_metadata, "fields", [])
lhs_constexpr_fields = getattr(lhs_metadata, "constexpr_fields", [])
rhs_constexpr_fields = getattr(rhs_metadata, "constexpr_fields", [])
return (
lhs.node_type == rhs.node_type
and lhs_fields == rhs_fields
and lhs_constexpr_fields == rhs_constexpr_fields
and len(lhs.child_treedefs) == len(rhs.child_treedefs)
and all(map(_check_tree_equal, lhs.child_treedefs, rhs.child_treedefs))
)
case _:
return False
def check_tree_equal(lhs: PyTreeDef, rhs: PyTreeDef) -> int:
"""
Check if two tree definitions are equal and return the index of first difference.
This function compares two tree definitions and returns the index of the
first child that differs, or -1 if they are completely equal.
Args:
lhs: Left tree definition
rhs: Right tree definition
Returns:
int: Index of the first differing child, or -1 if trees are equal
Example:
>>> treedef1 = tree_flatten([1, [2, 3]])[1]
>>> treedef2 = tree_flatten([1, [2, 4]])[1]
>>> check_tree_equal(treedef1, treedef2)
1 # The second child differs
"""
assert len(lhs.child_treedefs) == len(rhs.child_treedefs)
def find_first_difference(
index_and_pair: tuple[int, tuple[PyTreeDef, PyTreeDef]]
) -> int:
index, (l, r) = index_and_pair
return index if not _check_tree_equal(l, r) else -1
differences = map(
find_first_difference, enumerate(zip(lhs.child_treedefs, rhs.child_treedefs))
)
return next((diff for diff in differences if diff != -1), -1)

View File

@@ -1,3 +1,3 @@
# Use `pip install -r requirements.txt` with the present file to install a
# wheel consistent with the present state of the github repository
nvidia-cutlass-dsl==4.1.0
nvidia-cutlass-dsl==4.2.0