v4.3.4 update. (#2892)

This commit is contained in:
Junkai-Wu
2025-12-22 00:49:12 +08:00
committed by GitHub
parent 331e2f451c
commit 7f5fe3edf1
31 changed files with 839 additions and 240 deletions

View File

@@ -470,7 +470,56 @@ class DSLPreprocessor(ast.NodeTransformer):
names=[ast.alias(name=f"{top_module_name}.base_dsl", asname="_dsl_")]
)
)
transformed_tree.body = import_stmts + transformed_tree.body
assert len(transformed_tree.body) == 1
assert isinstance(transformed_tree.body[0], ast.FunctionDef)
transformed_tree.body[0].body = import_stmts + transformed_tree.body[0].body
# Remove all decorators from top level function
transformed_tree.body[0].decorator_list = []
# Step 4. Wrap the function with nonlocal captures, if has any
# if the function has a nonlocal variable, wrap it in a function and return the function
# pseudo code:
# def foo():
# nonlocal_var_0 = None
# nonlocal_var_1 = None
# def foo(args):
# ...
# return foo
# foo = foo()
nonlocals = {v: None for v in function_pointer.__code__.co_freevars}
if len(nonlocals) > 0:
assignments = []
for n, _ in nonlocals.items():
assignments.append(
ast.Assign(
targets=[ast.Name(id=n, ctx=ast.Store())],
value=ast.Constant(value=None),
)
)
return_expr = [ast.Return(value=ast.Name(id=func_name, ctx=ast.Load()))]
wrapper_fcn = ast.FunctionDef(
name=func_name,
args=ast.arguments(
posonlyargs=[],
args=[],
kwonlyargs=[],
kw_defaults=[],
defaults=[],
),
body=assignments + transformed_tree.body + return_expr,
decorator_list=[],
)
invoke = ast.Call(
func=ast.Name(id=func_name, ctx=ast.Load()), args=[], keywords=[]
)
assign = ast.Assign(
targets=[ast.Name(id=func_name, ctx=ast.Store())], value=invoke
)
transformed_tree.body = [wrapper_fcn, assign]
# Step 4. Import cutlass and base_dsl
ast.fix_missing_locations(transformed_tree)
@@ -1521,6 +1570,15 @@ class DSLPreprocessor(ast.NodeTransformer):
self.scope_manager.add_to_scope(node.name)
for arg in node.args.args:
self.scope_manager.add_to_scope(arg.arg)
arg.annotation = None
for arg in node.args.kwonlyargs:
self.scope_manager.add_to_scope(arg.arg)
arg.annotation = None
for arg in node.args.posonlyargs:
self.scope_manager.add_to_scope(arg.arg)
arg.annotation = None
self.generic_visit(node)

View File

@@ -622,18 +622,14 @@ class CompileCallable:
func,
)
# If it's a wrapped function created by jit decorator, get the original function
if hasattr(func, "__wrapped__"):
# If it's a wrapped function created by decorators, get the original function
while hasattr(func, "__wrapped__"):
func = func.__wrapped__
# Lazy initialization of DSL object if has not been initialized
# Use local import to avoid circular import
from .dsl import BaseDSL
BaseDSL._lazy_initialize_dsl(func)
if not hasattr(func, "_dsl_object"):
raise DSLRuntimeError("Function is not decorated with jit decorator.")
raise DSLRuntimeError(
f"Function {func} is not decorated with jit decorator."
)
# process compile options, extract the options and remove them from the kwargs
options = kwargs.pop("options", None)
@@ -645,8 +641,4 @@ class CompileCallable:
else:
compile_options = self._compile_options
func._dsl_object.compile_options = compile_options
fcn_ptr = func._dsl_object._preprocess_and_execute(func)
if hasattr(func, "_decorator_frame"):
kwargs["_decorator_frame"] = func._decorator_frame
return func._dsl_object._func(fcn_ptr, *args, **kwargs)
return func._dsl_object._func(func, *args, **kwargs)

View File

