# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: LicenseRef-NvidiaProprietary # # Use of this software is governed by the terms and conditions of the # NVIDIA End User License Agreement (EULA), available at: # https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html # # Any use, reproduction, disclosure, or distribution of this software # and related documentation outside the scope permitted by the EULA # is strictly prohibited. """ This module provides helper functions that are generated by the preprocessor. The preprocessor read through python's ast and changes the input code. """ from typing import Callable, Iterator, Optional, overload from .utils.logger import log from .common import * from ._mlir_helpers.arith import ArithValue class Executor: """ The Executor class handles dynamic and compile-time (constexpr) execution of "for" loops and "if-else-elif" statements. Methods: set_functions: Assigns the functions for checking loop bounds and conditional evaluation. for_dynamic: Generates MLIR for OP for_constexpr: Executes a for loop at JIT compile-time for_execute: Decides whether to execute the loop at compile-time or generate MLIR for OP based on the provided bounds. if_dynamic: Generates MLIR if OP if_constexpr: Executes a if at JIT compile-time by python interpreter if_execute: Decides whether to execute the if statement at compile-time or generate MLIR if OP based on the predicate. """ def __init__(self): self._is_dynamic_expression = None self._loop_execute_range_dynamic = None self._if_dynamic = None self._while_dynamic = None def set_functions( self, is_dynamic_expression: Callable, loop_execute_range_dynamic: Callable, if_dynamic: Callable, while_dynamic: Callable, ): self._is_dynamic_expression = is_dynamic_expression self._loop_execute_range_dynamic = loop_execute_range_dynamic self._if_dynamic = if_dynamic self._while_dynamic = while_dynamic @staticmethod def convert_to_list(x): """This function is used to convert x to a list. If x is None, return an empty list. If x is not a list, return a list containing x. Otherwise, return x itself. """ if x is None: return [] if not isinstance(x, list): return [x] return x @staticmethod def converge_ret_val(res): """This function is used to converge res (the return value) of the function. If res is None, return None. If res is a list and has only one element, return the element. Otherwise, return res itself. """ if res is None: return res elif isinstance(res, list) and len(res) == 1: return res[0] return res def for_dynamic( self, func: Callable, start, stop, step, used_args: list, iter_args: list, iter_arg_names: list, unroll=bool, unroll_full=int, ): log().info("start [%s] stop [%s] step [%s]", start, stop, step) return self._loop_execute_range_dynamic( func, start, stop, step, used_args, iter_args, iter_arg_names, unroll, unroll_full, ) @staticmethod def for_constexpr( func: Callable, start: int, stop: int, step: int, used_args: list, iter_args: list, ): log().info("start [%s] stop [%s] step [%s]", start, stop, step) loop_results = iter_args log().debug("iter_args [%s]", iter_args) for i in range(start, stop, step): log().debug("i [%s] iter_args [%s]", i, iter_args) loop_results = func(i, *used_args, *loop_results) log().debug("loop_results [%s]", loop_results) if loop_results is None: loop_results = [] if not isinstance(loop_results, list): loop_results = [loop_results] log().debug("done loop_results [%s]", loop_results) return Executor.converge_ret_val(loop_results) def for_execute( self, func, start, stop, step, used_args=[], iter_args=[], iter_arg_names=[], unroll=-1, unroll_full=False, is_range_constexpr=None, ): assert ( self._loop_execute_range_dynamic and self._is_dynamic_expression ), "Functions must be set before execution." log().debug("start [%s] stop [%s] step [%s]", start, stop, step) any_dynamic_expression = ( self._is_dynamic_expression(start) or self._is_dynamic_expression(stop) or self._is_dynamic_expression(step) ) if is_range_constexpr is None: if not any_dynamic_expression: return self.for_constexpr(func, start, stop, step, used_args, iter_args) else: return self.for_dynamic( func, start, stop, step, used_args, iter_args, iter_arg_names, unroll, unroll_full, ) # Ensure bounds are compile-time constants for constexpr execution if is_range_constexpr: if any_dynamic_expression: raise DSLRuntimeError( "Loop bounds must be constexpr (compile-time constants)" ) return self.for_constexpr(func, start, stop, step, used_args, iter_args) # MLIR generation return self.for_dynamic( func, start, stop, step, used_args, iter_args, iter_arg_names, unroll, unroll_full, ) def if_dynamic( self, pred, then_block: Callable, else_block: Optional[Callable] = None, used_args=[], yield_args=[], yield_arg_names=[], ): return self._if_dynamic( pred, then_block, else_block, used_args, yield_args, yield_arg_names ) @staticmethod def if_constexpr( pred, then_block: Callable, else_block: Optional[Callable] = None, used_args=[], yield_args=[], ): if pred: log().debug(" running then block [%s]", yield_args) res = then_block(*used_args, *yield_args) log().debug("result [%s]", res) return Executor.converge_ret_val(res) elif else_block is not None: log().debug("running else [%s]", yield_args) res = else_block(*used_args, *yield_args) log().debug("result [%s]", res) return Executor.converge_ret_val(res) def if_execute( self, pred, then_block: Callable, else_block: Optional[Callable] = None, used_args=[], yield_args=[], yield_arg_names=[], if_constexpr=None, ): assert ( self._if_dynamic and self._is_dynamic_expression ), "Functions must be set before execution." is_if_constexpr = not self._is_dynamic_expression(pred) if if_constexpr is None: if is_if_constexpr: return self.if_constexpr( pred, then_block, else_block, used_args, yield_args ) else: return self.if_dynamic( pred, then_block, else_block, used_args, yield_args, yield_arg_names ) # Ensure bounds are compile-time constants for constexpr execution if if_constexpr: if not is_if_constexpr: raise DSLRuntimeError( "If predicate must be constexpr (compile-time constants)" ) return self.if_constexpr( pred, then_block, else_block, used_args, yield_args ) # MLIR generation return self.if_dynamic( pred, then_block, else_block, used_args, yield_args, yield_arg_names ) def while_dynamic( self, while_before_block: Callable, while_after_block: Callable, used_args=[], yield_args=[], yield_arg_names=[], ): return self._while_dynamic( while_before_block, while_after_block, used_args, yield_args, yield_arg_names, ) @staticmethod def while_constexpr( while_before_block, while_after_block, used_args=[], yield_args=[], ): log().debug( "while_constexpr begin %s", while_before_block.__qualname__ ) cond, loop_results = while_before_block(*used_args, *yield_args) while cond: loop_results = Executor.convert_to_list(loop_results) log().debug( "calling while_after [%s], [%s]", used_args, loop_results, ) loop_results = while_after_block(*used_args, *loop_results) log().debug( "while after [%s]", loop_results ) loop_results = Executor.convert_to_list(loop_results) log().debug( "calling while_before [%s], [%s]", used_args, loop_results, ) cond, loop_results = while_before_block(*used_args, *loop_results) log().debug( "while_before cond, results [%s], [%s]", cond, loop_results, ) log().debug( "while_constexpr results %s", loop_results ) return Executor.converge_ret_val(loop_results) def while_execute( self, pred, while_before_block: Callable, while_after_block: Callable, used_args=[], yield_args=[], yield_arg_names=[], while_constexpr=None, ): assert ( self._while_dynamic and self._is_dynamic_expression ), "Functions must be set before execution." is_while_constexpr = not self._is_dynamic_expression(pred) # Ensure bounds are compile-time constants for constexpr execution if while_constexpr: if not is_while_constexpr: raise DSLRuntimeError( "While predicate must be constexpr (compile-time constants)" ) return self.while_constexpr( while_before_block, while_after_block, used_args, yield_args ) # MLIR generation return self.while_dynamic( while_before_block, while_after_block, used_args, yield_args, yield_arg_names, ) # ============================================================================= # Decorator # ============================================================================= executor = Executor() def loop_selector( start, stop, step, used_args=[], iter_args=[], iter_arg_names=[], unroll=-1, unroll_full=False, constexpr=None, ): log().info( "start [%s] stop [%s] step [%s] used_args [%s] iter_args [%s] unroll [%s] unroll_full [%s] constexpr [%s]", start, stop, step, used_args, iter_args, unroll, unroll_full, constexpr, ) from .typing import Integer, Numeric def _maybe_upcast(value): if isinstance(value, Integer): value = value.ir_value() return value start = _maybe_upcast(start) stop = _maybe_upcast(stop) step = _maybe_upcast(step) def ir_loop(func): return executor.for_execute( func, start, stop, step, used_args, iter_args, iter_arg_names, unroll, unroll_full, constexpr, ) return ir_loop def if_selector(pred, used_args=[], yield_args=[]): log().info("pred [%s] used_args [%s] yield_args [%s]", pred, used_args, yield_args) # Handle Numeric types here? from .typing import Numeric if isinstance(pred, Numeric): pred = pred.value def ir_loop(func): return func(pred, *used_args, *yield_args) return ir_loop def while_selector(pred, used_args=[], yield_args=[]): def ir_while_loop(func): return func(pred, *used_args, *yield_args) return ir_while_loop def while_executor( pred, while_before_block: Callable, while_after_block: Callable, used_args=[], yield_args=[], yield_arg_names=[], constexpr=None, ): return executor.while_execute( pred, while_before_block, while_after_block, used_args, yield_args, yield_arg_names, constexpr, ) def if_executor( pred, then_block: Callable, else_block: Optional[Callable] = None, used_args=[], yield_args=[], yield_arg_names=[], constexpr=None, ): return executor.if_execute( pred, then_block, else_block, used_args, yield_args, yield_arg_names, constexpr ) # ============================================================================= # Range # ============================================================================= class range_dynamic: @overload def __new__(cls, stop, unroll=0, unroll_full=False): pass @overload def __new__(cls, start, stop, step, unroll=0, unroll_full=False): pass def __new__(cls, *args, **kwargs): raise DSLRuntimeError("range_dynamic should be always preprocessed to IR") class range_constexpr: def __init__(self, *args): if len(args) == 1: self.start = 0 self.stop = args[0] self.step = 1 elif len(args) == 2: self.start, self.stop = args self.step = 1 elif len(args) == 3: self.start, self.stop, self.step = args else: raise DSLRuntimeError( "range_constexpr supports up to 3 arguments (start, stop, step)" ) # Ensure the arguments are compile-time constants (if required) for arg_name, arg_value in [ ("step", self.step), ("start", self.start), ("stop", self.stop), ]: if executor._is_dynamic_expression(arg_value): raise DSLRuntimeError( f"`range_constexpr` requires `constexpr` (non-IR Values) for all arguments, " f"but `{arg_name}` is not. If the arguments are dynamic, use `range`; the DSL " f"will handle them during runtime. ", suggestion="Use `range` instead of `range_constexpr`.", ) def __iter__(self) -> Iterator[int]: current = self.start while current < self.stop: yield current current += self.step # ============================================================================= # If expressions # ============================================================================= def const_expr(expression): if executor._is_dynamic_expression(expression): raise DSLRuntimeError( f"The function `const_expr({expression})` received a dynamic expression (non compile-time constant).", context={ "const_expr": "Accepts only constexpr (compile-time constant)", "If your expression depends on dynamic values": "Avoid marking it as `const_expr()`", "If the expression could be either dynamic or constexpr": "Omit explicit `const_expr()` marker; the DSL will infer the correct handling automatically", }, ) return expression def dynamic_expr(expression): raise DSLRuntimeError("dynamic_expr should be always preprocessed to IR") # ============================================================================= # Assertion & casting # ============================================================================= def assert_executor(test, msg=None): from .typing import Numeric fail = False # Implicit convert dynamic expression to bool is not allowed # So here explicitly do a None check if test is not None and executor._is_dynamic_expression(test): if isinstance(test, Numeric): try: test = test.to(bool) except: fail = True else: fail = True if not fail: assert test, msg else: raise DSLRuntimeError( "Only constexpr (Python Value) is allowed here, but got non-constexpr (IR Values) expression.", suggestion = "Please replace with runtime assert." ) def bool_cast(value): if executor._is_dynamic_expression(value): raise DSLRuntimeError( "Only constexpr (Python Value) is allowed here, but got non-constexpr (IR Values) expression.", suggestion = "Please explicitly convert to boolean with expressions like comparision." ) return bool(value)