mirror of
https://github.com/NVIDIA/cutlass.git
synced 2026-05-12 09:15:56 +00:00
v4.1 release update v2. (#2481)
This commit is contained in:
@@ -300,6 +300,21 @@ def if_executor(
|
||||
|
||||
|
||||
class range:
|
||||
"""
|
||||
A range-like object for dynamic loop iteration in the DSL.
|
||||
|
||||
This class provides a range interface similar to Python's built-in range,
|
||||
but is designed to be preprocessed into constructs for dynamic
|
||||
loop execution.
|
||||
|
||||
The class supports both single-argument (stop) and three-argument
|
||||
(start, stop, step) constructors with additional parameters for loop
|
||||
optimization:
|
||||
|
||||
- unroll: Number of iterations to unroll (0 or 1 = no unrolling)
|
||||
- unroll_full: Whether to fully unroll the loop
|
||||
- pipelining: Compiler generated pipeline configuration
|
||||
"""
|
||||
@overload
|
||||
def __new__(cls, stop, unroll=0, unroll_full=False, pipelining=None):
|
||||
pass
|
||||
@@ -460,7 +475,31 @@ def range_value_check(*args):
|
||||
Ensure all `range_constexpr` bounds are compile-time constants (Python ints).
|
||||
"""
|
||||
try:
|
||||
return tuple(arg.__index__() for arg in args)
|
||||
args = tuple(arg.__index__() for arg in args)
|
||||
|
||||
# Compute range size and warn if it's too large
|
||||
start = 0
|
||||
end = 0
|
||||
step = 1
|
||||
if len(args) == 1:
|
||||
end = args[0]
|
||||
elif len(args) == 2:
|
||||
start = args[0]
|
||||
end = args[1]
|
||||
elif len(args) == 3:
|
||||
start = args[0]
|
||||
end = args[1]
|
||||
step = args[2]
|
||||
|
||||
range_length = (abs(end - start) - 1) // abs(step) + 1
|
||||
if range_length >= 64:
|
||||
warnings.warn(
|
||||
f"This static loop has {range_length} iterations, which may be very slow to compile, consider using `cutlass.range(..., unroll_full=True)` instead.",
|
||||
category=UserWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
return (start, end, step)
|
||||
except:
|
||||
raise DSLRuntimeError(
|
||||
"`range_constexpr` requires constexpr (compile-time constant) for all arguments.",
|
||||
@@ -477,8 +516,8 @@ def range_perf_warning(filename, lineno, *args):
|
||||
if not has_dynamic_expr:
|
||||
warnings.warn_explicit(
|
||||
(
|
||||
"The loop was previously unrolled in Python, but now it may not unroll in IR. This may cause performance regression."
|
||||
"If you want to unroll the loop in Python, please use `range_constexpr` instead of `range`."
|
||||
"This loop is no longer unrolled and may cause performance regression. "
|
||||
"Use `range(..., unroll_full=True)` for full unrolling, or switch to `range_constexpr` when bounds are compile-time constants."
|
||||
),
|
||||
category=UserWarning,
|
||||
filename=filename,
|
||||
|
||||
@@ -102,6 +102,8 @@ class ScopeManager:
|
||||
return cls([])
|
||||
|
||||
def add_to_scope(self, name: str) -> None:
|
||||
if name == "_":
|
||||
return
|
||||
self.scopes[-1].add(name)
|
||||
|
||||
def get_active_symbols(self) -> List[Set[str]]:
|
||||
@@ -361,13 +363,13 @@ class DSLPreprocessor(ast.NodeTransformer):
|
||||
isinstance(func, ast.Name)
|
||||
and func.id in self.SUPPORTED_FOR_RANGE_STATEMENTS
|
||||
):
|
||||
return func.id, True
|
||||
return func.id, True, len(iter_node.keywords) != 0
|
||||
if (
|
||||
isinstance(func, ast.Attribute)
|
||||
and func.attr in self.SUPPORTED_FOR_RANGE_STATEMENTS
|
||||
):
|
||||
return func.attr, False
|
||||
return None, None
|
||||
return func.attr, False, len(iter_node.keywords) != 0
|
||||
return None, None, None
|
||||
|
||||
def transform(self, original_function, exec_globals):
|
||||
"""
|
||||
@@ -378,6 +380,7 @@ class DSLPreprocessor(ast.NodeTransformer):
|
||||
transformed_tree = self.transform_function(
|
||||
original_function.__name__, original_function
|
||||
)
|
||||
self.function_globals = None
|
||||
unified_tree = ast.Module(body=transformed_tree, type_ignores=[])
|
||||
unified_tree = ast.fix_missing_locations(unified_tree)
|
||||
|
||||
@@ -731,7 +734,7 @@ class DSLPreprocessor(ast.NodeTransformer):
|
||||
self.scope_manager.add_to_scope(node.target.id)
|
||||
|
||||
# For static for loop (for with range_constexpr or not range based for), preprocessor keeps the loop.
|
||||
range_kind, is_builtin_range = self._get_range_kind(node.iter)
|
||||
range_kind, is_builtin_range, has_keyword = self._get_range_kind(node.iter)
|
||||
if range_kind == "range_constexpr" or range_kind == None:
|
||||
self.generic_visit(node)
|
||||
if range_kind == "range_constexpr":
|
||||
@@ -752,7 +755,7 @@ class DSLPreprocessor(ast.NodeTransformer):
|
||||
warnings.simplefilter("default", DeprecationWarning) # reset filter
|
||||
|
||||
warning_call = None
|
||||
if range_kind == "range" and is_builtin_range:
|
||||
if range_kind == "range" and is_builtin_range and not has_keyword:
|
||||
# Warn about possible performance regression due to behavior change
|
||||
warning_call = ast.Expr(
|
||||
ast.Call(
|
||||
@@ -1109,6 +1112,12 @@ class DSLPreprocessor(ast.NodeTransformer):
|
||||
self.generic_visit(node)
|
||||
return node
|
||||
|
||||
def visit_Name(self, node):
|
||||
self.generic_visit(node)
|
||||
if node.id == "_" and isinstance(node.ctx, ast.Load):
|
||||
raise DSLAstPreprocessorError("Read '_' is not allowed")
|
||||
return node
|
||||
|
||||
def check_decorator(self, node: ast.AST) -> bool:
|
||||
"""
|
||||
Check if the function has the correct decorator for preprocessing.
|
||||
|
||||
@@ -19,7 +19,9 @@ from typing import Sequence, Optional, Tuple
|
||||
import os
|
||||
import sys
|
||||
import inspect
|
||||
import argparse
|
||||
from .common import DSLRuntimeError
|
||||
from .utils.logger import log
|
||||
|
||||
_SCRIPT_PATH = os.path.dirname(os.path.abspath(__file__))
|
||||
sys.path.append(_SCRIPT_PATH)
|
||||
@@ -182,7 +184,67 @@ class Compiler:
|
||||
return self.jit(module, opt_level, shared_libs)
|
||||
|
||||
|
||||
class CompileOptions:
|
||||
def __init__(self, options: str = ""):
|
||||
"""
|
||||
This class encapsulates all compilation options relevant to function compilation.
|
||||
It provides a convenient way to manage and pass compilation options,
|
||||
particularly for controlling compilation settings.
|
||||
By centralizing these options, it ensures consistent and flexible configuration of
|
||||
compilation parameters such as optimization level, debugging control, etc.
|
||||
|
||||
:param options: The options for the function. Will be parsed by argparse.
|
||||
:type options: str
|
||||
"""
|
||||
if not isinstance(options, str):
|
||||
raise DSLRuntimeError(
|
||||
f"Invalid compilation `options`: {options}, it should be a string"
|
||||
)
|
||||
self._parser = argparse.ArgumentParser()
|
||||
self._parser.add_argument("--opt-level", nargs="?", type=int, default=3)
|
||||
self._parser.add_argument(
|
||||
"--enable-device-assertions", action="store_true", default=False
|
||||
)
|
||||
try:
|
||||
self._options = self._parser.parse_args(options.split())
|
||||
except SystemExit as e:
|
||||
# catch argparse error and raise as DSLRuntimeError
|
||||
raise DSLRuntimeError(
|
||||
f"Invalid compile options: '{options}'. Please check the option values and format."
|
||||
)
|
||||
log().info("`cute.compile` CompileOptions: options=" + options)
|
||||
|
||||
def to_str(self):
|
||||
"""
|
||||
Generate a string representation of all compilation options
|
||||
which will be used in pipeline options.
|
||||
"""
|
||||
option_strings = []
|
||||
for key, value in vars(self._options).items():
|
||||
hyphen_key = key.replace("_", "-")
|
||||
if isinstance(value, bool):
|
||||
formatted_value = "true" if value else "false"
|
||||
else:
|
||||
formatted_value = str(value)
|
||||
option_strings.append(f"{hyphen_key}={formatted_value}")
|
||||
|
||||
return " ".join(option_strings)
|
||||
|
||||
|
||||
def compile(func, *args, **kwargs):
|
||||
"""
|
||||
This function is used to compile a `cute.jit` decorated function.
|
||||
It will process the compile options and input parameters, do explicit compilation and return the jit executor.
|
||||
|
||||
:param func: The function to compile. It can be a regular function, a method or a class instance.
|
||||
:param args: The arguments to pass to the function.
|
||||
:param kwargs: The keyword arguments to pass to the function. It can contain `options` like
|
||||
`opt_level` to control the compilation flags.
|
||||
|
||||
:return: The jit executor.
|
||||
|
||||
:raises: DSLRuntimeError if the function is not decorated with `cute.jit` or is not callable.
|
||||
"""
|
||||
if func is None:
|
||||
raise DSLRuntimeError("Function is not set or invalid.")
|
||||
|
||||
@@ -217,5 +279,8 @@ def compile(func, *args, **kwargs):
|
||||
if not hasattr(func, "_dsl_object"):
|
||||
raise DSLRuntimeError("Function is not decorated with jit decorator.")
|
||||
|
||||
# process compile options, extract the options and remove them from the kwargs
|
||||
options = kwargs.pop("options", "")
|
||||
func._dsl_object.compile_options = CompileOptions(options)
|
||||
fcn_ptr = func._dsl_object._preprocess_and_execute(func)
|
||||
return func._dsl_object._func(fcn_ptr, *args, **kwargs)
|
||||
|
||||
@@ -38,6 +38,7 @@ import warnings
|
||||
|
||||
from . import typing as t
|
||||
from .env_manager import EnvironmentVarManager
|
||||
from .compiler import CompileOptions
|
||||
|
||||
# =============================================================================
|
||||
# CUDA Python
|
||||
@@ -232,6 +233,50 @@ def new_from_mlir_values(obj, values):
|
||||
return obj
|
||||
|
||||
|
||||
class DSLCallable:
|
||||
"""
|
||||
Wrapper class for a callable object used within the DSL.
|
||||
|
||||
DSLCallable is designed to wrap a function and provide additional
|
||||
introspection utilities such as retrieving the argument specification
|
||||
and signature. It ensures that the wrapped function can only be called
|
||||
once, after which the reference to the function is cleared to prevent
|
||||
further invocations. This is useful in scenarios where a function should
|
||||
only be executed a single time within the DSL's execution model.
|
||||
|
||||
Attributes:
|
||||
func (callable): The function to be wrapped and managed.
|
||||
|
||||
Methods:
|
||||
__call__(*args, **kwargs): Calls the wrapped function and clears it.
|
||||
get_arg_spec(): Returns the argument specification of the function.
|
||||
get_signature(): Returns the signature of the function.
|
||||
"""
|
||||
|
||||
def __init__(self, func):
|
||||
self.func = func
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
ret = self.__func__(*args, **kwargs)
|
||||
self.func = None
|
||||
return ret
|
||||
|
||||
@property
|
||||
def __func__(self):
|
||||
assert self.func is not None, "DSLCallable is already called"
|
||||
return self.func
|
||||
|
||||
@property
|
||||
def __name__(self):
|
||||
return self.__func__.__name__
|
||||
|
||||
def get_arg_spec(self):
|
||||
return inspect.getfullargspec(self.__func__)
|
||||
|
||||
def get_signature(self):
|
||||
return inspect.signature(self.__func__)
|
||||
|
||||
|
||||
class BaseDSL:
|
||||
gpu_module = None
|
||||
|
||||
@@ -306,6 +351,8 @@ class BaseDSL:
|
||||
self.kernel_symbols = []
|
||||
# used to generate unique name for gpu.launch
|
||||
self.launch_inner_count = 0
|
||||
# initialize default compile options
|
||||
self.compile_options = CompileOptions()
|
||||
|
||||
if preprocess:
|
||||
self.preprocessor = DSLPreprocessor()
|
||||
@@ -392,26 +439,24 @@ class BaseDSL:
|
||||
if hasattr(func, "_transformed_ast"):
|
||||
# If the function ptr is already materialized, use the existing one
|
||||
func._dsl_object.frame = func._decorator_frame
|
||||
|
||||
if func._transformed_ast is None:
|
||||
func._transformed_ast = func._dsl_object.run_preprocessor(func)
|
||||
if func._transformed_ast is None:
|
||||
del func._decorator_frame
|
||||
del func._transformed_ast
|
||||
func._dsl_object.frame = None
|
||||
return func
|
||||
|
||||
fcn_ptr = func._dsl_object.get_function_ptr(func, func._transformed_ast)
|
||||
fcn_ptr = func._dsl_object.get_function_ptr(func)
|
||||
# If the function is decorated, de-decorate it
|
||||
fcn_ptr = BaseDSL._get_original_function(fcn_ptr, func.__name__)
|
||||
return fcn_ptr
|
||||
func._dsl_object.frame = None
|
||||
return DSLCallable(fcn_ptr)
|
||||
return func
|
||||
|
||||
def jit_runner(self, frame, executor, *dargs, **dkwargs):
|
||||
def jit_runner(self, executor, frame, *dargs, **dkwargs):
|
||||
"""
|
||||
Decorator to mark a function for JIT compilation.
|
||||
"""
|
||||
# Set the frame, that can be used AST preprocessor
|
||||
self.frame = frame
|
||||
log().info("jit_runner")
|
||||
|
||||
def jit_runner_decorator(func):
|
||||
@@ -444,7 +489,7 @@ class BaseDSL:
|
||||
frame = inspect.currentframe().f_back
|
||||
# Instantiate the DSL Class
|
||||
main_dsl = cls._get_dsl()
|
||||
return main_dsl.jit_runner(frame, main_dsl._func, *dargs, **dkwargs)
|
||||
return main_dsl.jit_runner(main_dsl._func, frame, *dargs, **dkwargs)
|
||||
|
||||
@classmethod
|
||||
def kernel(cls, *dargs, **dkwargs):
|
||||
@@ -454,7 +499,7 @@ class BaseDSL:
|
||||
frame = inspect.currentframe().f_back
|
||||
# Instantiate the DSL Class
|
||||
main_dsl = cls._get_dsl()
|
||||
return main_dsl.jit_runner(frame, main_dsl._kernel_helper, *dargs, **dkwargs)
|
||||
return main_dsl.jit_runner(main_dsl._kernel_helper, frame, *dargs, **dkwargs)
|
||||
|
||||
@abstractmethod
|
||||
def _kernel_helper(self, func, *args, **kwargs):
|
||||
@@ -627,6 +672,12 @@ class BaseDSL:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def _get_module_globals(self):
|
||||
"""
|
||||
Get the module's globals.
|
||||
"""
|
||||
pass
|
||||
|
||||
def _get_globals(self):
|
||||
"""
|
||||
Combines global and local variables from the current context and the
|
||||
@@ -639,7 +690,11 @@ class BaseDSL:
|
||||
AST preprocessor generates a new python code, so the resulting globals
|
||||
dictionary is used to execute the python code.
|
||||
"""
|
||||
pass
|
||||
all_globals = self._get_module_globals().copy()
|
||||
if self.frame:
|
||||
all_globals.update(self.frame.f_globals)
|
||||
all_globals.update(self.frame.f_locals)
|
||||
return all_globals
|
||||
|
||||
def _is_tensor_descriptor(self, maybe_tensor_descriptor) -> bool:
|
||||
return isinstance(
|
||||
@@ -881,20 +936,15 @@ class BaseDSL:
|
||||
Get python location information and generate MLIR location
|
||||
"""
|
||||
|
||||
frame = self.frame
|
||||
if frame is None:
|
||||
print("Frame is None")
|
||||
if self.frame is None:
|
||||
log().debug("Frame is None")
|
||||
return None
|
||||
|
||||
file_loc = ir.Location.file(frame.f_code.co_filename, frame.f_lineno, 0)
|
||||
file_loc = ir.Location.file(
|
||||
self.frame.f_code.co_filename, self.frame.f_lineno, 0
|
||||
)
|
||||
|
||||
def print_all_frames():
|
||||
for i, frame in enumerate(inspect.stack()):
|
||||
print(
|
||||
f"Frame {i}: {frame.function} in {frame.filename}, line {frame.lineno}"
|
||||
)
|
||||
|
||||
loc = ir.Location.name(frame.f_code.co_name, childLoc=file_loc)
|
||||
loc = ir.Location.name(self.frame.f_code.co_name, childLoc=file_loc)
|
||||
return loc
|
||||
|
||||
def compile_and_jit(self, module, pipeline, shared_libs, function_name=""):
|
||||
@@ -992,6 +1042,8 @@ class BaseDSL:
|
||||
for attr, value in self.envar.__dict__.items():
|
||||
if value is not None:
|
||||
s.write(str(value).encode())
|
||||
# Add compile options to the hash
|
||||
s.write(self.compile_options.to_str().encode())
|
||||
module_hash = self.get_version().copy()
|
||||
module_hash.update(s.getvalue())
|
||||
module_hash = module_hash.hexdigest()
|
||||
@@ -1145,6 +1197,8 @@ class BaseDSL:
|
||||
self.launch_inner_count = 0
|
||||
# reset num_kernels to 0 for next compilation.
|
||||
self.num_kernels = 0
|
||||
# reset the compile options after the compilation is done.
|
||||
self.compile_options = CompileOptions()
|
||||
|
||||
def generate_mlir(
|
||||
self,
|
||||
@@ -1226,9 +1280,11 @@ class BaseDSL:
|
||||
return transformed_ast
|
||||
return None
|
||||
|
||||
def get_function_ptr(self, original_function, transformed_ast):
|
||||
def get_function_ptr(self, original_function):
|
||||
file_name = inspect.getsourcefile(original_function)
|
||||
code_object = compile(transformed_ast, filename=file_name, mode="exec")
|
||||
code_object = compile(
|
||||
original_function._transformed_ast, filename=file_name, mode="exec"
|
||||
)
|
||||
return self.preprocessor.exec(
|
||||
original_function.__name__,
|
||||
original_function,
|
||||
@@ -1236,10 +1292,6 @@ class BaseDSL:
|
||||
self._get_globals(),
|
||||
)
|
||||
|
||||
@lru_cache(maxsize=None)
|
||||
def _get_function_signature(self, func):
|
||||
return inspect.signature(func)
|
||||
|
||||
def _get_function_bound_args(self, sig, func_name, *args, **kwargs):
|
||||
"""
|
||||
Binds provided arguments to a function's signature and applies default values.
|
||||
@@ -1260,12 +1312,11 @@ class BaseDSL:
|
||||
)
|
||||
return bound_args
|
||||
|
||||
def _canonicalize_args(self, *args, **kwargs):
|
||||
def _canonicalize_args(self, sig, *args, **kwargs):
|
||||
"""
|
||||
Canonicalize the input arguments so that returned args only contain
|
||||
positional arguments and kwargs only contain keyword arguments.
|
||||
"""
|
||||
sig = self._get_function_signature(self.funcBody)
|
||||
function_name = self.funcBody.__name__
|
||||
bound_args = self._get_function_bound_args(sig, function_name, *args, **kwargs)
|
||||
canonicalized_args = bound_args.args
|
||||
@@ -1276,8 +1327,11 @@ class BaseDSL:
|
||||
if not self.funcBody:
|
||||
raise DSLRuntimeError("Function body is not set.")
|
||||
|
||||
# Pass the actual function object to _get_function_signature.
|
||||
sig = self._get_function_signature(self.funcBody)
|
||||
# Pass the actual function object to inspect.signature to get the signature.
|
||||
if isinstance(self.funcBody, DSLCallable):
|
||||
sig = self.funcBody.get_signature()
|
||||
else:
|
||||
sig = inspect.signature(self.funcBody)
|
||||
function_name = self.funcBody.__name__
|
||||
|
||||
bound_args = self._get_function_bound_args(sig, function_name, *args, **kwargs)
|
||||
@@ -1292,6 +1346,8 @@ class BaseDSL:
|
||||
f"Missing required argument in `{function_name}`: '{param.name}'"
|
||||
)
|
||||
|
||||
return sig
|
||||
|
||||
def _func(self, funcBody, *args, **kwargs):
|
||||
"""Decorator for MLIR functions.
|
||||
It cuts the boilerplate code, does the following:
|
||||
@@ -1324,13 +1380,16 @@ class BaseDSL:
|
||||
self.print_warning("Cache is disabled as user wants to compile only.")
|
||||
|
||||
# Check the number of arguments
|
||||
self._check_arg_count(*args, **kwargs)
|
||||
sig = self._check_arg_count(*args, **kwargs)
|
||||
|
||||
args_spec = inspect.getfullargspec(funcBody)
|
||||
if isinstance(funcBody, DSLCallable):
|
||||
args_spec = funcBody.get_arg_spec()
|
||||
else:
|
||||
args_spec = inspect.getfullargspec(funcBody)
|
||||
|
||||
# Canonicalize the input arguments
|
||||
canonicalized_args, canonicalized_kwargs = self._canonicalize_args(
|
||||
*args, **kwargs
|
||||
sig, *args, **kwargs
|
||||
)
|
||||
|
||||
# Simple name mangling
|
||||
@@ -1528,7 +1587,10 @@ class BaseDSL:
|
||||
kernelGenHelper = dkwargs.get("kernelGenHelper", None)
|
||||
|
||||
kernel_name = funcBody.__name__
|
||||
args_spec = inspect.getfullargspec(funcBody)
|
||||
if isinstance(funcBody, DSLCallable):
|
||||
args_spec = funcBody.get_arg_spec()
|
||||
else:
|
||||
args_spec = inspect.getfullargspec(funcBody)
|
||||
self.funcBody = funcBody
|
||||
|
||||
# Give each kernel a unique name. (The same kernel may be
|
||||
@@ -1568,11 +1630,11 @@ class BaseDSL:
|
||||
), "kernelGenHelper should be explicitly specified!"
|
||||
|
||||
# check arguments
|
||||
self._check_arg_count(*args, **kwargs)
|
||||
sig = self._check_arg_count(*args, **kwargs)
|
||||
|
||||
# Canonicalize the input arguments
|
||||
canonicalized_args, canonicalized_kwargs = self._canonicalize_args(
|
||||
*args, **kwargs
|
||||
sig, *args, **kwargs
|
||||
)
|
||||
|
||||
kernel_operands, kernel_types, kernel_arg_attrs = (
|
||||
|
||||
@@ -527,7 +527,16 @@ class IntegerMeta(NumericMeta):
|
||||
return 2**cls.width - 1
|
||||
|
||||
def recast_width(cls, width):
|
||||
return eval(f"Int{width}")
|
||||
type_map = {
|
||||
8: Int8,
|
||||
16: Int16,
|
||||
32: Int32,
|
||||
64: Int64,
|
||||
128: Int128,
|
||||
}
|
||||
if width not in type_map:
|
||||
raise TypeError(f"Unsupported width: {width}")
|
||||
return type_map[width]
|
||||
|
||||
|
||||
class FloatMeta(NumericMeta):
|
||||
@@ -603,7 +612,14 @@ class FloatMeta(NumericMeta):
|
||||
return cls._mantissa_width
|
||||
|
||||
def recast_width(cls, width):
|
||||
return eval(f"Float{width}")
|
||||
type_map = {
|
||||
16: Float16,
|
||||
32: Float32,
|
||||
64: Float64,
|
||||
}
|
||||
if width not in type_map:
|
||||
raise TypeError(f"Unsupported width: {width}")
|
||||
return type_map[width]
|
||||
|
||||
|
||||
def _arith_signless_to_int(a, target_type):
|
||||
|
||||
Reference in New Issue
Block a user