@@ -31,9 +31,10 @@ import weakref
from functools import lru_cache, wraps
from collections import namedtuple, OrderedDict
from abc import ABC, abstractmethod
from typing import Any, Callable, List
from typing import Any, Callable, List, ClassVar
from types import SimpleNamespace
import warnings
import threading
from . import typing as t
from .env_manager import EnvironmentVarManager
@@ -228,49 +229,92 @@ def new_from_mlir_values(obj, values):
assert len(values) == 0, f"{obj} expects 0 values, but got {values}"
return obj
class DSLCallable:
@dataclass(frozen=True)
class DSLLocation:
"""
Wrapper class for a callable object used within the DSL.
DSLCallable is designed to wrap a function and provide additional
introspection utilities such as retrieving the argument specification
and signature. It ensures that the wrapped function can only be called
once, after which the reference to the function is cleared to prevent
further invocations. This is useful in scenarios where a function should
only be executed a single time within the DSL's execution model.
Represents Python source location information for MLIR DSL code.
Attributes:
func (callable): The function to be wrapped and managed.
filename (str): Name of the Python source file.
lineno (int): Line number in the source file.
col_offset (int): Column offset in the source line.
function_name (str): Name of the function in which the location occurs.
Methods:
__call__(*args, **kwargs): Calls the wrapped function and clears it.
This is used primarily to annotate or trace locations in generated MLIR IR
back to the original Python code for better diagnostic and debugging.
"""
def __init__(self, func):
self.func = func
self.name = func.__name__
def __call__(self, *args, **kwargs):
ret = self.__func__(*args, **kwargs)
self.func = None
return ret
@property
def __func__(self):
assert self.func is not None, "DSLCallable is already called"
return self.func
@property
def __signature__(self):
return inspect.signature(self.__func__)
@property
def __name__(self):
return self.name
filename: str
lineno: int
col_offset: int
function_name: str
class BaseDSL:
@dataclass
class PreprocessSessionData:
"""
Holds metadata and transformed AST related to a DSL preprocessing session.
Attributes:
decorator_globals (dict): The global variables from the decorator's environment,
captured for possible AST or code evaluation during preprocessing.
"""
decorator_globals: dict
class DSLSingletonMeta(type):
"""
Metaclass implementing the Singleton pattern for DSL classes.
The DSLSingletonMeta ensures that only one instance of a derived DSL class exists at any time.
When a class is called, it checks if an instance already exists in the `_instances` dictionary.
- If requesting `BaseDSL` itself, it asserts that a concrete subclass has been initialized,
and returns the first available singleton instance among subclasses.
- If requesting a concrete subclass, it creates a new instance if none exists, or returns
the already created instance.
This metaclass is useful for maintaining global state and configuration across the DSL system,
ensuring that all parts of the application operate on the same DSL instance.
Attributes:
_instances (dict): Maps DSL classes to their singleton instances.
Example:
class MyDSL(BaseDSL): ...
dsl1 = MyDSL()
dsl2 = MyDSL()
assert dsl1 is dsl2 # Singleton property
"""
_instances: ClassVar[dict] = {}
_lock: ClassVar[threading.Lock] = threading.Lock()
def __call__(cls, *args, **kwargs):
with cls._lock:
log().info(f"DSLSingletonMeta __call__ for {cls}")
if cls is BaseDSL:
# If one is querying a BaseDSL which is abstract, returns an arbitrary instance of a concrete subclass should be fine.
# Here we just return the first instance of a concrete subclass.
assert cls._instances, (
"Need to initialize a concrete subclass of BaseDSL first"
)
return next(iter(cls._instances.values()))
elif cls not in cls._instances:
instance = super().__call__(*args, **kwargs)
cls._instances[cls] = instance
log().info(f"Active DSL singleton instances: {cls._instances}")
return cls._instances[cls]
def clear_instances(cls):
log().info(
f"Clearing DSL singleton instances for {cls}, current instances: {cls._instances}"
)
if cls in cls._instances:
del cls._instances[cls]
log().info(f"DSL singleton instances after clearing: {cls._instances}")
class BaseDSL(metaclass=DSLSingletonMeta):
gpu_module = None
_env_class = EnvironmentVarManager
@@ -310,7 +354,8 @@ class BaseDSL:
self.name = name
self.compiler_provider = compiler_provider
self.pass_sm_arch_name = pass_sm_arch_name
self.frame = None
self.preprocess_session_data = None
self.decorator_location = None
self.no_cache = False
self.device_compilation_only = device_compilation_only
self.num_kernels = 0
@@ -379,7 +424,6 @@ class BaseDSL:
warnings.warn(message, UserWarning)
@classmethod
@lru_cache(maxsize=1)
def _get_dsl(cls):
# Instantiate the DSL Class once
main_dsl = cls()
@@ -414,38 +458,22 @@ class BaseDSL:
return fcn_ptr
@staticmethod
def _preprocess_and_execute(func):
def _preprocess_and_replace_code(func):
"""
Run ast transformation and return the materialized function pointer
"""
# Lazy initialization of DSL object if has not been initialized
if not hasattr(func, "_dsl_object"):
func._dsl_object = func._dsl_cls._get_dsl()
delattr(func, "_dsl_cls")
if not func._dsl_object.enable_preprocessor:
if hasattr(func, "_decorator_frame"):
delattr(func, "_decorator_frame")
if hasattr(func, "_transformed_ast"):
delattr(func, "_transformed_ast")
return func
if hasattr(func, "_transformed_ast"):
if hasattr(func, "_preprocess_session_data"):
# If the function ptr is already materialized, use the existing one
func._dsl_object.frame = func._decorator_frame
if func._transformed_ast is None:
func._transformed_ast = func._dsl_object.run_preprocessor(func)
if func._transformed_ast is None:
del func._transformed_ast
func._dsl_object.frame = None
return func
fcn_ptr = func._dsl_object.get_function_ptr(func)
# If the function is decorated, de-decorate it
fcn_ptr = BaseDSL._get_original_function(fcn_ptr, func.__name__)
func._dsl_object.frame = None
return DSLCallable(fcn_ptr)
func._dsl_object.preprocess_session_data = func._preprocess_session_data
func._dsl_object.decorator_location = func._decorator_location
transformed_ast = func._dsl_object.run_preprocessor(func)
fcn_ptr = func._dsl_object.get_function_ptr(func, transformed_ast)
func.__code__ = (
fcn_ptr.__code__
if not isinstance(fcn_ptr, staticmethod)
else fcn_ptr.__func__.__code__
)
return func
@staticmethod
@@ -457,20 +485,27 @@ class BaseDSL:
def jit_runner_decorator(func):
# Run preprocessor that alters AST
func._dsl_cls = cls
if BaseDSL._can_preprocess(**dkwargs):
func._dsl_object = cls._get_dsl()
func._decorator_location = BaseDSL.get_location_from_frame(frame)
if (
func._dsl_object.enable_preprocessor
and func._dsl_object._can_preprocess(**dkwargs)
):
# For an annotated function, add some DSL attributes
# When materializing the AST, we need decorator's frame
func._decorator_frame = frame
# No transformed ast at this point
func._transformed_ast = None
func._preprocess_session_data = PreprocessSessionData(
decorator_globals=frame.f_globals,
)
BaseDSL._preprocess_and_replace_code(func)
@wraps(func)
def jit_wrapper(*args, **kwargs):
func_ptr = BaseDSL._preprocess_and_execute(func)
return getattr(func._dsl_object, executor_name)(
func_ptr, *args, **kwargs
)
return getattr(func._dsl_object, executor_name)(func, *args, **kwargs)
def set_name_prefix(name: str):
jit_wrapper._name_prefix = name
jit_wrapper.set_name_prefix = set_name_prefix
return jit_wrapper
@@ -479,15 +514,6 @@ class BaseDSL:
else:
return jit_runner_decorator
@staticmethod
def _lazy_initialize_dsl(func):
"""
Lazy initialization of DSL object if has not been initialized
"""
if hasattr(func, "_dsl_cls"):
func._dsl_object = func._dsl_cls._get_dsl()
delattr(func, "_dsl_cls")
@classmethod
def jit(cls, *dargs, **dkwargs):
"""
@@ -516,6 +542,7 @@ class BaseDSL:
"""
Build the module op that contains the kernels.
"""
log().info(f"[abstract] Building GPU module for {self.name}")
pass
@abstractmethod
@@ -688,9 +715,11 @@ class BaseDSL:
dictionary is used to execute the python code.
"""
all_globals = {}
if self.frame:
all_globals.update(self.frame.f_globals)
all_globals.update(self.frame.f_locals)
if (
self.preprocess_session_data
and self.preprocess_session_data.decorator_globals
):
all_globals.update(self.preprocess_session_data.decorator_globals)
return all_globals
@abstractmethod
@@ -955,25 +984,40 @@ class BaseDSL:
else:
ir._GlobalDebug.set_types(f"diagnostic-{args.diagnostic}")
def get_location(self, frame=None):
"""
Get python location information and generate MLIR location
"""
frame = self.frame if frame is None else frame
frame = inspect.currentframe().f_back if frame is None else frame
@staticmethod
def get_location_from_frame(frame):
frameInfo = inspect.getframeinfo(frame)
file_loc = ir.Location.file(
frame.f_code.co_filename,
frame.f_lineno,
frameInfo.positions.col_offset if hasattr(frameInfo, "positions") else 0,
)
loc = ir.Location.name(
(
return DSLLocation(
filename=frameInfo.filename,
lineno=frameInfo.lineno,
col_offset=(
frameInfo.positions.col_offset if hasattr(frameInfo, "positions") else 0
),
function_name=(
"".join([c.strip() for c in frameInfo.code_context])
if frameInfo.code_context
else frameInfo.function
),
)
def get_ir_location(self, location: DSLLocation = None):
"""
Get python location information and generate MLIR location
"""
if location is None:
if self.decorator_location:
location = self.decorator_location
if location is None:
return ir.Location.unknown()
file_loc = ir.Location.file(
location.filename,
location.lineno,
location.col_offset,
)
loc = ir.Location.name(
(location.function_name),
childLoc=file_loc,
)
return loc
@@ -1140,10 +1184,10 @@ class BaseDSL:
gpu_module_attrs,
args,
args_spec,
frame=None,
location=None,
):
def build_ir_module():
loc = self.get_location(frame)
loc = self.get_ir_location(location)
module = ir.Module.create(loc=loc)
unit_attr = ir.UnitAttr.get()
module.operation.attributes["gpu.container_module"] = unit_attr
@@ -1308,6 +1352,10 @@ class BaseDSL:
self.num_kernels = 0
# reset the compile options after the compilation is done.
self.compile_options = CompileOptions()
# reset preprocess session data after the compilation is done.
self.preprocess_session_data = None
# reset decorator location after the compilation is done.
self.decorator_location = None
def extract_dynamic_args(self, funcBody, args, kwargs, args_spec):
"""This function is used to extract the original dynamic arguments for AOT C header generation.
@@ -1348,11 +1396,10 @@ class BaseDSL:
pipeline,
no_cache,
compile_only,
loc=None,
frame=None,
location=None,
):
"""Generate MLIR module and compile iself.T_provider."""
with ir.Context(), self.get_location(frame):
with ir.Context(), self.get_ir_location(location):
try:
# Convert input arguments to MLIR arguments
exe_args, func_types, adapted_args = self.generate_mlir_function_types(
@@ -1374,7 +1421,7 @@ class BaseDSL:
gpu_module_attrs,
args,
args_spec,
frame=frame,
location=location,
)
# dryrun is used to only generate IR
@@ -1437,11 +1484,14 @@ class BaseDSL:
return transformed_ast
return None
def get_function_ptr(self, original_function):
def get_function_ptr(self, original_function, transformed_ast):
file_name = inspect.getsourcefile(original_function)
code_object = compile(
original_function._transformed_ast, filename=file_name, mode="exec"
transformed_ast,
filename=file_name,
mode="exec",
)
return self.preprocessor.exec(
original_function.__name__,
original_function,
@@ -1523,7 +1573,7 @@ class BaseDSL:
pipeline = kwargs.pop("pipeline", None)
gpu_module_attrs = kwargs.pop("gpu_module_attrs", {})
decorator_frame = kwargs.pop("_decorator_frame", None)
self.decorator_location = getattr(funcBody, "_decorator_location", None)
# Disable cache
no_cache = kwargs.pop("no_cache", False)
@@ -1556,7 +1606,7 @@ class BaseDSL:
function_name = self.mangle_name(function_name, canonicalized_args, args_spec)
self.compile_options.apply_envar_settings(self.envar, function_name)
if not self.compile_options.generate_line_info:
decorator_frame = None
self.decorator_location = None
# Generate MLIR Context and start generating IR
log().debug(f"Generating MLIR for function '{function_name}'")
@@ -1570,7 +1620,7 @@ class BaseDSL:
pipeline,
no_cache,
compile_only,
frame=decorator_frame,
location=self.decorator_location,
)
return result
@@ -1679,8 +1729,7 @@ class BaseDSL:
"""
ret = None
with ir.Context(), self.get_location():
loc = self.get_location()
with ir.Context(), self.get_ir_location() as loc:
module = ir.Module.create(loc=loc)
unit_attr = ir.UnitAttr.get()
module.operation.attributes["gpu.container_module"] = unit_attr
@@ -1819,7 +1868,7 @@ class BaseDSL:
)
)
loc = self.get_location()
loc = self.get_ir_location()
with self._enter_gpu_module():
log().debug("Generating device kernel")
if self.device_compilation_only:

