mirror of
https://github.com/NVIDIA/cutlass.git
synced 2026-04-20 06:48:59 +00:00
v4.2 tag release. (#2638)
This commit is contained in:
@@ -207,7 +207,7 @@ def int_to_int(a, dst_elem_type, *, loc=None, ip=None):
|
||||
|
||||
if dst_width == src_width:
|
||||
return a
|
||||
elif src_signed and not dst_signed:
|
||||
elif src_signed != False and not dst_signed:
|
||||
# Signed -> Unsigned
|
||||
if dst_width > src_width:
|
||||
return arith.extui(dst_mlir_type, a, loc=loc, ip=ip)
|
||||
@@ -216,7 +216,7 @@ def int_to_int(a, dst_elem_type, *, loc=None, ip=None):
|
||||
elif src_signed == dst_signed:
|
||||
# Same signedness
|
||||
if dst_width > src_width:
|
||||
if src_signed and src_width > 1:
|
||||
if src_signed != False and src_width > 1:
|
||||
return arith.extsi(dst_mlir_type, a, loc=loc, ip=ip)
|
||||
else:
|
||||
return arith.extui(dst_mlir_type, a, loc=loc, ip=ip)
|
||||
@@ -479,7 +479,7 @@ class ArithValue(ir.Value):
|
||||
if self.is_float:
|
||||
q = arith.divf(self, other, loc=loc, ip=ip)
|
||||
return math.floor(q, loc=loc, ip=ip)
|
||||
elif self.signed:
|
||||
elif self.signed != False:
|
||||
return arith.floordivsi(self, other, loc=loc, ip=ip)
|
||||
else:
|
||||
return arith.divui(self, other, loc=loc, ip=ip)
|
||||
@@ -489,7 +489,7 @@ class ArithValue(ir.Value):
|
||||
def __mod__(self, other, *, loc=None, ip=None) -> "ArithValue":
|
||||
if self.is_float:
|
||||
return arith.remf(self, other, loc=loc, ip=ip)
|
||||
elif self.signed:
|
||||
elif self.signed != False:
|
||||
return arith.remsi(self, other, loc=loc, ip=ip)
|
||||
else:
|
||||
return arith.remui(self, other, loc=loc, ip=ip)
|
||||
@@ -524,7 +524,7 @@ class ArithValue(ir.Value):
|
||||
def __lt__(self, other, *, loc=None, ip=None) -> "ArithValue":
|
||||
if self.is_float:
|
||||
return arith.cmpf(arith.CmpFPredicate.OLT, self, other, loc=loc, ip=ip)
|
||||
elif self.signed:
|
||||
elif self.signed != False:
|
||||
return arith.cmpi(arith.CmpIPredicate.slt, self, other, loc=loc, ip=ip)
|
||||
else:
|
||||
return arith.cmpi(arith.CmpIPredicate.ult, self, other, loc=loc, ip=ip)
|
||||
@@ -534,7 +534,7 @@ class ArithValue(ir.Value):
|
||||
def __le__(self, other, *, loc=None, ip=None) -> "ArithValue":
|
||||
if self.is_float:
|
||||
return arith.cmpf(arith.CmpFPredicate.OLE, self, other, loc=loc, ip=ip)
|
||||
elif self.signed:
|
||||
elif self.signed != False:
|
||||
return arith.cmpi(arith.CmpIPredicate.sle, self, other, loc=loc, ip=ip)
|
||||
else:
|
||||
return arith.cmpi(arith.CmpIPredicate.ule, self, other, loc=loc, ip=ip)
|
||||
@@ -561,7 +561,7 @@ class ArithValue(ir.Value):
|
||||
def __gt__(self, other, *, loc=None, ip=None) -> "ArithValue":
|
||||
if self.is_float:
|
||||
return arith.cmpf(arith.CmpFPredicate.OGT, self, other, loc=loc, ip=ip)
|
||||
elif self.signed:
|
||||
elif self.signed != False:
|
||||
return arith.cmpi(arith.CmpIPredicate.sgt, self, other, loc=loc, ip=ip)
|
||||
else:
|
||||
return arith.cmpi(arith.CmpIPredicate.ugt, self, other, loc=loc, ip=ip)
|
||||
@@ -571,7 +571,7 @@ class ArithValue(ir.Value):
|
||||
def __ge__(self, other, *, loc=None, ip=None) -> "ArithValue":
|
||||
if self.is_float:
|
||||
return arith.cmpf(arith.CmpFPredicate.OGE, self, other, loc=loc, ip=ip)
|
||||
elif self.signed:
|
||||
elif self.signed != False:
|
||||
return arith.cmpi(arith.CmpIPredicate.sge, self, other, loc=loc, ip=ip)
|
||||
else:
|
||||
return arith.cmpi(arith.CmpIPredicate.uge, self, other, loc=loc, ip=ip)
|
||||
@@ -599,7 +599,7 @@ class ArithValue(ir.Value):
|
||||
@_dispatch_to_rhs_r_op
|
||||
@_binary_op
|
||||
def __rshift__(self, other, *, loc=None, ip=None) -> "ArithValue":
|
||||
if self.signed:
|
||||
if self.signed != False:
|
||||
return arith.shrsi(self, other, loc=loc, ip=ip)
|
||||
else:
|
||||
return arith.shrui(self, other, loc=loc, ip=ip)
|
||||
@@ -633,7 +633,7 @@ class ArithValue(ir.Value):
|
||||
return super().__hash__()
|
||||
|
||||
def __str__(self):
|
||||
return super().__str__().replace(ir.Value.__name__, ArithValue.__name__)
|
||||
return "?"
|
||||
|
||||
def __repr__(self):
|
||||
return self.__str__()
|
||||
@@ -657,7 +657,7 @@ def _min(lhs, rhs, *, loc=None, ip=None):
|
||||
rhs = arith.constant(lhs.type, rhs, loc=loc, ip=ip)
|
||||
|
||||
if arith._is_integer_like_type(lhs.type):
|
||||
if lhs.signed:
|
||||
if lhs.signed != False:
|
||||
return arith.minsi(lhs, rhs, loc=loc, ip=ip)
|
||||
else:
|
||||
return arith.minui(lhs, rhs, loc=loc, ip=ip)
|
||||
@@ -683,7 +683,7 @@ def _max(lhs, rhs, *, loc=None, ip=None):
|
||||
rhs = arith.constant(lhs.type, rhs, loc=loc, ip=ip)
|
||||
|
||||
if arith._is_integer_like_type(lhs.type):
|
||||
if lhs.signed:
|
||||
if lhs.signed != False:
|
||||
return arith.maxsi(lhs, rhs, loc=loc, ip=ip)
|
||||
else:
|
||||
return arith.maxui(lhs, rhs, loc=loc, ip=ip)
|
||||
|
||||
@@ -17,12 +17,16 @@ The preprocessor read through python's ast and changes the input code.
|
||||
from typing import Callable, Iterator, Optional, overload
|
||||
from typing_extensions import deprecated
|
||||
import warnings
|
||||
import inspect
|
||||
from types import BuiltinFunctionType
|
||||
from functools import lru_cache
|
||||
|
||||
from .utils.logger import log
|
||||
from .common import *
|
||||
|
||||
from ._mlir_helpers.arith import ArithValue
|
||||
|
||||
|
||||
class Executor:
|
||||
"""
|
||||
The Executor class handles dynamic and compile-time (constexpr) execution
|
||||
@@ -45,9 +49,11 @@ class Executor:
|
||||
self._compare_executor = None
|
||||
self._any_executor = None
|
||||
self._all_executor = None
|
||||
self._builtin_redirector = None
|
||||
|
||||
def set_functions(
|
||||
self,
|
||||
*,
|
||||
is_dynamic_expression: Callable,
|
||||
loop_execute_range_dynamic: Callable,
|
||||
if_dynamic: Callable,
|
||||
@@ -55,6 +61,7 @@ class Executor:
|
||||
compare_executor: Callable,
|
||||
any_executor: Callable = None,
|
||||
all_executor: Callable = None,
|
||||
builtin_redirector: Callable = None,
|
||||
):
|
||||
self._is_dynamic_expression = is_dynamic_expression
|
||||
self._loop_execute_range_dynamic = loop_execute_range_dynamic
|
||||
@@ -63,6 +70,7 @@ class Executor:
|
||||
self._compare_executor = compare_executor
|
||||
self._any_executor = any_executor
|
||||
self._all_executor = all_executor
|
||||
self._builtin_redirector = builtin_redirector
|
||||
|
||||
@staticmethod
|
||||
def convert_to_list(x):
|
||||
@@ -90,42 +98,18 @@ class Executor:
|
||||
return res[0]
|
||||
return res
|
||||
|
||||
@staticmethod
|
||||
def for_constexpr(
|
||||
func: Callable,
|
||||
start: int,
|
||||
stop: int,
|
||||
step: int,
|
||||
used_args: list,
|
||||
iter_args: list,
|
||||
):
|
||||
log().debug("start [%s] stop [%s] step [%s]", start, stop, step)
|
||||
loop_results = iter_args
|
||||
log().debug("iter_args [%s]", iter_args)
|
||||
for i in range(start, stop, step):
|
||||
log().debug("i [%s] iter_args [%s]", i, iter_args)
|
||||
loop_results = func(i, *used_args, *loop_results)
|
||||
log().debug("loop_results [%s]", loop_results)
|
||||
if loop_results is None:
|
||||
loop_results = []
|
||||
if not isinstance(loop_results, list):
|
||||
loop_results = [loop_results]
|
||||
|
||||
log().debug("done loop_results [%s]", loop_results)
|
||||
return Executor.converge_ret_val(loop_results)
|
||||
|
||||
def for_execute(
|
||||
self,
|
||||
func,
|
||||
start,
|
||||
stop,
|
||||
step,
|
||||
used_args=[],
|
||||
iter_args=[],
|
||||
iter_arg_names=[],
|
||||
write_args=[],
|
||||
full_write_args_count=0,
|
||||
write_args_names=[],
|
||||
unroll=-1,
|
||||
unroll_full=False,
|
||||
pipelining=None,
|
||||
prefetch_stages=None,
|
||||
):
|
||||
assert (
|
||||
self._loop_execute_range_dynamic
|
||||
@@ -137,12 +121,12 @@ class Executor:
|
||||
start,
|
||||
stop,
|
||||
step,
|
||||
used_args,
|
||||
iter_args,
|
||||
iter_arg_names,
|
||||
write_args,
|
||||
full_write_args_count,
|
||||
write_args_names,
|
||||
unroll,
|
||||
unroll_full,
|
||||
pipelining,
|
||||
prefetch_stages,
|
||||
)
|
||||
|
||||
def if_execute(
|
||||
@@ -150,15 +134,20 @@ class Executor:
|
||||
pred,
|
||||
then_block: Callable,
|
||||
else_block: Optional[Callable] = None,
|
||||
used_args=[],
|
||||
yield_args=[],
|
||||
yield_arg_names=[],
|
||||
write_args=[],
|
||||
full_write_args_count=0,
|
||||
write_args_names=[],
|
||||
):
|
||||
assert self._if_dynamic, "Functions must be set before execution."
|
||||
|
||||
# MLIR generation
|
||||
return self._if_dynamic(
|
||||
pred, then_block, else_block, used_args, yield_args, yield_arg_names
|
||||
pred,
|
||||
then_block,
|
||||
else_block,
|
||||
write_args,
|
||||
full_write_args_count,
|
||||
write_args_names,
|
||||
)
|
||||
|
||||
def while_execute(
|
||||
@@ -166,9 +155,9 @@ class Executor:
|
||||
pred,
|
||||
while_before_block: Callable,
|
||||
while_after_block: Callable,
|
||||
used_args=[],
|
||||
yield_args=[],
|
||||
yield_arg_names=[],
|
||||
write_args=[],
|
||||
full_write_args_count=0,
|
||||
write_args_names=[],
|
||||
):
|
||||
assert self._while_dynamic, "Functions must be set before execution."
|
||||
|
||||
@@ -176,9 +165,9 @@ class Executor:
|
||||
return self._while_dynamic(
|
||||
while_before_block,
|
||||
while_after_block,
|
||||
used_args,
|
||||
yield_args,
|
||||
yield_arg_names,
|
||||
write_args,
|
||||
full_write_args_count,
|
||||
write_args_names,
|
||||
)
|
||||
|
||||
|
||||
@@ -194,23 +183,24 @@ def loop_selector(
|
||||
stop,
|
||||
step,
|
||||
*,
|
||||
used_args=[],
|
||||
iter_args=[],
|
||||
iter_arg_names=[],
|
||||
write_args=[],
|
||||
full_write_args_count=0,
|
||||
write_args_names=[],
|
||||
unroll=-1,
|
||||
unroll_full=False,
|
||||
pipelining=None,
|
||||
prefetch_stages=None,
|
||||
):
|
||||
log().debug(
|
||||
"start [%s] stop [%s] step [%s] used_args [%s] iter_args [%s] unroll [%s] unroll_full [%s] pipelining [%s]",
|
||||
"start [%s] stop [%s] step [%s] write_args [%s] full_write_args_count [%s] write_args_names [%s] unroll [%s] unroll_full [%s] prefetch_stages [%s]",
|
||||
start,
|
||||
stop,
|
||||
step,
|
||||
used_args,
|
||||
iter_args,
|
||||
write_args,
|
||||
full_write_args_count,
|
||||
write_args_names,
|
||||
unroll,
|
||||
unroll_full,
|
||||
pipelining,
|
||||
prefetch_stages,
|
||||
)
|
||||
from .typing import Integer, Numeric
|
||||
|
||||
@@ -230,19 +220,19 @@ def loop_selector(
|
||||
start,
|
||||
stop,
|
||||
step,
|
||||
used_args,
|
||||
iter_args,
|
||||
iter_arg_names,
|
||||
write_args,
|
||||
full_write_args_count,
|
||||
write_args_names,
|
||||
unroll,
|
||||
unroll_full,
|
||||
pipelining,
|
||||
prefetch_stages,
|
||||
)
|
||||
|
||||
return ir_loop
|
||||
|
||||
|
||||
def if_selector(pred, used_args=[], yield_args=[]):
|
||||
log().debug("pred [%s] used_args [%s] yield_args [%s]", pred, used_args, yield_args)
|
||||
def if_selector(pred, write_args=[]):
|
||||
log().debug("pred [%s] write_args [%s]", pred, write_args)
|
||||
# Handle Numeric types here?
|
||||
|
||||
from .typing import Numeric
|
||||
@@ -251,14 +241,14 @@ def if_selector(pred, used_args=[], yield_args=[]):
|
||||
pred = pred.value
|
||||
|
||||
def ir_loop(func):
|
||||
return func(pred, *used_args, *yield_args)
|
||||
return func(pred, *write_args)
|
||||
|
||||
return ir_loop
|
||||
|
||||
|
||||
def while_selector(pred, used_args=[], yield_args=[]):
|
||||
def while_selector(pred, write_args=[]):
|
||||
def ir_while_loop(func):
|
||||
return func(pred, *used_args, *yield_args)
|
||||
return func(pred, *write_args)
|
||||
|
||||
return ir_while_loop
|
||||
|
||||
@@ -267,17 +257,17 @@ def while_executor(
|
||||
pred,
|
||||
while_before_block: Callable,
|
||||
while_after_block: Callable,
|
||||
used_args=[],
|
||||
yield_args=[],
|
||||
yield_arg_names=[],
|
||||
write_args=[],
|
||||
full_write_args_count=0,
|
||||
write_args_names=[],
|
||||
):
|
||||
return executor.while_execute(
|
||||
pred,
|
||||
while_before_block,
|
||||
while_after_block,
|
||||
used_args,
|
||||
yield_args,
|
||||
yield_arg_names,
|
||||
write_args,
|
||||
full_write_args_count,
|
||||
write_args_names,
|
||||
)
|
||||
|
||||
|
||||
@@ -285,12 +275,17 @@ def if_executor(
|
||||
pred,
|
||||
then_block: Callable,
|
||||
else_block: Optional[Callable] = None,
|
||||
used_args=[],
|
||||
yield_args=[],
|
||||
yield_arg_names=[],
|
||||
write_args=[],
|
||||
full_write_args_count=0,
|
||||
write_args_names=[],
|
||||
):
|
||||
return executor.if_execute(
|
||||
pred, then_block, else_block, used_args, yield_args, yield_arg_names
|
||||
pred,
|
||||
then_block,
|
||||
else_block,
|
||||
write_args,
|
||||
full_write_args_count,
|
||||
write_args_names,
|
||||
)
|
||||
|
||||
|
||||
@@ -313,14 +308,17 @@ class range:
|
||||
|
||||
- unroll: Number of iterations to unroll (0 or 1 = no unrolling)
|
||||
- unroll_full: Whether to fully unroll the loop
|
||||
- pipelining: Compiler generated pipeline configuration
|
||||
- prefetch_stages: Number of prefetch stages to generate
|
||||
"""
|
||||
|
||||
@overload
|
||||
def __new__(cls, stop, unroll=0, unroll_full=False, pipelining=None):
|
||||
def __new__(cls, stop, unroll=0, unroll_full=False, prefetch_stages=None):
|
||||
pass
|
||||
|
||||
@overload
|
||||
def __new__(cls, start, stop, step, unroll=0, unroll_full=False, pipelining=None):
|
||||
def __new__(
|
||||
cls, start, stop, step, unroll=0, unroll_full=False, prefetch_stages=None
|
||||
):
|
||||
pass
|
||||
|
||||
def __new__(cls, *args, **kwargs):
|
||||
@@ -340,6 +338,7 @@ def range_dynamic(*args, **kwargs):
|
||||
def range_constexpr(*args):
|
||||
raise DSLRuntimeError("range_constexpr should be preprocessed by preprocessor.")
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# If expressions
|
||||
# =============================================================================
|
||||
@@ -405,7 +404,7 @@ def assert_executor(test, msg=None):
|
||||
else:
|
||||
raise DSLRuntimeError(
|
||||
"Only constexpr (Python Value) is allowed here, but got non-constexpr (IR Values) expression.",
|
||||
suggestion = "Please replace with runtime assert."
|
||||
suggestion="Please replace with runtime assert.",
|
||||
)
|
||||
|
||||
|
||||
@@ -413,10 +412,11 @@ def bool_cast(value):
|
||||
if executor._is_dynamic_expression(value):
|
||||
raise DSLRuntimeError(
|
||||
"Only constexpr (Python Value) is allowed here, but got non-constexpr (IR Values) expression.",
|
||||
suggestion = "Please explicitly convert to boolean with expressions like comparision."
|
||||
suggestion="Please explicitly convert to boolean with expressions like comparision.",
|
||||
)
|
||||
return bool(value)
|
||||
|
||||
|
||||
def compare_executor(left, comparators, ops):
|
||||
"""
|
||||
Executes comparison operations with a left operand and a list of comparators.
|
||||
@@ -470,6 +470,19 @@ def all_executor(iterable):
|
||||
# =============================================================================
|
||||
# Control flow checks
|
||||
# =============================================================================
|
||||
class DSLOptimizationWarning(Warning):
|
||||
"""
|
||||
This warning is used to warn the user about the optimization related issues in DSL.
|
||||
"""
|
||||
|
||||
def __init__(self, message):
|
||||
self.message = message
|
||||
super().__init__()
|
||||
|
||||
def __str__(self):
|
||||
return self.message
|
||||
|
||||
|
||||
def range_value_check(*args):
|
||||
"""
|
||||
Ensure all `range_constexpr` bounds are compile-time constants (Python ints).
|
||||
@@ -495,7 +508,7 @@ def range_value_check(*args):
|
||||
if range_length >= 64:
|
||||
warnings.warn(
|
||||
f"This static loop has {range_length} iterations, which may be very slow to compile, consider using `cutlass.range(..., unroll_full=True)` instead.",
|
||||
category=UserWarning,
|
||||
category=DSLOptimizationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
@@ -519,7 +532,50 @@ def range_perf_warning(filename, lineno, *args):
|
||||
"This loop is no longer unrolled and may cause performance regression. "
|
||||
"Use `range(..., unroll_full=True)` for full unrolling, or switch to `range_constexpr` when bounds are compile-time constants."
|
||||
),
|
||||
category=UserWarning,
|
||||
category=DSLOptimizationWarning,
|
||||
filename=filename,
|
||||
lineno=lineno,
|
||||
)
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def _get_self_module():
|
||||
"""
|
||||
This function is used to get the owning module of this function.
|
||||
"""
|
||||
return inspect.getmodule(_get_self_module)
|
||||
|
||||
|
||||
def cf_symbol_check(symbol):
|
||||
"""
|
||||
Check if the symbol is control flow symbol from current module.
|
||||
"""
|
||||
|
||||
failed = False
|
||||
name = symbol.__name__
|
||||
self_module = _get_self_module()
|
||||
if inspect.ismodule(symbol):
|
||||
name = "range"
|
||||
if not self_module.__name__.startswith(symbol.__name__):
|
||||
failed = True
|
||||
else:
|
||||
owning_module = inspect.getmodule(symbol)
|
||||
if owning_module != self_module:
|
||||
failed = True
|
||||
|
||||
if failed:
|
||||
raise DSLRuntimeError(
|
||||
f"Incorrect {symbol.__name__} is used.",
|
||||
suggestion=f"Please avoid overriding `{symbol.__name__}` from DSL package.",
|
||||
)
|
||||
|
||||
|
||||
def redirect_builtin_function(fcn):
|
||||
"""
|
||||
This function is used to redirect built-in function call
|
||||
to the function defined in DSL package.
|
||||
"""
|
||||
# Only redirect if it's a built-in
|
||||
if isinstance(fcn, BuiltinFunctionType) and executor._builtin_redirector:
|
||||
return executor._builtin_redirector(fcn)
|
||||
return fcn
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -139,8 +139,7 @@ def dump_cache_to_path(
|
||||
dsl_name, jit_cache, cache_limit, path=default_generated_ir_path
|
||||
):
|
||||
log().info("JIT cache : dumping [%s] items=[%s]", dsl_name, len(jit_cache))
|
||||
if not os.path.exists(path):
|
||||
os.makedirs(path)
|
||||
os.makedirs(path, exist_ok=True)
|
||||
original_path = os.getcwd()
|
||||
try:
|
||||
os.chdir(path)
|
||||
|
||||
@@ -205,6 +205,8 @@ class CompileOptions:
|
||||
self._parser.add_argument(
|
||||
"--enable-device-assertions", action="store_true", default=False
|
||||
)
|
||||
self._parser.add_argument("--link-libraries", type=str, default="")
|
||||
|
||||
try:
|
||||
self._options = self._parser.parse_args(options.split())
|
||||
except SystemExit as e:
|
||||
|
||||
@@ -32,13 +32,14 @@ import hashlib
|
||||
from functools import lru_cache, wraps
|
||||
from collections import namedtuple
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Union, Tuple, get_origin, get_args
|
||||
from types import FunctionType
|
||||
from typing import Any, Union, Tuple, get_origin, get_args, List
|
||||
from types import FunctionType, SimpleNamespace
|
||||
import warnings
|
||||
|
||||
from . import typing as t
|
||||
from .env_manager import EnvironmentVarManager
|
||||
from .compiler import CompileOptions
|
||||
from .ast_helpers import DSLOptimizationWarning
|
||||
|
||||
# =============================================================================
|
||||
# CUDA Python
|
||||
@@ -56,7 +57,7 @@ from .utils.timer import timer
|
||||
from .utils.logger import setup_log, log
|
||||
from .utils.stacktrace import filter_exception, walk_to_top_module, filter_stackframe
|
||||
from .runtime.jit_arg_adapters import is_argument_constexpr, JitArgAdapterRegistry
|
||||
from .runtime.tensor_descriptor import TensorDescriptor
|
||||
|
||||
from .ast_preprocessor import DSLPreprocessor
|
||||
from .common import *
|
||||
from .typing import (
|
||||
@@ -73,12 +74,6 @@ from .._mlir import runtime as rt
|
||||
from .._mlir.extras import types as T
|
||||
from .._mlir.dialects import arith, math, func
|
||||
|
||||
# =============================================================================
|
||||
# cutlass.dlpack_runtime
|
||||
# =============================================================================
|
||||
|
||||
from .runtime.dlpack_runtime import dlpack_to_tensor_desc, mark_layout_dynamic
|
||||
|
||||
# =============================================================================
|
||||
# Global Variables
|
||||
# =============================================================================
|
||||
@@ -177,6 +172,7 @@ def is_dynamic_expression(value):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def extract_mlir_values(obj):
|
||||
"""
|
||||
Given the `obj`, recursively go through it to extract all contained IR values as list of MLIR values
|
||||
@@ -186,6 +182,10 @@ def extract_mlir_values(obj):
|
||||
res = obj.__extract_mlir_values__()
|
||||
elif isinstance(obj, (tuple, list)):
|
||||
res = sum((extract_mlir_values(x) for x in obj), [])
|
||||
elif isinstance(obj, SimpleNamespace):
|
||||
res = []
|
||||
for k, v in obj.__dict__.items():
|
||||
res.extend(extract_mlir_values(v))
|
||||
# Can't call is_dynamic_expression as _is_dynamic_expression depends on extract_mlir_values
|
||||
elif isinstance(obj, set):
|
||||
raise DSLRuntimeError(
|
||||
@@ -215,6 +215,13 @@ def new_from_mlir_values(obj, values):
|
||||
values = values[n_items:]
|
||||
obj_ty = type(obj)
|
||||
return obj_ty(res)
|
||||
elif isinstance(obj, SimpleNamespace):
|
||||
res = SimpleNamespace()
|
||||
for k, v in obj.__dict__.items():
|
||||
n_items = len(get_mlir_types(v))
|
||||
res.__dict__[k] = new_from_mlir_values(v, values[:n_items])
|
||||
values = values[n_items:]
|
||||
return res
|
||||
elif isinstance(obj, set):
|
||||
raise DSLRuntimeError(
|
||||
"Sets are not supported in new_from_mlir_values to ensure order preservation",
|
||||
@@ -249,8 +256,6 @@ class DSLCallable:
|
||||
|
||||
Methods:
|
||||
__call__(*args, **kwargs): Calls the wrapped function and clears it.
|
||||
get_arg_spec(): Returns the argument specification of the function.
|
||||
get_signature(): Returns the signature of the function.
|
||||
"""
|
||||
|
||||
def __init__(self, func):
|
||||
@@ -266,23 +271,23 @@ class DSLCallable:
|
||||
assert self.func is not None, "DSLCallable is already called"
|
||||
return self.func
|
||||
|
||||
@property
|
||||
def __signature__(self):
|
||||
return inspect.signature(self.__func__)
|
||||
|
||||
@property
|
||||
def __name__(self):
|
||||
return self.__func__.__name__
|
||||
|
||||
def get_arg_spec(self):
|
||||
return inspect.getfullargspec(self.__func__)
|
||||
|
||||
def get_signature(self):
|
||||
return inspect.signature(self.__func__)
|
||||
|
||||
|
||||
class BaseDSL:
|
||||
gpu_module = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
name: str,
|
||||
dsl_package_name: List[str],
|
||||
compiler_provider: Any,
|
||||
pass_sm_arch_name: str,
|
||||
device_compilation_only=False,
|
||||
@@ -293,6 +298,7 @@ class BaseDSL:
|
||||
|
||||
Parameters:
|
||||
- name (str): Name of DSL, used for environment variables and logging.
|
||||
- package_name (str): Name of the package, used for the preprocessor.
|
||||
- compiler_provider (MLIR dialect): Provider for compiler.
|
||||
- pass_sm_arch_name (str): The keyword name of the SM.
|
||||
- device_compilation_only (bool) : Only device code, and call it via cuda driver
|
||||
@@ -330,6 +336,9 @@ class BaseDSL:
|
||||
self.device_jit_decorator_name = f"@{BaseDSL.kernel.__name__}"
|
||||
|
||||
# set warning
|
||||
if not self.envar.enable_optimization_warnings:
|
||||
# By default, optimization warnings are disabled
|
||||
warnings.filterwarnings("ignore", category=DSLOptimizationWarning)
|
||||
if self.envar.warnings_as_errors:
|
||||
warnings.filterwarnings("error")
|
||||
if self.envar.warnings_ignore:
|
||||
@@ -355,7 +364,7 @@ class BaseDSL:
|
||||
self.compile_options = CompileOptions()
|
||||
|
||||
if preprocess:
|
||||
self.preprocessor = DSLPreprocessor()
|
||||
self.preprocessor = DSLPreprocessor(dsl_package_name)
|
||||
log().info(f"Initializing {name} DSL")
|
||||
log().debug(f"Logger initialized for {self.name}")
|
||||
|
||||
@@ -656,7 +665,7 @@ class BaseDSL:
|
||||
return ir_args, ir_kwargs
|
||||
|
||||
@abstractmethod
|
||||
def _generate_mlir_type_for_tensor_descriptor(self, tensor: TensorDescriptor):
|
||||
def _generate_mlir_type_for_tensor_descriptor(self, tensor):
|
||||
"""
|
||||
Generate MLIR type for the tensor descriptor.
|
||||
"""
|
||||
@@ -671,13 +680,6 @@ class BaseDSL:
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def _get_module_globals(self):
|
||||
"""
|
||||
Get the module's globals.
|
||||
"""
|
||||
pass
|
||||
|
||||
def _get_globals(self):
|
||||
"""
|
||||
Combines global and local variables from the current context and the
|
||||
@@ -690,43 +692,21 @@ class BaseDSL:
|
||||
AST preprocessor generates a new python code, so the resulting globals
|
||||
dictionary is used to execute the python code.
|
||||
"""
|
||||
all_globals = self._get_module_globals().copy()
|
||||
all_globals = {}
|
||||
if self.frame:
|
||||
all_globals.update(self.frame.f_globals)
|
||||
all_globals.update(self.frame.f_locals)
|
||||
return all_globals
|
||||
|
||||
@abstractmethod
|
||||
def _is_tensor_descriptor(self, maybe_tensor_descriptor) -> bool:
|
||||
return isinstance(
|
||||
maybe_tensor_descriptor, TensorDescriptor
|
||||
) or TensorDescriptor.can_transformed_to_dlpack(maybe_tensor_descriptor)
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def _handle_tensor_descriptor(
|
||||
self, maybe_tensor, arg_name: str, need_gpu_memory: bool
|
||||
) -> TensorDescriptor:
|
||||
if self._is_tensor_descriptor(maybe_tensor):
|
||||
tensor = (
|
||||
maybe_tensor
|
||||
if isinstance(maybe_tensor, TensorDescriptor)
|
||||
else TensorDescriptor(maybe_tensor)
|
||||
)
|
||||
if need_gpu_memory and not tensor.is_in_device:
|
||||
log().info(
|
||||
"FAIL name=[%s] tensor=[%s] in_gpu=[%s]",
|
||||
arg_name,
|
||||
tensor,
|
||||
tensor.is_in_device,
|
||||
)
|
||||
raise DSLRuntimeError(
|
||||
f'Tensor "{arg_name}" is tensor "{tensor}" '
|
||||
"is not in the GPU memory. "
|
||||
)
|
||||
|
||||
return tensor
|
||||
|
||||
raise DSLRuntimeError(
|
||||
f"Argument {arg_name} could not be transformed into a TensorDescriptor."
|
||||
)
|
||||
) -> Any:
|
||||
pass
|
||||
|
||||
def _validate_arg(self, arg, arg_index, arg_name, arg_spec):
|
||||
"""
|
||||
@@ -882,10 +862,11 @@ class BaseDSL:
|
||||
cluster: list = None
|
||||
grid: list = field(default_factory=lambda: [1, 1, 1])
|
||||
block: list = field(default_factory=lambda: [1, 1, 1])
|
||||
smem: int = 0
|
||||
smem: int = None
|
||||
async_deps: list = field(default_factory=list)
|
||||
has_cluster: bool = False
|
||||
min_blocks_per_mp: int = 0
|
||||
auto_smem: bool = False
|
||||
|
||||
def __post_init__(self):
|
||||
if len(self.grid) != 3:
|
||||
@@ -893,6 +874,10 @@ class BaseDSL:
|
||||
if len(self.block) != 3:
|
||||
raise DSLRuntimeError(f"Expect 3d block!")
|
||||
|
||||
if self.smem is None:
|
||||
self.smem = 0
|
||||
self.auto_smem = True
|
||||
|
||||
self.has_cluster = self.cluster is not None
|
||||
if self.cluster is None:
|
||||
self.cluster = [None, None, None]
|
||||
@@ -1116,8 +1101,6 @@ class BaseDSL:
|
||||
try:
|
||||
result = funcBody(*ir_args, **ir_kwargs)
|
||||
func.ReturnOp([])
|
||||
except DSLAstPreprocessorError as pp_error:
|
||||
raise pp_error
|
||||
except NameError as name_error:
|
||||
raise DSLRuntimeError(
|
||||
f"💥💥💥 Error during runtime code generation for function `{funcBody.__name__}` 💥💥💥",
|
||||
@@ -1127,11 +1110,6 @@ class BaseDSL:
|
||||
except DSLRuntimeError as dsl_error:
|
||||
# Throw it's already a DSL error
|
||||
raise dsl_error
|
||||
except Exception as general_e:
|
||||
# Transform internal error to a DSL error
|
||||
raise DSLRuntimeError(
|
||||
f"💥💥💥 Error during runtime code generation for function `{funcBody.__name__}` 💥💥💥"
|
||||
) from general_e
|
||||
return module, result
|
||||
|
||||
# Build IR module
|
||||
@@ -1328,10 +1306,8 @@ class BaseDSL:
|
||||
raise DSLRuntimeError("Function body is not set.")
|
||||
|
||||
# Pass the actual function object to inspect.signature to get the signature.
|
||||
if isinstance(self.funcBody, DSLCallable):
|
||||
sig = self.funcBody.get_signature()
|
||||
else:
|
||||
sig = inspect.signature(self.funcBody)
|
||||
sig = inspect.signature(self.funcBody)
|
||||
|
||||
function_name = self.funcBody.__name__
|
||||
|
||||
bound_args = self._get_function_bound_args(sig, function_name, *args, **kwargs)
|
||||
@@ -1382,10 +1358,7 @@ class BaseDSL:
|
||||
# Check the number of arguments
|
||||
sig = self._check_arg_count(*args, **kwargs)
|
||||
|
||||
if isinstance(funcBody, DSLCallable):
|
||||
args_spec = funcBody.get_arg_spec()
|
||||
else:
|
||||
args_spec = inspect.getfullargspec(funcBody)
|
||||
args_spec = inspect.getfullargspec(funcBody)
|
||||
|
||||
# Canonicalize the input arguments
|
||||
canonicalized_args, canonicalized_kwargs = self._canonicalize_args(
|
||||
@@ -1447,7 +1420,7 @@ class BaseDSL:
|
||||
return cuda_helpers.stream_create()
|
||||
|
||||
def _execute_cuda(
|
||||
self, fname_cubin, kernel_name, grid_size, block_size, stream=None
|
||||
self, fname_cubin, kernel_name, grid_size, block_size, smem_size, stream=None
|
||||
):
|
||||
"""
|
||||
Executes a specified CUDA kernel from a cubin file, handling module loading,
|
||||
@@ -1471,7 +1444,7 @@ class BaseDSL:
|
||||
grid_size,
|
||||
block_size,
|
||||
stream,
|
||||
smem_size=16000,
|
||||
smem_size=smem_size,
|
||||
kernel_args=self.exe_args,
|
||||
)
|
||||
|
||||
@@ -1480,7 +1453,13 @@ class BaseDSL:
|
||||
cuda_helpers.stream_sync(stream)
|
||||
|
||||
def _execute_by_cuda_driver(
|
||||
self, kernel_generator, generate_cubin, grid_size, block_size, stream=None
|
||||
self,
|
||||
kernel_generator,
|
||||
generate_cubin,
|
||||
grid_size,
|
||||
block_size,
|
||||
smem_size,
|
||||
stream=None,
|
||||
):
|
||||
"""
|
||||
This function builds IR and execute the module using cuda driver.
|
||||
@@ -1511,10 +1490,9 @@ class BaseDSL:
|
||||
fname_cubin = generate_cubin(module, kernel_name)
|
||||
|
||||
# Execute a cuda kernel from cubin
|
||||
if block_size is None:
|
||||
# The TileIR driver should set this automatically.
|
||||
block_size = self.block_size
|
||||
self._execute_cuda(fname_cubin, kernel_name, grid_size, block_size, stream)
|
||||
self._execute_cuda(
|
||||
fname_cubin, kernel_name, grid_size, block_size, smem_size, stream
|
||||
)
|
||||
|
||||
return ret
|
||||
|
||||
@@ -1587,10 +1565,7 @@ class BaseDSL:
|
||||
kernelGenHelper = dkwargs.get("kernelGenHelper", None)
|
||||
|
||||
kernel_name = funcBody.__name__
|
||||
if isinstance(funcBody, DSLCallable):
|
||||
args_spec = funcBody.get_arg_spec()
|
||||
else:
|
||||
args_spec = inspect.getfullargspec(funcBody)
|
||||
args_spec = inspect.getfullargspec(funcBody)
|
||||
self.funcBody = funcBody
|
||||
|
||||
# Give each kernel a unique name. (The same kernel may be
|
||||
|
||||
@@ -58,6 +58,11 @@ def get_int_env_var(var_name, default_value=0):
|
||||
return int(value) if value and value.isdigit() else default_value
|
||||
|
||||
|
||||
@lru_cache(maxsize=None)
|
||||
def has_env_var(var_name):
|
||||
return os.getenv(var_name) is not None
|
||||
|
||||
|
||||
def detect_gpu_arch(prefix):
|
||||
"""
|
||||
Attempts to detect the machine's GPU architecture.
|
||||
@@ -256,6 +261,7 @@ class EnvironmentVarManager:
|
||||
- [DSL_NAME]_ARCH: GPU architecture (default: "sm_100")
|
||||
- [DSL_NAME]_WARNINGS_AS_ERRORS: Enable warnings as error (default: False)
|
||||
- [DSL_NAME]_WARNINGS_IGNORE: Ignore warnings (default: False)
|
||||
- [DSL_NAME]_ENABLE_OPTIMIZATION_WARNINGS: Enable warnings of optimization warnings (default: False)
|
||||
- [DSL_NAME]_JIT_TIME_PROFILING: Whether or not to profile the IR generation/compilation/execution time (default: False)
|
||||
- [DSL_NAME]_DISABLE_FILE_CACHING: Disable file caching (default: False)
|
||||
- [DSL_NAME]_FILE_CACHING_CAPACITY: Limits the number of the cache save/load files (default: 1000)
|
||||
@@ -267,7 +273,6 @@ class EnvironmentVarManager:
|
||||
self.prefix = prefix # change if needed
|
||||
|
||||
# Printing options
|
||||
self.log_to_console = get_bool_env_var(f"{prefix}_LOG_TO_CONSOLE", False)
|
||||
self.print_after_preprocessor = get_bool_env_var(
|
||||
f"{prefix}_PRINT_AFTER_PREPROCESSOR", False
|
||||
)
|
||||
@@ -275,15 +280,29 @@ class EnvironmentVarManager:
|
||||
self.filterStacktrace = get_bool_env_var(f"{prefix}_FILTER_STACKTRACE", True)
|
||||
# File options
|
||||
self.keepIR = get_bool_env_var(f"{prefix}_KEEP_IR", False)
|
||||
# Logging options
|
||||
self.log_to_console = get_bool_env_var(f"{prefix}_LOG_TO_CONSOLE", False)
|
||||
self.log_to_file = get_bool_env_var(f"{prefix}_LOG_TO_FILE", False)
|
||||
# Other options
|
||||
if (
|
||||
has_env_var(f"{prefix}_LOG_LEVEL")
|
||||
and not self.log_to_console
|
||||
and not self.log_to_file
|
||||
):
|
||||
log().warning(
|
||||
f"Log level was set, but neither logging to file ({prefix}_LOG_TO_FILE) nor logging to console ({prefix}_LOG_TO_CONSOLE) is enabled!"
|
||||
)
|
||||
self.log_level = get_int_env_var(f"{prefix}_LOG_LEVEL", 1)
|
||||
|
||||
# Other options
|
||||
self.dryrun = get_bool_env_var(f"{prefix}_DRYRUN", False)
|
||||
self.arch = get_str_env_var(f"{prefix}_ARCH", detect_gpu_arch(prefix))
|
||||
self.warnings_as_errors = get_bool_env_var(
|
||||
f"{prefix}_WARNINGS_AS_ERRORS", False
|
||||
)
|
||||
self.warnings_ignore = get_bool_env_var(f"{prefix}_WARNINGS_IGNORE", False)
|
||||
self.enable_optimization_warnings = get_bool_env_var(
|
||||
f"{prefix}_ENABLE_OPTIMIZATION_WARNINGS", False
|
||||
)
|
||||
self.jitTimeProfiling = get_bool_env_var(f"{prefix}_JIT_TIME_PROFILING", False)
|
||||
self.disable_file_caching = get_bool_env_var(
|
||||
f"{prefix}_DISABLE_FILE_CACHING", False
|
||||
|
||||
@@ -14,16 +14,12 @@ This module provides a runtime utility functions that are needed for
|
||||
the DSL.
|
||||
"""
|
||||
|
||||
from . import device_tensor
|
||||
from . import dlpack_types
|
||||
from . import cuda
|
||||
from . import tensor_descriptor
|
||||
from . import jit_arg_adapters
|
||||
|
||||
__all__ = [
|
||||
"device_tensor",
|
||||
"dlpack_types",
|
||||
"cuda",
|
||||
"tensor_descriptor",
|
||||
"jit_arg_adapters",
|
||||
]
|
||||
|
||||
@@ -309,7 +309,7 @@ def get_kernel_function(module, kernel_name):
|
||||
return kernel
|
||||
|
||||
|
||||
def launch_kernel(kernel, grid_dims, block_dims, stream, smem_size=0, kernel_args=None):
|
||||
def launch_kernel(kernel, grid_dims, block_dims, stream, smem_size, kernel_args=None):
|
||||
"""
|
||||
Launches the CUDA kernel.
|
||||
"""
|
||||
|
||||
@@ -183,6 +183,13 @@ class TensorDescriptor:
|
||||
"""
|
||||
return self.device_type == _dpack.DLDeviceType.kDLGPU
|
||||
|
||||
@staticmethod
|
||||
def is_compatible(maybe_tensor_descriptor) -> bool:
|
||||
"""Check if the object is a TensorDescriptor or can be converted to one."""
|
||||
return isinstance(
|
||||
maybe_tensor_descriptor, TensorDescriptor
|
||||
) or TensorDescriptor.can_transformed_to_dlpack(maybe_tensor_descriptor)
|
||||
|
||||
|
||||
def from_tensor(tensor) -> TensorDescriptor:
|
||||
"""Create a TensorDescriptor from a tensor object."""
|
||||
@@ -192,10 +199,3 @@ def from_tensor(tensor) -> TensorDescriptor:
|
||||
def to_tensor(tensor_descriptor: TensorDescriptor):
|
||||
"""Return tensor object from tensor descriptor."""
|
||||
return tensor_descriptor.tensor
|
||||
|
||||
|
||||
def is_tensor_descriptor(maybe_tensor_descriptor) -> bool:
|
||||
"""Check if the object is a TensorDescriptor."""
|
||||
return isinstance(
|
||||
maybe_tensor_descriptor, TensorDescriptor
|
||||
) or TensorDescriptor.can_transformed_to_dlpack(maybe_tensor_descriptor)
|
||||
|
||||
@@ -126,6 +126,7 @@ from .core import (
|
||||
basic_copy_if,
|
||||
autovec_copy,
|
||||
copy,
|
||||
copy_atom_call,
|
||||
gemm,
|
||||
# Wrapper classes
|
||||
ComposedLayout,
|
||||
@@ -290,6 +291,7 @@ __all__ = [
|
||||
"basic_copy_if",
|
||||
"autovec_copy",
|
||||
"copy",
|
||||
"copy_atom_call",
|
||||
"gemm",
|
||||
# Tensor creation
|
||||
"full",
|
||||
|
||||
@@ -315,3 +315,35 @@ def mbarrier_arrive(
|
||||
loc=loc,
|
||||
ip=ip,
|
||||
)
|
||||
|
||||
|
||||
@dsl_user_op
|
||||
def cp_async_mbarrier_arrive_noinc(mbar_ptr: Pointer, *, loc=None, ip=None) -> None:
|
||||
"""
|
||||
Arrives on an mbarrier for async load **without incrementing** the arrival count
|
||||
(`cp.async.mbarrier.arrive.shared ..., noinc=1`).
|
||||
Used in the warp-specialized kernel when the non-TMA load warp(producer) is not the same
|
||||
as the math/epilogue warp(consumer).
|
||||
|
||||
:param mbar_ptr: A pointer to the mbarrier in SMEM
|
||||
:type mbar_ptr: Pointer
|
||||
"""
|
||||
arch = CuTeDSL._get_dsl().envar.arch
|
||||
check_value_in(
|
||||
arch,
|
||||
[
|
||||
"sm_90",
|
||||
"sm_90a",
|
||||
"sm_100a",
|
||||
"sm_100f",
|
||||
],
|
||||
"arch",
|
||||
)
|
||||
|
||||
mbar_llvm_ptr = mbar_ptr.llvm_ptr
|
||||
nvvm.cp_async_mbarrier_arrive_shared(
|
||||
mbar_llvm_ptr,
|
||||
noinc=True,
|
||||
loc=loc,
|
||||
ip=ip,
|
||||
)
|
||||
|
||||
@@ -11,6 +11,7 @@
|
||||
|
||||
from functools import partial
|
||||
from typing import Optional, Tuple, Union, Callable
|
||||
from typing_extensions import deprecated
|
||||
|
||||
from cutlass.cutlass_dsl import T, dsl_user_op
|
||||
|
||||
@@ -642,6 +643,9 @@ def rcp_approx(a: Union[float, Float32], *, loc=None, ip=None):
|
||||
|
||||
|
||||
@dsl_user_op
|
||||
@deprecated(
|
||||
"cute.arch.exp2 is deprecated, use cute.math.exp2 with `fastmath=True` instead"
|
||||
)
|
||||
def exp2(a: Union[float, Float32], *, loc=None, ip=None) -> Float32:
|
||||
return Float32(
|
||||
llvm.inline_asm(
|
||||
@@ -656,15 +660,19 @@ def exp2(a: Union[float, Float32], *, loc=None, ip=None) -> Float32:
|
||||
)
|
||||
|
||||
|
||||
# TODO: add `fastmath` flag for this op
|
||||
@dsl_user_op
|
||||
@deprecated(
|
||||
"cute.arch.exp is deprecated, use cute.math.exp with `fastmath=True` instead"
|
||||
)
|
||||
def exp(a: Union[float, Float32], *, loc=None, ip=None) -> Float32:
|
||||
LOG2_E = 1.4426950408889634
|
||||
return exp2(a * LOG2_E, loc=loc, ip=ip)
|
||||
|
||||
|
||||
# TODO: add `fastmath` flag for this op
|
||||
@dsl_user_op
|
||||
@deprecated(
|
||||
"cute.arch.exp_packed_f32x2 is deprecated, use cute.arch.mul_packed_f32x2 and cute.math.exp2 with `fastmath=True` instead"
|
||||
)
|
||||
def exp_packed_f32x2(
|
||||
a: Tuple[Float32, Float32], *, loc=None, ip=None
|
||||
) -> Tuple[Float32, Float32]:
|
||||
|
||||
@@ -31,7 +31,6 @@ from typing import (
|
||||
Optional,
|
||||
)
|
||||
from enum import Enum, auto
|
||||
from typing_extensions import deprecated
|
||||
|
||||
from cutlass.cutlass_dsl import (
|
||||
const,
|
||||
@@ -1662,7 +1661,9 @@ class _Tensor(Tensor):
|
||||
|
||||
|
||||
@dsl_user_op
|
||||
def print_tensor(tensor: Tensor, *, verbose: bool = False, loc=None, ip=None):
|
||||
def print_tensor(
|
||||
tensor: Union[Tensor, "TensorSSA"], *, verbose: bool = False, loc=None, ip=None
|
||||
):
|
||||
"""Print content of the tensor in human readable format.
|
||||
|
||||
Outputs the tensor data in a structured format showing both metadata
|
||||
@@ -1693,6 +1694,11 @@ def print_tensor(tensor: Tensor, *, verbose: bool = False, loc=None, ip=None):
|
||||
[ 0.9159, 0.7577, 0.6918, 0.0754, 0.0591],
|
||||
[ 0.6551, 0.1626, 0.1189, 0.0292, 0.8655]])
|
||||
"""
|
||||
if isinstance(tensor, TensorSSA):
|
||||
tmp = make_fragment(tensor.shape, tensor.dtype)
|
||||
tmp.store(tensor)
|
||||
tensor = tmp
|
||||
|
||||
if not isinstance(tensor.type, _cute_ir.MemRefType):
|
||||
raise NotImplementedError(
|
||||
f"printing {tensor} is not supported because it doesn't support trivial dereferencing. "
|
||||
@@ -1769,7 +1775,7 @@ def is_static(x: Union[ir.Type, ir.Value, XTuple]) -> bool:
|
||||
return False
|
||||
elif is_dynamic_expression(x):
|
||||
return _cute_ir.is_static(x.type)
|
||||
elif isinstance(x, int) or x is None:
|
||||
elif isinstance(x, (bool, int, float)) or x is None:
|
||||
return True
|
||||
elif isinstance(x, ScaledBasis):
|
||||
return x.is_static()
|
||||
@@ -2241,7 +2247,7 @@ def is_weakly_congruent(
|
||||
* X is a non-tuple value, OR
|
||||
* X and Y are both tuples of the same rank AND all corresponding elements are weakly congruent.
|
||||
|
||||
Weak congruence allows scalar values to match with tuples, making it useful
|
||||
Weak congruence allows scalar values to match with tuples, making it useful
|
||||
for determining whether an object has a hierarchical structure "up to" another.
|
||||
|
||||
:param a: First object to compare
|
||||
@@ -2921,33 +2927,46 @@ def flatten_to_tuple(a: Union[IntTuple, Coord, Shape, Stride]) -> tuple:
|
||||
return tuple(chain.from_iterable(tuple(flatten_to_tuple(x) for x in a)))
|
||||
|
||||
|
||||
def flatten(a: Union[IntTuple, Coord, Shape, Stride, Layout, Tensor]) -> tuple:
|
||||
@overload
|
||||
def flatten(a: Union[IntTuple, Coord, Shape, Stride]) -> IntTuple: ...
|
||||
@overload
|
||||
def flatten(a: Tensor) -> Tensor: ...
|
||||
@overload
|
||||
def flatten(a: Layout) -> Layout: ...
|
||||
|
||||
|
||||
def flatten(a):
|
||||
"""Flattens a CuTe data structure into a simpler form.
|
||||
|
||||
For tuples, this function flattens the structure into a single-level tuple.
|
||||
For non-tuple types, it returns the input unchanged.
|
||||
For layouts, it returns a new layout with flattened shape and stride.
|
||||
For tensors, it returns a new tensor with flattened layout.
|
||||
For other types, it returns the input unchanged.
|
||||
|
||||
:param a: The structure to flatten
|
||||
:type a: Union[IntTuple, Coord, Shape, Stride, Layout, Tensor]
|
||||
:return: The flattened structure
|
||||
:rtype: Union[tuple, Any]
|
||||
:raises NotImplementedError: If input is a Layout or Tensor
|
||||
|
||||
**Examples:**
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
flatten((1, 2, 3)) # Returns (1, 2, 3)
|
||||
flatten(((1, 2), (3, 4))) # Returns (1, 2, 3, 4)
|
||||
flatten(5) # Returns 5
|
||||
"""
|
||||
if isinstance(a, (Layout, Tensor)):
|
||||
raise NotImplementedError("flatten layout and tensor is not supported")
|
||||
flatten((1, 2, 3)) # Returns (1, 2, 3)
|
||||
flatten(((1, 2), (3, 4))) # Returns (1, 2, 3, 4)
|
||||
flatten(5) # Returns 5
|
||||
flatten(Layout(shape, stride)) # Returns Layout(flatten(shape), flatten(stride))
|
||||
flatten(Tensor(layout)) # Returns Tensor(flatten(layout))
|
||||
|
||||
if not isinstance(a, tuple):
|
||||
return a
|
||||
else:
|
||||
"""
|
||||
if isinstance(a, Tensor):
|
||||
return make_tensor(a.iterator, flatten(a.layout))
|
||||
elif isinstance(a, Layout):
|
||||
return make_layout(flatten(a.shape), stride=flatten(a.stride))
|
||||
elif isinstance(a, tuple):
|
||||
return flatten_to_tuple(a)
|
||||
else:
|
||||
return a
|
||||
|
||||
|
||||
def unflatten(
|
||||
@@ -4120,14 +4139,14 @@ def complement(
|
||||
@dsl_user_op
|
||||
def right_inverse(input: Layout, *, loc=None, ip=None) -> Layout:
|
||||
if not isinstance(input, Layout):
|
||||
raise TypeError(f"expects input of type Layout, but got {type(Layout)}")
|
||||
raise TypeError(f"expects input of type Layout, but got {type(input)}")
|
||||
return _cute_ir.right_inverse(input=input, loc=loc, ip=ip)
|
||||
|
||||
|
||||
@dsl_user_op
|
||||
def left_inverse(input: Layout, *, loc=None, ip=None) -> Layout:
|
||||
if not isinstance(input, Layout):
|
||||
raise TypeError(f"expects input of type Layout, but got {type(Layout)}")
|
||||
raise TypeError(f"expects input of type Layout, but got {type(input)}")
|
||||
return _cute_ir.left_inverse(input=input, loc=loc, ip=ip)
|
||||
|
||||
|
||||
@@ -5156,7 +5175,6 @@ def _make_tiled_copy(atom, layout_tv, tiler_mn, *, loc=None, ip=None):
|
||||
return TiledCopy(atom.op, trait)
|
||||
|
||||
|
||||
@deprecated("Use make_tiled_copy_tv instead")
|
||||
def make_tiled_copy(atom, layout_tv, tiler_mn, *, loc=None, ip=None):
|
||||
"""Create a tiled type given a TV partitioner and tiler.
|
||||
|
||||
@@ -5434,6 +5452,14 @@ def gemm(
|
||||
For MMA Atoms that require single-threaded execution, the gemm op automatically handles thread
|
||||
election internally. Manual thread selection is not required in such cases.
|
||||
|
||||
Following dispatch rules are supported:
|
||||
|
||||
- Dispatch [1]: (V) x (V) => (V) => (V,1,1) x (V,1,1) => (V,1,1)
|
||||
- Dispatch [2]: (M) x (N) => (M,N) => (1,M,1) x (1,N,1) => (1,M,N)
|
||||
- Dispatch [3]: (M,K) x (N,K) => (M,N) => (1,M,K) x (1,N,K) => (1,M,N)
|
||||
- Dispatch [4]: (V,M) x (V,N) => (V,M,N) => (V,M,1) x (V,N,1) => (V,M,N)
|
||||
- Dispatch [5]: (V,M,K) x (V,N,K) => (V,M,N)
|
||||
|
||||
:param atom: MMA atom
|
||||
:type atom: MmaAtom
|
||||
:param d: Destination tensor
|
||||
@@ -5454,6 +5480,27 @@ def gemm(
|
||||
:rtype: None
|
||||
"""
|
||||
|
||||
a_rank = rank(a.shape)
|
||||
b_rank = rank(b.shape)
|
||||
c_rank = rank(c.shape)
|
||||
d_rank = rank(d.shape)
|
||||
|
||||
if a_rank != b_rank:
|
||||
raise ValueError("`a` and `b` must have the same rank")
|
||||
|
||||
if c_rank != d_rank:
|
||||
raise ValueError("`c` and `d` must have the same rank")
|
||||
|
||||
if a_rank == 1:
|
||||
if c_rank > 2:
|
||||
raise ValueError("`c` must have rank <= 2 when `a` has rank 1")
|
||||
elif a_rank == 2:
|
||||
if c_rank not in (2, 3):
|
||||
raise ValueError("`c` must have rank 2 or 3 when `a` has rank 2")
|
||||
elif a_rank == 3:
|
||||
if c_rank != 3:
|
||||
raise ValueError("`c` must have rank 3 when `a` has rank 3")
|
||||
|
||||
value = atom._unpack(loc=loc, ip=ip, **kwargs)
|
||||
return _cute_ir.gemm(value, d.value, a.value, b.value, c.value, loc=loc, ip=ip)
|
||||
|
||||
@@ -5645,6 +5692,76 @@ def copy(
|
||||
return _cute_ir.copy(value, src.value, dst.value, pred=pred, loc=loc, ip=ip)
|
||||
|
||||
|
||||
@dsl_user_op
|
||||
def copy_atom_call(
|
||||
atom: CopyAtom,
|
||||
src: Tensor,
|
||||
dst: Tensor,
|
||||
*,
|
||||
pred: Optional[Tensor] = None,
|
||||
loc=None,
|
||||
ip=None,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
"""
|
||||
Execute a single copy atom operation.
|
||||
|
||||
The copy_atom_call operation executes a copy atom with the given operands.
|
||||
Following src/dst layout of atom are valid:
|
||||
* ((atom_v))
|
||||
* (atom_v)
|
||||
|
||||
Note: The format ((atom_v, rest_v)) is NOT valid for copy_atom_call since it would
|
||||
require multiple atom operations, which contradicts the definition of a single copy atom call.
|
||||
|
||||
Examples:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
# Call a copy atom operation
|
||||
cute.copy_atom_call(copy_atom, src_tensor, dst_tensor)
|
||||
|
||||
An additional predication tensor can be provided. If the partitioned tensors have the following
|
||||
logical profile ``((ATOM_V,ATOM_REST),REST_M,...)``, the predication tensor must have a profile
|
||||
consistent with ``(ATOM_REST,REST_M,...)``.
|
||||
"""
|
||||
if isinstance(src.type, _cute_ir.MemRefType) and isinstance(
|
||||
dst.type, _cute_ir.MemRefType
|
||||
):
|
||||
if src.element_type.width != dst.element_type.width:
|
||||
raise TypeError(
|
||||
"`copy_atom_call` currently only supports equal source and destination "
|
||||
"element type bit width"
|
||||
)
|
||||
|
||||
value = atom._unpack(loc=loc, ip=ip, **kwargs)
|
||||
if isinstance(pred, Tensor):
|
||||
pred = pred.value
|
||||
return _cute_ir.copy_atom_call(
|
||||
value, src.value, dst.value, pred=pred, loc=loc, ip=ip
|
||||
)
|
||||
|
||||
|
||||
def prefetch(atom: CopyAtom, src: Tensor, *, loc=None, ip=None) -> None:
|
||||
"""
|
||||
The Prefetch algorithm.
|
||||
|
||||
The "prefetch" expects source tensors to be partitioned according to the provided Copy Atom.
|
||||
Prefetch is used for loading tensors from global memory to L2.
|
||||
|
||||
Prefetch accepts Copy Atom but not all are allowed. Currently, only support for tma load tensor prefetch.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
cute.prefetch(tma_atom, src)
|
||||
|
||||
For Copy Atoms that require single-threaded execution, the copy op automatically handles thread
|
||||
election internally. Manual thread selection is not required in such cases.
|
||||
"""
|
||||
dummy_tma_bar_ptr = make_ptr(Int64, 0, AddressSpace.smem, loc=loc, ip=ip)
|
||||
value = atom._unpack(loc=loc, ip=ip, tma_bar_ptr=dummy_tma_bar_ptr)
|
||||
return _cute_ir.prefetch(value, src.value, loc=loc, ip=ip)
|
||||
|
||||
####################################################################################################
|
||||
#
|
||||
# TensorSSA class (experimental)
|
||||
@@ -5657,6 +5774,11 @@ class ReductionOp(Enum):
|
||||
MUL = auto()
|
||||
MAX = auto()
|
||||
MIN = auto()
|
||||
INC = auto()
|
||||
DEC = auto()
|
||||
AND = auto()
|
||||
OR = auto()
|
||||
XOR = auto()
|
||||
|
||||
def __str__(self):
|
||||
return self.name.lower()
|
||||
@@ -5697,6 +5819,7 @@ class TensorSSA(cutlass_arith.ArithValue):
|
||||
|
||||
self._shape = shape
|
||||
self._dtype = dtype
|
||||
self._layout = None
|
||||
|
||||
@property
|
||||
def dtype(self) -> Type[Numeric]:
|
||||
@@ -5776,13 +5899,26 @@ class TensorSSA(cutlass_arith.ArithValue):
|
||||
):
|
||||
res_type = Boolean
|
||||
|
||||
if lhs.shape != rhs.shape:
|
||||
raise ValueError(
|
||||
f"lhs and rhs must have the same shape type, but got {lhs.shape} and {rhs.shape}"
|
||||
)
|
||||
assert isinstance(rhs, TensorSSA), f"rhs must be TensorSSA but got {rhs}"
|
||||
|
||||
if not isinstance(rhs, TensorSSA):
|
||||
raise TypeError(f"rhs must be TensorSSA but got {rhs}")
|
||||
def _broadcast(s, t):
|
||||
if s == 1:
|
||||
return t
|
||||
elif t == 1:
|
||||
return s
|
||||
elif s == t:
|
||||
return s
|
||||
else:
|
||||
raise ValueError(f"cannot broadcast {s} and {t}")
|
||||
|
||||
max_rank = max(rank(lhs.shape), rank(rhs.shape))
|
||||
lhs_shape = append(lhs.shape, 1, up_to_rank=max_rank)
|
||||
rhs_shape = append(rhs.shape, 1, up_to_rank=max_rank)
|
||||
res_shape = transform_leaf(_broadcast, lhs_shape, rhs_shape)
|
||||
|
||||
# broadcast to the same shape
|
||||
lhs = lhs.broadcast_to(res_shape)
|
||||
rhs = rhs.broadcast_to(res_shape)
|
||||
|
||||
if (
|
||||
op in (operator.add, operator.sub)
|
||||
@@ -5807,6 +5943,38 @@ class TensorSSA(cutlass_arith.ArithValue):
|
||||
|
||||
return res
|
||||
|
||||
def broadcast_to(self, target_shape: Shape, *, loc=None, ip=None) -> "TensorSSA":
|
||||
"""
|
||||
Broadcast the tensor to the target shape.
|
||||
"""
|
||||
# pad source shape to the same rank
|
||||
shape = append(self.shape, 1, up_to_rank=rank(target_shape))
|
||||
if shape == target_shape:
|
||||
return self
|
||||
|
||||
def _check_broadcast(s, t):
|
||||
if s != t and s != 1:
|
||||
raise ValueError(
|
||||
f"src_shape and target_shape must be the same when src_shape is not 1, but got {s} and {t}"
|
||||
)
|
||||
|
||||
transform_leaf(_check_broadcast, shape, target_shape)
|
||||
|
||||
# reshape to flatten N-D vector
|
||||
flat_shp = flatten_to_tuple(shape)
|
||||
temp_ty = ir.VectorType.get(list(flat_shp), self.dtype.mlir_type)
|
||||
temp_vect = vector.shape_cast(temp_ty, self, loc=loc, ip=ip)
|
||||
|
||||
# broadcast to result N-D vector
|
||||
flat_tgt_shp = flatten_to_tuple(target_shape)
|
||||
temp_tgt_ty = ir.VectorType.get(list(flat_tgt_shp), self.dtype.mlir_type)
|
||||
temp_tgt_vect = vector.broadcast(temp_tgt_ty, temp_vect, loc=loc, ip=ip)
|
||||
|
||||
res_1d_ty = ir.VectorType.get([size(target_shape)], self.dtype.mlir_type) # type: ignore
|
||||
res_1d_vect = vector.shape_cast(res_1d_ty, temp_tgt_vect, loc=loc, ip=ip)
|
||||
|
||||
return TensorSSA(res_1d_vect, target_shape, self.dtype)
|
||||
|
||||
def __pow__(self, other, *, loc=None, ip=None) -> "TensorSSA":
|
||||
"""
|
||||
Returns the results of tensor^other.
|
||||
@@ -6093,6 +6261,16 @@ class TensorSSA(cutlass_arith.ArithValue):
|
||||
"""
|
||||
return self._apply_op(operator.and_, other, flip=True, loc=loc, ip=ip)
|
||||
|
||||
def __neg__(self, *, loc=None, ip=None) -> "TensorSSA":
|
||||
"""
|
||||
Returns the negation of the tensor.
|
||||
|
||||
:return: The element-wise negation of the tensor
|
||||
:rtype: TensorSSA
|
||||
"""
|
||||
|
||||
return self._apply_op(operator.sub, 0, flip=True, loc=loc, ip=ip)
|
||||
|
||||
def _flatten_shape_and_coord(self, crd, *, loc=None, ip=None):
|
||||
# Coalesce and flatten source layout at terminal of coordinate
|
||||
# (N_0,(N_1,...), ...) -> (N_0,N_1,N_2,...)
|
||||
@@ -6158,17 +6336,13 @@ class TensorSSA(cutlass_arith.ArithValue):
|
||||
if crd is None:
|
||||
return self
|
||||
|
||||
if not has_underscore(crd) or depth(crd) == 0:
|
||||
idx = crd2idx(crd, make_layout(self._shape))
|
||||
if is_static(idx):
|
||||
res = vector.extract(
|
||||
self, dynamic_position=[], static_position=[idx], loc=loc, ip=ip
|
||||
)
|
||||
else:
|
||||
res = vector.extract(
|
||||
self, dynamic_position=[crd], static_position=[], loc=loc, ip=ip
|
||||
)
|
||||
return self.dtype(res)
|
||||
if not has_underscore(crd):
|
||||
if self._layout is None:
|
||||
self._layout = make_layout(self._shape, loc=loc, ip=ip)
|
||||
idx = crd2idx(crd, self._layout, loc=loc, ip=ip)
|
||||
idx_val = as_numeric(idx).ir_value(loc=loc, ip=ip)
|
||||
res_val = vector.extractelement(self, position=idx_val, loc=loc, ip=ip)
|
||||
return self.dtype(res_val)
|
||||
|
||||
if not is_static(crd):
|
||||
raise ValueError("dynamic coordinate is not supported")
|
||||
@@ -6274,7 +6448,7 @@ class TensorSSA(cutlass_arith.ArithValue):
|
||||
:type op: operator
|
||||
:param init_val: The initial value for the reduction
|
||||
:type init_val: numeric
|
||||
:param reduction_profile: Specifies which dimensions to reduce. Dimensions marked with '_' are kept.
|
||||
:param reduction_profile: Specifies which dimensions to reduce. Dimensions marked with `None` are kept.
|
||||
:type reduction_profile: Coord
|
||||
|
||||
:return: The reduced tensor
|
||||
@@ -6289,9 +6463,9 @@ class TensorSSA(cutlass_arith.ArithValue):
|
||||
|
||||
reduce(f32 o (4, 5))
|
||||
=> f32
|
||||
reduce(f32 o (4, (5, 4)), reduction_profile=(_, 1))
|
||||
reduce(f32 o (4, (5, 4)), reduction_profile=(None, 1))
|
||||
=> f32 o (4,)
|
||||
reduce(f32 o (4, (5, 4)), reduction_profile=(_, (_, 1)))
|
||||
reduce(f32 o (4, (5, 4)), reduction_profile=(None, (None, 1)))
|
||||
=> f32 o (4, (5,))
|
||||
"""
|
||||
# short-cut to no-op
|
||||
@@ -6354,21 +6528,6 @@ class TensorSSA(cutlass_arith.ArithValue):
|
||||
return self._build_result(res_vect, res_shp, loc=loc, ip=ip)
|
||||
|
||||
|
||||
def _get_attr_for_type(ty, value):
|
||||
if isinstance(ty, ir.IntegerType):
|
||||
return ir.IntegerAttr.get(ty, value.to(int))
|
||||
elif isinstance(ty, ir.FloatType):
|
||||
return ir.FloatAttr.get(ty, value.to(float))
|
||||
else:
|
||||
raise TypeError(f"unsupported type: {ty}")
|
||||
|
||||
|
||||
def _splat(res_ty, fill_value):
|
||||
elem_attr = _get_attr_for_type(res_ty.element_type, fill_value)
|
||||
vect_attr = ir.DenseElementsAttr.get_splat(res_ty, elem_attr)
|
||||
return arith.constant(res_ty, vect_attr)
|
||||
|
||||
|
||||
@dsl_user_op
|
||||
def full(shape, fill_value, dtype: Type[Numeric], *, loc=None, ip=None) -> TensorSSA:
|
||||
"""
|
||||
@@ -6389,9 +6548,14 @@ def full(shape, fill_value, dtype: Type[Numeric], *, loc=None, ip=None) -> Tenso
|
||||
|
||||
if isinstance(fill_value, (ir.Value, int, float, bool)):
|
||||
fill_value = dtype(fill_value)
|
||||
elif isinstance(fill_value, Numeric):
|
||||
fill_value = fill_value.to(dtype, loc=loc, ip=ip)
|
||||
else:
|
||||
raise ValueError(f"Expected fill_value be numeric type, but got {fill_value}")
|
||||
|
||||
res_mlir_type = T.vector(size, dtype.mlir_type)
|
||||
return TensorSSA(_splat(res_mlir_type, fill_value), shape, dtype)
|
||||
res_ty = T.vector(size, dtype.mlir_type)
|
||||
res_val = vector.splat(res_ty, fill_value.ir_value(loc=loc, ip=ip), loc=loc, ip=ip)
|
||||
return TensorSSA(res_val, shape, dtype)
|
||||
|
||||
|
||||
def full_like(
|
||||
@@ -6547,7 +6711,7 @@ class struct:
|
||||
|
||||
**Usage:**
|
||||
|
||||
.. code-block::
|
||||
.. code-block:: python
|
||||
|
||||
# Supports base_dsl scalar int/float elements, array and nested struct:
|
||||
@cute.struct
|
||||
@@ -6661,7 +6825,8 @@ class struct:
|
||||
Initializes a new memory range.
|
||||
|
||||
:param dtype: The data type.
|
||||
:param size: The size of the memory range in bytes.
|
||||
:param size: Size of the memory range in bytes. A size of **0** is accepted, but in that
|
||||
case the range can only be used for its address (e.g. as a partition marker).
|
||||
:param base: The base address of the memory range.
|
||||
"""
|
||||
self._dtype = dtype
|
||||
@@ -6673,9 +6838,9 @@ class struct:
|
||||
Returns start pointer to the data in this memory range.
|
||||
|
||||
:return: A pointer to the start of the memory range.
|
||||
:raises AssertionError: If the size of the memory range is not greater than zero.
|
||||
:raises AssertionError: If the size of the memory range is negative.
|
||||
"""
|
||||
assert self._size > 0
|
||||
assert self._size >= 0
|
||||
return recast_ptr(self._base, dtype=self._dtype)
|
||||
|
||||
def get_tensor(self, layout, swizzle=None, dtype=None):
|
||||
@@ -6716,31 +6881,48 @@ class struct:
|
||||
|
||||
:param v: The object to align. Must be a struct, MemRange, or a scalar type.
|
||||
:param align: The alignment value to set.
|
||||
:return: A copy of the object with the specified alignment.
|
||||
:raises TypeError: If the object is not a struct, MemRange, or a scalar type.
|
||||
|
||||
:ivar _dtype: The data type to be aligned.
|
||||
:ivar _align: The alignment of the data type.
|
||||
"""
|
||||
|
||||
_dtype = None
|
||||
_align = None
|
||||
|
||||
def __new__(cls, name, bases, dct):
|
||||
return super().__new__(cls, name, bases, dct)
|
||||
|
||||
def __getitem__(cls, params) -> Any:
|
||||
if len(params) == 2:
|
||||
obj, align = params
|
||||
dtype, align = params
|
||||
assert align > 0
|
||||
else:
|
||||
raise TypeError("Invalid struct.Align Arguments")
|
||||
|
||||
# make a copy of type and mark alignment
|
||||
if struct._is_scalar_type(obj) or isinstance(
|
||||
obj, (struct, struct._MemRangeMeta)
|
||||
if not struct._is_scalar_type(dtype) and not isinstance(
|
||||
dtype, (struct, struct._MemRangeMeta)
|
||||
):
|
||||
new_obj = py_copy.copy(obj)
|
||||
setattr(new_obj, "_struct_alignment_", align)
|
||||
return new_obj
|
||||
else:
|
||||
raise TypeError(
|
||||
"align only can be applied to sturct/MemRange/base_dsl scalar"
|
||||
"align only can be applied to struct/MemRange/base_dsl scalar"
|
||||
)
|
||||
|
||||
# Create new class with alignment
|
||||
new_cls = type(
|
||||
f"struct.Align[{dtype.__name__}, {align}]",
|
||||
(struct.Align,),
|
||||
{"_dtype": dtype, "_align": align},
|
||||
)
|
||||
return new_cls
|
||||
|
||||
@property
|
||||
def dtype(cls):
|
||||
return cls._dtype
|
||||
|
||||
@property
|
||||
def align(cls):
|
||||
return cls._align
|
||||
|
||||
class Align(metaclass=_AlignMeta):
|
||||
"""
|
||||
Aligns the given type by `Align[T, alignment]`.
|
||||
@@ -6768,6 +6950,7 @@ class struct:
|
||||
:raises TypeError: If the struct is empty.
|
||||
"""
|
||||
self._cls = cls
|
||||
self.__name__ = f"struct::{cls.__name__}"
|
||||
# Get the class annotations
|
||||
self._annotations = cls.__annotations__
|
||||
# Create a dictionary to store the offsets
|
||||
@@ -6780,12 +6963,10 @@ class struct:
|
||||
raise TypeError("Empty struct is not supported!")
|
||||
for name, object in self._annotations.items():
|
||||
# get alignment of object
|
||||
def alignof(object, default: int = 1):
|
||||
return getattr(object, "_struct_alignment_", default)
|
||||
|
||||
# alignment for the next offset
|
||||
def align_offset(offset, align):
|
||||
return (offset + (align - 1)) & ~(align - 1)
|
||||
sub_align = 1
|
||||
if isinstance(object, struct._AlignMeta):
|
||||
sub_align = object.align
|
||||
object = object.dtype
|
||||
|
||||
# switch addition order to support dynamic size
|
||||
def add_offset(val):
|
||||
@@ -6793,35 +6974,37 @@ class struct:
|
||||
|
||||
# size of scalar
|
||||
if struct._is_scalar_type(object):
|
||||
dtype_size = object.width // 8
|
||||
sub_align = alignof(object, dtype_size)
|
||||
offset = align_offset(offset, sub_align)
|
||||
dtype_size = max(1, object.width // 8)
|
||||
sub_align = max(dtype_size, sub_align)
|
||||
offset = self.align_offset(offset, sub_align)
|
||||
self._offsets[name] = offset
|
||||
offset = add_offset(dtype_size)
|
||||
# size of array is size_in_bytes, alignment is elem_size
|
||||
elif isinstance(object, struct._MemRangeMeta):
|
||||
if object.size == 0:
|
||||
continue # skip empty array
|
||||
sub_align = alignof(object, max(1, object.elem_width // 8))
|
||||
offset = align_offset(offset, sub_align)
|
||||
# Allow empty array as a free marker-only struct member.
|
||||
# Use max(sub_align, ) because we might have in the future some
|
||||
# object.elem_width less than 8, such as fp4, bit and others,
|
||||
# and align_offset() does not support an alignment of 0.
|
||||
sub_align = max(object.elem_width // 8, sub_align)
|
||||
offset = self.align_offset(offset, sub_align)
|
||||
self._offsets[name] = offset
|
||||
offset = add_offset(object.size_in_bytes)
|
||||
# size of struct
|
||||
elif isinstance(object, struct):
|
||||
sub_align = max(object.__alignof__(), alignof(object))
|
||||
offset = align_offset(offset, sub_align)
|
||||
sub_align = max(object.__alignof__(), sub_align)
|
||||
offset = self.align_offset(offset, sub_align)
|
||||
self._offsets[name] = offset
|
||||
offset = add_offset(object.__sizeof__())
|
||||
else:
|
||||
raise TypeError(
|
||||
f"Struct element only support sturct/array/base_dsl scalar, "
|
||||
f"Struct element only support struct/array/base_dsl scalar, "
|
||||
f"but got {object}"
|
||||
)
|
||||
# Total aligment determined by the strictest requirement
|
||||
alignment = max(alignment, sub_align)
|
||||
# Total size determined by alignment
|
||||
self._align_of = alignment
|
||||
self._size_of = align_offset(offset, alignment)
|
||||
self._size_of = self.align_offset(offset, alignment)
|
||||
|
||||
# create the __init__ method for decorated struct
|
||||
def __call__(self, base: Any) -> None:
|
||||
@@ -6840,6 +7023,8 @@ class struct:
|
||||
setattr(cls, "_base", base)
|
||||
for name, off in self._offsets.items():
|
||||
obj = self._annotations[name]
|
||||
if isinstance(obj, struct._AlignMeta):
|
||||
obj = obj.dtype
|
||||
if struct._is_scalar_type(obj):
|
||||
new_obj = recast_ptr(base + off, dtype=obj)
|
||||
setattr(cls, name, new_obj)
|
||||
@@ -6851,7 +7036,7 @@ class struct:
|
||||
setattr(cls, name, new_obj)
|
||||
else:
|
||||
raise TypeError(
|
||||
f"Struct element only support sturct/array/base_dsl scalar, "
|
||||
f"Struct element only support struct/array/base_dsl scalar, "
|
||||
f"but got {obj}"
|
||||
)
|
||||
return cls
|
||||
@@ -6872,3 +7057,14 @@ class struct:
|
||||
# get alignment
|
||||
def __alignof__(self) -> int:
|
||||
return self._align_of
|
||||
|
||||
# util func for aligning offset
|
||||
@staticmethod
|
||||
def align_offset(offset, align):
|
||||
"""
|
||||
Return the round-up offset up to the next multiple of align.
|
||||
"""
|
||||
assert align > 0 and not (
|
||||
align & (align - 1)
|
||||
), "align should be a strictly positive power of 2."
|
||||
return (offset + (align - 1)) & ~(align - 1)
|
||||
|
||||
@@ -10,16 +10,53 @@
|
||||
# is strictly prohibited.
|
||||
|
||||
from .core import TensorSSA
|
||||
from .typing import Numeric
|
||||
from cutlass._mlir.dialects import math, arith
|
||||
|
||||
from typing import Callable, Union
|
||||
|
||||
def acos(a: TensorSSA) -> TensorSSA:
|
||||
|
||||
def _math_op(func: Callable, fastmath: bool, *args, **kwargs):
|
||||
"""Dispatch the function to either a TensorSSA or a Numeric(Float).
|
||||
|
||||
:param func: The function to dispatch
|
||||
:param args: The input tensor or scalar
|
||||
:param kwargs: The input tensor or scalar
|
||||
"""
|
||||
arg_type = type(args[0])
|
||||
for arg in args:
|
||||
if not isinstance(arg, TensorSSA) and (
|
||||
not isinstance(arg, Numeric) or not type(arg).is_float
|
||||
):
|
||||
raise TypeError(
|
||||
f"Expected a TensorSSA or Numeric(Float), but got {type(arg)}"
|
||||
)
|
||||
if not isinstance(arg, arg_type):
|
||||
raise TypeError(
|
||||
f"Expected all inputs to be of type {arg_type}, but got {type(arg)}"
|
||||
)
|
||||
|
||||
fastmath_flag = arith.FastMathFlags.fast if fastmath else arith.FastMathFlags.none
|
||||
if isinstance(args[0], TensorSSA):
|
||||
return TensorSSA(
|
||||
func(*args, fastmath=fastmath_flag), args[0].shape, args[0].dtype
|
||||
)
|
||||
else:
|
||||
args = [a.ir_value() for a in args]
|
||||
return func(*args, fastmath=fastmath_flag)
|
||||
|
||||
|
||||
def acos(
|
||||
a: Union[TensorSSA, Numeric], fastmath: bool = False
|
||||
) -> Union[TensorSSA, Numeric]:
|
||||
"""Compute element-wise arc cosine of the input tensor.
|
||||
|
||||
:param a: Input tensor
|
||||
:type a: TensorSSA
|
||||
:type a: Union[TensorSSA, Numeric]
|
||||
:param fastmath: Enable fast math optimizations, defaults to False
|
||||
:type fastmath: bool, optional
|
||||
:return: Tensor containing the arc cosine of each element in input tensor
|
||||
:rtype: TensorSSA
|
||||
:rtype: Union[TensorSSA, Numeric]
|
||||
|
||||
Example:
|
||||
|
||||
@@ -29,16 +66,20 @@ def acos(a: TensorSSA) -> TensorSSA:
|
||||
y = x.load() # Load values
|
||||
z = acos(y) # Compute arc cosine
|
||||
"""
|
||||
return TensorSSA(math.acos(a, fastmath=arith.FastMathFlags.none), a.shape, a.dtype)
|
||||
return _math_op(math.acos, fastmath, a)
|
||||
|
||||
|
||||
def asin(a: TensorSSA) -> TensorSSA:
|
||||
def asin(
|
||||
a: Union[TensorSSA, Numeric], fastmath: bool = False
|
||||
) -> Union[TensorSSA, Numeric]:
|
||||
"""Compute element-wise arc sine of the input tensor.
|
||||
|
||||
:param a: Input tensor
|
||||
:type a: TensorSSA
|
||||
:type a: Union[TensorSSA, Numeric]
|
||||
:param fastmath: Enable fast math optimizations, defaults to False
|
||||
:type fastmath: bool, optional
|
||||
:return: Tensor containing the arc sine of each element in input tensor
|
||||
:rtype: TensorSSA
|
||||
:rtype: Union[TensorSSA, Numeric]
|
||||
|
||||
Example:
|
||||
|
||||
@@ -48,18 +89,20 @@ def asin(a: TensorSSA) -> TensorSSA:
|
||||
y = x.load() # Load values
|
||||
z = asin(y) # Compute arc sine
|
||||
"""
|
||||
return TensorSSA(math.asin(a, fastmath=arith.FastMathFlags.none), a.shape, a.dtype)
|
||||
return _math_op(math.asin, fastmath, a)
|
||||
|
||||
|
||||
def atan(a: TensorSSA, fastmath: bool = False) -> TensorSSA:
|
||||
def atan(
|
||||
a: Union[TensorSSA, Numeric], fastmath: bool = False
|
||||
) -> Union[TensorSSA, Numeric]:
|
||||
"""Compute element-wise arc tangent of the input tensor.
|
||||
|
||||
:param a: Input tensor
|
||||
:type a: TensorSSA
|
||||
:type a: Union[TensorSSA, Numeric]
|
||||
:param fastmath: Enable fast math optimizations, defaults to False
|
||||
:type fastmath: bool, optional
|
||||
:return: Tensor containing the arc tangent of each element in input tensor
|
||||
:rtype: TensorSSA
|
||||
:rtype: Union[TensorSSA, Numeric]
|
||||
|
||||
Example:
|
||||
|
||||
@@ -70,23 +113,25 @@ def atan(a: TensorSSA, fastmath: bool = False) -> TensorSSA:
|
||||
z = atan(y) # Compute arc tangent
|
||||
"""
|
||||
raise NotImplementedError("atan is not implemented")
|
||||
return TensorSSA(math.atan(a, fastmath=arith.FastMathFlags.none), a.shape, a.dtype)
|
||||
return _math_op(math.atan, fastmath, a)
|
||||
|
||||
|
||||
def atan2(a: TensorSSA, b: TensorSSA, fastmath: bool = False) -> TensorSSA:
|
||||
def atan2(
|
||||
a: Union[TensorSSA, Numeric], b: Union[TensorSSA, Numeric], fastmath: bool = False
|
||||
) -> Union[TensorSSA, Numeric]:
|
||||
"""Compute element-wise arc tangent of two tensors.
|
||||
|
||||
Computes atan2(a, b) element-wise. The function atan2(a, b) is the angle in radians
|
||||
between the positive x-axis and the point given by the coordinates (b, a).
|
||||
|
||||
:param a: First input tensor (y-coordinates)
|
||||
:type a: TensorSSA
|
||||
:type a: Union[TensorSSA, Numeric]
|
||||
:param b: Second input tensor (x-coordinates)
|
||||
:type b: TensorSSA
|
||||
:type b: Union[TensorSSA, Numeric]
|
||||
:param fastmath: Enable fast math optimizations, defaults to False
|
||||
:type fastmath: bool, optional
|
||||
:return: Tensor containing the arc tangent of a/b element-wise
|
||||
:rtype: TensorSSA
|
||||
:rtype: Union[TensorSSA, Numeric]
|
||||
|
||||
Example:
|
||||
|
||||
@@ -96,20 +141,20 @@ def atan2(a: TensorSSA, b: TensorSSA, fastmath: bool = False) -> TensorSSA:
|
||||
x = cute.make_fragment(ptr2, layout).load() # x coordinates
|
||||
theta = atan2(y, x) # Compute angles
|
||||
"""
|
||||
return TensorSSA(
|
||||
math.atan2(a, b, fastmath=arith.FastMathFlags.none), a.shape, a.dtype
|
||||
)
|
||||
return _math_op(math.atan2, fastmath, a, b)
|
||||
|
||||
|
||||
def cos(a: TensorSSA, fastmath: bool = False) -> TensorSSA:
|
||||
def cos(
|
||||
a: Union[TensorSSA, Numeric], fastmath: bool = False
|
||||
) -> Union[TensorSSA, Numeric]:
|
||||
"""Compute element-wise cosine of the input tensor.
|
||||
|
||||
:param a: Input tensor (in radians)
|
||||
:type a: TensorSSA
|
||||
:type a: Union[TensorSSA, Numeric]
|
||||
:param fastmath: Enable fast math optimizations, defaults to False
|
||||
:type fastmath: bool, optional
|
||||
:return: Tensor containing the cosine of each element
|
||||
:rtype: TensorSSA
|
||||
:rtype: Union[TensorSSA, Numeric]
|
||||
|
||||
Example:
|
||||
|
||||
@@ -119,21 +164,23 @@ def cos(a: TensorSSA, fastmath: bool = False) -> TensorSSA:
|
||||
y = x.load() # Load values
|
||||
z = cos(y) # Compute cosine
|
||||
"""
|
||||
return TensorSSA(math.cos(a, fastmath=arith.FastMathFlags.none), a.shape, a.dtype)
|
||||
return _math_op(math.cos, fastmath, a)
|
||||
|
||||
|
||||
def erf(a: TensorSSA, fastmath: bool = False) -> TensorSSA:
|
||||
def erf(
|
||||
a: Union[TensorSSA, Numeric], fastmath: bool = False
|
||||
) -> Union[TensorSSA, Numeric]:
|
||||
"""Compute element-wise error function of the input tensor.
|
||||
|
||||
The error function is defined as:
|
||||
erf(x) = 2/√π ∫[0 to x] exp(-t²) dt
|
||||
|
||||
:param a: Input tensor
|
||||
:type a: TensorSSA
|
||||
:type a: Union[TensorSSA, Numeric]
|
||||
:param fastmath: Enable fast math optimizations, defaults to False
|
||||
:type fastmath: bool, optional
|
||||
:return: Tensor containing the error function value for each element
|
||||
:rtype: TensorSSA
|
||||
:rtype: Union[TensorSSA, Numeric]
|
||||
|
||||
Example:
|
||||
|
||||
@@ -143,18 +190,43 @@ def erf(a: TensorSSA, fastmath: bool = False) -> TensorSSA:
|
||||
y = x.load() # Load values
|
||||
z = erf(y) # Compute error function
|
||||
"""
|
||||
return TensorSSA(math.erf(a, fastmath=arith.FastMathFlags.none), a.shape, a.dtype)
|
||||
return _math_op(math.erf, fastmath, a)
|
||||
|
||||
|
||||
def exp2(a: TensorSSA, fastmath: bool = False) -> TensorSSA:
|
||||
def exp(
|
||||
a: Union[TensorSSA, Numeric], fastmath: bool = False
|
||||
) -> Union[TensorSSA, Numeric]:
|
||||
"""Compute element-wise exponential of the input tensor.
|
||||
|
||||
:param a: Input tensor
|
||||
:type a: Union[TensorSSA, Numeric]
|
||||
:param fastmath: Enable fast math optimizations, defaults to False
|
||||
:type fastmath: bool, optional
|
||||
:return: Tensor containing the exponential of each element
|
||||
:rtype: Union[TensorSSA, Numeric]
|
||||
|
||||
Example:
|
||||
|
||||
.. code-block::
|
||||
|
||||
x = cute.make_fragment(layout) # Create tensor
|
||||
y = x.load() # Load values
|
||||
z = exp(y) # Compute exponential
|
||||
"""
|
||||
return _math_op(math.exp, fastmath, a)
|
||||
|
||||
|
||||
def exp2(
|
||||
a: Union[TensorSSA, Numeric], fastmath: bool = False
|
||||
) -> Union[TensorSSA, Numeric]:
|
||||
"""Compute element-wise base-2 exponential of the input tensor.
|
||||
|
||||
:param a: Input tensor
|
||||
:type a: TensorSSA
|
||||
:type a: Union[TensorSSA, Numeric]
|
||||
:param fastmath: Enable fast math optimizations, defaults to False
|
||||
:type fastmath: bool, optional
|
||||
:return: Tensor containing 2 raised to the power of each element
|
||||
:rtype: TensorSSA
|
||||
:rtype: Union[TensorSSA, Numeric]
|
||||
|
||||
Example:
|
||||
|
||||
@@ -164,18 +236,20 @@ def exp2(a: TensorSSA, fastmath: bool = False) -> TensorSSA:
|
||||
y = x.load() # Load values
|
||||
z = exp2(y) # Compute 2^x
|
||||
"""
|
||||
return TensorSSA(math.exp2(a, fastmath=arith.FastMathFlags.none), a.shape, a.dtype)
|
||||
return _math_op(math.exp2, fastmath, a)
|
||||
|
||||
|
||||
def log(a: TensorSSA, fastmath: bool = False) -> TensorSSA:
|
||||
def log(
|
||||
a: Union[TensorSSA, Numeric], fastmath: bool = False
|
||||
) -> Union[TensorSSA, Numeric]:
|
||||
"""Compute element-wise natural logarithm of the input tensor.
|
||||
|
||||
:param a: Input tensor
|
||||
:type a: TensorSSA
|
||||
:type a: Union[TensorSSA, Numeric]
|
||||
:param fastmath: Enable fast math optimizations, defaults to False
|
||||
:type fastmath: bool, optional
|
||||
:return: Tensor containing the natural logarithm of each element
|
||||
:rtype: TensorSSA
|
||||
:rtype: Union[TensorSSA, Numeric]
|
||||
|
||||
Example:
|
||||
|
||||
@@ -185,18 +259,20 @@ def log(a: TensorSSA, fastmath: bool = False) -> TensorSSA:
|
||||
y = x.load() # Load values
|
||||
z = log(y) # Compute natural logarithm
|
||||
"""
|
||||
return TensorSSA(math.log(a, fastmath=arith.FastMathFlags.none), a.shape, a.dtype)
|
||||
return _math_op(math.log, fastmath, a)
|
||||
|
||||
|
||||
def log2(a: TensorSSA, fastmath: bool = False) -> TensorSSA:
|
||||
def log2(
|
||||
a: Union[TensorSSA, Numeric], fastmath: bool = False
|
||||
) -> Union[TensorSSA, Numeric]:
|
||||
"""Compute element-wise base-2 logarithm of the input tensor.
|
||||
|
||||
:param a: Input tensor
|
||||
:type a: TensorSSA
|
||||
:type a: Union[TensorSSA, Numeric]
|
||||
:param fastmath: Enable fast math optimizations, defaults to False
|
||||
:type fastmath: bool, optional
|
||||
:return: Tensor containing the base-2 logarithm of each element
|
||||
:rtype: TensorSSA
|
||||
:rtype: Union[TensorSSA, Numeric]
|
||||
|
||||
Example:
|
||||
|
||||
@@ -206,18 +282,20 @@ def log2(a: TensorSSA, fastmath: bool = False) -> TensorSSA:
|
||||
y = x.load() # Load values
|
||||
z = log2(y) # Compute log base 2
|
||||
"""
|
||||
return TensorSSA(math.log2(a, fastmath=arith.FastMathFlags.none), a.shape, a.dtype)
|
||||
return _math_op(math.log2, fastmath, a)
|
||||
|
||||
|
||||
def log10(a: TensorSSA, fastmath: bool = False) -> TensorSSA:
|
||||
def log10(
|
||||
a: Union[TensorSSA, Numeric], fastmath: bool = False
|
||||
) -> Union[TensorSSA, Numeric]:
|
||||
"""Compute element-wise base-10 logarithm of the input tensor.
|
||||
|
||||
:param a: Input tensor
|
||||
:type a: TensorSSA
|
||||
:type a: Union[TensorSSA, Numeric]
|
||||
:param fastmath: Enable fast math optimizations, defaults to False
|
||||
:type fastmath: bool, optional
|
||||
:return: Tensor containing the base-10 logarithm of each element
|
||||
:rtype: TensorSSA
|
||||
:rtype: Union[TensorSSA, Numeric]
|
||||
|
||||
Example:
|
||||
|
||||
@@ -227,20 +305,22 @@ def log10(a: TensorSSA, fastmath: bool = False) -> TensorSSA:
|
||||
y = x.load() # Load values
|
||||
z = log10(y) # Compute log base 10
|
||||
"""
|
||||
return TensorSSA(math.log10(a, fastmath=arith.FastMathFlags.none), a.shape, a.dtype)
|
||||
return _math_op(math.log10, fastmath, a)
|
||||
|
||||
|
||||
def rsqrt(a: TensorSSA, fastmath: bool = False) -> TensorSSA:
|
||||
def rsqrt(
|
||||
a: Union[TensorSSA, Numeric], fastmath: bool = False
|
||||
) -> Union[TensorSSA, Numeric]:
|
||||
"""Compute element-wise reciprocal square root of the input tensor.
|
||||
|
||||
Computes 1/√x element-wise.
|
||||
|
||||
:param a: Input tensor
|
||||
:type a: TensorSSA
|
||||
:type a: Union[TensorSSA, Numeric]
|
||||
:param fastmath: Enable fast math optimizations, defaults to False
|
||||
:type fastmath: bool, optional
|
||||
:return: Tensor containing the reciprocal square root of each element
|
||||
:rtype: TensorSSA
|
||||
:rtype: Union[TensorSSA, Numeric]
|
||||
|
||||
Example:
|
||||
|
||||
@@ -250,18 +330,20 @@ def rsqrt(a: TensorSSA, fastmath: bool = False) -> TensorSSA:
|
||||
y = x.load() # Load values
|
||||
z = rsqrt(y) # Compute 1/√x
|
||||
"""
|
||||
return TensorSSA(math.rsqrt(a, fastmath=arith.FastMathFlags.none), a.shape, a.dtype)
|
||||
return _math_op(math.rsqrt, fastmath, a)
|
||||
|
||||
|
||||
def sin(a: TensorSSA, fastmath: bool = False) -> TensorSSA:
|
||||
def sin(
|
||||
a: Union[TensorSSA, Numeric], fastmath: bool = False
|
||||
) -> Union[TensorSSA, Numeric]:
|
||||
"""Compute element-wise sine of the input tensor.
|
||||
|
||||
:param a: Input tensor (in radians)
|
||||
:type a: TensorSSA
|
||||
:type a: Union[TensorSSA, Numeric]
|
||||
:param fastmath: Enable fast math optimizations, defaults to False
|
||||
:type fastmath: bool, optional
|
||||
:return: Tensor containing the sine of each element
|
||||
:rtype: TensorSSA
|
||||
:rtype: Union[TensorSSA, Numeric]
|
||||
|
||||
Example:
|
||||
|
||||
@@ -271,18 +353,20 @@ def sin(a: TensorSSA, fastmath: bool = False) -> TensorSSA:
|
||||
y = x.load() # Load values
|
||||
z = sin(y) # Compute sine
|
||||
"""
|
||||
return TensorSSA(math.sin(a, fastmath=arith.FastMathFlags.none), a.shape, a.dtype)
|
||||
return _math_op(math.sin, fastmath, a)
|
||||
|
||||
|
||||
def sqrt(a: TensorSSA, fastmath: bool = False) -> TensorSSA:
|
||||
def sqrt(
|
||||
a: Union[TensorSSA, Numeric], fastmath: bool = False
|
||||
) -> Union[TensorSSA, Numeric]:
|
||||
"""Compute element-wise square root of the input tensor.
|
||||
|
||||
:param a: Input tensor
|
||||
:type a: TensorSSA
|
||||
:type a: Union[TensorSSA, Numeric]
|
||||
:param fastmath: Enable fast math optimizations, defaults to False
|
||||
:type fastmath: bool, optional
|
||||
:return: Tensor containing the square root of each element
|
||||
:rtype: TensorSSA
|
||||
:rtype: Union[TensorSSA, Numeric]
|
||||
|
||||
Example:
|
||||
|
||||
@@ -292,16 +376,20 @@ def sqrt(a: TensorSSA, fastmath: bool = False) -> TensorSSA:
|
||||
y = x.load() # Load values
|
||||
z = sqrt(y) # Compute square root
|
||||
"""
|
||||
return TensorSSA(math.sqrt(a, fastmath=arith.FastMathFlags.none), a.shape, a.dtype)
|
||||
return _math_op(math.sqrt, fastmath, a)
|
||||
|
||||
|
||||
def tan(a: TensorSSA) -> TensorSSA:
|
||||
def tan(
|
||||
a: Union[TensorSSA, Numeric], fastmath: bool = False
|
||||
) -> Union[TensorSSA, Numeric]:
|
||||
"""Compute element-wise tangent of the input tensor.
|
||||
|
||||
:param a: Input tensor (in radians)
|
||||
:type a: TensorSSA
|
||||
:type a: Union[TensorSSA, Numeric]
|
||||
:param fastmath: Enable fast math optimizations, defaults to False
|
||||
:type fastmath: bool, optional
|
||||
:return: Tensor containing the tangent of each element
|
||||
:rtype: TensorSSA
|
||||
:rtype: Union[TensorSSA, Numeric]
|
||||
|
||||
Example:
|
||||
|
||||
@@ -311,18 +399,20 @@ def tan(a: TensorSSA) -> TensorSSA:
|
||||
y = x.load() # Load values
|
||||
z = tan(y) # Compute tangent
|
||||
"""
|
||||
return TensorSSA(math.tan(a, fastmath=arith.FastMathFlags.none), a.shape, a.dtype)
|
||||
return _math_op(math.tan, fastmath, a)
|
||||
|
||||
|
||||
def tanh(a: TensorSSA, fastmath: bool = False) -> TensorSSA:
|
||||
def tanh(
|
||||
a: Union[TensorSSA, Numeric], fastmath: bool = False
|
||||
) -> Union[TensorSSA, Numeric]:
|
||||
"""Compute element-wise hyperbolic tangent of the input tensor.
|
||||
|
||||
:param a: Input tensor
|
||||
:type a: TensorSSA
|
||||
:type a: Union[TensorSSA, Numeric]
|
||||
:param fastmath: Enable fast math optimizations, defaults to False
|
||||
:type fastmath: bool, optional
|
||||
:return: Tensor containing the hyperbolic tangent of each element
|
||||
:rtype: TensorSSA
|
||||
:rtype: Union[TensorSSA, Numeric]
|
||||
|
||||
Example:
|
||||
|
||||
@@ -332,7 +422,7 @@ def tanh(a: TensorSSA, fastmath: bool = False) -> TensorSSA:
|
||||
y = x.load() # Load values
|
||||
z = tanh(y) # Compute hyperbolic tangent
|
||||
"""
|
||||
return TensorSSA(math.tanh(a, fastmath=arith.FastMathFlags.none), a.shape, a.dtype)
|
||||
return _math_op(math.tanh, fastmath, a)
|
||||
|
||||
|
||||
__all__ = [
|
||||
@@ -342,6 +432,7 @@ __all__ = [
|
||||
"atan2",
|
||||
"cos",
|
||||
"erf",
|
||||
"exp",
|
||||
"exp2",
|
||||
"log",
|
||||
"log10",
|
||||
|
||||
@@ -8,7 +8,7 @@
|
||||
# Any use, reproduction, disclosure, or distribution of this software
|
||||
# and related documentation outside the scope permitted by the EULA
|
||||
# is strictly prohibited.
|
||||
|
||||
import enum
|
||||
from dataclasses import dataclass
|
||||
from typing import Type, Optional
|
||||
|
||||
@@ -101,6 +101,42 @@ class MmaUniversalTrait(core.Trait):
|
||||
####################################################################################################
|
||||
|
||||
|
||||
class MemoryOrder(enum.Enum):
|
||||
WEAK = _cute_ir.MemOrderKind.WEAK
|
||||
RELAXED = _cute_ir.MemOrderKind.RELAXED
|
||||
ACQUIRE = _cute_ir.MemOrderKind.ACQUIRE
|
||||
RELEASE = _cute_ir.MemOrderKind.RELEASE
|
||||
ACQ_REL = _cute_ir.MemOrderKind.ACQ_REL
|
||||
SC = _cute_ir.MemOrderKind.SC
|
||||
MMIO = _cute_ir.MemOrderKind.MMIO
|
||||
CONSTANT = _cute_ir.MemOrderKind.CONSTANT
|
||||
VOLATILE = _cute_ir.MemOrderKind.VOLATILE
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"{self.__class__.__name__}.{self.name}"
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<{self.__class__.__name__}.{self.name}>"
|
||||
|
||||
def _to_ir(self) -> _cute_ir.MemOrderKind:
|
||||
return self.value
|
||||
|
||||
|
||||
class MemoryScope(enum.Enum):
|
||||
CTA = _cute_ir.MemScopeKind.CTA
|
||||
CLUSTER = _cute_ir.MemScopeKind.CLUSTER
|
||||
GPU = _cute_ir.MemScopeKind.GPU
|
||||
SYS = _cute_ir.MemScopeKind.SYS
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"{self.__class__.__name__}.{self.name}"
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<{self.__class__.__name__}.{self.name}>"
|
||||
|
||||
def _to_ir(self) -> _cute_ir.MemScopeKind:
|
||||
return self.value
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class CopyUniversalOp(core.CopyOp):
|
||||
"""
|
||||
@@ -133,13 +169,18 @@ class CopyUniversalOp(core.CopyOp):
|
||||
**kwargs,
|
||||
) -> "CopyUniversalTrait":
|
||||
num_bits_per_copy = kwargs.get("num_bits_per_copy", 0)
|
||||
memory_order = kwargs.get("memory_order", MemoryOrder.WEAK)
|
||||
memory_scope = kwargs.get("memory_scope", MemoryScope.CTA)
|
||||
if not isinstance(num_bits_per_copy, int) or (num_bits_per_copy < 0):
|
||||
raise ValueError(
|
||||
"expects a 'num_bits_per_copy' kw argument of type int that is non-negative "
|
||||
f"when creating a copy Atom for {self.__class__.__name__}"
|
||||
)
|
||||
ty = _cute_nvgpu_ir.CopyAtomSIMTSyncCopyType.get(
|
||||
copy_internal_type.mlir_type, num_bits_per_copy
|
||||
copy_internal_type.mlir_type,
|
||||
num_bits_per_copy,
|
||||
memory_order._to_ir(),
|
||||
memory_scope._to_ir(),
|
||||
)
|
||||
return CopyUniversalTrait(_cute_ir.atom(ty, loc=loc, ip=ip))
|
||||
|
||||
|
||||
@@ -23,6 +23,7 @@ __all__ = [
|
||||
"CopyBulkTensorTileG2SOp",
|
||||
"CopyBulkTensorTileG2SMulticastOp",
|
||||
"CopyBulkTensorTileS2GOp",
|
||||
"CopyReduceBulkTensorTileS2GOp",
|
||||
#
|
||||
# helpers.py
|
||||
#
|
||||
|
||||
@@ -19,7 +19,7 @@ import cutlass._mlir.dialects.cute as _cute_ir
|
||||
import cutlass._mlir.dialects.cute_nvgpu as _cute_nvgpu_ir
|
||||
from cutlass._mlir import ir
|
||||
|
||||
from ...core import CopyOp, Trait
|
||||
from ...core import CopyOp, Trait, ReductionOp
|
||||
from ...typing import Int16, Pointer, Integer, Numeric
|
||||
from ..common import OpError
|
||||
from ..tcgen05.mma import CtaGroup
|
||||
@@ -80,6 +80,12 @@ class CopyG2SOp(CopyOp):
|
||||
**kwargs,
|
||||
) -> "CopyG2STrait":
|
||||
num_bits_per_copy = kwargs.get("num_bits_per_copy", None)
|
||||
# Verify that the user provided enum values
|
||||
if not isinstance(self.cache_mode, LoadCacheMode):
|
||||
raise OpError(
|
||||
self,
|
||||
"expects the 'cache_mode' Op parameter to be a LoadCacheMode instance",
|
||||
)
|
||||
if not isinstance(num_bits_per_copy, int) or (num_bits_per_copy <= 0):
|
||||
raise ValueError(
|
||||
"expects a 'num_bits_per_copy' kw argument of type int that is positive "
|
||||
@@ -330,7 +336,7 @@ class CopyBulkTensorTileG2SMulticastNonExecTrait(Trait):
|
||||
@dataclass(frozen=True)
|
||||
class CopyBulkTensorTileS2GOp(CopyOp):
|
||||
"""
|
||||
Bulk tensor asynchrnous SMEM to GMEM Copy Operation using the TMA unit.
|
||||
Bulk tensor asynchronous SMEM to GMEM Copy Operation using the TMA unit.
|
||||
|
||||
See the `PTX documentation <https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-cp-async-bulk-tensor>`__.
|
||||
This Operation uses TMA in the ``.tile`` mode.
|
||||
@@ -379,3 +385,87 @@ class CopyBulkTensorTileS2GTrait(Trait):
|
||||
exec_value, attr, tma_desc_ptr.value, loc=loc, ip=ip
|
||||
)
|
||||
return exec_value
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class CopyReduceBulkTensorTileS2GOp(CopyOp):
|
||||
"""
|
||||
Bulk tensor asynchronous SMEM to GMEM Reduction Operation using the TMA unit.
|
||||
|
||||
See the `PTX documentation <https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-cp-reduce-async-bulk>`__.
|
||||
This Operation uses TMA in the ``.tile`` mode.
|
||||
"""
|
||||
|
||||
reduction_kind: ReductionOp = ReductionOp.ADD
|
||||
|
||||
admissible_archs = [
|
||||
"sm_90",
|
||||
"sm_90a",
|
||||
"sm_100a",
|
||||
"sm_100f",
|
||||
]
|
||||
|
||||
def __post__init__(self):
|
||||
# Arch verification
|
||||
arch = CuTeDSL.__get_dsl().envar.arch
|
||||
if arch not in self.admissible_archs:
|
||||
raise OpError(
|
||||
self,
|
||||
f"expects arch to be one of {self.admissible_archs}, but got {arch}",
|
||||
suggestion="Ensure env CUTE_DSL_ARCH matches your GPU architecture",
|
||||
)
|
||||
|
||||
def __str__(self) -> str:
|
||||
return "cp.async SMEM -> GMEM bulk tensor reduction Operation"
|
||||
|
||||
def _make_trait(
|
||||
self, copy_internal_type: Type[Numeric], *, loc=None, ip=None, **kwargs
|
||||
) -> "CopyReduceBulkTensorTileS2GTrait":
|
||||
raise NotImplementedError(
|
||||
"Use cpasync.make_tiled_tma_atom to obtain a copy Atom for TMA"
|
||||
)
|
||||
|
||||
def _to_ir(self) -> _cute_nvgpu_ir.ReductionKind:
|
||||
if self.reduction_kind == ReductionOp.ADD:
|
||||
return _cute_nvgpu_ir.ReductionKind.ADD
|
||||
elif self.reduction_kind == ReductionOp.MIN:
|
||||
return _cute_nvgpu_ir.ReductionKind.MIN
|
||||
elif self.reduction_kind == ReductionOp.MAX:
|
||||
return _cute_nvgpu_ir.ReductionKind.MAX
|
||||
elif self.reduction_kind == ReductionOp.INC:
|
||||
return _cute_nvgpu_ir.ReductionKind.INC
|
||||
elif self.reduction_kind == ReductionOp.DEC:
|
||||
return _cute_nvgpu_ir.ReductionKind.DEC
|
||||
elif self.reduction_kind == ReductionOp.AND:
|
||||
return _cute_nvgpu_ir.ReductionKind.AND
|
||||
elif self.reduction_kind == ReductionOp.OR:
|
||||
return _cute_nvgpu_ir.ReductionKind.OR
|
||||
elif self.reduction_kind == ReductionOp.XOR:
|
||||
return _cute_nvgpu_ir.ReductionKind.XOR
|
||||
else:
|
||||
assert False, "unrecognized self.reduction_kind"
|
||||
|
||||
|
||||
class CopyReduceBulkTensorTileS2GTrait(Trait):
|
||||
def unpack(self, *, loc=None, ip=None, tma_desc_ptr: Optional[Pointer] = None):
|
||||
"""
|
||||
Custom implementation of unpack for non-executable TMAs.
|
||||
"""
|
||||
exec_value = _cute_nvgpu_ir.atom_make_exec_tma(self.value, loc=loc, ip=ip)
|
||||
if isinstance(tma_desc_ptr, Pointer):
|
||||
attr_str = (
|
||||
f"#cute_nvgpu.atom_copy_field_tmareduce<{TMA_DESC_PTR_FIELD_NAME}>"
|
||||
)
|
||||
attr = ir.Attribute.parse(attr_str)
|
||||
exec_value = _cute_nvgpu_ir.atom_set_value(
|
||||
exec_value, attr, tma_desc_ptr.value, loc=loc, ip=ip
|
||||
)
|
||||
return exec_value
|
||||
|
||||
__all__ = [
|
||||
"LoadCacheMode",
|
||||
"CopyG2SOp",
|
||||
"CopyBulkTensorTileG2SOp",
|
||||
"CopyBulkTensorTileG2SMulticastOp",
|
||||
"CopyBulkTensorTileS2GOp",
|
||||
"CopyReduceBulkTensorTileS2GOp",
|
||||
]
|
||||
|
||||
@@ -22,9 +22,11 @@ from .copy import (
|
||||
CopyBulkTensorTileG2SOp,
|
||||
CopyBulkTensorTileG2SMulticastOp,
|
||||
CopyBulkTensorTileS2GOp,
|
||||
CopyReduceBulkTensorTileS2GOp,
|
||||
CopyBulkTensorTileG2SNonExecTrait,
|
||||
CopyBulkTensorTileG2SMulticastNonExecTrait,
|
||||
CopyBulkTensorTileS2GTrait,
|
||||
CopyReduceBulkTensorTileS2GTrait,
|
||||
)
|
||||
|
||||
|
||||
@@ -34,6 +36,7 @@ def make_tiled_tma_atom(
|
||||
CopyBulkTensorTileG2SOp,
|
||||
CopyBulkTensorTileG2SMulticastOp,
|
||||
CopyBulkTensorTileS2GOp,
|
||||
CopyReduceBulkTensorTileS2GOp,
|
||||
],
|
||||
gmem_tensor: Tensor,
|
||||
smem_layout: Union[Layout, core.ComposedLayout],
|
||||
@@ -67,7 +70,7 @@ def make_tiled_tma_atom(
|
||||
similarly to any other CuTe tensors using the algebra.
|
||||
|
||||
:param op: The Copy Operation to construct an Atom for
|
||||
:type op: Union[CopyBulkTensorTileG2SOp, CopyBulkTensorTileG2SMulticastOp, CopyBulkTensorTileS2GOp]
|
||||
:type op: Union[CopyBulkTensorTileG2SOp, CopyBulkTensorTileG2SMulticastOp, CopyBulkTensorTileS2GOp, CopyReduceBulkTensorTileS2GOp]
|
||||
:param gmem_tensor: The GMEM tensor involved in the Copy
|
||||
:type gmem_tensor: Tensor
|
||||
:param smem_layout: The SMEM layout to construct the Copy Atom for
|
||||
@@ -141,6 +144,17 @@ def make_tiled_tma_atom(
|
||||
ip=ip,
|
||||
)
|
||||
return core.CopyAtom(op, CopyBulkTensorTileS2GTrait(res[0])), res[1]
|
||||
elif isinstance(op, CopyReduceBulkTensorTileS2GOp):
|
||||
res = _cute_nvgpu_ir.atom_make_non_exec_tiled_tma_reduce(
|
||||
gmem_tensor.value,
|
||||
smem_layout,
|
||||
cta_v_map,
|
||||
op._to_ir(),
|
||||
internal_type=internal_type,
|
||||
loc=loc,
|
||||
ip=ip,
|
||||
)
|
||||
return core.CopyAtom(op, CopyReduceBulkTensorTileS2GTrait(res[0])), res[1]
|
||||
else:
|
||||
raise ValueError(f"expects a bulk tensor (TMA) Copy Op, but got {op}")
|
||||
|
||||
|
||||
@@ -21,7 +21,7 @@ from cutlass._mlir import ir
|
||||
import cutlass._mlir.dialects.cute as _cute_ir
|
||||
|
||||
from cutlass.base_dsl.dsl import is_dynamic_expression
|
||||
from cutlass.cutlass_dsl import TensorFormat, JitArgAdapterRegistry
|
||||
from cutlass.cutlass_dsl import JitArgAdapterRegistry
|
||||
|
||||
# Local modules imports
|
||||
from .typing import (
|
||||
@@ -82,42 +82,36 @@ class _Pointer(Pointer):
|
||||
self._dtype = dtype
|
||||
self._addr_space = mem_space
|
||||
|
||||
is_in_device = mem_space == _cute_ir.AddressSpace.gmem
|
||||
if assumed_align is None:
|
||||
if is_in_device:
|
||||
self._assumed_align = 32
|
||||
else:
|
||||
self._assumed_align = dtype.width // 8
|
||||
self._assumed_align = dtype.width // 8
|
||||
else:
|
||||
self._assumed_align = assumed_align
|
||||
|
||||
class PtrDescriptor(ctypes.Structure):
|
||||
"""A ctype descriptor for CuTe memref ptr"""
|
||||
|
||||
_fields_ = [("ptr", ctypes.c_void_p)]
|
||||
|
||||
def __str__(self):
|
||||
return f"0x{self.ptr:016x}"
|
||||
|
||||
self._desc = PtrDescriptor(int(self._pointer))
|
||||
self._c_pointer = ctypes.cast(ctypes.pointer(self._desc), ctypes.c_void_p)
|
||||
self._c_pointer = None
|
||||
assert (
|
||||
self._desc.ptr % self._assumed_align == 0
|
||||
int(self._pointer) % self._assumed_align == 0
|
||||
), f"pointer must be {self._assumed_align} bytes aligned"
|
||||
|
||||
def size_in_bytes(self) -> int:
|
||||
self._desc = ctypes.c_void_p(int(self._pointer))
|
||||
return ctypes.sizeof(self._desc)
|
||||
|
||||
def __get_mlir_types__(self):
|
||||
return [self.mlir_type]
|
||||
|
||||
def __c_pointers__(self):
|
||||
if self._c_pointer is None:
|
||||
self._desc = ctypes.c_void_p(int(self._pointer))
|
||||
self._c_pointer = ctypes.addressof(self._desc)
|
||||
return [self._c_pointer]
|
||||
|
||||
def __new_from_mlir_values__(self, values):
|
||||
assert len(values) == 1
|
||||
return values[0]
|
||||
|
||||
def __extract_mlir_values__(self):
|
||||
return [self._c_pointer]
|
||||
|
||||
# Move mlir Type out of __init__ to decouple with mlir Context
|
||||
@property
|
||||
def mlir_type(self) -> ir.Type:
|
||||
@@ -145,7 +139,7 @@ class _Pointer(Pointer):
|
||||
return False
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"Ptr<0x{self._desc.ptr:016x}@{self._addr_space}>"
|
||||
return f"Ptr<0x{int(self._pointer):016x}@{self._addr_space}>"
|
||||
|
||||
def __repr__(self):
|
||||
return self.__str__()
|
||||
|
||||
@@ -31,9 +31,12 @@ from .helpers import (
|
||||
|
||||
from .sm90 import (
|
||||
PipelineAsync,
|
||||
PipelineCpAsync,
|
||||
PipelineTmaAsync,
|
||||
PipelineTmaMultiConsumersAsync,
|
||||
PipelineTmaStore,
|
||||
PipelineProducer,
|
||||
PipelineConsumer,
|
||||
)
|
||||
|
||||
from .sm100 import (
|
||||
@@ -53,10 +56,13 @@ __all__ = [
|
||||
"PipelineUserType",
|
||||
"PipelineState",
|
||||
"PipelineAsync",
|
||||
"PipelineCpAsync",
|
||||
"PipelineTmaAsync",
|
||||
"PipelineTmaUmma",
|
||||
"PipelineTmaMultiConsumersAsync",
|
||||
"PipelineAsyncUmma",
|
||||
"PipelineUmmaAsync",
|
||||
"PipelineTmaStore",
|
||||
"PipelineProducer",
|
||||
"PipelineConsumer",
|
||||
]
|
||||
|
||||
@@ -89,6 +89,8 @@ class PipelineOp(enum.Enum):
|
||||
TmaStore = enum.auto()
|
||||
# Composite of multiple PipelineOps
|
||||
Composite = enum.auto()
|
||||
# Async load without TMA
|
||||
AsyncLoad = enum.auto()
|
||||
|
||||
|
||||
def _get_pipeline_op(type_str):
|
||||
@@ -226,6 +228,8 @@ class MbarrierArray(SyncObject):
|
||||
self.arrive_tcgen05mma(index, dst, cta_group)
|
||||
elif self.op_type in [PipelineOp.TmaLoad]:
|
||||
self.arrive_and_expect_tx(index, self.tx_count)
|
||||
elif self.op_type is PipelineOp.AsyncLoad:
|
||||
self.arrive_cp_async_mbarrier(index)
|
||||
else:
|
||||
assert (
|
||||
False
|
||||
@@ -237,6 +241,9 @@ class MbarrierArray(SyncObject):
|
||||
else:
|
||||
cute.arch.mbarrier_arrive(self.get_barrier(index), dst_rank)
|
||||
|
||||
def arrive_cp_async_mbarrier(self, index: int):
|
||||
cute.arch.cp_async_mbarrier_arrive_noinc(self.get_barrier(index))
|
||||
|
||||
def arrive_tcgen05mma(
|
||||
self, index: int, mask: Optional[int], cta_group: cute.nvgpu.tcgen05.CtaGroup
|
||||
) -> None:
|
||||
|
||||
@@ -19,6 +19,7 @@ import cutlass.cute as cute
|
||||
from cutlass.cutlass_dsl import Boolean, if_generate
|
||||
|
||||
from cutlass.pipeline import (
|
||||
Agent,
|
||||
CooperativeGroup,
|
||||
PipelineOp,
|
||||
PipelineState,
|
||||
@@ -106,9 +107,9 @@ class PipelineTmaUmma(PipelineAsync):
|
||||
:type barrier_storage: cute.Pointer
|
||||
:param num_stages: Number of buffer stages for this pipeline
|
||||
:type num_stages: Int32
|
||||
:param producer_group: CooperativeGroup for the producer agent
|
||||
:param producer_group: `CooperativeGroup` for the producer agent
|
||||
:type producer_group: CooperativeGroup
|
||||
:param consumer_group: CooperativeGroup for the consumer agent
|
||||
:param consumer_group: `CooperativeGroup` for the consumer agent
|
||||
:type consumer_group: CooperativeGroup
|
||||
:param tx_count: Number of bytes expected to be written to the transaction barrier for one stage
|
||||
:type tx_count: int
|
||||
@@ -258,9 +259,9 @@ class PipelineAsyncUmma(PipelineAsync):
|
||||
:type barrier_storage: cute.Pointer
|
||||
:param num_stages: Number of buffer stages for this pipeline
|
||||
:type num_stages: Int32
|
||||
:param producer_group: CooperativeGroup for the producer agent
|
||||
:param producer_group: `CooperativeGroup` for the producer agent
|
||||
:type producer_group: CooperativeGroup
|
||||
:param consumer_group: CooperativeGroup for the consumer agent
|
||||
:param consumer_group: `CooperativeGroup` for the consumer agent
|
||||
:type consumer_group: CooperativeGroup
|
||||
:param cta_layout_vmnk: Layout of the cluster shape
|
||||
:type cta_layout_vmnk: cute.Layout | None
|
||||
@@ -368,9 +369,9 @@ class PipelineUmmaAsync(PipelineAsync):
|
||||
:type barrier_storage: cute.Pointer
|
||||
:param num_stages: Number of buffer stages for this pipeline
|
||||
:type num_stages: Int32
|
||||
:param producer_group: CooperativeGroup for the producer agent
|
||||
:param producer_group: `CooperativeGroup` for the producer agent
|
||||
:type producer_group: CooperativeGroup
|
||||
:param consumer_group: CooperativeGroup for the consumer agent
|
||||
:param consumer_group: `CooperativeGroup` for the consumer agent
|
||||
:type consumer_group: CooperativeGroup
|
||||
:param cta_layout_vmnk: Layout of the cluster shape
|
||||
:type cta_layout_vmnk: cute.Layout | None
|
||||
|
||||
@@ -10,15 +10,18 @@
|
||||
# is strictly prohibited.
|
||||
|
||||
import enum
|
||||
from typing import Type, Tuple
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Union
|
||||
import warnings
|
||||
|
||||
import cutlass
|
||||
import cutlass.cute as cute
|
||||
from cutlass.cutlass_dsl import Boolean, Int32, if_generate
|
||||
|
||||
from cutlass.pipeline import (
|
||||
Agent,
|
||||
CooperativeGroup,
|
||||
PipelineOp,
|
||||
SyncObject,
|
||||
@@ -91,6 +94,30 @@ class PipelineAsync:
|
||||
- D: Data ready (producer has written data to buffer)
|
||||
- R: Consumer reading (consumer is consuming data from buffer)
|
||||
|
||||
**Example:**
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
# Create pipeline with 5 stages
|
||||
pipeline = PipelineAsync.create(
|
||||
num_stages=5, # number of pipeline stages
|
||||
producer_group=producer_warp,
|
||||
consumer_group=consumer_warp
|
||||
barrier_storage=smem_ptr, # smem pointer for array of mbarriers in shared memory
|
||||
)
|
||||
|
||||
producer, consumer = pipeline.make_participants()
|
||||
# Producer side
|
||||
for i in range(num_iterations):
|
||||
handle = producer.acquire_and_advance() # Wait for buffer to be empty & Move index to next stage
|
||||
# Write data to pipeline buffer
|
||||
handle.commit() # Signal buffer is full
|
||||
|
||||
# Consumer side
|
||||
for i in range(num_iterations):
|
||||
handle = consumer.wait_and_advance() # Wait for buffer to be full & Move index to next stage
|
||||
# Read data from pipeline buffer
|
||||
handle.release() # Signal buffer is empty
|
||||
"""
|
||||
|
||||
sync_object_full: SyncObject
|
||||
@@ -114,6 +141,7 @@ class PipelineAsync:
|
||||
PipelineOp.TmaLoad,
|
||||
PipelineOp.TCGen05Mma,
|
||||
PipelineOp.Composite,
|
||||
PipelineOp.AsyncLoad,
|
||||
]:
|
||||
return MbarrierArray(
|
||||
barrier_storage=barrier_storage,
|
||||
@@ -232,6 +260,74 @@ class PipelineAsync:
|
||||
state.advance()
|
||||
self.producer_acquire(state)
|
||||
|
||||
# Util methods to manage produer and consumer
|
||||
def make_producer(self):
|
||||
state = make_pipeline_state(PipelineUserType.Producer, self.num_stages)
|
||||
return PipelineProducer(self, state, self.sync_object_full.cg)
|
||||
|
||||
def make_consumer(self):
|
||||
state = make_pipeline_state(PipelineUserType.Consumer, self.num_stages)
|
||||
return PipelineConsumer(self, state, self.sync_object_empty.cg)
|
||||
|
||||
def make_participants(self):
|
||||
return self.make_producer(), self.make_consumer()
|
||||
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class PipelineCpAsync(PipelineAsync):
|
||||
"""
|
||||
PipelineCpAsync is used for CpAsync producers and AsyncThread consumers (e.g. Hopper non-TMA mainloops).
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def create(
|
||||
barrier_storage: cute.Pointer,
|
||||
num_stages: Int32,
|
||||
producer_group: CooperativeGroup,
|
||||
consumer_group: CooperativeGroup,
|
||||
producer_mask: Int32 = None,
|
||||
consumer_mask: Int32 = None,
|
||||
):
|
||||
"""
|
||||
This helper function computes any necessary attributes and returns an instance of PipelineAsync.
|
||||
:param barrier_storage: Pointer to the smem address for this pipeline's mbarriers
|
||||
:type barrier_storage: cute.Pointer
|
||||
:param num_stages: Number of buffer stages for this pipeline
|
||||
:type num_stages: Int32
|
||||
:param producer_group: CooperativeGroup for the producer agent
|
||||
:type producer_group: CooperativeGroup
|
||||
:param consumer_group: CooperativeGroup for the consumer agent
|
||||
:type consumer_group: CooperativeGroup
|
||||
:param producer_mask: Mask for signaling arrives for the producer agent
|
||||
:type producer_mask: Int32 | None
|
||||
:param consumer_mask: Mask for signaling arrives for the consumer agent
|
||||
:type consumer_mask: Int32 | None
|
||||
"""
|
||||
producer_type = PipelineOp.AsyncLoad
|
||||
consumer_type = PipelineOp.AsyncThread
|
||||
|
||||
producer = (producer_type, producer_group)
|
||||
consumer = (consumer_type, consumer_group)
|
||||
|
||||
sync_object_array_full = PipelineCpAsync._make_sync_object(
|
||||
barrier_storage.align(min_align=8), num_stages, producer
|
||||
)
|
||||
sync_object_array_empty = PipelineCpAsync._make_sync_object(
|
||||
barrier_storage.align(min_align=8) + num_stages, num_stages, consumer
|
||||
)
|
||||
|
||||
pipeline_init_wait()
|
||||
|
||||
return PipelineCpAsync(
|
||||
sync_object_array_full,
|
||||
sync_object_array_empty,
|
||||
num_stages,
|
||||
producer_mask,
|
||||
consumer_mask,
|
||||
)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class PipelineTmaAsync(PipelineAsync):
|
||||
"""
|
||||
@@ -294,9 +390,9 @@ class PipelineTmaAsync(PipelineAsync):
|
||||
:type barrier_storage: cute.Pointer
|
||||
:param num_stages: Number of buffer stages for this pipeline
|
||||
:type num_stages: Int32
|
||||
:param producer_group: CooperativeGroup for the producer agent
|
||||
:param producer_group: `CooperativeGroup` for the producer agent
|
||||
:type producer_group: CooperativeGroup
|
||||
:param consumer_group: CooperativeGroup for the consumer agent
|
||||
:param consumer_group: `CooperativeGroup` for the consumer agent
|
||||
:type consumer_group: CooperativeGroup
|
||||
:param tx_count: Number of bytes expected to be written to the transaction barrier for one stage
|
||||
:type tx_count: int
|
||||
@@ -404,11 +500,11 @@ class PipelineTmaMultiConsumersAsync(PipelineAsync):
|
||||
:type barrier_storage: cute.Pointer
|
||||
:param num_stages: Number of buffer stages for this pipeline
|
||||
:type num_stages: Int32
|
||||
:param producer_group: CooperativeGroup for the producer agent
|
||||
:param producer_group: `CooperativeGroup` for the producer agent
|
||||
:type producer_group: CooperativeGroup
|
||||
:param consumer_group_umma: CooperativeGroup for the UMMA consumer agent
|
||||
:param consumer_group_umma: `CooperativeGroup` for the UMMA consumer agent
|
||||
:type consumer_group_umma: CooperativeGroup
|
||||
:param consumer_group_async: CooperativeGroup for the AsyncThread consumer agent
|
||||
:param consumer_group_async: `CooperativeGroup` for the AsyncThread consumer agent
|
||||
:type consumer_group_async: CooperativeGroup
|
||||
:param tx_count: Number of bytes expected to be written to the transaction barrier for one stage
|
||||
:type tx_count: int
|
||||
@@ -529,9 +625,10 @@ class PipelineTmaStore(PipelineAsync):
|
||||
This helper function computes any necessary attributes and returns an instance of PipelineTmaStore.
|
||||
:param num_stages: Number of buffer stages for this pipeline
|
||||
:type num_stages: Int32
|
||||
:param producer_group: CooperativeGroup for the producer agent
|
||||
:param producer_group: `CooperativeGroup` for the producer agent
|
||||
:type producer_group: CooperativeGroup
|
||||
"""
|
||||
|
||||
producer_type = PipelineOp.TmaStore
|
||||
|
||||
producer = (producer_type, producer_group)
|
||||
@@ -556,3 +653,333 @@ class PipelineTmaStore(PipelineAsync):
|
||||
self.sync_object_full.tail()
|
||||
|
||||
|
||||
#################################################################
|
||||
# Utilities to help user of pipeline to simplify the workflow
|
||||
#################################################################
|
||||
|
||||
|
||||
class ImmutableResourceHandle:
|
||||
__origin: PipelineAsync
|
||||
__immutable_state: PipelineState
|
||||
|
||||
def __init__(self, origin: PipelineAsync, immutable_state: PipelineState):
|
||||
self.__origin = origin
|
||||
self.__immutable_state = immutable_state
|
||||
|
||||
@property
|
||||
def index(self):
|
||||
"""Get the index of the current pipeline stage."""
|
||||
return self.__immutable_state.index
|
||||
|
||||
@property
|
||||
def count(self):
|
||||
"""Get the count of how many handles this producer has committed.
|
||||
This is useful for tracking the number of blocks that have been loaded from gmem.
|
||||
"""
|
||||
return self.__immutable_state.count
|
||||
|
||||
def get_origin(self):
|
||||
"""Get the original pipeline this resource handle belongs to."""
|
||||
return self.__origin
|
||||
|
||||
def __extract_mlir_values__(self):
|
||||
"""Extract MLIR values from the current state.
|
||||
|
||||
:return: List of MLIR values representing the current state
|
||||
:rtype: list
|
||||
"""
|
||||
# TODO: need to handle pipeline as well
|
||||
return self.__immutable_state.__extract_mlir_values__()
|
||||
|
||||
def __new_from_mlir_values__(self, values):
|
||||
"""Create a new Producer instance from MLIR values.
|
||||
|
||||
:param values: MLIR values to initialize the state
|
||||
:type values: Any
|
||||
:return: New Producer instance with state initialized from values
|
||||
:rtype: Producer
|
||||
"""
|
||||
return self.__class__(
|
||||
self.__origin, self.__immutable_state.__new_from_mlir_values__(values)
|
||||
)
|
||||
|
||||
class PipelineProducer:
|
||||
"""A class representing a producer in an asynchronous pipeline.
|
||||
|
||||
The Producer class manages the producer side of an asynchronous pipeline, handling
|
||||
synchronization and state management for producing data. It provides methods for
|
||||
acquiring, committing, and advancing through pipeline stages.
|
||||
|
||||
:ivar __pipeline: The asynchronous pipeline this producer belongs to
|
||||
:type __pipeline: PipelineAsync
|
||||
:ivar __state: The current state of the producer in the pipeline
|
||||
:type __state: PipelineState
|
||||
:ivar __group: The cooperative group this producer operates in
|
||||
:type __group: CooperativeGroup
|
||||
|
||||
**Examples:**
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
pipeline = PipelineAsync.create(...)
|
||||
producer = pipeline.create_producer(producer_group, stages)
|
||||
for i in range(iterations):
|
||||
handle = producer.acquire_and_advance() # Wait for buffer to be empty
|
||||
# Produce data
|
||||
producer.commit(handle) # Signal data is ready
|
||||
# An alternative way to do this is:
|
||||
# handle.commit() # Signal data is ready
|
||||
"""
|
||||
|
||||
__pipeline: PipelineAsync
|
||||
__state: PipelineState
|
||||
__group: CooperativeGroup
|
||||
|
||||
class ImmutableResourceHandle(ImmutableResourceHandle):
|
||||
@property
|
||||
def barrier(self):
|
||||
"""Get the barrier pointer for the current pipeline stage.
|
||||
|
||||
:return: Pointer to the barrier for the current stage
|
||||
:rtype: cute.Pointer
|
||||
"""
|
||||
return self.get_origin().producer_get_barrier(
|
||||
self._ImmutableResourceHandle__immutable_state
|
||||
)
|
||||
|
||||
def commit(self):
|
||||
"""Signal that data production is complete for the current stage.
|
||||
This allows consumers to start processing the data.
|
||||
"""
|
||||
self.get_origin().producer_commit(
|
||||
self._ImmutableResourceHandle__immutable_state
|
||||
)
|
||||
|
||||
def __init__(self, pipeline, state, group: CooperativeGroup):
|
||||
"""Initialize a new Producer instance.
|
||||
|
||||
:param pipeline: The pipeline this producer belongs to
|
||||
:type pipeline: PipelineAsync
|
||||
:param state: Initial pipeline state
|
||||
:type state: PipelineState
|
||||
:param group: The cooperative group for synchronization
|
||||
:type group: CooperativeGroup
|
||||
"""
|
||||
self.__pipeline = pipeline
|
||||
self.__state = state
|
||||
self.__group = group
|
||||
|
||||
def acquire(
|
||||
self,
|
||||
try_acquire_token: Optional[Boolean] = None,
|
||||
) -> ImmutableResourceHandle:
|
||||
"""Wait for the current buffer to be empty before producing data.
|
||||
This is a blocking operation.
|
||||
|
||||
:param try_acquire_token: Optional token to try to acquire the buffer
|
||||
:type try_acquire_token: Optional[Boolean]
|
||||
:return: A handle to the producer for committing the data
|
||||
:rtype: ImmutableResourceHandle
|
||||
"""
|
||||
self.__pipeline.producer_acquire(self.__state, try_acquire_token)
|
||||
handle = PipelineProducer.ImmutableResourceHandle(
|
||||
self.__pipeline, self.__state.clone()
|
||||
)
|
||||
return handle
|
||||
|
||||
def advance(self):
|
||||
"""Move to the next pipeline stage."""
|
||||
self.__state.advance()
|
||||
|
||||
def acquire_and_advance(
|
||||
self, try_acquire_token: Optional[Boolean] = None
|
||||
) -> ImmutableResourceHandle:
|
||||
"""Wait for the current buffer to be empty before producing data.
|
||||
Then advance to the next stage.
|
||||
This is a blocking operation.
|
||||
|
||||
:param try_acquire_token: Optional token to try to acquire the buffer
|
||||
:type try_acquire_token: Optional[Boolean]
|
||||
:return: A handle to the producer for committing the data
|
||||
:rtype: ImmutableResourceHandle
|
||||
"""
|
||||
handle = self.acquire(try_acquire_token)
|
||||
self.advance()
|
||||
return handle
|
||||
|
||||
def try_acquire(self) -> Boolean:
|
||||
"""Try to acquire the current buffer without blocking.
|
||||
|
||||
:return: True if acquisition was successful, False otherwise
|
||||
:rtype: Boolean
|
||||
"""
|
||||
return self.__pipeline.producer_try_acquire(self.__state)
|
||||
|
||||
def commit(self, handle: Optional[ImmutableResourceHandle] = None):
|
||||
"""Signal that data production is complete for the current stage.
|
||||
This allows consumers to start processing the data.
|
||||
"""
|
||||
if handle is not None:
|
||||
assert (
|
||||
handle.get_origin() is self
|
||||
), "ResourceHandle does not belong to this PipelineProducer instance"
|
||||
handle.commit()
|
||||
else:
|
||||
self.__pipeline.producer_commit(self.__state)
|
||||
|
||||
def tail(self):
|
||||
"""Ensure all used buffers are properly synchronized before producer exit.
|
||||
This should be called before the producer finishes to avoid dangling signals.
|
||||
"""
|
||||
self.__pipeline.producer_tail(self.__state)
|
||||
|
||||
def __extract_mlir_values__(self):
|
||||
"""Extract MLIR values from the current state.
|
||||
|
||||
:return: List of MLIR values representing the current state
|
||||
:rtype: list
|
||||
"""
|
||||
# TODO: need to handle pipeline as well
|
||||
return self.__state.__extract_mlir_values__()
|
||||
|
||||
def __new_from_mlir_values__(self, values):
|
||||
"""Create a new Producer instance from MLIR values.
|
||||
|
||||
:param values: MLIR values to initialize the state
|
||||
:type values: Any
|
||||
:return: New Producer instance with state initialized from values
|
||||
:rtype: Producer
|
||||
"""
|
||||
return PipelineProducer(
|
||||
self.__pipeline, self.__state.__new_from_mlir_values__(values), self.__group
|
||||
)
|
||||
|
||||
class PipelineConsumer:
|
||||
"""A class representing a consumer in an asynchronous pipeline.
|
||||
|
||||
The Consumer class manages the consumer side of an asynchronous pipeline, handling
|
||||
synchronization and state management for consuming data. It provides methods for
|
||||
waiting, releasing, and advancing through pipeline stages.
|
||||
|
||||
:ivar __pipeline: The asynchronous pipeline this consumer belongs to
|
||||
:type __pipeline: PipelineAsync
|
||||
:ivar __state: The current state of the consumer in the pipeline
|
||||
:type __state: PipelineState
|
||||
:ivar __group: The cooperative group this consumer operates in
|
||||
:type __group: CooperativeGroup
|
||||
|
||||
**Examples:**
|
||||
.. code-block:: python
|
||||
|
||||
pipeline = PipelineAsync.create(...)
|
||||
consumer = pipeline.create_consumer(consumer_group, stages)
|
||||
for i in range(iterations):
|
||||
handle = consumer.wait_and_advance() # Wait for data to be ready
|
||||
# Consume data
|
||||
consumer.release(handle) # Signal buffer is empty
|
||||
# An alternative way to do this is:
|
||||
# handle.release() # Signal buffer is empty
|
||||
"""
|
||||
|
||||
__pipeline: PipelineAsync
|
||||
__state: PipelineState
|
||||
__group: CooperativeGroup
|
||||
|
||||
class ImmutableResourceHandle(ImmutableResourceHandle):
|
||||
def release(self):
|
||||
"""Signal that data production is complete for the current stage.
|
||||
This allows consumers to start processing the data.
|
||||
"""
|
||||
self.get_origin().consumer_release(
|
||||
self._ImmutableResourceHandle__immutable_state
|
||||
)
|
||||
|
||||
def __init__(self, pipeline, state: PipelineState, group: CooperativeGroup):
|
||||
"""Initialize a new Consumer instance.
|
||||
|
||||
:param pipeline: The pipeline this consumer belongs to
|
||||
:type pipeline: PipelineAsync
|
||||
:param state: Initial pipeline state
|
||||
:type state: PipelineState
|
||||
:param group: The cooperative group for synchronization
|
||||
:type group: CooperativeGroup
|
||||
"""
|
||||
self.__pipeline = pipeline
|
||||
self.__group = group
|
||||
self.__state = state
|
||||
|
||||
def wait(self, try_wait_token: Optional[Boolean] = None) -> ImmutableResourceHandle:
|
||||
"""Wait for data to be ready in the current buffer.
|
||||
This is a blocking operation.
|
||||
|
||||
:param try_wait_token: Optional token to try to wait for the buffer
|
||||
:type try_wait_token: Optional[Boolean]
|
||||
:return: A handle to the consumer for releasing the data
|
||||
:rtype: PipelineConsumerHandle
|
||||
"""
|
||||
self.__pipeline.consumer_wait(self.__state, try_wait_token)
|
||||
handle = PipelineConsumer.ImmutableResourceHandle(
|
||||
self.__pipeline, self.__state.clone()
|
||||
)
|
||||
return handle
|
||||
|
||||
def advance(self):
|
||||
"""Move to the next pipeline stage."""
|
||||
self.__state.advance()
|
||||
|
||||
def wait_and_advance(
|
||||
self, try_wait_token: Optional[Boolean] = None
|
||||
) -> ImmutableResourceHandle:
|
||||
"""Wait for data to be ready in the current buffer.
|
||||
Then advance to the next stage.
|
||||
This is a blocking operation.
|
||||
|
||||
:param try_wait_token: Optional token to try to wait for the buffer
|
||||
:type try_wait_token: Optional[Boolean]
|
||||
:return: A handle to the consumer for releasing the data
|
||||
:rtype: PipelineConsumerHandle
|
||||
"""
|
||||
handle = self.wait(try_wait_token)
|
||||
self.advance()
|
||||
return handle
|
||||
|
||||
def try_wait(self) -> Boolean:
|
||||
"""Try to check if data is ready without blocking.
|
||||
|
||||
:return: True if data is ready, False otherwise
|
||||
:rtype: Boolean
|
||||
"""
|
||||
return self.__pipeline.consumer_try_wait(self.__state)
|
||||
|
||||
def release(self, handle: Optional[ImmutableResourceHandle] = None):
|
||||
"""Signal that data consumption is complete for the current stage.
|
||||
This allows producers to start producing new data.
|
||||
"""
|
||||
if handle is not None:
|
||||
assert (
|
||||
handle.get_origin() is self
|
||||
), "ResourceHandle does not belong to this PipelineConsumer instance"
|
||||
handle.release()
|
||||
else:
|
||||
self.__pipeline.consumer_release(self.__state)
|
||||
|
||||
def __extract_mlir_values__(self):
|
||||
"""Extract MLIR values from the current state.
|
||||
|
||||
:return: List of MLIR values representing the current state
|
||||
:rtype: list
|
||||
"""
|
||||
return self.__state.__extract_mlir_values__()
|
||||
|
||||
def __new_from_mlir_values__(self, values):
|
||||
"""Create a new Consumer instance from MLIR values.
|
||||
|
||||
:param values: MLIR values to initialize the state
|
||||
:type values: Any
|
||||
:return: New Consumer instance with state initialized from values
|
||||
:rtype: Consumer
|
||||
"""
|
||||
# TODO: need to call pipeline.__new_from_mlir_values__ recursively
|
||||
return PipelineConsumer(
|
||||
self.__pipeline, self.__state.__new_from_mlir_values__(values), self.__group
|
||||
)
|
||||
|
||||
@@ -9,6 +9,8 @@
|
||||
# and related documentation outside the scope permitted by the EULA
|
||||
# is strictly prohibited.
|
||||
|
||||
import ctypes
|
||||
from math import prod
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import Optional, Type, Union
|
||||
@@ -54,6 +56,25 @@ def dtype(ty: Type[Numeric]):
|
||||
return torch_dtype
|
||||
|
||||
|
||||
def as_tensor(pointer, shape, torch_type):
|
||||
"""Convert a pointer to a torch tensor"""
|
||||
if torch_type.itemsize == 1:
|
||||
cytype = ctypes.c_uint8
|
||||
elif torch_type.itemsize == 2:
|
||||
cytype = ctypes.c_uint16
|
||||
elif torch_type.itemsize == 4:
|
||||
cytype = ctypes.c_uint32
|
||||
elif torch_type.itemsize == 8:
|
||||
cytype = ctypes.c_uint64
|
||||
else:
|
||||
raise ValueError(f"Unsupported torch dtype: {torch_type}")
|
||||
cpointer = ctypes.cast(pointer, ctypes.POINTER(cytype))
|
||||
arr = (cpointer._type_ * prod(shape)).from_address(
|
||||
ctypes.addressof(cpointer.contents)
|
||||
)
|
||||
return torch.frombuffer(arr, dtype=torch_type).view(*shape)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ScalarInitConfig:
|
||||
"""Configuration for scalar initialization"""
|
||||
@@ -128,7 +149,7 @@ def create_and_permute_torch_tensor(
|
||||
if not isinstance(init_config, GaussianInitConfig):
|
||||
raise ValueError("init_config must be GaussianInitConfig()")
|
||||
f32_torch_tensor = init_torch_tensor.normal_(init_config.mean, init_config.std)
|
||||
f32_torch_tensor = f32_torch_tensor * (1 << init_config.scale)
|
||||
f32_torch_tensor = f32_torch_tensor * init_config.scale
|
||||
else:
|
||||
raise ValueError(f"Invalid init type: {init_type}")
|
||||
|
||||
|
||||
@@ -64,6 +64,18 @@ from .smem_capacity import (
|
||||
get_smem_capacity_in_bytes,
|
||||
)
|
||||
|
||||
from .distributed_helpers import (
|
||||
spin_lock_wait,
|
||||
spin_lock_multimem_arrive,
|
||||
multimem_ld_reduce_8xf16,
|
||||
multimem_ld_reduce_4xf32,
|
||||
multimem_ld_reduce_8xbf16,
|
||||
multimem_ld_reduce_16xe4m3,
|
||||
multimem_ld_reduce_16xe5m2,
|
||||
multimem_st_4xb32,
|
||||
sm_wise_inter_gpu_multimem_barrier,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"get_smem_capacity_in_bytes",
|
||||
"SmemAllocator",
|
||||
|
||||
179
python/CuTeDSL/cutlass/utils/distributed_helpers.py
Normal file
179
python/CuTeDSL/cutlass/utils/distributed_helpers.py
Normal file
@@ -0,0 +1,179 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
|
||||
#
|
||||
# Use of this software is governed by the terms and conditions of the
|
||||
# NVIDIA End User License Agreement (EULA), available at:
|
||||
# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html
|
||||
#
|
||||
# Any use, reproduction, disclosure, or distribution of this software
|
||||
# and related documentation outside the scope permitted by the EULA
|
||||
# is strictly prohibited.
|
||||
|
||||
from functools import partial
|
||||
from typing import Tuple
|
||||
|
||||
import cutlass.cute as cute
|
||||
from cutlass.cutlass_dsl import T, dsl_user_op, while_generate
|
||||
|
||||
from cutlass._mlir import ir
|
||||
from cutlass._mlir.dialects import arith, llvm, nvvm, scf
|
||||
from cutlass._mlir.dialects.nvvm import (
|
||||
MemOrderKind,
|
||||
MemScopeKind,
|
||||
AtomicOpKind,
|
||||
)
|
||||
from cutlass.cute.typing import Pointer, Int32, Boolean
|
||||
|
||||
|
||||
@dsl_user_op
|
||||
def atomicAdd(dst_ptr: Pointer, val: Int32, loc=None, ip=None) -> Int32:
|
||||
return nvvm.atomicrmw(
|
||||
T.i32(),
|
||||
AtomicOpKind.ADD,
|
||||
dst_ptr.llvm_ptr,
|
||||
val.ir_value(loc=loc, ip=ip),
|
||||
mem_order=MemOrderKind.RELAXED,
|
||||
syncscope=MemScopeKind.SYS,
|
||||
loc=loc,
|
||||
ip=ip,
|
||||
)
|
||||
|
||||
|
||||
@cute.jit
|
||||
def ld_bypass(input_tensor: cute.Tensor):
|
||||
fragment = cute.make_fragment(input_tensor.layout, input_tensor.element_type)
|
||||
copy_atom_load = cute.make_copy_atom(
|
||||
cute.nvgpu.CopyUniversalOp(),
|
||||
input_tensor.element_type,
|
||||
memory_order=cute.nvgpu.common.MemoryOrder.VOLATILE,
|
||||
memory_scope=cute.nvgpu.common.MemoryScope.SYS,
|
||||
)
|
||||
cute.copy(copy_atom_load, input_tensor, fragment)
|
||||
vals = fragment.load()
|
||||
return vals
|
||||
|
||||
@cute.jit
|
||||
def spin_lock_wait(lock_ptr: Pointer, expect_count: Int32, mem_order : str = "relaxed", mem_scope : str = "gpu", loc=None, ip=None) -> None:
|
||||
"""
|
||||
wait on a spin lock until the expected count is reached.
|
||||
"""
|
||||
res = 0
|
||||
while res != expect_count:
|
||||
res = nvvm.atomicrmw(
|
||||
T.i32(),
|
||||
AtomicOpKind.CAS,
|
||||
lock_ptr.llvm_ptr,
|
||||
Int32(0).ir_value(loc=loc, ip=ip),
|
||||
b=Int32(expect_count).ir_value(loc=loc, ip=ip),
|
||||
mem_order=MemOrderKind.ACQUIRE if mem_order == "acquire" else MemOrderKind.RELAXED,
|
||||
syncscope=MemScopeKind.GPU if mem_scope == "gpu" else MemScopeKind.SYS
|
||||
)
|
||||
|
||||
|
||||
@dsl_user_op
|
||||
def multimem_red_add_sys_release(mc_ptr: Pointer, loc=None, ip=None) -> None:
|
||||
"""
|
||||
add 1 to the multimem address
|
||||
"""
|
||||
llvm.inline_asm(
|
||||
None,
|
||||
[mc_ptr.toint().ir_value()],
|
||||
"multimem.red.release.sys.global.add.u32 [$0], 1;",
|
||||
"l",
|
||||
has_side_effects=True,
|
||||
asm_dialect=0,
|
||||
loc=loc,
|
||||
ip=ip,
|
||||
)
|
||||
|
||||
@dsl_user_op
|
||||
def multimem_red_add_gpu_relaxed(mc_ptr: Pointer, loc=None, ip=None) -> None:
|
||||
"""
|
||||
add 1 to the multimem address
|
||||
"""
|
||||
llvm.inline_asm(
|
||||
None,
|
||||
[mc_ptr.toint().ir_value()],
|
||||
"multimem.red.relaxed.gpu.global.add.u32 [$0], 1;",
|
||||
"l",
|
||||
has_side_effects=True,
|
||||
asm_dialect=0,
|
||||
loc=loc,
|
||||
ip=ip,
|
||||
)
|
||||
|
||||
|
||||
def spin_lock_multimem_arrive(lock_ptr: Pointer, loc=None, ip=None) -> None:
|
||||
"""
|
||||
arrive a spin lock when the lock_ptr is a multimem address.
|
||||
"""
|
||||
multimem_red_add_gpu_relaxed(lock_ptr, loc=loc, ip=ip)
|
||||
|
||||
|
||||
def sm_wise_inter_gpu_multimem_barrier(barrier : Pointer, barrier_mc : Pointer, num_ranks, loc=None, ip=None) -> None :
|
||||
"""
|
||||
barrier for inter-gpu sm-wise
|
||||
"""
|
||||
bidx, bidy, bidz = cute.arch.block_idx()
|
||||
bdimx, bdimy, _ = cute.arch.grid_dim()
|
||||
pid = bidx + bidy * bdimx + bidz * bdimx * bdimy
|
||||
multimem_red_add_sys_release(barrier_mc + pid, loc=loc, ip=ip)
|
||||
cute.arch.fence_proxy(cute.arch.ProxyKind.alias)
|
||||
spin_lock_wait(barrier + pid, num_ranks, mem_order="acquire", mem_scope="sys", loc=loc, ip=ip)
|
||||
|
||||
|
||||
@dsl_user_op
|
||||
def multimem_ld_reduce_base(
|
||||
mc_ptr: Pointer,
|
||||
*,
|
||||
ptx_string: str = "",
|
||||
loc=None,
|
||||
ip=None,
|
||||
) -> Tuple[Int32, Int32, Int32, Int32]:
|
||||
# ld reduce 8xf16 elts
|
||||
mc_ptr_int = mc_ptr.toint(loc=loc, ip=ip).ir_value()
|
||||
return_struct = llvm.inline_asm(
|
||||
ir.Type.parse("!llvm.struct<(i32,i32,i32,i32)>"),
|
||||
[mc_ptr_int],
|
||||
ptx_string,
|
||||
"=r,=r,=r,=r,l",
|
||||
has_side_effects=True,
|
||||
asm_dialect=0,
|
||||
loc=loc,
|
||||
ip=ip,
|
||||
)
|
||||
return_regs = [llvm.extractvalue(T.i32(), return_struct, [i]) for i in range(4)]
|
||||
return return_regs[0], return_regs[1], return_regs[2], return_regs[3]
|
||||
|
||||
|
||||
multimem_ld_reduce_8xf16 = partial(multimem_ld_reduce_base, ptx_string="multimem.ld_reduce.sys.relaxed.global.add.acc::f32.v4.f16x2 {$0, $1, $2, $3}, [$4];")
|
||||
multimem_ld_reduce_4xf32 = partial(multimem_ld_reduce_base, ptx_string="multimem.ld_reduce.sys.relaxed.global.add.v4.f32 {$0, $1, $2, $3}, [$4];")
|
||||
multimem_ld_reduce_8xbf16 = partial(multimem_ld_reduce_base, ptx_string="multimem.ld_reduce.sys.relaxed.global.add.acc::f32.v4.bf16x2 {$0, $1, $2, $3}, [$4];")
|
||||
multimem_ld_reduce_16xe4m3 = partial(multimem_ld_reduce_base, ptx_string="multimem.ld_reduce.sys.relaxed.global.add.acc::f16.v4.e4m3x4 {$0, $1, $2, $3}, [$4];")
|
||||
multimem_ld_reduce_16xe5m2 = partial(multimem_ld_reduce_base, ptx_string="multimem.ld_reduce.sys.relaxed.global.add.acc::f16.v4.e5m2x4 {$0, $1, $2, $3}, [$4];")
|
||||
|
||||
|
||||
@dsl_user_op
|
||||
def multimem_st_4xb32(
|
||||
mc_ptr: Pointer,
|
||||
x: Int32,
|
||||
y: Int32,
|
||||
z: Int32,
|
||||
w: Int32,
|
||||
*,
|
||||
loc=None,
|
||||
ip=None,
|
||||
) -> None:
|
||||
# st 4x32 bits of data
|
||||
mc_ptr_int = mc_ptr.toint(loc=loc, ip=ip).ir_value()
|
||||
llvm.inline_asm(
|
||||
T.i32(),
|
||||
[mc_ptr_int, x, y, z, w],
|
||||
"multimem.st.sys.relaxed.global.v4.f32 [$1], {$2, $3, $4, $5};",
|
||||
"=r,l,r,r,r,r",
|
||||
has_side_effects=True,
|
||||
asm_dialect=0,
|
||||
loc=loc,
|
||||
ip=ip,
|
||||
)
|
||||
|
||||
@@ -34,18 +34,6 @@ class LayoutEnum(Enum):
|
||||
else warpgroup.OperandMajorMode.MN
|
||||
)
|
||||
|
||||
def is_k_major_a(self):
|
||||
return self == LayoutEnum.ROW_MAJOR
|
||||
|
||||
def is_m_major_a(self):
|
||||
return self == LayoutEnum.COL_MAJOR
|
||||
|
||||
def is_k_major_b(self):
|
||||
return self == LayoutEnum.COL_MAJOR
|
||||
|
||||
def is_n_major_b(self):
|
||||
return self == LayoutEnum.ROW_MAJOR
|
||||
|
||||
def is_n_major_c(self):
|
||||
return self == LayoutEnum.ROW_MAJOR
|
||||
|
||||
|
||||
@@ -11,7 +11,7 @@
|
||||
|
||||
from typing import Type, Union, overload
|
||||
|
||||
from cutlass.cutlass_dsl import Int8, Numeric, NumericMeta
|
||||
from cutlass.cutlass_dsl import Int8, Numeric, NumericMeta, CutlassBaseDSL
|
||||
|
||||
import cutlass.cute as cute
|
||||
from cutlass.cute.arch import get_dyn_smem, get_dyn_smem_size
|
||||
@@ -40,14 +40,17 @@ class SmemAllocator:
|
||||
"""
|
||||
self._base = get_dyn_smem(Int8, alignment=1024)
|
||||
self._allocated_bytes = 0
|
||||
CutlassBaseDSL.track_smem_allocator(self, lambda cls: cls._allocated_bytes)
|
||||
|
||||
@overload
|
||||
def allocate(self, size_or_type: int, byte_alignment: int): ...
|
||||
def allocate(self, size_or_type: int, byte_alignment: int) -> cute.Pointer: ...
|
||||
|
||||
@overload
|
||||
def allocate(self, size_or_type: cute.struct, byte_alignment: int): ...
|
||||
def allocate(
|
||||
self, size_or_type: cute.struct, byte_alignment: int
|
||||
) -> cute.Pointer: ...
|
||||
|
||||
def allocate(self, size_or_type, byte_alignment: int = 1) -> int:
|
||||
def allocate(self, size_or_type, byte_alignment: int = 1) -> cute.Pointer:
|
||||
"""Allocate a block of memory with specified size and alignment.
|
||||
|
||||
This method adjusts the base pointer to ensure proper alignment and updates
|
||||
|
||||
@@ -382,3 +382,5 @@ class StaticPersistentTileScheduler:
|
||||
@property
|
||||
def num_tiles_executed(self) -> Int32:
|
||||
return self._num_tiles_executed
|
||||
|
||||
|
||||
|
||||
@@ -17,6 +17,7 @@ from ..base_dsl.ast_helpers import (
|
||||
if_executor,
|
||||
while_selector,
|
||||
while_executor,
|
||||
range,
|
||||
range_constexpr,
|
||||
range_dynamic,
|
||||
const_expr,
|
||||
@@ -28,6 +29,8 @@ from ..base_dsl.ast_helpers import (
|
||||
all_executor,
|
||||
range_value_check,
|
||||
range_perf_warning,
|
||||
cf_symbol_check,
|
||||
redirect_builtin_function,
|
||||
)
|
||||
|
||||
from ..base_dsl import *
|
||||
@@ -38,5 +41,4 @@ from ..base_dsl._mlir_helpers.op import dsl_user_op
|
||||
from ..base_dsl.runtime import *
|
||||
from ..base_dsl.runtime import cuda as cuda_helpers
|
||||
from ..base_dsl.compiler import compile
|
||||
from ..base_dsl.runtime.dlpack_runtime import *
|
||||
from ..base_dsl.runtime.jit_arg_adapters import *
|
||||
|
||||
@@ -15,12 +15,14 @@ regarding to that dialect.
|
||||
"""
|
||||
|
||||
# Local module imports
|
||||
from typing import Callable, Union, Type, List, Union, Sequence, ForwardRef
|
||||
from inspect import isclass
|
||||
from itertools import chain
|
||||
from types import GenericAlias, SimpleNamespace, UnionType
|
||||
from typing import Callable, Union, Type, List, Union, Sequence, ForwardRef, Any
|
||||
import functools
|
||||
import pkgutil
|
||||
from dataclasses import is_dataclass
|
||||
from dataclasses import is_dataclass, fields
|
||||
from collections.abc import Sequence
|
||||
import builtins
|
||||
|
||||
from ..base_dsl import *
|
||||
from ..base_dsl import compiler
|
||||
@@ -51,20 +53,15 @@ from ..base_dsl.ast_helpers import (
|
||||
while_selector,
|
||||
while_executor,
|
||||
assert_executor,
|
||||
const_expr,
|
||||
dynamic_expr,
|
||||
bool_cast,
|
||||
compare_executor,
|
||||
any_executor,
|
||||
all_executor,
|
||||
range_value_check,
|
||||
range_perf_warning,
|
||||
)
|
||||
from ..base_dsl.runtime.dlpack_runtime import (
|
||||
get_cute_tensor_c_pointer,
|
||||
get_tensor_desc_shape_all,
|
||||
get_tensor_desc_stride_all,
|
||||
get_tensor_desc_element_type,
|
||||
get_tensor_desc_is_in_device,
|
||||
get_tensor_desc_assumed_align,
|
||||
cf_symbol_check,
|
||||
)
|
||||
|
||||
from .cutlass_ast_decorators import (
|
||||
@@ -73,6 +70,16 @@ from .cutlass_ast_decorators import (
|
||||
_while_execute_dynamic,
|
||||
)
|
||||
|
||||
from .tree_utils import (
|
||||
is_constexpr_field,
|
||||
tree_flatten,
|
||||
tree_unflatten,
|
||||
PyTreeDef,
|
||||
is_frozen_dataclass,
|
||||
DSLTreeFlattenError,
|
||||
)
|
||||
from ..base_dsl.runtime.jit_arg_adapters import JitArgAdapterRegistry
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Cutlass DSL Base Abstract Class
|
||||
@@ -125,6 +132,46 @@ def is_cute_algebra_type(arg_spec):
|
||||
return False
|
||||
|
||||
|
||||
def _get_c_pointers_cutlass(obj):
|
||||
"""
|
||||
This is an extended version of `get_c_pointers` that supports dataclasses, SimpleNamespace, and dict.
|
||||
"""
|
||||
if hasattr(obj, "__c_pointers__"):
|
||||
return obj.__c_pointers__()
|
||||
elif isinstance(obj, (tuple, list)):
|
||||
return list(chain.from_iterable(_get_c_pointers_cutlass(x) for x in obj))
|
||||
elif isinstance(obj, SimpleNamespace):
|
||||
return list(
|
||||
chain.from_iterable(
|
||||
_get_c_pointers_cutlass(x) for x in obj.__dict__.values()
|
||||
)
|
||||
)
|
||||
elif isinstance(obj, dict):
|
||||
return list(
|
||||
chain.from_iterable(_get_c_pointers_cutlass(x) for x in obj.values())
|
||||
)
|
||||
elif is_dataclass(obj):
|
||||
return list(
|
||||
chain.from_iterable(
|
||||
_get_c_pointers_cutlass(getattr(obj, f.name))
|
||||
for f in fields(obj)
|
||||
if not is_constexpr_field(f)
|
||||
)
|
||||
)
|
||||
elif isinstance(obj, set):
|
||||
raise DSLRuntimeError(
|
||||
"Sets are not supported in get_c_pointers to ensure order preservation",
|
||||
context="The DSL attempted to generate JIT function argument(s) for an argument of type set but failed.",
|
||||
suggestion="Consider using a list or tuple instead",
|
||||
)
|
||||
else:
|
||||
# Try get adapter
|
||||
adapter = JitArgAdapterRegistry.get_registered_adapter(type(obj))
|
||||
if adapter is not None:
|
||||
return _get_c_pointers_cutlass(adapter(obj))
|
||||
return []
|
||||
|
||||
|
||||
class CutlassBaseDSL(BaseDSL):
|
||||
"""This abstract class provides a DSL for Cutlass."""
|
||||
|
||||
@@ -137,16 +184,25 @@ class CutlassBaseDSL(BaseDSL):
|
||||
preprocess: bool = False,
|
||||
):
|
||||
super().__init__(
|
||||
name,
|
||||
compiler_provider,
|
||||
pass_sm_arch_name,
|
||||
device_compilation_only,
|
||||
preprocess,
|
||||
name=name,
|
||||
dsl_package_name=["cutlass"],
|
||||
compiler_provider=compiler_provider,
|
||||
pass_sm_arch_name=pass_sm_arch_name,
|
||||
device_compilation_only=device_compilation_only,
|
||||
preprocess=preprocess,
|
||||
)
|
||||
self._smem_usage_tracker: tuple = None
|
||||
|
||||
# this method is not useful for cutlass_dsl, so we only provide a dummy implementation.
|
||||
def _is_tensor_descriptor(self, maybe_tensor_descriptor) -> bool:
|
||||
return False
|
||||
|
||||
# this method is not useful for cutlass_dsl, so we only provide a dummy implementation.
|
||||
def _handle_tensor_descriptor(
|
||||
self, maybe_tensor, arg_name: str, need_gpu_memory: bool
|
||||
) -> Any:
|
||||
return False
|
||||
|
||||
def _build_gpu_module(self, attrs):
|
||||
self.gpu_module = gpu.GPUModuleOp(ir.StringAttr.get("kernels"))
|
||||
with ir.InsertionPoint(self.gpu_module.bodyRegion.blocks.append(*[])):
|
||||
@@ -229,8 +285,43 @@ class CutlassBaseDSL(BaseDSL):
|
||||
|
||||
return version_hash
|
||||
|
||||
@staticmethod
|
||||
def track_smem_allocator(allocator, callback):
|
||||
"""
|
||||
Tracks shared memory usage for kernel functions.
|
||||
Find and set allocator to its parent dsl object.
|
||||
"""
|
||||
frame = inspect.currentframe().f_back
|
||||
while frame:
|
||||
obj = frame.f_locals.get("self", None)
|
||||
if obj and isinstance(obj, CutlassBaseDSL):
|
||||
obj._set_smem_tracking(allocator, callback)
|
||||
return
|
||||
frame = frame.f_back
|
||||
warnings.warn("Cannot find parent dsl for allocator!", UserWarning)
|
||||
|
||||
def _set_smem_tracking(self, allocator, callback):
|
||||
# Registers an allocator and callback for current dsl
|
||||
self._smem_usage_tracker = (allocator, callback)
|
||||
|
||||
def _reset_smem_tracking(self):
|
||||
# Clear an allocator and callback for current dsl
|
||||
self._smem_usage_tracker = None
|
||||
|
||||
def _get_smem_usage(self) -> int:
|
||||
# Treat final allocated bytes of allocator as smem usage
|
||||
if not self._smem_usage_tracker:
|
||||
return 0
|
||||
allocator, callback = self._smem_usage_tracker
|
||||
return callback(allocator)
|
||||
|
||||
def _kernel_helper(self, funcBody, *args, **kwargs):
|
||||
class _CutlassIrKernelGenHelper(BaseDSL._KernelGenHelper):
|
||||
def __init__(self, dsl: CutlassBaseDSL):
|
||||
super().__init__()
|
||||
self.dsl = dsl
|
||||
self.dsl._reset_smem_tracking()
|
||||
|
||||
def generate_func_op(self, arg_types, arg_attrs, kernel_name, loc=None):
|
||||
super().generate_func_op(arg_types, arg_attrs, kernel_name)
|
||||
self.func_op = func.FuncOp(
|
||||
@@ -272,6 +363,17 @@ class CutlassBaseDSL(BaseDSL):
|
||||
if cfg.has_cluster:
|
||||
cfg.cluster = [to_index(size) for size in cfg.cluster]
|
||||
|
||||
smem_usage = self.dsl._get_smem_usage()
|
||||
if any(not isinstance(x, int) for x in [cfg.smem, smem_usage]):
|
||||
pass # cannot compare dynamic value inside kernel to launch op in py
|
||||
elif cfg.auto_smem:
|
||||
cfg.smem = smem_usage
|
||||
elif smem_usage > cfg.smem:
|
||||
warnings.warn(
|
||||
f"Potential error: specified kernel launch smem bytes "
|
||||
f"({cfg.smem}) is smaller than kernel usage ({smem_usage})!",
|
||||
UserWarning,
|
||||
)
|
||||
cfg.smem = const(cfg.smem)
|
||||
|
||||
if not isinstance(cfg.async_deps, (list, tuple)):
|
||||
@@ -295,12 +397,13 @@ class CutlassBaseDSL(BaseDSL):
|
||||
return token if is_async else None
|
||||
|
||||
return KernelLauncher(
|
||||
self, _CutlassIrKernelGenHelper, funcBody, *args, **kwargs
|
||||
self,
|
||||
lambda: _CutlassIrKernelGenHelper(self),
|
||||
funcBody,
|
||||
*args,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def _get_module_globals(self):
|
||||
return globals()
|
||||
|
||||
def _preprocess_launch_config_args(self, args, kwargs):
|
||||
"""Helper to preprocess args and kwargs for LaunchConfig"""
|
||||
if "stream" in kwargs:
|
||||
@@ -316,7 +419,10 @@ class CutlassBaseDSL(BaseDSL):
|
||||
Validates if the arg is really of the annotated type.
|
||||
"""
|
||||
|
||||
if is_arg_spec_constexpr(arg_annotation, arg_name, arg_index, None):
|
||||
if (
|
||||
is_arg_spec_constexpr(arg_annotation, arg_name, arg_index, None)
|
||||
or arg_annotation is Any
|
||||
):
|
||||
pass
|
||||
else:
|
||||
origin = get_origin(arg_annotation)
|
||||
@@ -329,11 +435,12 @@ class CutlassBaseDSL(BaseDSL):
|
||||
f"expects argument #{arg_index+1} ({arg_name}) to be Type[{expected_base}], but got {arg}"
|
||||
)
|
||||
# Handle Union types and generic types
|
||||
elif origin is Union:
|
||||
elif origin is Union or isinstance(arg_annotation, UnionType):
|
||||
# For Union types, check if arg matches any of the allowed types
|
||||
allowed_types = get_args(arg_annotation)
|
||||
if not any(
|
||||
(isinstance(ty, type) and isinstance(arg, ty))
|
||||
(ty is Any)
|
||||
or (isinstance(ty, type) and isinstance(arg, ty))
|
||||
or (get_origin(ty) is tuple and isinstance(arg, tuple))
|
||||
for ty in allowed_types
|
||||
):
|
||||
@@ -381,6 +488,26 @@ class CutlassBaseDSL(BaseDSL):
|
||||
jit_exec_arg.extend(get_c_pointers(arg) if is_host else dyn_vals)
|
||||
else:
|
||||
jit_exec_arg = jit_arg_type = jit_arg_attr = None
|
||||
elif not hasattr(arg, "__extract_mlir_values__") and not hasattr(
|
||||
arg, "__new_from_mlir_values__"
|
||||
):
|
||||
# Try tree_flatten
|
||||
try:
|
||||
dyn_vals, _ = tree_flatten(arg)
|
||||
except DSLTreeFlattenError:
|
||||
# If fails, just return the original arg
|
||||
return jit_exec_arg, jit_arg_type, jit_arg_attr
|
||||
|
||||
if dyn_vals:
|
||||
jit_arg_type.extend([v.type for v in dyn_vals])
|
||||
jit_arg_attr.extend([default_attr] * len(dyn_vals))
|
||||
jit_exec_arg.extend(
|
||||
_get_c_pointers_cutlass(arg) if is_host else dyn_vals
|
||||
)
|
||||
else:
|
||||
# If tree flatten yields empty list, treat it as a constexpr thing
|
||||
# Like a dataclass with all fields are constexpr, or an empty tuple or list
|
||||
jit_exec_arg = jit_arg_type = jit_arg_attr = None
|
||||
return jit_exec_arg, jit_arg_type, jit_arg_attr
|
||||
|
||||
def _generate_execution_arguments_for_known_types(
|
||||
@@ -396,6 +523,17 @@ class CutlassBaseDSL(BaseDSL):
|
||||
blk_args = fop_args[iv_block_args : iv_block_args + n_args]
|
||||
ir_arg.append(new_from_mlir_values(arg, blk_args))
|
||||
iv_block_args += n_args
|
||||
elif not hasattr(arg, "__extract_mlir_values__") and not hasattr(
|
||||
arg, "__new_from_mlir_values__"
|
||||
):
|
||||
# Try tree_unflatten
|
||||
try:
|
||||
dyn_vals, tree_def = tree_flatten(arg)
|
||||
block_args = fop_args[iv_block_args : iv_block_args + len(dyn_vals)]
|
||||
ir_arg.append(tree_unflatten(tree_def, block_args))
|
||||
iv_block_args += len(dyn_vals)
|
||||
except DSLTreeFlattenError:
|
||||
return ir_arg, iv_block_args
|
||||
|
||||
return ir_arg, iv_block_args
|
||||
|
||||
@@ -458,10 +596,7 @@ class KernelLauncher:
|
||||
|
||||
def _check_func_args(self, funcBody, *func_args, **func_kwargs):
|
||||
# Get function signature
|
||||
if isinstance(funcBody, DSLCallable):
|
||||
sig = funcBody.get_signature()
|
||||
else:
|
||||
sig = inspect.signature(funcBody)
|
||||
sig = inspect.signature(funcBody)
|
||||
|
||||
# func_args and func_kwargs should match funcBody's signature,
|
||||
# no extra or missing arguments.
|
||||
@@ -473,6 +608,12 @@ class KernelLauncher:
|
||||
cause=e,
|
||||
)
|
||||
|
||||
def smem_usage(self) -> int:
|
||||
"""
|
||||
Check smem usage for this kernel, only available after `launch`
|
||||
"""
|
||||
return self.dsl._get_smem_usage()
|
||||
|
||||
def launch(self, *args, **kwargs):
|
||||
self.dsl.frame = inspect.currentframe().f_back
|
||||
self.dsl._preprocess_launch_config_args(args, kwargs)
|
||||
@@ -497,134 +638,151 @@ class KernelLauncher:
|
||||
# =============================================================================
|
||||
# Utils
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def is_frozen_dataclass(obj_or_cls) -> bool:
|
||||
def _filter_readonly_frozen_dataclass(
|
||||
iter_args: List[Any], items_to_filter: List[Any], full_write_args_count: int
|
||||
) -> List[Any]:
|
||||
"""
|
||||
Return True if obj_or_cls is a dataclass (class or instance) declared with frozen=True,
|
||||
otherwise False.
|
||||
"""
|
||||
if not isinstance(obj_or_cls, type):
|
||||
# If it's an instance, get its class
|
||||
obj_or_cls = obj_or_cls.__class__
|
||||
Filter items based on whether corresponding iter_args are frozen dataclasses.
|
||||
|
||||
# Must be a dataclass, and __dataclass_params__.frozen must be True
|
||||
return (
|
||||
is_dataclass(obj_or_cls)
|
||||
and getattr(obj_or_cls, "__dataclass_params__", None) is not None
|
||||
and obj_or_cls.__dataclass_params__.frozen
|
||||
This function filters items (which can be values or names) based on the same
|
||||
logic: keep items if they correspond to full-write arguments (index < full_write_args_count)
|
||||
or if the corresponding iter_arg is not a frozen dataclass.
|
||||
|
||||
Args:
|
||||
iter_args: List of arguments to check for frozen dataclass status
|
||||
items_to_filter: List of items to filter (values or names)
|
||||
full_write_args_count: Number of arguments that are always written (not read-only)
|
||||
|
||||
Returns:
|
||||
Filtered list of items
|
||||
|
||||
Examples:
|
||||
# Filter values (original remove_read_only_frozen_dataclass behavior)
|
||||
filtered_values = _filter_readonly_frozen_dataclass(iter_args, iter_args, full_write_args_count)
|
||||
|
||||
# Filter names (original filter_readonly_frozen_dataclass_names behavior)
|
||||
filtered_names = _filter_readonly_frozen_dataclass(iter_args, iter_args_names, full_write_args_count)
|
||||
"""
|
||||
return [
|
||||
item
|
||||
for i, item in enumerate(items_to_filter)
|
||||
if i < full_write_args_count or not is_frozen_dataclass(iter_args[i])
|
||||
]
|
||||
|
||||
|
||||
def remove_read_only_frozen_dataclass(
|
||||
iter_args: List[Any], full_write_args_count: int
|
||||
) -> List[Any]:
|
||||
"""Filter out frozen dataclass arguments that are not full-write arguments."""
|
||||
return _filter_readonly_frozen_dataclass(
|
||||
iter_args, iter_args, full_write_args_count
|
||||
)
|
||||
|
||||
|
||||
def filter_readonly_frozen_dataclass_names(
|
||||
iter_args: List[Any], iter_args_names: List[str], full_write_args_count: int
|
||||
) -> List[str]:
|
||||
"""Filter names based on whether corresponding iter_args are frozen dataclasses."""
|
||||
return _filter_readonly_frozen_dataclass(
|
||||
iter_args, iter_args_names, full_write_args_count
|
||||
)
|
||||
|
||||
|
||||
def insert_read_only_frozen_dataclass(
|
||||
iter_args: List[Any], original_iter_args: List[Any], full_write_args_count: int
|
||||
) -> List[Any]:
|
||||
"""
|
||||
Insert read-only frozen dataclass arguments back into the iteration arguments.
|
||||
|
||||
This function takes the new iteration arguments and the original arguments,
|
||||
and preserves frozen dataclass instances from the original arguments while
|
||||
using the new arguments for non-frozen dataclass instances.
|
||||
|
||||
Args:
|
||||
iter_args: New iteration arguments to use for non-frozen dataclass instances
|
||||
original_iter_args: Original iteration arguments to preserve frozen dataclass instances from
|
||||
full_write_args_count: Number of arguments that are always written (not read-only)
|
||||
|
||||
Returns:
|
||||
List of arguments with frozen dataclass instances preserved from original
|
||||
"""
|
||||
# Take full-write arguments from new iter_args
|
||||
full_write_args = (
|
||||
iter_args[:full_write_args_count] if full_write_args_count > 0 else []
|
||||
)
|
||||
|
||||
# Process remaining arguments: preserve frozen dataclass from original, use new for others
|
||||
remaining_original = original_iter_args[full_write_args_count:]
|
||||
remaining_new = iter_args[full_write_args_count:]
|
||||
|
||||
def process_remaining_arg(original_arg, new_arg_iter):
|
||||
"""Process a single remaining argument, preserving frozen dataclass if present"""
|
||||
return original_arg if is_frozen_dataclass(original_arg) else next(new_arg_iter)
|
||||
|
||||
# Use zip to pair original args with new args, then map the processing function
|
||||
new_arg_iter = iter(remaining_new)
|
||||
processed_remaining = [
|
||||
process_remaining_arg(orig_arg, new_arg_iter) for orig_arg in remaining_original
|
||||
]
|
||||
|
||||
return full_write_args + processed_remaining
|
||||
|
||||
|
||||
def unpack_to_irvalue(
|
||||
mixed_values: List[Any], body_name: str, full_write_args_count: int
|
||||
) -> Tuple[List[ir.Value], PyTreeDef]:
|
||||
log().debug("===--- Values UNPack")
|
||||
for idx, packed in enumerate(mixed_values):
|
||||
log().debug("[%d]: will-unpacked: [type:%s] %s", idx, type(packed), packed)
|
||||
|
||||
try:
|
||||
unpacked_values, treedef = tree_flatten(
|
||||
remove_read_only_frozen_dataclass(mixed_values, full_write_args_count)
|
||||
)
|
||||
except DSLTreeFlattenError as e:
|
||||
raise DSLRuntimeError(
|
||||
f"The '{body_name}' statement encountered a user-defined Python object, which cannot be automatically converted into an dynamic expression.",
|
||||
context={
|
||||
e.message: (
|
||||
f"All expressions within '{body_name}' must be dynamic expressions, "
|
||||
"mixing Python objects and dynamic expressions is not supported. "
|
||||
"The DSL failed to convert the Python object into dynamic expressions."
|
||||
)
|
||||
},
|
||||
suggestion=(
|
||||
f"Please ensure '{e.type_str}' implements the '{DynamicExpression.__name__}' or mark with `dataclass`, "
|
||||
f"so it can be treated as a valid dynamic expression or mark '{body_name}' as a constant expression if conditions are Python objects."
|
||||
),
|
||||
)
|
||||
|
||||
log().debug("------------------ ")
|
||||
for idx, unpacked in enumerate(unpacked_values):
|
||||
log().debug("[%d]: unpacked values: %s", idx, unpacked)
|
||||
log().debug("treedef: %s", treedef)
|
||||
log().debug("------------------ ")
|
||||
|
||||
return unpacked_values, treedef
|
||||
|
||||
|
||||
def pack_from_irvalue(
|
||||
ir_values: List["ir.Value"],
|
||||
indices: Dict[int, Tuple[int, int]],
|
||||
class_types: List[Any],
|
||||
pytree_def: PyTreeDef,
|
||||
mixed_values: List[Any],
|
||||
full_write_args_count: int,
|
||||
) -> List[Any]:
|
||||
"""
|
||||
Packs MLIR values into a list of mixed values.
|
||||
"""
|
||||
log().debug("===--- Values Pack (%d)", len(ir_values))
|
||||
for idx, packed in enumerate(ir_values):
|
||||
log().debug("[%d]: will-packed: %s", idx, ir_values)
|
||||
for idx, unpacked in indices.items():
|
||||
log().debug("[%d]: indices: %s", idx, unpacked)
|
||||
for idx, c in enumerate(class_types):
|
||||
log().debug("[%d]: obj-types: %s", idx, type(c))
|
||||
|
||||
mixed_values = [None] * len(indices)
|
||||
for idx, (start, length) in sorted(indices.items()):
|
||||
chunk = ir_values[start : start + length]
|
||||
obj = class_types[idx]
|
||||
if is_frozen_dataclass(obj):
|
||||
mixed_values[idx] = obj
|
||||
elif not isinstance(obj, type) and hasattr(obj, "__new_from_mlir_values__"):
|
||||
mixed_values[idx] = obj.__new_from_mlir_values__(chunk)
|
||||
elif isinstance(chunk, list) and chunk[0] is None:
|
||||
mixed_values[idx] = class_types[idx]
|
||||
else:
|
||||
if len(chunk) == 1:
|
||||
try:
|
||||
mixed_values[idx] = t.as_numeric(chunk[0])
|
||||
except ValueError:
|
||||
# Suppress the conversion error and try new_from_mlir_values below
|
||||
pass
|
||||
|
||||
if mixed_values[idx] is None:
|
||||
mixed_values[idx] = new_from_mlir_values(obj, chunk)
|
||||
|
||||
log().debug("------------------ ")
|
||||
for idx, packed in enumerate(mixed_values):
|
||||
log().debug("[%d]: packed: %s", idx, packed)
|
||||
log().debug("------------------ ")
|
||||
return mixed_values
|
||||
|
||||
|
||||
def unpack_to_irvalue(
|
||||
mixed_values: List[Any], body_name: str
|
||||
) -> Tuple[List[ir.Value], List[Any], Dict[int, Tuple[int, int]], List[Any]]:
|
||||
"""
|
||||
Unpacks mixed values into ir.Value values.
|
||||
"""
|
||||
unpacked_values = []
|
||||
ir_values = []
|
||||
indices = {}
|
||||
class_types = []
|
||||
current_offset = 0
|
||||
|
||||
log().debug("===--- Values UNPack (%d)", len(mixed_values))
|
||||
for idx, packed in enumerate(mixed_values):
|
||||
log().debug("[%d]: will-unpacked: [type:%s] %s", idx, type(packed), packed)
|
||||
for idx, item in enumerate(mixed_values):
|
||||
class_types.append(item)
|
||||
try:
|
||||
if is_frozen_dataclass(item):
|
||||
extracted_vals = [None]
|
||||
else:
|
||||
extracted_vals = extract_mlir_values(item)
|
||||
# it's consexpr (python value), so we create mlir value for it
|
||||
if extracted_vals == []:
|
||||
if item is None:
|
||||
extracted_vals = [None]
|
||||
else:
|
||||
dyn_expr = t.as_numeric(item)
|
||||
extracted_vals = extract_mlir_values(dyn_expr)
|
||||
ir_values.extend(extracted_vals)
|
||||
else:
|
||||
ir_values.extend(extracted_vals)
|
||||
|
||||
unpacked_values.extend(extracted_vals)
|
||||
length = len(extracted_vals)
|
||||
indices[idx] = (current_offset, length)
|
||||
current_offset += length
|
||||
except Exception as e:
|
||||
raise DSLRuntimeError(
|
||||
f"The '{body_name}' statement encountered a user-defined Python object, which cannot be automatically converted into an dynamic expression (aka MLIR value).",
|
||||
context={
|
||||
item: (
|
||||
f"All expressions within '{body_name}' must be dynamic expressions, "
|
||||
"mixing Python objects and dynamic expressions (aka MLIR values) is not supported. "
|
||||
"The DSL failed to convert the Python object into MLIR values."
|
||||
)
|
||||
},
|
||||
suggestion=(
|
||||
f"Please ensure '{item}' implements the '{DynamicExpression.__name__}', "
|
||||
f"so it can be treated as a valid dynamic expression or mark '{body_name}' as a constant expression if conditions are Python objects."
|
||||
),
|
||||
) from e
|
||||
|
||||
log().debug("------------------ ")
|
||||
for idx, unpacked in enumerate(unpacked_values):
|
||||
log().debug("[%d]: unpacked values: %s", idx, unpacked)
|
||||
for idx, unpacked in enumerate(ir_values):
|
||||
log().debug("[%d]: unpacked ir_values: %s", idx, unpacked)
|
||||
for idx, unpacked in indices.items():
|
||||
log().debug("[%d]: indices: %s", idx, unpacked)
|
||||
for idx, unpacked in enumerate(class_types):
|
||||
log().debug("[%d]: initial-class-types: %s", idx, unpacked)
|
||||
for idx, value in enumerate(ir_values):
|
||||
log().debug("[%d]: will-packed: %s", idx, value)
|
||||
log().debug("treedef: %s", pytree_def)
|
||||
log().debug("------------------ ")
|
||||
|
||||
return ir_values, unpacked_values, indices, class_types
|
||||
unflattened = tree_unflatten(pytree_def, ir_values)
|
||||
return insert_read_only_frozen_dataclass(
|
||||
unflattened, mixed_values, full_write_args_count
|
||||
)
|
||||
|
||||
|
||||
def to_index(value):
|
||||
@@ -1015,8 +1173,8 @@ def any_(iterable):
|
||||
|
||||
def select_(cond, if_value, else_value):
|
||||
def _as_scalar(value):
|
||||
if const_expr(isinstance(value, list)):
|
||||
if const_expr(len(value) == 1):
|
||||
if isinstance(value, list):
|
||||
if len(value) == 1:
|
||||
return value[0]
|
||||
else:
|
||||
raise DSLRuntimeError(
|
||||
@@ -1024,16 +1182,16 @@ def select_(cond, if_value, else_value):
|
||||
)
|
||||
return value
|
||||
|
||||
if const_expr(not is_dynamic_expression(cond)):
|
||||
if not is_dynamic_expression(cond):
|
||||
raise DSLRuntimeError("Conditional expression must be dynamic")
|
||||
|
||||
# Extract MLIR values
|
||||
cond = extract_mlir_values(cond)
|
||||
if const_expr(is_dynamic_expression(if_value)):
|
||||
if is_dynamic_expression(if_value):
|
||||
if_value = extract_mlir_values(if_value)
|
||||
else:
|
||||
if_value = const(if_value)
|
||||
if const_expr(is_dynamic_expression(else_value)):
|
||||
if is_dynamic_expression(else_value):
|
||||
else_value = extract_mlir_values(else_value)
|
||||
else:
|
||||
else_value = const(else_value)
|
||||
@@ -1089,7 +1247,7 @@ def for_generate(
|
||||
iter_args: Optional[Sequence[ir.Value]] = None,
|
||||
*,
|
||||
unroll: LoopUnroll = None,
|
||||
pipelining=None,
|
||||
prefetch_stages=None,
|
||||
loc=None,
|
||||
ip=None,
|
||||
):
|
||||
@@ -1127,8 +1285,8 @@ def for_generate(
|
||||
if unroll is not None:
|
||||
for_op.attributes["loop_annotation"] = unroll
|
||||
|
||||
if pipelining is not None:
|
||||
for_op.attributes["cutlass.pipelining"] = _createI32Attr(pipelining)
|
||||
if prefetch_stages is not None:
|
||||
for_op.attributes["cutlass.pipelining"] = _createI32Attr(prefetch_stages)
|
||||
|
||||
iv = for_op.induction_variable
|
||||
new_results = new_from_mlir_values(iter_args, for_op.results)
|
||||
@@ -1155,11 +1313,11 @@ def not_(lhs: Union[ir.Value, bool], *, loc=None, ip=None):
|
||||
"""
|
||||
res = None
|
||||
# Handle Python bool first to prevent infinite recursion
|
||||
if const_expr(type(lhs) == bool):
|
||||
if type(lhs) == bool:
|
||||
res = lhs ^ True
|
||||
elif const_expr(hasattr(lhs, "__dsl_not__")):
|
||||
elif hasattr(lhs, "__dsl_not__"):
|
||||
res = lhs.__dsl_not__(loc=loc, ip=ip)
|
||||
elif const_expr(is_dynamic_expression(lhs)):
|
||||
elif is_dynamic_expression(lhs):
|
||||
# If lhs is MLIR value, compute not using xor
|
||||
res = arith.XOrIOp(lhs, const(1, lhs.type)).result
|
||||
else:
|
||||
@@ -1338,29 +1496,59 @@ def equal(lhs, rhs):
|
||||
return lhs == rhs
|
||||
|
||||
|
||||
def in_(lhs, rhs, op):
|
||||
def not_equal(lhs, rhs):
|
||||
if not is_dynamic_expression(lhs) and not is_dynamic_expression(rhs):
|
||||
return lhs != rhs
|
||||
|
||||
# Both sequence
|
||||
if isinstance(lhs, Sequence) and isinstance(rhs, Sequence):
|
||||
# Short-circuit for unequal length
|
||||
if len(lhs) != len(rhs):
|
||||
return True
|
||||
return any_(not_equal(l, r) for l, r in zip(lhs, rhs))
|
||||
|
||||
if hasattr(lhs, "__ne__"):
|
||||
return lhs != rhs
|
||||
elif hasattr(rhs, "__ne__"):
|
||||
return rhs != lhs
|
||||
else:
|
||||
return not_(equal(lhs, rhs))
|
||||
|
||||
|
||||
def in_(lhs, rhs):
|
||||
if not is_dynamic_expression(lhs) and not is_dynamic_expression(rhs):
|
||||
return lhs in rhs
|
||||
|
||||
if not isinstance(rhs, Sequence):
|
||||
raise DSLRuntimeError(
|
||||
f"'{op}' not supported between instances of {type(lhs)} and {type(rhs)}"
|
||||
f"'in' not supported between instances of {type(lhs)} and {type(rhs)}"
|
||||
)
|
||||
|
||||
return any_(equal(lhs, r) for r in rhs)
|
||||
|
||||
|
||||
def _lt_gt(lhs, rhs, op):
|
||||
def native_lt_gt(lhs, rhs, op):
|
||||
if op == "<":
|
||||
return lhs < rhs
|
||||
elif op == ">":
|
||||
return lhs > rhs
|
||||
else:
|
||||
raise DSLRuntimeError(f"Unsupported comparison operator: {op}")
|
||||
def _lte_gte(lhs, rhs, op):
|
||||
def native_lte_gte(lhs, rhs, op):
|
||||
match op:
|
||||
case "<":
|
||||
return lhs < rhs
|
||||
case "<=":
|
||||
if hasattr(lhs, "__le__"):
|
||||
return lhs <= rhs
|
||||
else:
|
||||
return not_(lhs > rhs)
|
||||
case ">":
|
||||
return lhs > rhs
|
||||
case ">=":
|
||||
if hasattr(lhs, "__ge__"):
|
||||
return lhs >= rhs
|
||||
else:
|
||||
return not_(lhs < rhs)
|
||||
case _:
|
||||
raise DSLRuntimeError(f"Unsupported comparison operator: {op}")
|
||||
|
||||
if not is_dynamic_expression(lhs) and not is_dynamic_expression(rhs):
|
||||
return native_lt_gt(lhs, rhs, op)
|
||||
return native_lte_gte(lhs, rhs, op)
|
||||
|
||||
# Both sequence, comparisons other than == and != do not allow mixing different types of sequences
|
||||
if (
|
||||
@@ -1375,7 +1563,7 @@ def _lt_gt(lhs, rhs, op):
|
||||
is_equal = equal(l, r)
|
||||
mask.append(not_(or_(is_equal, unequal_found)))
|
||||
unequal_found = not_(is_equal)
|
||||
comp_results.append(_lt_gt(l, r, op))
|
||||
comp_results.append(_lte_gte(l, r, op))
|
||||
|
||||
result = any_(and_(r, m) for r, m in zip(comp_results, mask))
|
||||
|
||||
@@ -1383,62 +1571,126 @@ def _lt_gt(lhs, rhs, op):
|
||||
# Ref https://docs.python.org/3/tutorial/datastructures.html#comparing-sequences-and-other-types
|
||||
# If one sequence is an initial sub-sequence of the other, the shorter sequence is the smaller (lesser) one
|
||||
has_valid_mask = any_(mask)
|
||||
if op == "<":
|
||||
length_result = len(lhs) < len(rhs)
|
||||
elif op == ">":
|
||||
length_result = len(lhs) > len(rhs)
|
||||
match op:
|
||||
case "<":
|
||||
length_result = len(lhs) < len(rhs)
|
||||
case ">":
|
||||
length_result = len(lhs) > len(rhs)
|
||||
case "<=":
|
||||
length_result = len(lhs) <= len(rhs)
|
||||
case ">=":
|
||||
length_result = len(lhs) >= len(rhs)
|
||||
if type(has_valid_mask) == bool:
|
||||
return result if has_valid_mask else length_result
|
||||
else:
|
||||
return select_(has_valid_mask, result, length_result)
|
||||
else:
|
||||
return result
|
||||
if op in {"<=", ">="}:
|
||||
# If no unequal, return True
|
||||
return select_(unequal_found, result, True)
|
||||
else:
|
||||
return result
|
||||
else:
|
||||
return native_lt_gt(lhs, rhs, op)
|
||||
return native_lte_gte(lhs, rhs, op)
|
||||
|
||||
|
||||
def greater_than(lhs, rhs):
|
||||
return _lt_gt(lhs, rhs, ">")
|
||||
return _lte_gte(lhs, rhs, ">")
|
||||
|
||||
|
||||
def greater_equal(lhs, rhs):
|
||||
return _lte_gte(lhs, rhs, ">=")
|
||||
|
||||
|
||||
def less_than(lhs, rhs):
|
||||
return _lt_gt(lhs, rhs, "<")
|
||||
return _lte_gte(lhs, rhs, "<")
|
||||
|
||||
|
||||
def less_equal(lhs, rhs):
|
||||
return _lte_gte(lhs, rhs, "<=")
|
||||
|
||||
|
||||
def _compare_dispatch(lhs, rhs, op):
|
||||
"""
|
||||
Dispatches the comparison operation between lhs and rhs based on the given operator.
|
||||
|
||||
:param lhs: The left-hand side operand for the comparison.
|
||||
:param rhs: The right-hand side operand for the comparison.
|
||||
:param op: The comparison operator as a string. Supported operators are:
|
||||
- "is", "is not": Python identity comparisons.
|
||||
- "in", "not in": Membership tests.
|
||||
- "==", "!=": Equality and inequality.
|
||||
- "<", ">", "<=", ">=": Relational comparisons.
|
||||
:return: The result of the comparison, which may be a boolean or a DSL-specific type.
|
||||
:raises DSLRuntimeError: If the operator is not supported.
|
||||
"""
|
||||
match op:
|
||||
# 'is' and 'is not' are pure python operators
|
||||
case "is":
|
||||
return lhs is rhs
|
||||
case "is not":
|
||||
return lhs is not rhs
|
||||
case "in":
|
||||
return in_(lhs, rhs)
|
||||
case "not in":
|
||||
return not_(in_(lhs, rhs))
|
||||
case "==":
|
||||
return equal(lhs, rhs)
|
||||
case "!=":
|
||||
return not_equal(lhs, rhs)
|
||||
case "<":
|
||||
return less_than(lhs, rhs)
|
||||
case ">":
|
||||
return greater_than(lhs, rhs)
|
||||
case ">=":
|
||||
return greater_equal(lhs, rhs)
|
||||
case "<=":
|
||||
return less_equal(lhs, rhs)
|
||||
case _:
|
||||
raise DSLRuntimeError(f"Unsupported comparison operator: {op}")
|
||||
|
||||
|
||||
def _compare_executor(left, comparators, ops):
|
||||
result = left
|
||||
# Fast path for single comparison
|
||||
if len(comparators) == 1:
|
||||
return _compare_dispatch(left, comparators[0], ops[0])
|
||||
|
||||
# Chain comparison, dispatch in a loop
|
||||
result = True
|
||||
current = left
|
||||
for comparator, op in zip(comparators, ops):
|
||||
# 'is' and 'is not' are pure python operators
|
||||
if op == "is":
|
||||
result = result is comparator
|
||||
elif op == "is not":
|
||||
result = result is not comparator
|
||||
elif op in ["in", "not in"]:
|
||||
result = in_(left, comparator, op)
|
||||
elif op in ["==", "!="]:
|
||||
result = equal(left, comparator)
|
||||
elif op in ["<", ">="]:
|
||||
result = less_than(left, comparator)
|
||||
elif op in [">", "<="]:
|
||||
result = greater_than(left, comparator)
|
||||
else:
|
||||
raise DSLRuntimeError(f"Unsupported comparison operator: {op}")
|
||||
# Invert the result for NotIn, NotEq, GtE, LtE
|
||||
if op in ["not in", "!=", ">=", "<="]:
|
||||
result = not_(result)
|
||||
cmp_result = _compare_dispatch(current, comparator, op)
|
||||
result = and_(result, cmp_result)
|
||||
current = comparator
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def _builtin_redirector(fcn):
|
||||
if fcn == builtins.max:
|
||||
return max
|
||||
elif fcn == builtins.min:
|
||||
return min
|
||||
elif fcn == builtins.any:
|
||||
return any_
|
||||
elif fcn == builtins.all:
|
||||
return all_
|
||||
else:
|
||||
raise DSLRuntimeError(f"Unsupported built-in function: {fcn}")
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Set the AST decorator
|
||||
# =============================================================================
|
||||
|
||||
# Set the DSL specific functions
|
||||
executor.set_functions(
|
||||
is_dynamic_expression,
|
||||
_loop_execute_range_dynamic,
|
||||
_if_execute_dynamic,
|
||||
_while_execute_dynamic,
|
||||
_compare_executor,
|
||||
any_,
|
||||
all_,
|
||||
is_dynamic_expression=is_dynamic_expression,
|
||||
loop_execute_range_dynamic=_loop_execute_range_dynamic,
|
||||
if_dynamic=_if_execute_dynamic,
|
||||
while_dynamic=_while_execute_dynamic,
|
||||
compare_executor=_compare_executor,
|
||||
any_executor=any_,
|
||||
all_executor=all_,
|
||||
builtin_redirector=_builtin_redirector,
|
||||
)
|
||||
|
||||
@@ -14,13 +14,22 @@ from types import NoneType
|
||||
from cutlass._mlir import ir
|
||||
from cutlass._mlir.dialects import scf, arith
|
||||
from cutlass._mlir.extras import types as T
|
||||
from collections.abc import Sequence
|
||||
|
||||
from ..base_dsl.dsl import extract_mlir_values, new_from_mlir_values
|
||||
from ..base_dsl.dsl import is_dynamic_expression
|
||||
from ..base_dsl.ast_helpers import *
|
||||
from ..base_dsl.utils.logger import log
|
||||
from ..base_dsl import typing as t
|
||||
from ..base_dsl.typing import Int32, Float32, Boolean, Numeric, get_mlir_types
|
||||
from ..base_dsl.typing import (
|
||||
Int32,
|
||||
Float32,
|
||||
Boolean,
|
||||
Numeric,
|
||||
get_mlir_types,
|
||||
as_numeric,
|
||||
)
|
||||
from . import cutlass as cutlass_dsl
|
||||
from .tree_utils import PyTreeDef, check_tree_equal
|
||||
|
||||
# =============================================================================
|
||||
# AST Helpers
|
||||
@@ -57,14 +66,6 @@ class ScfGenerator:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def fill_none(ir_values, unpacked_values):
|
||||
i = 0
|
||||
for idx, item in enumerate(unpacked_values):
|
||||
if item is not None:
|
||||
unpacked_values[idx] = ir_values[i]
|
||||
i += 1
|
||||
|
||||
@staticmethod
|
||||
def _normalize_region_result_to_list(region_result: Any) -> List[Any]:
|
||||
"""
|
||||
@@ -82,34 +83,109 @@ class ScfGenerator:
|
||||
return region_result_list
|
||||
|
||||
@staticmethod
|
||||
def check_region_result(region_values, ir_values):
|
||||
for i, (expected_value, actual_value) in enumerate(
|
||||
zip(ir_values, region_values)
|
||||
def _check_region_result(original_value, region_value, arg_name, op_type_name):
|
||||
"""
|
||||
Validate that a region result maintains the same type as the original value.
|
||||
|
||||
This method checks for type consistency between the original value passed to a dynamic
|
||||
SCF operation (like for, if, while) and the value returned from the operation's region.
|
||||
|
||||
Args:
|
||||
original_value: The value before entering the SCF operation region
|
||||
region_value: The value returned from the SCF operation region
|
||||
arg_name: Name of the argument being checked (for error reporting)
|
||||
op_type_name: Type of SCF operation (e.g., 'for', 'if', 'while') for error reporting
|
||||
|
||||
Raises:
|
||||
DSLRuntimeError: If the region value has a different type than the original value.
|
||||
The error includes suggestions for using compile-time control flow instead.
|
||||
|
||||
Note:
|
||||
This method performs relaxed type checking that allows inheritance relationships.
|
||||
For example, a child class can be returned where a parent class was expected.
|
||||
However, fundamental type changes (like None to non-None, different sequence types,
|
||||
or different numeric types) are not allowed in dynamic SCF operations.
|
||||
"""
|
||||
|
||||
def get_type_name(value):
|
||||
if isinstance(value, NoneType):
|
||||
return "None"
|
||||
elif isinstance(value, Sequence):
|
||||
return f"{type(value).__name__}<{len(value)}>"
|
||||
else:
|
||||
return type(value).__name__
|
||||
|
||||
# Check for type mismatches
|
||||
type_mismatch = False
|
||||
old_type_name = None
|
||||
new_type_name = None
|
||||
|
||||
# Handle None type changes
|
||||
if isinstance(original_value, NoneType) != isinstance(region_value, NoneType):
|
||||
type_mismatch = True
|
||||
old_type_name = get_type_name(original_value)
|
||||
new_type_name = get_type_name(region_value)
|
||||
# Handle sequence type/length changes
|
||||
elif isinstance(original_value, Sequence) and isinstance(
|
||||
region_value, Sequence
|
||||
):
|
||||
expected_value_type = get_mlir_types(expected_value)
|
||||
actual_value_type = get_mlir_types(actual_value)
|
||||
if expected_value_type != actual_value_type:
|
||||
return False, i, expected_value_type, actual_value_type
|
||||
return True, -1, None, None
|
||||
if type(original_value) != type(region_value) or len(original_value) != len(
|
||||
region_value
|
||||
):
|
||||
type_mismatch = True
|
||||
old_type_name = get_type_name(original_value)
|
||||
new_type_name = get_type_name(region_value)
|
||||
# Handle numeric type changes
|
||||
elif isinstance(
|
||||
original_value, (Numeric, ArithValue, ir.Value, int, float, bool)
|
||||
) or isinstance(
|
||||
region_value, (Numeric, ArithValue, ir.Value, int, float, bool)
|
||||
):
|
||||
try:
|
||||
original_numeric = as_numeric(original_value)
|
||||
region_numeric = as_numeric(region_value)
|
||||
if original_numeric.dtype != region_numeric.dtype:
|
||||
type_mismatch = True
|
||||
old_type_name = original_numeric.dtype.__name__
|
||||
new_type_name = region_numeric.dtype.__name__
|
||||
except Exception:
|
||||
pass
|
||||
# Handle general type changes (relaxed for inheritance)
|
||||
elif type(original_value) != type(region_value):
|
||||
old_type = type(original_value)
|
||||
new_type = type(region_value)
|
||||
if not (issubclass(old_type, new_type) or issubclass(new_type, old_type)):
|
||||
type_mismatch = True
|
||||
old_type_name = old_type.__name__
|
||||
new_type_name = new_type.__name__
|
||||
|
||||
if type_mismatch:
|
||||
raise DSLRuntimeError(
|
||||
f"`{arg_name}` is {old_type_name} prior to this `{op_type_name}`, "
|
||||
f"and update to {new_type_name} inside of this `{op_type_name}` is not supported.",
|
||||
suggestion=(
|
||||
f"Please avoid changing type inside a dynamic `{op_type_name}`, "
|
||||
f"or change to compile-time control flow by marking this `{op_type_name}` with "
|
||||
f"`{'range_constexpr' if op_type_name == 'for' else 'const_expr'}`."
|
||||
),
|
||||
)
|
||||
|
||||
def scf_execute_dynamic(
|
||||
self,
|
||||
op_type_name: str,
|
||||
used_args: List[Any],
|
||||
mix_iter_args: List[Any],
|
||||
full_write_args_count: int,
|
||||
mix_iter_arg_names: List[str],
|
||||
create_op_func: Callable[
|
||||
[List[ir.Value], Dict[int, Tuple[int, int]], List[Any]], ir.Operation
|
||||
],
|
||||
create_op_func: Callable[[List[ir.Value]], ir.Operation],
|
||||
region_builders: List[
|
||||
Callable[
|
||||
[
|
||||
"ir.Operation",
|
||||
List["ir.Value"], # block_args
|
||||
List[Any], # used_args
|
||||
List["ir.Value"], # dyn_yield_ops
|
||||
Dict[int, Tuple[int, int]],
|
||||
PyTreeDef,
|
||||
List[Any],
|
||||
int,
|
||||
],
|
||||
Any,
|
||||
]
|
||||
@@ -119,11 +195,11 @@ class ScfGenerator:
|
||||
block_term_op_builder: Dict[Callable, Callable] = {},
|
||||
) -> Any:
|
||||
# 1) Unpack
|
||||
ir_values, dyn_unpacked_values, dyn_indices, dyn_class_types = (
|
||||
cutlass_dsl.unpack_to_irvalue(mix_iter_args, op_type_name)
|
||||
ir_values, pytree_def = cutlass_dsl.unpack_to_irvalue(
|
||||
mix_iter_args, op_type_name, full_write_args_count
|
||||
)
|
||||
# 2) Create the SCF op
|
||||
op = create_op_func(ir_values, dyn_indices, dyn_class_types)
|
||||
op = create_op_func(ir_values)
|
||||
log().debug("Generated scf.%s \n[%s]", op_type_name, op)
|
||||
|
||||
# 3) Build the regions
|
||||
@@ -135,76 +211,61 @@ class ScfGenerator:
|
||||
region_result = builder(
|
||||
op,
|
||||
block_args,
|
||||
used_args,
|
||||
dyn_unpacked_values,
|
||||
dyn_indices,
|
||||
dyn_class_types,
|
||||
ir_values,
|
||||
pytree_def,
|
||||
mix_iter_args,
|
||||
full_write_args_count,
|
||||
)
|
||||
|
||||
# Use custom terminator if provided for this builder, otherwise use default YieldOp
|
||||
if builder in block_term_op_builder:
|
||||
# Use the provided terminator generator
|
||||
block_term_op_builder[builder](region_result)
|
||||
block_term_op_builder[builder](region_result, full_write_args_count)
|
||||
else:
|
||||
# For standard yield op, check result
|
||||
for arg, result, name in zip(
|
||||
mix_iter_args,
|
||||
(
|
||||
region_result
|
||||
if isinstance(region_result, list)
|
||||
else [region_result]
|
||||
),
|
||||
mix_iter_arg_names,
|
||||
):
|
||||
if isinstance(arg, NoneType) and not isinstance(
|
||||
result, NoneType
|
||||
):
|
||||
raise DSLRuntimeError(
|
||||
(
|
||||
f"`{name}` is None prior to this `{op_type_name}`, "
|
||||
f"and update to non-None value inside of this `{op_type_name}` is not supported."
|
||||
),
|
||||
suggestion=(
|
||||
f"Please make sure `{name}` is not None prior to this `{op_type_name}`, "
|
||||
f"or mark this `{op_type_name}` with "
|
||||
f"`{'range' if op_type_name == 'for' else 'const_expr'}`."
|
||||
),
|
||||
)
|
||||
# Normalize region_result
|
||||
region_result_list = ScfGenerator._normalize_region_result_to_list(
|
||||
region_result
|
||||
)
|
||||
# Default behavior - generate YieldOp
|
||||
region_values, unpacked_values, _, _ = (
|
||||
cutlass_dsl.unpack_to_irvalue(region_result_list, op_type_name)
|
||||
)
|
||||
|
||||
is_match, mismatch_idx, expected_type, actual_type = (
|
||||
ScfGenerator.check_region_result(region_values, ir_values)
|
||||
)
|
||||
|
||||
if not is_match:
|
||||
# From unpacked index, we need to find the original index
|
||||
original_idx = -1
|
||||
for unpacked_idx, (original_idx, length) in dyn_indices.items():
|
||||
if (
|
||||
mismatch_idx >= original_idx
|
||||
and mismatch_idx < original_idx + length
|
||||
):
|
||||
original_idx = unpacked_idx
|
||||
break
|
||||
raise DSLRuntimeError(
|
||||
f"`{op_type_name}` expects {expected_type} type for varible `{mix_iter_arg_names[original_idx]}`, but got {actual_type}.",
|
||||
suggestion=f"Please make sure `{mix_iter_arg_names[original_idx]}` type is not changed inside of `{op_type_name}`.",
|
||||
# For standard yield op, check result
|
||||
for arg, result, name in zip(
|
||||
mix_iter_args,
|
||||
region_result_list,
|
||||
mix_iter_arg_names,
|
||||
):
|
||||
ScfGenerator._check_region_result(
|
||||
arg, result, name, op_type_name
|
||||
)
|
||||
|
||||
# Default behavior - generate YieldOp
|
||||
region_values, yield_pytree_def = cutlass_dsl.unpack_to_irvalue(
|
||||
region_result_list, op_type_name, full_write_args_count
|
||||
)
|
||||
|
||||
mismatch = check_tree_equal(pytree_def, yield_pytree_def)
|
||||
if mismatch != -1:
|
||||
# Get arg name
|
||||
filterd_arg_names = (
|
||||
cutlass_dsl.filter_readonly_frozen_dataclass_names(
|
||||
mix_iter_args, mix_iter_arg_names, full_write_args_count
|
||||
)
|
||||
)
|
||||
|
||||
raise DSLRuntimeError(
|
||||
f"`{filterd_arg_names[mismatch]}` is structured different after this `{op_type_name}`.",
|
||||
suggestion=(
|
||||
f"Please avoid changing type structure inside a dynamic `{op_type_name}`, "
|
||||
f"or change to compile-time control flow by marking this `{op_type_name}` with "
|
||||
f"`{'range_constexpr' if op_type_name == 'for' else 'const_expr'}`."
|
||||
),
|
||||
)
|
||||
|
||||
scf.YieldOp(region_values)
|
||||
|
||||
log().debug("Completed scf.%s \n[%s]", op_type_name, op)
|
||||
ScfGenerator.fill_none(op.results, unpacked_values)
|
||||
|
||||
# 4) Pack final results
|
||||
final_results = cutlass_dsl.pack_from_irvalue(
|
||||
unpacked_values, dyn_indices, dyn_class_types
|
||||
op.results, pytree_def, mix_iter_args, full_write_args_count
|
||||
)
|
||||
|
||||
# 5) Return in a nice pattern
|
||||
@@ -215,28 +276,32 @@ class ScfGenerator:
|
||||
return final_results
|
||||
|
||||
|
||||
def _attr_const_check(attr, expected_type, attr_name):
|
||||
# Use strict type equality to prevent `bool` being accepted where `int` is required.
|
||||
if is_dynamic_expression(attr) or type(attr) is not expected_type:
|
||||
raise DSLRuntimeError(
|
||||
f"loop attribute `{attr_name}` must be a Python value of type `{expected_type.__name__}`, got `{type(attr).__name__}`."
|
||||
)
|
||||
|
||||
|
||||
def _loop_execute_range_dynamic(
|
||||
func: Callable,
|
||||
start: Any,
|
||||
stop: Any,
|
||||
step: Any,
|
||||
used_args: List[Any] = [],
|
||||
mix_iter_args: List[Any] = [],
|
||||
full_write_args_count: int = 0,
|
||||
mix_iter_arg_names: List[str] = [],
|
||||
unroll: int = -1,
|
||||
unroll_full: bool = False,
|
||||
pipelining: int = None,
|
||||
prefetch_stages: int = None,
|
||||
):
|
||||
"""
|
||||
Example: build an scf.for with optional unroll, using our universal helper.
|
||||
"""
|
||||
scf_gen = ScfGenerator()
|
||||
|
||||
def create_for_op(
|
||||
dyn_yield_ops: List[ir.Value],
|
||||
dyn_indices: Dict[int, Tuple[int, int]],
|
||||
dyn_class_types: List[Any],
|
||||
):
|
||||
def create_for_op(dyn_yield_ops: List[ir.Value]):
|
||||
for d in dyn_yield_ops:
|
||||
if not isinstance(d, ir.Value):
|
||||
raise DSLRuntimeError(
|
||||
@@ -254,6 +319,10 @@ def _loop_execute_range_dynamic(
|
||||
stop_ = stop_.ir_value()
|
||||
step_ = step_.ir_value()
|
||||
|
||||
# Attributes must be pure Python value, add a check
|
||||
_attr_const_check(unroll, int, "unroll")
|
||||
_attr_const_check(unroll_full, bool, "unroll_full")
|
||||
|
||||
# Possibly attach unroll attributes
|
||||
unroll_attr = None
|
||||
if unroll_full:
|
||||
@@ -262,17 +331,18 @@ def _loop_execute_range_dynamic(
|
||||
unroll_attr = LoopUnroll(count=unroll)
|
||||
log().debug("Unroll attribute: %s", unroll_attr)
|
||||
|
||||
pipelining_attr = None
|
||||
if pipelining is not None:
|
||||
if pipelining >= 0:
|
||||
pipelining_attr = ir.IntegerAttr.get(
|
||||
ir.IntegerType.get_signless(32), pipelining
|
||||
prefetch_stages_attr = None
|
||||
if prefetch_stages is not None:
|
||||
_attr_const_check(prefetch_stages, int, "prefetch_stages")
|
||||
if prefetch_stages >= 0:
|
||||
prefetch_stages_attr = ir.IntegerAttr.get(
|
||||
ir.IntegerType.get_signless(32), prefetch_stages
|
||||
)
|
||||
else:
|
||||
raise DSLRuntimeError(
|
||||
f"Pipelining must be non-negative, got {pipelining}"
|
||||
f"loop attribute `prefetch_stages` must be non-negative, got `{prefetch_stages}`."
|
||||
)
|
||||
log().debug("Pipelining attribute: %s", pipelining_attr)
|
||||
log().debug("prefetch_stages attribute: %s", prefetch_stages_attr)
|
||||
|
||||
log().debug(
|
||||
"Creating scf.ForOp \n\t\tstart=%s: type : %s\n\t\tstop=%s: type : %s\n\t\tstep=%s: type : %s",
|
||||
@@ -303,47 +373,48 @@ def _loop_execute_range_dynamic(
|
||||
if unroll_attr is not None:
|
||||
for_op.attributes["loop_annotation"] = unroll_attr
|
||||
|
||||
if pipelining_attr is not None:
|
||||
for_op.attributes["cutlass.pipelining"] = pipelining_attr
|
||||
if prefetch_stages_attr is not None:
|
||||
for_op.attributes["cutlass.pipelining"] = prefetch_stages_attr
|
||||
|
||||
return for_op
|
||||
|
||||
def for_body_builder(
|
||||
op, block_args, used_args, dyn_yield_ops, dyn_indices, dyn_class_types
|
||||
op,
|
||||
block_args,
|
||||
_,
|
||||
pytree_def,
|
||||
mix_iter_args,
|
||||
full_write_args_count,
|
||||
):
|
||||
# Insert induction variable at the beginning
|
||||
dyn_yield_ops.insert(0, block_args[0])
|
||||
ScfGenerator.fill_none(block_args, dyn_yield_ops)
|
||||
block_args = dyn_yield_ops
|
||||
# scf.ForOp block_args are typically [induction_var, iter_args...]
|
||||
# But MLIR also gives you op.induction_variable
|
||||
iv = t.as_numeric(op.induction_variable)
|
||||
log().debug(
|
||||
"For body builder: %s block_args: %s used_args: %s",
|
||||
"For body builder: %s block_args: %s full_write_args_count: %s",
|
||||
iv,
|
||||
block_args,
|
||||
used_args,
|
||||
full_write_args_count,
|
||||
)
|
||||
if len(block_args) <= 1:
|
||||
# block_args[1:] are iteration variables
|
||||
func_args = []
|
||||
func_args.extend(
|
||||
cutlass_dsl.pack_from_irvalue(
|
||||
block_args[1:], pytree_def, mix_iter_args, full_write_args_count
|
||||
)
|
||||
)
|
||||
if not func_args:
|
||||
# No iteration arguments, or only the induction var
|
||||
func(iv, *used_args)
|
||||
func(iv)
|
||||
return [] # yield nothing
|
||||
else:
|
||||
# block_args[1:] are iteration variables
|
||||
func_args = [*used_args]
|
||||
func_args.extend(
|
||||
cutlass_dsl.pack_from_irvalue(
|
||||
block_args[1:], dyn_indices, dyn_class_types
|
||||
)
|
||||
)
|
||||
updated_func_args = func(iv, *func_args)
|
||||
return updated_func_args
|
||||
|
||||
# Now call the universal SCF executor with a single region builder
|
||||
return scf_gen.scf_execute_dynamic(
|
||||
op_type_name="for",
|
||||
used_args=used_args,
|
||||
mix_iter_args=mix_iter_args,
|
||||
full_write_args_count=full_write_args_count,
|
||||
mix_iter_arg_names=mix_iter_arg_names,
|
||||
create_op_func=create_for_op,
|
||||
region_builders=[for_body_builder],
|
||||
@@ -354,8 +425,8 @@ def _if_execute_dynamic(
|
||||
pred: "ir.Value",
|
||||
then_block: Callable,
|
||||
else_block: Callable = None,
|
||||
used_args: List[Any] = [],
|
||||
mix_yield_args: List[Any] = [],
|
||||
full_write_args_count: int = 0,
|
||||
mix_yield_arg_names: List[str] = [],
|
||||
if_constexpr=None, # ignoring for brevity
|
||||
):
|
||||
@@ -364,11 +435,7 @@ def _if_execute_dynamic(
|
||||
"""
|
||||
scf_gen = ScfGenerator()
|
||||
|
||||
def create_if_op(
|
||||
dyn_yield_ops: List[ir.Value],
|
||||
dyn_indices: Dict[int, Tuple[int, int]],
|
||||
dyn_class_types: List[Any],
|
||||
):
|
||||
def create_if_op(dyn_yield_ops: List[ir.Value]):
|
||||
# Assume final result types match the dynamic yields
|
||||
result_types = [arg.type for arg in dyn_yield_ops]
|
||||
|
||||
@@ -387,11 +454,18 @@ def _if_execute_dynamic(
|
||||
return if_op
|
||||
|
||||
def then_builder(
|
||||
if_op, block_args, used_args, dyn_yield_ops, dyn_indices, dyn_class_types
|
||||
if_op,
|
||||
_,
|
||||
dyn_yield_ops,
|
||||
pytree_def,
|
||||
mix_iter_args,
|
||||
full_write_args_count,
|
||||
):
|
||||
flat_args = [*used_args]
|
||||
flat_args = []
|
||||
flat_args.extend(
|
||||
cutlass_dsl.pack_from_irvalue(dyn_yield_ops, dyn_indices, dyn_class_types)
|
||||
cutlass_dsl.pack_from_irvalue(
|
||||
dyn_yield_ops, pytree_def, mix_iter_args, full_write_args_count
|
||||
)
|
||||
)
|
||||
return then_block(*flat_args)
|
||||
|
||||
@@ -400,12 +474,17 @@ def _if_execute_dynamic(
|
||||
if else_block is not None:
|
||||
|
||||
def else_builder(
|
||||
if_op, block_args, used_args, dyn_yield_ops, dyn_indices, dyn_class_types
|
||||
if_op,
|
||||
_,
|
||||
dyn_yield_ops,
|
||||
pytree_def,
|
||||
mix_iter_args,
|
||||
full_write_args_count,
|
||||
):
|
||||
flat_args = [*used_args]
|
||||
flat_args = []
|
||||
flat_args.extend(
|
||||
cutlass_dsl.pack_from_irvalue(
|
||||
dyn_yield_ops, dyn_indices, dyn_class_types
|
||||
dyn_yield_ops, pytree_def, mix_iter_args, full_write_args_count
|
||||
)
|
||||
)
|
||||
return else_block(*flat_args)
|
||||
@@ -414,8 +493,8 @@ def _if_execute_dynamic(
|
||||
|
||||
return scf_gen.scf_execute_dynamic(
|
||||
op_type_name="if",
|
||||
used_args=used_args,
|
||||
mix_iter_args=mix_yield_args,
|
||||
full_write_args_count=full_write_args_count,
|
||||
mix_iter_arg_names=mix_yield_arg_names,
|
||||
create_op_func=create_if_op,
|
||||
region_builders=region_builders,
|
||||
@@ -425,9 +504,9 @@ def _if_execute_dynamic(
|
||||
def _while_execute_dynamic(
|
||||
while_before_block: Callable,
|
||||
while_after_block: Callable = None,
|
||||
used_args=[],
|
||||
yield_args=[],
|
||||
yield_arg_names=[],
|
||||
write_args=[],
|
||||
full_write_args_count=0,
|
||||
write_args_names=[],
|
||||
):
|
||||
"""
|
||||
Create and return an SCF WhileOp for dynamic loops.
|
||||
@@ -436,8 +515,7 @@ def _while_execute_dynamic(
|
||||
Args:
|
||||
while_before_block: Function that returns (condition, updated_values)
|
||||
while_after_block: Function that returns updated values
|
||||
used_args: Additional arguments used in the loop body
|
||||
yield_args: Values that are updated in the loop
|
||||
write_args: Values that are updated in the loop
|
||||
|
||||
See create_while_function in ast_preprocessor.py for details on the input structure.
|
||||
"""
|
||||
@@ -445,11 +523,7 @@ def _while_execute_dynamic(
|
||||
while_op_type_name = "while"
|
||||
scf_gen = ScfGenerator()
|
||||
|
||||
def create_while_op(
|
||||
dyn_yield_ops: List[ir.Value],
|
||||
dyn_indices: Dict[int, Tuple[int, int]],
|
||||
dyn_class_types: List[Any],
|
||||
):
|
||||
def create_while_op(dyn_yield_ops: List[ir.Value]):
|
||||
# Create the while operation with the types from yield_args
|
||||
result_types = [arg.type for arg in dyn_yield_ops]
|
||||
try:
|
||||
@@ -468,14 +542,19 @@ def _while_execute_dynamic(
|
||||
) from e
|
||||
|
||||
def before_block_builder(
|
||||
op, block_args, used_args, dyn_yield_ops, dyn_indices, dyn_class_types
|
||||
op,
|
||||
block_args,
|
||||
_,
|
||||
pytree_def,
|
||||
mix_iter_args,
|
||||
full_write_args_count,
|
||||
):
|
||||
# Build the before (condition) block
|
||||
ScfGenerator.fill_none(block_args, dyn_yield_ops)
|
||||
block_args = dyn_yield_ops
|
||||
flat_args = [*used_args]
|
||||
flat_args = []
|
||||
flat_args.extend(
|
||||
cutlass_dsl.pack_from_irvalue(block_args, dyn_indices, dyn_class_types)
|
||||
cutlass_dsl.pack_from_irvalue(
|
||||
block_args, pytree_def, mix_iter_args, full_write_args_count
|
||||
)
|
||||
)
|
||||
|
||||
log().debug("before block args: %s", flat_args)
|
||||
@@ -493,18 +572,15 @@ def _while_execute_dynamic(
|
||||
|
||||
return cond, before_results
|
||||
|
||||
def before_block_terminator(cond_and_results):
|
||||
def before_block_terminator(cond_and_results, full_write_args_count):
|
||||
# Generate a condition op instead of yield op
|
||||
cond = cond_and_results[0]
|
||||
before_result_list = ScfGenerator._normalize_region_result_to_list(
|
||||
cond_and_results[1]
|
||||
)
|
||||
ir_cond_list, _, _, _ = cutlass_dsl.unpack_to_irvalue(
|
||||
[cond], while_op_type_name
|
||||
)
|
||||
ir_cond = ir_cond_list[0]
|
||||
ir_results_list, _, _, _ = cutlass_dsl.unpack_to_irvalue(
|
||||
before_result_list, while_op_type_name
|
||||
ir_cond = as_numeric(cond).ir_value()
|
||||
ir_results_list, pytree_def = cutlass_dsl.unpack_to_irvalue(
|
||||
before_result_list, while_op_type_name, full_write_args_count
|
||||
)
|
||||
log().debug(
|
||||
"creating scf.ConditionOp with [%s], [%s]",
|
||||
@@ -514,14 +590,19 @@ def _while_execute_dynamic(
|
||||
scf.ConditionOp(ir_cond, ir_results_list)
|
||||
|
||||
def after_block_builder(
|
||||
op, block_args, used_args, dyn_yield_ops, dyn_indices, dyn_class_types
|
||||
op,
|
||||
block_args,
|
||||
_,
|
||||
pytree_def,
|
||||
mix_iter_args,
|
||||
full_write_args_count,
|
||||
):
|
||||
# Build the after (body) block
|
||||
ScfGenerator.fill_none(block_args, dyn_yield_ops)
|
||||
block_args = dyn_yield_ops
|
||||
flat_args = [*used_args]
|
||||
flat_args = []
|
||||
flat_args.extend(
|
||||
cutlass_dsl.pack_from_irvalue(block_args, dyn_indices, dyn_class_types)
|
||||
cutlass_dsl.pack_from_irvalue(
|
||||
block_args, pytree_def, mix_iter_args, full_write_args_count
|
||||
)
|
||||
)
|
||||
|
||||
log().debug("after block args: %s", flat_args)
|
||||
@@ -541,9 +622,9 @@ def _while_execute_dynamic(
|
||||
# Call the universal SCF executor with two region builders
|
||||
return scf_gen.scf_execute_dynamic(
|
||||
op_type_name=while_op_type_name,
|
||||
used_args=used_args,
|
||||
mix_iter_args=yield_args,
|
||||
mix_iter_arg_names=yield_arg_names,
|
||||
mix_iter_args=write_args,
|
||||
full_write_args_count=full_write_args_count,
|
||||
mix_iter_arg_names=write_args_names,
|
||||
create_op_func=create_while_op,
|
||||
region_builders=[before_block_builder, after_block_builder],
|
||||
block_term_op_builder={
|
||||
|
||||
763
python/CuTeDSL/cutlass_dsl/tree_utils.py
Normal file
763
python/CuTeDSL/cutlass_dsl/tree_utils.py
Normal file
@@ -0,0 +1,763 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
|
||||
#
|
||||
# Use of this software is governed by the terms and conditions of the
|
||||
# NVIDIA End User License Agreement (EULA), available at:
|
||||
# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html
|
||||
#
|
||||
# Any use, reproduction, disclosure, or distribution of this software
|
||||
# and related documentation outside the scope permitted by the EULA
|
||||
# is strictly prohibited.
|
||||
|
||||
from typing import Callable, Any, Iterable, Iterator, NamedTuple, Union, get_origin
|
||||
import dataclasses
|
||||
import itertools as it
|
||||
from types import SimpleNamespace
|
||||
|
||||
from ..base_dsl.typing import as_numeric, Numeric, Constexpr
|
||||
from ..base_dsl._mlir_helpers.arith import ArithValue
|
||||
from ..base_dsl.common import DSLBaseError
|
||||
from .._mlir import ir
|
||||
|
||||
# =============================================================================
|
||||
# Tree Utils
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class DSLTreeFlattenError(DSLBaseError):
|
||||
"""Exception raised when tree flattening fails due to unsupported types."""
|
||||
|
||||
def __init__(self, msg: str, type_str: str):
|
||||
super().__init__(msg)
|
||||
self.type_str = type_str
|
||||
|
||||
|
||||
def unzip2(pairs: Iterable[tuple[Any, Any]]) -> tuple[list[Any], list[Any]]:
|
||||
"""Unzip a sequence of pairs into two lists."""
|
||||
lst1, lst2 = [], []
|
||||
for x1, x2 in pairs:
|
||||
lst1.append(x1)
|
||||
lst2.append(x2)
|
||||
return lst1, lst2
|
||||
|
||||
|
||||
def get_fully_qualified_class_name(x: Any) -> str:
|
||||
"""
|
||||
Get the fully qualified class name of an object.
|
||||
|
||||
Args:
|
||||
x: Any object
|
||||
|
||||
Returns:
|
||||
str: Fully qualified class name in format 'module.class_name'
|
||||
|
||||
Example:
|
||||
>>> get_fully_qualified_class_name([1, 2, 3])
|
||||
'builtins.list'
|
||||
"""
|
||||
return f"{x.__class__.__module__}.{x.__class__.__qualname__}"
|
||||
|
||||
|
||||
def is_frozen_dataclass(obj_or_cls: Any) -> bool:
|
||||
"""
|
||||
Check if an object or class is a frozen dataclass.
|
||||
|
||||
Args:
|
||||
obj_or_cls: Either a dataclass instance or class
|
||||
|
||||
Returns:
|
||||
bool: True if the object/class is a dataclass declared with frozen=True,
|
||||
False otherwise
|
||||
|
||||
Example:
|
||||
>>> from dataclasses import dataclass
|
||||
>>> @dataclass(frozen=True)
|
||||
... class Point:
|
||||
... x: int
|
||||
... y: int
|
||||
>>> is_frozen_dataclass(Point)
|
||||
True
|
||||
>>> is_frozen_dataclass(Point(1, 2))
|
||||
True
|
||||
"""
|
||||
cls = obj_or_cls if isinstance(obj_or_cls, type) else obj_or_cls.__class__
|
||||
|
||||
return (
|
||||
dataclasses.is_dataclass(cls)
|
||||
and getattr(cls, "__dataclass_params__", None) is not None
|
||||
and cls.__dataclass_params__.frozen
|
||||
)
|
||||
|
||||
|
||||
def is_dynamic_expression(x: Any) -> bool:
|
||||
"""
|
||||
Check if an object implements the DynamicExpression protocol.
|
||||
|
||||
Objects implementing this protocol must have both `__extract_mlir_values__`
|
||||
and `__new_from_mlir_values__` methods.
|
||||
|
||||
Args:
|
||||
x: Any object to check
|
||||
|
||||
Returns:
|
||||
bool: True if the object implements the DynamicExpression protocol,
|
||||
False otherwise
|
||||
"""
|
||||
return all(
|
||||
hasattr(x, attr)
|
||||
for attr in ("__extract_mlir_values__", "__new_from_mlir_values__")
|
||||
)
|
||||
|
||||
|
||||
def is_constexpr_field(field: dataclasses.Field) -> bool:
|
||||
"""
|
||||
Check if a field is a constexpr field.
|
||||
"""
|
||||
if field.type is Constexpr:
|
||||
return True
|
||||
elif get_origin(field.type) is Constexpr:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# PyTreeDef
|
||||
# =============================================================================
|
||||
|
||||
class NodeType(NamedTuple):
|
||||
"""
|
||||
Represents a node in a pytree structure.
|
||||
|
||||
Attributes:
|
||||
name: String representation of the node type
|
||||
to_iterable: Function to convert node to iterable form
|
||||
from_iterable: Function to reconstruct node from iterable form
|
||||
"""
|
||||
name: str
|
||||
to_iterable: Callable
|
||||
from_iterable: Callable
|
||||
|
||||
|
||||
class PyTreeDef(NamedTuple):
|
||||
"""
|
||||
Represents the structure definition of a pytree.
|
||||
|
||||
Attributes:
|
||||
node_type: The type of this node
|
||||
node_metadata: SimpleNamespace metadata associated with this node
|
||||
child_treedefs: Tuple of child tree definitions
|
||||
"""
|
||||
node_type: NodeType
|
||||
node_metadata: SimpleNamespace
|
||||
child_treedefs: tuple["PyTreeDef", ...]
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class Leaf:
|
||||
"""
|
||||
Represents a leaf node in a pytree structure.
|
||||
|
||||
Attributes:
|
||||
is_numeric: Whether this leaf contains a `Numeric` value
|
||||
is_none: Whether this leaf represents None
|
||||
node_metadata: SimpleNamespace metadata associated with this leaf
|
||||
ir_type_str: String representation of the IR type
|
||||
"""
|
||||
is_numeric: bool = False
|
||||
is_none: bool = False
|
||||
node_metadata: SimpleNamespace = None
|
||||
ir_type_str: str = None
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Default to_iterable and from_iterable
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def extract_dataclass_members(x: Any) -> tuple[list[str], list[Any]]:
|
||||
"""
|
||||
Extract non-method, non-function attributes from a dataclass instance.
|
||||
|
||||
Args:
|
||||
x: A dataclass instance
|
||||
|
||||
Returns:
|
||||
tuple: (field_names, field_values) lists
|
||||
"""
|
||||
fields = [field.name for field in dataclasses.fields(x)]
|
||||
|
||||
# If the dataclass has extra fields, raise an error
|
||||
for k in x.__dict__.keys():
|
||||
if k not in fields:
|
||||
raise DSLTreeFlattenError(
|
||||
f"`{x}` has extra field `{k}`",
|
||||
type_str=get_fully_qualified_class_name(x),
|
||||
)
|
||||
|
||||
if not fields:
|
||||
return [], []
|
||||
|
||||
# record constexpr fields
|
||||
members = []
|
||||
constexpr_fields = []
|
||||
for field in dataclasses.fields(x):
|
||||
if is_constexpr_field(field):
|
||||
constexpr_fields.append(field.name)
|
||||
fields.remove(field.name)
|
||||
v = getattr(x, field.name)
|
||||
if is_dynamic_expression(v):
|
||||
raise DSLTreeFlattenError(
|
||||
f"`{x}` has dynamic expression field `{field.name}` with a Constexpr type annotation `{field.type}`",
|
||||
type_str=get_fully_qualified_class_name(x),
|
||||
)
|
||||
else:
|
||||
members.append(getattr(x, field.name))
|
||||
|
||||
return fields, members, constexpr_fields
|
||||
|
||||
|
||||
def default_dataclass_to_iterable(x: Any) -> tuple[SimpleNamespace, list[Any]]:
|
||||
"""
|
||||
Convert a dataclass instance to iterable form for tree flattening.
|
||||
|
||||
Extracts all non-method, non-function attributes that don't start with '__'
|
||||
and returns them along with metadata about the dataclass.
|
||||
|
||||
Args:
|
||||
x: A dataclass instance
|
||||
|
||||
Returns:
|
||||
tuple: (metadata, members) where metadata contains type info and field names,
|
||||
and members is the list of attribute values
|
||||
"""
|
||||
fields, members, constexpr_fields = extract_dataclass_members(x)
|
||||
|
||||
metadata = SimpleNamespace(
|
||||
type_str=get_fully_qualified_class_name(x),
|
||||
fields=fields,
|
||||
constexpr_fields=constexpr_fields,
|
||||
original_obj=x,
|
||||
)
|
||||
return metadata, members
|
||||
|
||||
|
||||
def set_dataclass_attributes(
|
||||
instance: Any,
|
||||
fields: list[str],
|
||||
values: Iterable[Any],
|
||||
constexpr_fields: list[str],
|
||||
) -> Any:
|
||||
"""
|
||||
Set attributes on a dataclass instance.
|
||||
|
||||
Args:
|
||||
instance: The dataclass instance
|
||||
fields: List of field names
|
||||
values: Iterable of field values
|
||||
is_frozen: Whether the dataclass is frozen
|
||||
|
||||
Returns:
|
||||
The instance with attributes set
|
||||
"""
|
||||
if not fields:
|
||||
return instance
|
||||
|
||||
kwargs = dict(zip(fields, values))
|
||||
for field in constexpr_fields:
|
||||
kwargs[field] = getattr(instance, field)
|
||||
return dataclasses.replace(instance, **kwargs)
|
||||
|
||||
def default_dataclass_from_iterable(
|
||||
metadata: SimpleNamespace, children: Iterable[Any]
|
||||
) -> Any:
|
||||
"""
|
||||
Reconstruct a dataclass instance from iterable form.
|
||||
|
||||
Handles both regular and frozen dataclasses appropriately.
|
||||
|
||||
Args:
|
||||
metadata: Metadata containing type information and field names
|
||||
children: Iterable of attribute values to reconstruct the instance
|
||||
|
||||
Returns:
|
||||
The reconstructed dataclass instance
|
||||
"""
|
||||
instance = metadata.original_obj
|
||||
|
||||
new_instance = set_dataclass_attributes(
|
||||
instance, metadata.fields, children, metadata.constexpr_fields
|
||||
)
|
||||
metadata.original_obj = new_instance
|
||||
return new_instance
|
||||
|
||||
|
||||
def dynamic_expression_to_iterable(x: Any) -> tuple[SimpleNamespace, list[Any]]:
|
||||
"""
|
||||
Convert a dynamic expression to iterable form.
|
||||
|
||||
Uses the object's `__extract_mlir_values__` method to extract MLIR values.
|
||||
|
||||
Args:
|
||||
x: A dynamic expression object
|
||||
|
||||
Returns:
|
||||
tuple: (metadata, mlir_values) where metadata marks this as a dynamic expression
|
||||
and mlir_values are the extracted MLIR values
|
||||
"""
|
||||
return (
|
||||
SimpleNamespace(is_dynamic_expression=1, original_obj=x),
|
||||
x.__extract_mlir_values__(),
|
||||
)
|
||||
|
||||
|
||||
def dynamic_expression_from_iterable(
|
||||
metadata: SimpleNamespace, children: Iterable[Any]
|
||||
) -> Any:
|
||||
"""
|
||||
Reconstruct a dynamic expression from iterable form.
|
||||
|
||||
Uses the object's `__new_from_mlir_values__` method to reconstruct from MLIR values.
|
||||
|
||||
Args:
|
||||
metadata: Metadata containing the original object
|
||||
children: Iterable of MLIR values to reconstruct from
|
||||
|
||||
Returns:
|
||||
The reconstructed dynamic expression object
|
||||
"""
|
||||
return metadata.original_obj.__new_from_mlir_values__(list(children))
|
||||
|
||||
|
||||
def default_dict_to_iterable(x: Any) -> tuple[SimpleNamespace, list[Any]]:
|
||||
"""
|
||||
Convert a dict to iterable form.
|
||||
"""
|
||||
if isinstance(x, SimpleNamespace):
|
||||
keys = list(x.__dict__.keys())
|
||||
values = list(x.__dict__.values())
|
||||
else:
|
||||
keys = list(x.keys())
|
||||
values = list(x.values())
|
||||
|
||||
return (
|
||||
SimpleNamespace(
|
||||
type_str=get_fully_qualified_class_name(x), original_obj=x, fields=keys
|
||||
),
|
||||
values,
|
||||
)
|
||||
|
||||
|
||||
def default_dict_from_iterable(
|
||||
metadata: SimpleNamespace, children: Iterable[Any]
|
||||
) -> Any:
|
||||
"""
|
||||
Reconstruct a dict from iterable form.
|
||||
"""
|
||||
instance = metadata.original_obj
|
||||
fields = metadata.fields
|
||||
is_simple_namespace = isinstance(instance, SimpleNamespace)
|
||||
|
||||
for k, v in zip(fields, children):
|
||||
if is_simple_namespace:
|
||||
setattr(instance, k, v)
|
||||
else:
|
||||
instance[k] = v
|
||||
|
||||
return instance
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Register pytree nodes
|
||||
# =============================================================================
|
||||
|
||||
_node_types: dict[type, NodeType] = {}
|
||||
|
||||
|
||||
def register_pytree_node(ty: type, to_iter: Callable, from_iter: Callable) -> NodeType:
|
||||
"""
|
||||
Register a new node type for pytree operations.
|
||||
|
||||
Args:
|
||||
ty: The type to register
|
||||
to_iter: Function to convert instances of this type to iterable form
|
||||
from_iter: Function to reconstruct instances of this type from iterable form
|
||||
|
||||
Returns:
|
||||
NodeType: The created NodeType instance
|
||||
"""
|
||||
nt = NodeType(str(ty), to_iter, from_iter)
|
||||
_node_types[ty] = nt
|
||||
return nt
|
||||
|
||||
|
||||
def register_default_node_types() -> None:
|
||||
"""Register default node types for pytree operations."""
|
||||
default_registrations = [
|
||||
(
|
||||
tuple,
|
||||
lambda t: (SimpleNamespace(length=len(t)), list(t)),
|
||||
lambda _, xs: tuple(xs),
|
||||
),
|
||||
(
|
||||
list,
|
||||
lambda l: (SimpleNamespace(length=len(l)), list(l)),
|
||||
lambda _, xs: list(xs),
|
||||
),
|
||||
(
|
||||
dict,
|
||||
default_dict_to_iterable,
|
||||
default_dict_from_iterable,
|
||||
),
|
||||
(
|
||||
SimpleNamespace,
|
||||
default_dict_to_iterable,
|
||||
default_dict_from_iterable,
|
||||
),
|
||||
]
|
||||
|
||||
for ty, to_iter, from_iter in default_registrations:
|
||||
register_pytree_node(ty, to_iter, from_iter)
|
||||
|
||||
|
||||
# Initialize default registrations
|
||||
register_default_node_types()
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# tree_flatten and tree_unflatten
|
||||
# =============================================================================
|
||||
|
||||
"""
|
||||
Behavior of tree_flatten and tree_unflatten, for example:
|
||||
|
||||
```python
|
||||
a = (1, 2, 3)
|
||||
b = MyClass(a=1, b =[1,2,3])
|
||||
```
|
||||
|
||||
yields the following tree:
|
||||
|
||||
```python
|
||||
tree_a = PyTreeDef(type = 'tuple',
|
||||
metadata = {length = 3},
|
||||
children = [
|
||||
Leaf(type = int),
|
||||
Leaf(type = int),
|
||||
Leaf(type = int),
|
||||
],
|
||||
)
|
||||
flattened_a = [1, 2, 3]
|
||||
tree_b = PyTreeDef(type = 'MyClass',
|
||||
metadata = {fields = ['a','b']},
|
||||
children = [
|
||||
PyTreeDef(type = `list`,
|
||||
metadata = {length = 3},
|
||||
children = [
|
||||
Leaf(type=`int`),
|
||||
Leaf(type=`int`),
|
||||
Leaf(type=`int`),
|
||||
],
|
||||
),
|
||||
Leaf(type=int),
|
||||
],
|
||||
)
|
||||
flattened_b = [1, 1, 2, 3]
|
||||
```
|
||||
|
||||
Passing the flattened values and PyTreeDef to tree_unflatten to reconstruct the original structure.
|
||||
|
||||
``` python
|
||||
unflattened_a = tree_unflatten(tree_a, flattened_a)
|
||||
unflattened_b = tree_unflatten(tree_b, flattened_b)
|
||||
```
|
||||
|
||||
yields the following structure:
|
||||
|
||||
``` python
|
||||
unflattened_a = (1, 2, 3)
|
||||
unflattened_b = MyClass(a=1, b =[1,2,3])
|
||||
```
|
||||
|
||||
unflattened_a should be structurally identical to a, and unflattened_b should be structurally identical to b.
|
||||
|
||||
"""
|
||||
|
||||
|
||||
def tree_flatten(x: Any) -> tuple[list[Any], PyTreeDef]:
|
||||
"""
|
||||
Flatten a nested structure into a flat list of values and a tree definition.
|
||||
|
||||
This function recursively traverses nested data structures (trees) and
|
||||
flattens them into a linear list of leaf values, while preserving the
|
||||
structure information in a PyTreeDef.
|
||||
|
||||
Args:
|
||||
x: The nested structure to flatten
|
||||
|
||||
Returns:
|
||||
tuple: (flat_values, treedef) where flat_values is a list of leaf values
|
||||
and treedef is the tree structure definition
|
||||
|
||||
Raises:
|
||||
DSLTreeFlattenError: If the structure contains unsupported types
|
||||
|
||||
Example:
|
||||
>>> tree_flatten([1, [2, 3], 4])
|
||||
([1, 2, 3, 4], PyTreeDef(...))
|
||||
"""
|
||||
children_iter, treedef = _tree_flatten(x)
|
||||
return list(children_iter), treedef
|
||||
|
||||
|
||||
def get_registered_node_types_or_insert(x: Any) -> NodeType | None:
|
||||
"""
|
||||
Get the registered node type for an object, registering it if necessary.
|
||||
|
||||
This function checks if a type is already registered for pytree operations.
|
||||
If not, it automatically registers the type based on its characteristics:
|
||||
- Dynamic expressions get registered with dynamic expression handlers
|
||||
- Dataclasses get registered with default dataclass handlers
|
||||
|
||||
Args:
|
||||
x: The object to get or register a node type for
|
||||
|
||||
Returns:
|
||||
NodeType or None: The registered node type, or None if the type
|
||||
cannot be registered
|
||||
"""
|
||||
node_type = _node_types.get(type(x))
|
||||
if node_type:
|
||||
return node_type
|
||||
elif is_dynamic_expression(x):
|
||||
# If a class implements DynamicExpression protocol, register it before default dataclass one
|
||||
return register_pytree_node(
|
||||
type(x), dynamic_expression_to_iterable, dynamic_expression_from_iterable
|
||||
)
|
||||
elif dataclasses.is_dataclass(x):
|
||||
return register_pytree_node(
|
||||
type(x), default_dataclass_to_iterable, default_dataclass_from_iterable
|
||||
)
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
def create_leaf_for_value(
|
||||
x: Any,
|
||||
is_numeric: bool = False,
|
||||
is_none: bool = False,
|
||||
node_metadata: SimpleNamespace = None,
|
||||
ir_type_str: str = None,
|
||||
) -> Leaf:
|
||||
"""
|
||||
Create a Leaf node for a given value.
|
||||
|
||||
Args:
|
||||
x: The value to create a leaf for
|
||||
is_numeric: Whether this is a numeric value
|
||||
is_none: Whether this represents None
|
||||
node_metadata: Optional metadata
|
||||
ir_type_str: Optional IR type string
|
||||
|
||||
Returns:
|
||||
Leaf: The created leaf node
|
||||
"""
|
||||
return Leaf(
|
||||
is_numeric=is_numeric,
|
||||
is_none=is_none,
|
||||
node_metadata=node_metadata,
|
||||
ir_type_str=ir_type_str or (str(x.type) if hasattr(x, "type") else None),
|
||||
)
|
||||
|
||||
|
||||
def _tree_flatten(x: Any) -> tuple[Iterable[Any], PyTreeDef | Leaf]:
|
||||
"""
|
||||
Internal function to flatten a tree structure.
|
||||
|
||||
This is the core implementation of tree flattening that handles different
|
||||
types of objects including None, ArithValue, ir.Value, Numeric types,
|
||||
and registered pytree node types.
|
||||
|
||||
Args:
|
||||
x: The object to flatten
|
||||
|
||||
Returns:
|
||||
tuple: (flattened_values, treedef) where flattened_values is an iterable
|
||||
of leaf values and treedef is the tree structure
|
||||
|
||||
Raises:
|
||||
DSLTreeFlattenError: If the object type is not supported
|
||||
"""
|
||||
match x:
|
||||
case None:
|
||||
return [], create_leaf_for_value(x, is_none=True)
|
||||
|
||||
case ArithValue() if is_dynamic_expression(x):
|
||||
v = x.__extract_mlir_values__()
|
||||
return v, create_leaf_for_value(
|
||||
x,
|
||||
node_metadata=SimpleNamespace(is_dynamic_expression=1, original_obj=x),
|
||||
ir_type_str=str(v[0].type),
|
||||
)
|
||||
|
||||
case ArithValue():
|
||||
return [x], create_leaf_for_value(x, is_numeric=True)
|
||||
|
||||
case ir.Value():
|
||||
return [x], create_leaf_for_value(x)
|
||||
|
||||
case Numeric():
|
||||
v = x.__extract_mlir_values__()
|
||||
return v, create_leaf_for_value(
|
||||
x,
|
||||
node_metadata=SimpleNamespace(is_dynamic_expression=1, original_obj=x),
|
||||
ir_type_str=str(v[0].type),
|
||||
)
|
||||
|
||||
case _:
|
||||
node_type = get_registered_node_types_or_insert(x)
|
||||
if node_type:
|
||||
node_metadata, children = node_type.to_iterable(x)
|
||||
children_flat, child_trees = unzip2(map(_tree_flatten, children))
|
||||
flattened = it.chain.from_iterable(children_flat)
|
||||
return flattened, PyTreeDef(
|
||||
node_type, node_metadata, tuple(child_trees)
|
||||
)
|
||||
|
||||
# Try to convert to numeric
|
||||
try:
|
||||
nval = as_numeric(x).ir_value()
|
||||
return [nval], create_leaf_for_value(nval, is_numeric=True)
|
||||
except Exception:
|
||||
raise DSLTreeFlattenError(
|
||||
"Flatten Error", get_fully_qualified_class_name(x)
|
||||
)
|
||||
|
||||
|
||||
def tree_unflatten(treedef: PyTreeDef, xs: list[Any]) -> Any:
|
||||
"""
|
||||
Reconstruct a nested structure from a flat list of values and tree definition.
|
||||
|
||||
This is the inverse operation of tree_flatten. It takes the flattened
|
||||
values and the tree structure definition to reconstruct the original
|
||||
nested structure.
|
||||
|
||||
Args:
|
||||
treedef: The tree structure definition from tree_flatten
|
||||
xs: List of flat values to reconstruct from
|
||||
|
||||
Returns:
|
||||
The reconstructed nested structure
|
||||
|
||||
Example:
|
||||
>>> flat_values, treedef = tree_flatten([1, [2, 3], 4])
|
||||
>>> tree_unflatten(treedef, flat_values)
|
||||
[1, [2, 3], 4]
|
||||
"""
|
||||
return _tree_unflatten(treedef, iter(xs))
|
||||
|
||||
|
||||
def _tree_unflatten(treedef: PyTreeDef | Leaf, xs: Iterator[Any]) -> Any:
|
||||
"""
|
||||
Internal function to reconstruct a tree structure.
|
||||
|
||||
This is the core implementation of tree unflattening that handles
|
||||
different types of tree definitions including Leaf nodes and PyTreeDef nodes.
|
||||
|
||||
Args:
|
||||
treedef: The tree structure definition
|
||||
xs: Iterator of flat values to reconstruct from
|
||||
|
||||
Returns:
|
||||
The reconstructed object
|
||||
"""
|
||||
match treedef:
|
||||
case Leaf(is_none=True):
|
||||
return None
|
||||
|
||||
case Leaf(
|
||||
node_metadata=metadata
|
||||
) if metadata and metadata.is_dynamic_expression:
|
||||
return metadata.original_obj.__new_from_mlir_values__([next(xs)])
|
||||
|
||||
case Leaf(is_numeric=True):
|
||||
return as_numeric(next(xs))
|
||||
|
||||
case Leaf():
|
||||
return next(xs)
|
||||
|
||||
case PyTreeDef():
|
||||
children = (_tree_unflatten(t, xs) for t in treedef.child_treedefs)
|
||||
return treedef.node_type.from_iterable(treedef.node_metadata, children)
|
||||
|
||||
|
||||
def _check_tree_equal(lhs: Union[PyTreeDef, Leaf], rhs: Union[PyTreeDef, Leaf]) -> bool:
|
||||
"""
|
||||
Check if two tree definitions are structurally equal.
|
||||
|
||||
This is a helper function for check_tree_equal that recursively compares
|
||||
tree structures.
|
||||
|
||||
Args:
|
||||
lhs: Left tree definition (PyTreeDef or Leaf)
|
||||
rhs: Right tree definition (PyTreeDef or Leaf)
|
||||
|
||||
Returns:
|
||||
bool: True if the trees are structurally equal, False otherwise
|
||||
"""
|
||||
match (lhs, rhs):
|
||||
case (Leaf(), Leaf()):
|
||||
return lhs.is_none == rhs.is_none and lhs.ir_type_str == rhs.ir_type_str
|
||||
|
||||
case (PyTreeDef(), PyTreeDef()):
|
||||
lhs_metadata = lhs.node_metadata
|
||||
rhs_metadata = rhs.node_metadata
|
||||
|
||||
lhs_fields = getattr(lhs_metadata, "fields", [])
|
||||
rhs_fields = getattr(rhs_metadata, "fields", [])
|
||||
lhs_constexpr_fields = getattr(lhs_metadata, "constexpr_fields", [])
|
||||
rhs_constexpr_fields = getattr(rhs_metadata, "constexpr_fields", [])
|
||||
|
||||
return (
|
||||
lhs.node_type == rhs.node_type
|
||||
and lhs_fields == rhs_fields
|
||||
and lhs_constexpr_fields == rhs_constexpr_fields
|
||||
and len(lhs.child_treedefs) == len(rhs.child_treedefs)
|
||||
and all(map(_check_tree_equal, lhs.child_treedefs, rhs.child_treedefs))
|
||||
)
|
||||
|
||||
case _:
|
||||
return False
|
||||
|
||||
|
||||
def check_tree_equal(lhs: PyTreeDef, rhs: PyTreeDef) -> int:
|
||||
"""
|
||||
Check if two tree definitions are equal and return the index of first difference.
|
||||
|
||||
This function compares two tree definitions and returns the index of the
|
||||
first child that differs, or -1 if they are completely equal.
|
||||
|
||||
Args:
|
||||
lhs: Left tree definition
|
||||
rhs: Right tree definition
|
||||
|
||||
Returns:
|
||||
int: Index of the first differing child, or -1 if trees are equal
|
||||
|
||||
Example:
|
||||
>>> treedef1 = tree_flatten([1, [2, 3]])[1]
|
||||
>>> treedef2 = tree_flatten([1, [2, 4]])[1]
|
||||
>>> check_tree_equal(treedef1, treedef2)
|
||||
1 # The second child differs
|
||||
"""
|
||||
assert len(lhs.child_treedefs) == len(rhs.child_treedefs)
|
||||
|
||||
def find_first_difference(
|
||||
index_and_pair: tuple[int, tuple[PyTreeDef, PyTreeDef]]
|
||||
) -> int:
|
||||
index, (l, r) = index_and_pair
|
||||
return index if not _check_tree_equal(l, r) else -1
|
||||
|
||||
differences = map(
|
||||
find_first_difference, enumerate(zip(lhs.child_treedefs, rhs.child_treedefs))
|
||||
)
|
||||
return next((diff for diff in differences if diff != -1), -1)
|
||||
@@ -1,3 +1,3 @@
|
||||
# Use `pip install -r requirements.txt` with the present file to install a
|
||||
# wheel consistent with the present state of the github repository
|
||||
nvidia-cutlass-dsl==4.1.0
|
||||
nvidia-cutlass-dsl==4.2.0
|
||||
|
||||
Reference in New Issue
Block a user