v4.1 release update v2. (#2481)

This commit is contained in:
Junkai-Wu
2025-07-22 10:03:55 +08:00
committed by GitHub
parent 9baa06dd57
commit fd6cfe1ed0
179 changed files with 7878 additions and 1286 deletions

View File

@@ -300,6 +300,21 @@ def if_executor(
class range:
"""
A range-like object for dynamic loop iteration in the DSL.
This class provides a range interface similar to Python's built-in range,
but is designed to be preprocessed into constructs for dynamic
loop execution.
The class supports both single-argument (stop) and three-argument
(start, stop, step) constructors with additional parameters for loop
optimization:
- unroll: Number of iterations to unroll (0 or 1 = no unrolling)
- unroll_full: Whether to fully unroll the loop
- pipelining: Compiler generated pipeline configuration
"""
@overload
def __new__(cls, stop, unroll=0, unroll_full=False, pipelining=None):
pass
@@ -460,7 +475,31 @@ def range_value_check(*args):
Ensure all `range_constexpr` bounds are compile-time constants (Python ints).
"""
try:
return tuple(arg.__index__() for arg in args)
args = tuple(arg.__index__() for arg in args)
# Compute range size and warn if it's too large
start = 0
end = 0
step = 1
if len(args) == 1:
end = args[0]
elif len(args) == 2:
start = args[0]
end = args[1]
elif len(args) == 3:
start = args[0]
end = args[1]
step = args[2]
range_length = (abs(end - start) - 1) // abs(step) + 1
if range_length >= 64:
warnings.warn(
f"This static loop has {range_length} iterations, which may be very slow to compile, consider using `cutlass.range(..., unroll_full=True)` instead.",
category=UserWarning,
stacklevel=2,
)
return (start, end, step)
except:
raise DSLRuntimeError(
"`range_constexpr` requires constexpr (compile-time constant) for all arguments.",
@@ -477,8 +516,8 @@ def range_perf_warning(filename, lineno, *args):
if not has_dynamic_expr:
warnings.warn_explicit(
(
"The loop was previously unrolled in Python, but now it may not unroll in IR. This may cause performance regression."
"If you want to unroll the loop in Python, please use `range_constexpr` instead of `range`."
"This loop is no longer unrolled and may cause performance regression. "
"Use `range(..., unroll_full=True)` for full unrolling, or switch to `range_constexpr` when bounds are compile-time constants."
),
category=UserWarning,
filename=filename,

View File

@@ -102,6 +102,8 @@ class ScopeManager:
return cls([])
def add_to_scope(self, name: str) -> None:
if name == "_":
return
self.scopes[-1].add(name)
def get_active_symbols(self) -> List[Set[str]]:
@@ -361,13 +363,13 @@ class DSLPreprocessor(ast.NodeTransformer):
isinstance(func, ast.Name)
and func.id in self.SUPPORTED_FOR_RANGE_STATEMENTS
):
return func.id, True
return func.id, True, len(iter_node.keywords) != 0
if (
isinstance(func, ast.Attribute)
and func.attr in self.SUPPORTED_FOR_RANGE_STATEMENTS
):
return func.attr, False
return None, None
return func.attr, False, len(iter_node.keywords) != 0
return None, None, None
def transform(self, original_function, exec_globals):
"""
@@ -378,6 +380,7 @@ class DSLPreprocessor(ast.NodeTransformer):
transformed_tree = self.transform_function(
original_function.__name__, original_function
)
self.function_globals = None
unified_tree = ast.Module(body=transformed_tree, type_ignores=[])
unified_tree = ast.fix_missing_locations(unified_tree)
@@ -731,7 +734,7 @@ class DSLPreprocessor(ast.NodeTransformer):
self.scope_manager.add_to_scope(node.target.id)
# For static for loop (for with range_constexpr or not range based for), preprocessor keeps the loop.
range_kind, is_builtin_range = self._get_range_kind(node.iter)
range_kind, is_builtin_range, has_keyword = self._get_range_kind(node.iter)
if range_kind == "range_constexpr" or range_kind == None:
self.generic_visit(node)
if range_kind == "range_constexpr":
@@ -752,7 +755,7 @@ class DSLPreprocessor(ast.NodeTransformer):
warnings.simplefilter("default", DeprecationWarning) # reset filter
warning_call = None
if range_kind == "range" and is_builtin_range:
if range_kind == "range" and is_builtin_range and not has_keyword:
# Warn about possible performance regression due to behavior change
warning_call = ast.Expr(
ast.Call(
@@ -1109,6 +1112,12 @@ class DSLPreprocessor(ast.NodeTransformer):
self.generic_visit(node)
return node
def visit_Name(self, node):
self.generic_visit(node)
if node.id == "_" and isinstance(node.ctx, ast.Load):
raise DSLAstPreprocessorError("Read '_' is not allowed")
return node
def check_decorator(self, node: ast.AST) -> bool:
"""
Check if the function has the correct decorator for preprocessing.

View File

@@ -19,7 +19,9 @@ from typing import Sequence, Optional, Tuple
import os
import sys
import inspect
import argparse
from .common import DSLRuntimeError
from .utils.logger import log
_SCRIPT_PATH = os.path.dirname(os.path.abspath(__file__))
sys.path.append(_SCRIPT_PATH)
@@ -182,7 +184,67 @@ class Compiler:
return self.jit(module, opt_level, shared_libs)
class CompileOptions:
def __init__(self, options: str = ""):
"""
This class encapsulates all compilation options relevant to function compilation.
It provides a convenient way to manage and pass compilation options,
particularly for controlling compilation settings.
By centralizing these options, it ensures consistent and flexible configuration of
compilation parameters such as optimization level, debugging control, etc.
:param options: The options for the function. Will be parsed by argparse.
:type options: str
"""
if not isinstance(options, str):
raise DSLRuntimeError(
f"Invalid compilation `options`: {options}, it should be a string"
)
self._parser = argparse.ArgumentParser()
self._parser.add_argument("--opt-level", nargs="?", type=int, default=3)
self._parser.add_argument(
"--enable-device-assertions", action="store_true", default=False
)
try:
self._options = self._parser.parse_args(options.split())
except SystemExit as e:
# catch argparse error and raise as DSLRuntimeError
raise DSLRuntimeError(
f"Invalid compile options: '{options}'. Please check the option values and format."
)
log().info("`cute.compile` CompileOptions: options=" + options)
def to_str(self):
"""
Generate a string representation of all compilation options
which will be used in pipeline options.
"""
option_strings = []
for key, value in vars(self._options).items():
hyphen_key = key.replace("_", "-")
if isinstance(value, bool):
formatted_value = "true" if value else "false"
else:
formatted_value = str(value)
option_strings.append(f"{hyphen_key}={formatted_value}")
return " ".join(option_strings)
def compile(func, *args, **kwargs):
"""
This function is used to compile a `cute.jit` decorated function.
It will process the compile options and input parameters, do explicit compilation and return the jit executor.
:param func: The function to compile. It can be a regular function, a method or a class instance.
:param args: The arguments to pass to the function.
:param kwargs: The keyword arguments to pass to the function. It can contain `options` like
`opt_level` to control the compilation flags.
:return: The jit executor.
:raises: DSLRuntimeError if the function is not decorated with `cute.jit` or is not callable.
"""
if func is None:
raise DSLRuntimeError("Function is not set or invalid.")
@@ -217,5 +279,8 @@ def compile(func, *args, **kwargs):
if not hasattr(func, "_dsl_object"):
raise DSLRuntimeError("Function is not decorated with jit decorator.")
# process compile options, extract the options and remove them from the kwargs
options = kwargs.pop("options", "")
func._dsl_object.compile_options = CompileOptions(options)
fcn_ptr = func._dsl_object._preprocess_and_execute(func)
return func._dsl_object._func(fcn_ptr, *args, **kwargs)

View File

@@ -38,6 +38,7 @@ import warnings
from . import typing as t
from .env_manager import EnvironmentVarManager
from .compiler import CompileOptions
# =============================================================================
# CUDA Python
@@ -232,6 +233,50 @@ def new_from_mlir_values(obj, values):
return obj
class DSLCallable:
"""
Wrapper class for a callable object used within the DSL.
DSLCallable is designed to wrap a function and provide additional
introspection utilities such as retrieving the argument specification
and signature. It ensures that the wrapped function can only be called
once, after which the reference to the function is cleared to prevent
further invocations. This is useful in scenarios where a function should
only be executed a single time within the DSL's execution model.
Attributes:
func (callable): The function to be wrapped and managed.
Methods:
__call__(*args, **kwargs): Calls the wrapped function and clears it.
get_arg_spec(): Returns the argument specification of the function.
get_signature(): Returns the signature of the function.
"""
def __init__(self, func):
self.func = func
def __call__(self, *args, **kwargs):
ret = self.__func__(*args, **kwargs)
self.func = None
return ret
@property
def __func__(self):
assert self.func is not None, "DSLCallable is already called"
return self.func
@property
def __name__(self):
return self.__func__.__name__
def get_arg_spec(self):
return inspect.getfullargspec(self.__func__)
def get_signature(self):
return inspect.signature(self.__func__)
class BaseDSL:
gpu_module = None
@@ -306,6 +351,8 @@ class BaseDSL:
self.kernel_symbols = []
# used to generate unique name for gpu.launch
self.launch_inner_count = 0
# initialize default compile options
self.compile_options = CompileOptions()
if preprocess:
self.preprocessor = DSLPreprocessor()
@@ -392,26 +439,24 @@ class BaseDSL:
if hasattr(func, "_transformed_ast"):
# If the function ptr is already materialized, use the existing one
func._dsl_object.frame = func._decorator_frame
if func._transformed_ast is None:
func._transformed_ast = func._dsl_object.run_preprocessor(func)
if func._transformed_ast is None:
del func._decorator_frame
del func._transformed_ast
func._dsl_object.frame = None
return func
fcn_ptr = func._dsl_object.get_function_ptr(func, func._transformed_ast)
fcn_ptr = func._dsl_object.get_function_ptr(func)
# If the function is decorated, de-decorate it
fcn_ptr = BaseDSL._get_original_function(fcn_ptr, func.__name__)
return fcn_ptr
func._dsl_object.frame = None
return DSLCallable(fcn_ptr)
return func
def jit_runner(self, frame, executor, *dargs, **dkwargs):
def jit_runner(self, executor, frame, *dargs, **dkwargs):
"""
Decorator to mark a function for JIT compilation.
"""
# Set the frame, that can be used AST preprocessor
self.frame = frame
log().info("jit_runner")
def jit_runner_decorator(func):
@@ -444,7 +489,7 @@ class BaseDSL:
frame = inspect.currentframe().f_back
# Instantiate the DSL Class
main_dsl = cls._get_dsl()
return main_dsl.jit_runner(frame, main_dsl._func, *dargs, **dkwargs)
return main_dsl.jit_runner(main_dsl._func, frame, *dargs, **dkwargs)
@classmethod
def kernel(cls, *dargs, **dkwargs):
@@ -454,7 +499,7 @@ class BaseDSL:
frame = inspect.currentframe().f_back
# Instantiate the DSL Class
main_dsl = cls._get_dsl()
return main_dsl.jit_runner(frame, main_dsl._kernel_helper, *dargs, **dkwargs)
return main_dsl.jit_runner(main_dsl._kernel_helper, frame, *dargs, **dkwargs)
@abstractmethod
def _kernel_helper(self, func, *args, **kwargs):
@@ -627,6 +672,12 @@ class BaseDSL:
pass
@abstractmethod
def _get_module_globals(self):
"""
Get the module's globals.
"""
pass
def _get_globals(self):
"""
Combines global and local variables from the current context and the
@@ -639,7 +690,11 @@ class BaseDSL:
AST preprocessor generates a new python code, so the resulting globals
dictionary is used to execute the python code.
"""
pass
all_globals = self._get_module_globals().copy()
if self.frame:
all_globals.update(self.frame.f_globals)
all_globals.update(self.frame.f_locals)
return all_globals
def _is_tensor_descriptor(self, maybe_tensor_descriptor) -> bool:
return isinstance(
@@ -881,20 +936,15 @@ class BaseDSL:
Get python location information and generate MLIR location
"""
frame = self.frame
if frame is None:
print("Frame is None")
if self.frame is None:
log().debug("Frame is None")
return None
file_loc = ir.Location.file(frame.f_code.co_filename, frame.f_lineno, 0)
file_loc = ir.Location.file(
self.frame.f_code.co_filename, self.frame.f_lineno, 0
)
def print_all_frames():
for i, frame in enumerate(inspect.stack()):
print(
f"Frame {i}: {frame.function} in {frame.filename}, line {frame.lineno}"
)
loc = ir.Location.name(frame.f_code.co_name, childLoc=file_loc)
loc = ir.Location.name(self.frame.f_code.co_name, childLoc=file_loc)
return loc
def compile_and_jit(self, module, pipeline, shared_libs, function_name=""):
@@ -992,6 +1042,8 @@ class BaseDSL:
for attr, value in self.envar.__dict__.items():
if value is not None:
s.write(str(value).encode())
# Add compile options to the hash
s.write(self.compile_options.to_str().encode())
module_hash = self.get_version().copy()
module_hash.update(s.getvalue())
module_hash = module_hash.hexdigest()
@@ -1145,6 +1197,8 @@ class BaseDSL:
self.launch_inner_count = 0
# reset num_kernels to 0 for next compilation.
self.num_kernels = 0
# reset the compile options after the compilation is done.
self.compile_options = CompileOptions()
def generate_mlir(
self,
@@ -1226,9 +1280,11 @@ class BaseDSL:
return transformed_ast
return None
def get_function_ptr(self, original_function, transformed_ast):
def get_function_ptr(self, original_function):
file_name = inspect.getsourcefile(original_function)
code_object = compile(transformed_ast, filename=file_name, mode="exec")
code_object = compile(
original_function._transformed_ast, filename=file_name, mode="exec"
)
return self.preprocessor.exec(
original_function.__name__,
original_function,
@@ -1236,10 +1292,6 @@ class BaseDSL:
self._get_globals(),
)
@lru_cache(maxsize=None)
def _get_function_signature(self, func):
return inspect.signature(func)
def _get_function_bound_args(self, sig, func_name, *args, **kwargs):
"""
Binds provided arguments to a function's signature and applies default values.
@@ -1260,12 +1312,11 @@ class BaseDSL:
)
return bound_args
def _canonicalize_args(self, *args, **kwargs):
def _canonicalize_args(self, sig, *args, **kwargs):
"""
Canonicalize the input arguments so that returned args only contain
positional arguments and kwargs only contain keyword arguments.
"""
sig = self._get_function_signature(self.funcBody)
function_name = self.funcBody.__name__
bound_args = self._get_function_bound_args(sig, function_name, *args, **kwargs)
canonicalized_args = bound_args.args
@@ -1276,8 +1327,11 @@ class BaseDSL:
if not self.funcBody:
raise DSLRuntimeError("Function body is not set.")
# Pass the actual function object to _get_function_signature.
sig = self._get_function_signature(self.funcBody)
# Pass the actual function object to inspect.signature to get the signature.
if isinstance(self.funcBody, DSLCallable):
sig = self.funcBody.get_signature()
else:
sig = inspect.signature(self.funcBody)
function_name = self.funcBody.__name__
bound_args = self._get_function_bound_args(sig, function_name, *args, **kwargs)
@@ -1292,6 +1346,8 @@ class BaseDSL:
f"Missing required argument in `{function_name}`: '{param.name}'"
)
return sig
def _func(self, funcBody, *args, **kwargs):
"""Decorator for MLIR functions.
It cuts the boilerplate code, does the following:
@@ -1324,13 +1380,16 @@ class BaseDSL:
self.print_warning("Cache is disabled as user wants to compile only.")
# Check the number of arguments
self._check_arg_count(*args, **kwargs)
sig = self._check_arg_count(*args, **kwargs)
args_spec = inspect.getfullargspec(funcBody)
if isinstance(funcBody, DSLCallable):
args_spec = funcBody.get_arg_spec()
else:
args_spec = inspect.getfullargspec(funcBody)
# Canonicalize the input arguments
canonicalized_args, canonicalized_kwargs = self._canonicalize_args(
*args, **kwargs
sig, *args, **kwargs
)
# Simple name mangling
@@ -1528,7 +1587,10 @@ class BaseDSL:
kernelGenHelper = dkwargs.get("kernelGenHelper", None)
kernel_name = funcBody.__name__
args_spec = inspect.getfullargspec(funcBody)
if isinstance(funcBody, DSLCallable):
args_spec = funcBody.get_arg_spec()
else:
args_spec = inspect.getfullargspec(funcBody)
self.funcBody = funcBody
# Give each kernel a unique name. (The same kernel may be
@@ -1568,11 +1630,11 @@ class BaseDSL:
), "kernelGenHelper should be explicitly specified!"
# check arguments
self._check_arg_count(*args, **kwargs)
sig = self._check_arg_count(*args, **kwargs)
# Canonicalize the input arguments
canonicalized_args, canonicalized_kwargs = self._canonicalize_args(
*args, **kwargs
sig, *args, **kwargs
)
kernel_operands, kernel_types, kernel_arg_attrs = (

View File

@@ -527,7 +527,16 @@ class IntegerMeta(NumericMeta):
return 2**cls.width - 1
def recast_width(cls, width):
return eval(f"Int{width}")
type_map = {
8: Int8,
16: Int16,
32: Int32,
64: Int64,
128: Int128,
}
if width not in type_map:
raise TypeError(f"Unsupported width: {width}")
return type_map[width]
class FloatMeta(NumericMeta):
@@ -603,7 +612,14 @@ class FloatMeta(NumericMeta):
return cls._mantissa_width
def recast_width(cls, width):
return eval(f"Float{width}")
type_map = {
16: Float16,
32: Float32,
64: Float64,
}
if width not in type_map:
raise TypeError(f"Unsupported width: {width}")
return type_map[width]
def _arith_signless_to_int(a, target_type):