View File

@@ -138,6 +138,7 @@ class MLIRBuilder(MLIRTypeBuilder):
super().__init__()
self.module: Optional[ir.Module] = None
self.const_str_table: dict[str, ir.Value] = {}
self.const_func_ptr_table: dict[str, ir.Value] = {}
self.get_element_extra_kwargs: dict[str, Any] = {}
# create constants
@@ -368,6 +369,64 @@ class MLIRBuilder(MLIRTypeBuilder):
self.const_str_table[content] = symbol
return symbol
def get_or_load_global_func_ptr_from_text(
self,
current_block: ir.Block,
function_name: str,
) -> ir.Value:
"""Get or create a function pointer global in .text section and load it.
This creates a constant global function pointer in the .text section
(for AArch64 ADRP range compatibility) and performs a volatile load
to prevent optimization.
This forces the function pointer to be local to the code, bypassing GOT entry
ADRP lookup issues on AArch64 when GOT and .text section are more than 4GB
apart which can happen when ASLR is applied.
"""
# Check if we've already created this global
if function_name not in self.const_func_ptr_table:
symbol = f"__func_ptr_{function_name}"
module_body = self.module.body
with ir.InsertionPoint(module_body):
# 1. Create the global constant
# We use 'private' linkage so it doesn't conflict across modules
global_ptr = llvm.GlobalOp(
self.ptr_type,
symbol,
ir.Attribute.parse("#llvm.linkage<private>"),
# Initialization via block below
)
# 2. Set the necessary attributes for JIT safety and AArch64 range
# We use 'constant' to mark it as immutable
# We use 'section = ".text"' to force it into the code block
global_ptr.attributes["constant"] = ir.UnitAttr.get()
global_ptr.attributes["section"] = ir.StringAttr.get(".text")
# 3. Add a constructor block to the GlobalOp to initialize it
# with the address of the target function
initializer_block = global_ptr.initializer.blocks.append()
with ir.InsertionPoint(initializer_block):
# Get the address of the external function
func_addr = llvm.AddressOfOp(self.ptr_type, function_name).res
# Return the address as the initial value of the global
llvm.return_(arg=func_addr)
self.const_func_ptr_table[function_name] = symbol
else:
symbol = self.const_func_ptr_table[function_name]
# Load it with volatile semantics in the current block
with ir.InsertionPoint(current_block):
symbol_addr = self.address_of(symbol, self.ptr_type)
# Perform a volatile load to prevent optimization
load_op = llvm.load(self.ptr_type, symbol_addr)
# Set volatile attribute to prevent optimization
load_op.owner.attributes["volatile_"] = ir.UnitAttr.get()
return load_op
# function
def function(
self,

View File

@@ -210,7 +210,7 @@ EnableTVMFFI = _dsl.EnableTVMFFI
# attach the TVM FFI ABI interface postprocessor to the DSL
from . import _tvm_ffi_args_spec_converter
_tvm_ffi_args_spec_converter.attach_args_spec_converter()
_tvm_ffi_args_spec_converter.attach_args_spec_converter(_dsl.CuTeDSL._get_dsl())
# Explicitly export all symbols for documentation generation
__all__ = [

View File

@@ -395,8 +395,6 @@ def _tvm_ffi_args_spec_converter(
return params, kwargs_wrapper_spec
def attach_args_spec_converter():
"""Attach TVM FFI ABI interface postprocessor to the DSL."""
from .. import cutlass_dsl as _dsl
_dsl.CuTeDSL._get_dsl()._tvm_ffi_args_spec_converter = _tvm_ffi_args_spec_converter
def attach_args_spec_converter(dsl):
"""Attach TVM FFI ABI interface postprocessor to the DSL instance."""
dsl._tvm_ffi_args_spec_converter = _tvm_ffi_args_spec_converter

View File

@@ -10,7 +10,7 @@
# is strictly prohibited.
from cutlass.base_dsl.arch import Arch
from cutlass.cutlass_dsl import CuTeDSL, T, dsl_user_op
from cutlass.cutlass_dsl import BaseDSL, T, dsl_user_op
import cutlass._mlir.dialects.cute_nvgpu as _cute_nvgpu_ir
from cutlass._mlir.dialects import nvvm, scf
@@ -69,7 +69,7 @@ def elect_one(*, loc=None, ip=None) -> IfOpRegion:
# Only one thread in the warp executes the code in this context
pass
"""
CuTeDSL._get_dsl().check_arch(lambda arch: arch >= Arch.sm_90)
BaseDSL._get_dsl().check_arch(lambda arch: arch >= Arch.sm_90)
is_thread_leader = nvvm.elect_sync(T.bool())
if_op = scf.IfOp(is_thread_leader, loc=loc, ip=ip)
return IfOpRegion(if_op.then_block, loc=loc, ip=ip)

View File

@@ -11,7 +11,7 @@
from typing import Optional
from cutlass.base_dsl.arch import Arch
from cutlass.cutlass_dsl import CuTeDSL, T, if_generate, dsl_user_op
from cutlass.cutlass_dsl import BaseDSL, T, if_generate, dsl_user_op
from cutlass._mlir.dialects import nvvm
@@ -44,7 +44,7 @@ def mbarrier_init_fence(*, loc=None, ip=None) -> None:
"""
A fence operation that applies to the mbarrier initializations.
"""
CuTeDSL._get_dsl().check_arch(lambda arch: arch >= Arch.sm_90)
BaseDSL._get_dsl().check_arch(lambda arch: arch >= Arch.sm_90)
nvvm.fence_mbarrier_init(loc=loc, ip=ip)
@@ -63,7 +63,7 @@ def mbarrier_arrive_and_expect_tx(
the mbarrier is converted to a remote address in the peer CTA's
SMEM.
"""
CuTeDSL._get_dsl().check_arch(lambda arch: arch >= Arch.sm_90)
BaseDSL._get_dsl().check_arch(lambda arch: arch >= Arch.sm_90)
mbar_llvm_ptr = mbar_ptr.llvm_ptr
if peer_cta_rank_in_cluster is not None:
@@ -103,7 +103,7 @@ def mbarrier_expect_tx(
the mbarrier is converted to a remote address in the peer CTA's
SMEM.
"""
CuTeDSL._get_dsl().check_arch(lambda arch: arch >= Arch.sm_90)
BaseDSL._get_dsl().check_arch(lambda arch: arch >= Arch.sm_90)
mbar_llvm_ptr = mbar_ptr.llvm_ptr
if peer_cta_rank_in_cluster is not None:
@@ -138,7 +138,7 @@ def mbarrier_wait(mbar_ptr: Pointer, phase: Int, *, loc=None, ip=None) -> None:
:param phase: The phase to wait for (either 0 or 1)
:type phase: Int
"""
CuTeDSL._get_dsl().check_arch(lambda arch: arch >= Arch.sm_90)
BaseDSL._get_dsl().check_arch(lambda arch: arch >= Arch.sm_90)
timeout_ns = 10000000
# This NVVM Op is a spin-loop wrapping the mbarrier.try_wait.parity.shared.b64 PTX
@@ -164,7 +164,7 @@ def mbarrier_try_wait(mbar_ptr: Pointer, phase: Int, *, loc=None, ip=None) -> Bo
:return: A boolean value indicating whether the wait operation was successful
:rtype: Boolean
"""
CuTeDSL._get_dsl().check_arch(lambda arch: arch >= Arch.sm_90)
BaseDSL._get_dsl().check_arch(lambda arch: arch >= Arch.sm_90)
return Boolean(
nvvm.mbarrier_wait_parity(
@@ -193,7 +193,7 @@ def mbarrier_conditional_try_wait(
:return: A boolean value indicating whether the wait operation was successful
:rtype: Boolean
"""
CuTeDSL._get_dsl().check_arch(lambda arch: arch >= Arch.sm_90)
BaseDSL._get_dsl().check_arch(lambda arch: arch >= Arch.sm_90)
return if_generate(
cond,
lambda: mbarrier_try_wait(mbar_ptr, phase, loc=loc, ip=ip),
@@ -225,7 +225,7 @@ def mbarrier_arrive(
"""
mbar_llvm_ptr = mbar_ptr.llvm_ptr
if peer_cta_rank_in_cluster is not None:
CuTeDSL._get_dsl().check_arch(lambda arch: arch >= Arch.sm_90)
BaseDSL._get_dsl().check_arch(lambda arch: arch >= Arch.sm_90)
mbar_llvm_ptr = nvvm.mapa_shared_cluster(
mbar_llvm_ptr.type,
@@ -259,7 +259,7 @@ def cp_async_mbarrier_arrive_noinc(mbar_ptr: Pointer, *, loc=None, ip=None) -> N
:param mbar_ptr: A pointer to the mbarrier in SMEM
:type mbar_ptr: Pointer
"""
CuTeDSL._get_dsl().check_arch(lambda arch: arch >= Arch.sm_90)
BaseDSL._get_dsl().check_arch(lambda arch: arch >= Arch.sm_90)
mbar_llvm_ptr = mbar_ptr.llvm_ptr
nvvm.cp_async_mbarrier_arrive_shared(

View File

@@ -12,7 +12,8 @@
from cutlass.base_dsl.arch import Arch
from cutlass.base_dsl.common import DSLRuntimeError
from cutlass.cutlass_dsl import CuTeDSL, dsl_user_op
from cutlass.cutlass_dsl import BaseDSL, dsl_user_op
from cutlass._mlir import ir
from cutlass._mlir.dialects import builtin, arith, llvm, vector
@@ -53,7 +54,7 @@ def cvt_i8_bf16_intrinsic(vec_i8, length, *, loc=None, ip=None):
:return: The output 1D vector of bfloat16 with the same length as the input vector.
:rtype: 1D vector of bfloat16
"""
arch = CuTeDSL._get_dsl().get_arch_enum()
arch = BaseDSL._get_dsl().get_arch_enum()
if not arch in cvt_i8_bf16_intrinsic.supported_archs:
raise DSLRuntimeError(f"cvt_i8_bf16_intrinsic is not supported on {arch}")
src_pos = 0
@@ -130,7 +131,7 @@ def cvt_i4_bf16_intrinsic(vec_i4, length, *, loc=None, ip=None):
:return: The output 1D vector of bfloat16 with the same length as the input vector.
:rtype: 1D vector of bfloat16
"""
arch = CuTeDSL._get_dsl().get_arch_enum()
arch = BaseDSL._get_dsl().get_arch_enum()
if not arch in cvt_i4_bf16_intrinsic.supported_archs:
raise DSLRuntimeError(f"cvt_i4_bf16_intrinsic is not supported on {arch}")
src_pos = 0

View File

@@ -1305,6 +1305,46 @@ def exp_packed_f32x2(
return exp2(b[0], loc=loc, ip=ip), exp2(b[1], loc=loc, ip=ip)
@dsl_user_op
def griddepcontrol_wait(*, loc=None, ip=None) -> None:
"""
This instruction is used to wait for the previous kernel's grid ending
(all blocks of the previous kernel have finished and memflushed), i.e.,
the instruction after this instruction will not be issued until the previous
grid has finished.
"""
llvm.inline_asm(
res=None,
operands_=[],
asm_string="griddepcontrol.wait;",
constraints="",
has_side_effects=True,
asm_dialect=llvm.AsmDialect.AD_ATT,
loc=loc,
ip=ip,
)
@dsl_user_op
def griddepcontrol_launch_dependents(*, loc=None, ip=None) -> None:
"""
Issuing the launch_dependents instruction hints a dependent kernel to launch earlier.
launch_dependents doesn't impact the functionality but the performance:
Launching a dependent kernel too early can compete with current kernels,
while launching too late can lead to a long latency.
"""
llvm.inline_asm(
res=None,
operands_=[],
asm_string="griddepcontrol.launch_dependents;",
constraints="",
has_side_effects=True,
asm_dialect=llvm.AsmDialect.AD_ATT,
loc=loc,
ip=ip,
)
@dsl_user_op
def cvt_f4e2m1_f16(src, *, loc=None, ip=None):

View File

@@ -15,7 +15,7 @@ from typing import Optional, Type
from cutlass import cute
from cutlass.base_dsl.arch import Arch
from cutlass.cutlass_dsl import CuTeDSL
from cutlass.cutlass_dsl import BaseDSL
import cutlass._mlir.dialects.cute_nvgpu as _cute_nvgpu_ir
from cutlass._mlir import ir
@@ -146,7 +146,7 @@ class CopyBulkTensorTileG2SOp(TmaCopyOp):
self, "expects the 'cta_group' parameter to be a CtaGroup instance"
)
# Arch verification
arch: Arch = CuTeDSL._get_dsl().get_arch_enum()
arch: Arch = BaseDSL._get_dsl().get_arch_enum()
if not arch >= Arch.sm_90:
raise OpError(
self,
@@ -263,7 +263,7 @@ class CopyBulkTensorTileG2SMulticastOp(TmaCopyOp):
self, "expects the 'cta_group' parameter to be a CtaGroup instance"
)
# Arch verification
arch = CuTeDSL._get_dsl().get_arch_enum()
arch = BaseDSL._get_dsl().get_arch_enum()
if not arch >= Arch.sm_90:
raise OpError(
self,
@@ -386,7 +386,7 @@ class CopyBulkTensorTileS2GOp(TmaCopyOp):
def __post_init__(self):
# Arch verification
arch = CuTeDSL._get_dsl().get_arch_enum()
arch = BaseDSL._get_dsl().get_arch_enum()
if not arch >= Arch.sm_90:
raise OpError(
self,
@@ -561,7 +561,7 @@ class CopyBulkG2SOp(CopyOp):
def __post_init__(self) -> None:
# Arch verification
arch: Arch = CuTeDSL._get_dsl().get_arch_enum()
arch: Arch = BaseDSL._get_dsl().get_arch_enum()
if not arch >= Arch.sm_90:
raise OpError(
self,
@@ -646,7 +646,7 @@ class CopyBulkG2SMulticastOp(CopyOp):
def __post_init__(self) -> None:
# Arch verification
arch: Arch = CuTeDSL._get_dsl().get_arch_enum()
arch: Arch = BaseDSL._get_dsl().get_arch_enum()
if not arch >= Arch.sm_90:
raise OpError(
self,
@@ -740,7 +740,7 @@ class CopyBulkS2GOp(CopyOp):
def __post_init__(self) -> None:
# Arch verification
arch: Arch = CuTeDSL._get_dsl().get_arch_enum()
arch: Arch = BaseDSL._get_dsl().get_arch_enum()
if not arch >= Arch.sm_90:
raise OpError(
self,

View File

@@ -15,7 +15,7 @@ from typing import Type
from cutlass import cute
from cutlass.base_dsl.arch import Arch
from cutlass.cutlass_dsl import CuTeDSL
from cutlass.cutlass_dsl import BaseDSL
import cutlass._mlir.dialects.cute_nvgpu as _cute_nvgpu_ir
from cutlass._mlir import ir
@@ -113,7 +113,7 @@ class _LdBase(CopyOp):
:raises OpError: If pack parameter is not a Pack instance
"""
# Arch verification
arch = CuTeDSL._get_dsl().get_arch_enum()
arch = BaseDSL._get_dsl().get_arch_enum()
if arch not in self.admissible_archs:
raise OpError(
self,
@@ -416,7 +416,7 @@ class _StBase(CopyOp):
def __post_init__(self) -> None:
# Arch verification
arch = CuTeDSL._get_dsl().get_arch_enum()
arch = BaseDSL._get_dsl().get_arch_enum()
if arch not in self.admissible_archs:
raise OpError(
self,
@@ -625,7 +625,7 @@ class _S2TCopyBase(CopyOp):
def __post_init__(self) -> None:
# Arch verification
arch = CuTeDSL._get_dsl().get_arch_enum()
arch = BaseDSL._get_dsl().get_arch_enum()
if not arch.is_family_of(Arch.sm_100f):
raise OpError(
self,

View File

@@ -15,7 +15,7 @@ from typing import Type, Any
from cutlass import cute
from cutlass.base_dsl.arch import Arch
from cutlass.cutlass_dsl import CuTeDSL, T
from cutlass.cutlass_dsl import BaseDSL, T
import cutlass._mlir.dialects.cute as _cute_ir
import cutlass._mlir.dialects.cute_nvgpu as _cute_nvgpu_ir
@@ -162,7 +162,7 @@ class MmaOp(Tcgen05MmaOp):
def __post_init__(self) -> None:
# Verify arch
arch = CuTeDSL._get_dsl().get_arch_enum()
arch = BaseDSL._get_dsl().get_arch_enum()
if arch not in self.admissible_archs:
raise OpError(
self,
@@ -314,7 +314,7 @@ class BlockScaledMmaOp(Tcgen05MmaOp):
def __post_init__(self) -> None:
# Verify arch
arch = CuTeDSL._get_dsl().get_arch_enum()
arch = BaseDSL._get_dsl().get_arch_enum()
if arch not in self.admissible_archs:
raise OpError(
self,
@@ -471,7 +471,7 @@ class SparseMmaOp(Tcgen05MmaOp):
def __post_init__(self) -> None:
# Verify arch
arch = CuTeDSL._get_dsl().get_arch_enum()
arch = BaseDSL._get_dsl().get_arch_enum()
if arch not in self.admissible_archs:
raise OpError(
self,

View File

@@ -15,7 +15,7 @@ from typing import Type, Any
from cutlass import cute
from cutlass.base_dsl.arch import Arch
from cutlass.cutlass_dsl import CuTeDSL, T
from cutlass.cutlass_dsl import BaseDSL, T
import cutlass._mlir.dialects.cute as _cute_ir
import cutlass._mlir.dialects.cute_nvgpu as _cute_nvgpu_ir
@@ -130,7 +130,7 @@ class MmaOp(WarpGroupMmaOp):
def __post_init__(self) -> None:
# Verify arch
arch = CuTeDSL._get_dsl().get_arch_enum()
arch = BaseDSL._get_dsl().get_arch_enum()
if not arch == Arch.sm_90a:
raise OpError(
self,

View File

@@ -925,7 +925,17 @@ def load_module(file_path: str, *, enable_tvm_ffi: bool = True):
if enable_tvm_ffi:
import tvm_ffi
return tvm_ffi.load_module(file_path)
try:
# keep_module_alive=False means the module will be unloaded
# after the returned module goes out of scope, this is useful
# for frequent loading and unloading of modules. The only requirement
# is that the module do not return object that have deleter in the module
# and the returned object lives longer than the module.
# DSL functions to not have such issue so it is desirable to set this to False.
return tvm_ffi.load_module(file_path, keep_module_alive=False)
except TypeError:
# compatible with tvm-ffi < 0.1.6
return tvm_ffi.load_module(file_path)
else:
raise DSLRuntimeError(
"Unimplemented, please load the module with enable_tvm_ffi=True."

View File

@@ -20,7 +20,7 @@ from cutlass.cutlass_dsl import (
T,
cutlass_arith,
_binary_op_type_promote,
CuTeDSL,
BaseDSL,
)
from cutlass._mlir import ir
import cutlass._mlir.dialects.cute as _cute_ir
@@ -1776,7 +1776,7 @@ class TensorSSA(cutlass_arith.ArithValue):
fast_cvt_func = cvt_i8_bf16_intrinsic
elif src_dtype == Int4 and dtype == BFloat16:
fast_cvt_func = cvt_i4_bf16_intrinsic
arch = CuTeDSL._get_dsl().get_arch_enum()
arch = BaseDSL._get_dsl().get_arch_enum()
if fast_cvt_func is not None and arch in fast_cvt_func.supported_archs:
res_vect = fast_cvt_func(src, size(self.shape), loc=loc, ip=ip)
else:

View File

@@ -407,7 +407,7 @@ def benchmark(
To use CUDA graphs, the callable must be a compiled @cute.jit annotated function.
When using CUDA graphs, the kernel must be launched in a non-default stream.
:param callable: The function to benchmark
:param callable: The function to benchmark. For jit function, it must be compiled functions.
:type callable: Callable
:param warmup_iterations: Number of warmup iterations, defaults to 10
:type warmup_iterations: int, optional
@@ -475,15 +475,6 @@ def benchmark(
elapsed_time = float("nan")
if use_cuda_graphs:
# Check if the callable is a JitCompiledFunction or JitExecutor
# These are functions that can be called to launch kernels
compiled_types = (
cutlass.base_dsl.jit_executor.JitCompiledFunction,
cutlass.base_dsl.jit_executor.JitExecutor,
)
if not isinstance(callable, compiled_types):
raise TypeError("Function must be precompiled to be used with CUDA Graphs")
# Check if the stream is a non-default stream
if int(stream) == int(cuda_driver.CUstream_flags.CU_STREAM_DEFAULT):
raise ValueError(

View File

@@ -247,7 +247,10 @@ class CutlassBaseDSL(BaseDSL):
return False
def _build_gpu_module(self, attrs, loc=None):
log().info(f"self : {self}")
log().info(f"Building GPU module for {self.name}")
self.gpu_module = gpu.GPUModuleOp(ir.StringAttr.get("kernels"), loc=loc)
log().info(f"GPU module: {self.gpu_module}")
with ir.InsertionPoint(self.gpu_module.bodyRegion.blocks.append(*[])):
pass
@@ -275,6 +278,9 @@ class CutlassBaseDSL(BaseDSL):
return pipeline
def _enter_gpu_module(self):
log().info(f"self: {self}")
log().info(f"Entering GPU module for {self.name}")
log().info(f"GPU module: {self.gpu_module}")
return ir.InsertionPoint(self.gpu_module.bodyRegion.blocks[0])
def _generate_kernel_attrs(self, config: BaseDSL.LaunchConfig) -> dict:

View File

@@ -126,16 +126,22 @@ class TVMFFICuteCallProvider(DynamicParamPackCallProvider):
)
context.module.body.append(parsed_op)
with ir.InsertionPoint(current_block):
cuda_global_state_ptr = self.address_of(
self.cuda_global_state_symbol, self.ptr_type
)
cuda_init_ptr = self.address_of("cuda_init", self.ptr_type)
cuda_load_to_device_ptr = self.address_of("cuda_load_to_device", self.ptr_type)
set_error_ptr = self.address_of(
"TVMFFIErrorSetRaisedFromCStr", self.ptr_type
)
cuda_init_ptr = context.builder.get_or_load_global_func_ptr_from_text(
current_block, "cuda_init"
)
cuda_load_to_device_ptr = context.builder.get_or_load_global_func_ptr_from_text(
current_block, "cuda_load_to_device"
)
set_error_ptr = context.builder.get_or_load_global_func_ptr_from_text(
current_block, "TVMFFIErrorSetRaisedFromCStr"
)
with ir.InsertionPoint(current_block):
# Call the callback function with the loaded ptr value
init_result = llvm.call(
result=self.i32_type, # function returns i32
@@ -495,6 +501,13 @@ class TVMFFIJitCompiledFunctionBase(CudaDialectJitCompiledFunction):
"""Create the tvm_ffi.Function from the current execution engine.
"""
if self.engine is not None:
# trigger eager compile of init callbacks
cuda_init = self.engine.raw_lookup("cuda_init")
cuda_load_to_device = self.engine.raw_lookup("cuda_load_to_device")
if cuda_init is None:
raise DSLRuntimeError("cuda_init not found")
if cuda_load_to_device is None:
raise DSLRuntimeError("cuda_load_to_device not found")
tvm_ffi_function_ptr = self.engine.raw_lookup(
"__tvm_ffi_" + self.function_name
)

View File

@@ -261,7 +261,7 @@ def make_smem_layout_a(
a_smem_layout_staged = cute.tile_to_shape(
a_smem_layout_atom,
cute.append(a_smem_shape, num_stages),
order=(0, 1, 2) if is_k_major else (0, 1, 2),
order=(0, 1, 2) if is_k_major else (1, 0, 2),
loc=loc,
ip=ip,
)

View File

@@ -1,3 +1,3 @@
# Use `pip install -r requirements.txt` with the present file to install a
# wheel consistent with the present state of the github repository
nvidia-cutlass-dsl==4.3.3
nvidia-cutlass-dsl==4.3.4