mirror of
https://github.com/NVIDIA/cutlass.git
synced 2026-04-20 06:48:59 +00:00
v4.3.4 update. (#2892)
This commit is contained in:
@@ -470,7 +470,56 @@ class DSLPreprocessor(ast.NodeTransformer):
|
||||
names=[ast.alias(name=f"{top_module_name}.base_dsl", asname="_dsl_")]
|
||||
)
|
||||
)
|
||||
transformed_tree.body = import_stmts + transformed_tree.body
|
||||
|
||||
assert len(transformed_tree.body) == 1
|
||||
assert isinstance(transformed_tree.body[0], ast.FunctionDef)
|
||||
transformed_tree.body[0].body = import_stmts + transformed_tree.body[0].body
|
||||
# Remove all decorators from top level function
|
||||
transformed_tree.body[0].decorator_list = []
|
||||
|
||||
# Step 4. Wrap the function with nonlocal captures, if has any
|
||||
# if the function has a nonlocal variable, wrap it in a function and return the function
|
||||
# pseudo code:
|
||||
# def foo():
|
||||
# nonlocal_var_0 = None
|
||||
# nonlocal_var_1 = None
|
||||
# def foo(args):
|
||||
# ...
|
||||
# return foo
|
||||
# foo = foo()
|
||||
nonlocals = {v: None for v in function_pointer.__code__.co_freevars}
|
||||
|
||||
if len(nonlocals) > 0:
|
||||
assignments = []
|
||||
for n, _ in nonlocals.items():
|
||||
assignments.append(
|
||||
ast.Assign(
|
||||
targets=[ast.Name(id=n, ctx=ast.Store())],
|
||||
value=ast.Constant(value=None),
|
||||
)
|
||||
)
|
||||
|
||||
return_expr = [ast.Return(value=ast.Name(id=func_name, ctx=ast.Load()))]
|
||||
|
||||
wrapper_fcn = ast.FunctionDef(
|
||||
name=func_name,
|
||||
args=ast.arguments(
|
||||
posonlyargs=[],
|
||||
args=[],
|
||||
kwonlyargs=[],
|
||||
kw_defaults=[],
|
||||
defaults=[],
|
||||
),
|
||||
body=assignments + transformed_tree.body + return_expr,
|
||||
decorator_list=[],
|
||||
)
|
||||
invoke = ast.Call(
|
||||
func=ast.Name(id=func_name, ctx=ast.Load()), args=[], keywords=[]
|
||||
)
|
||||
assign = ast.Assign(
|
||||
targets=[ast.Name(id=func_name, ctx=ast.Store())], value=invoke
|
||||
)
|
||||
transformed_tree.body = [wrapper_fcn, assign]
|
||||
|
||||
# Step 4. Import cutlass and base_dsl
|
||||
ast.fix_missing_locations(transformed_tree)
|
||||
@@ -1521,6 +1570,15 @@ class DSLPreprocessor(ast.NodeTransformer):
|
||||
self.scope_manager.add_to_scope(node.name)
|
||||
for arg in node.args.args:
|
||||
self.scope_manager.add_to_scope(arg.arg)
|
||||
arg.annotation = None
|
||||
|
||||
for arg in node.args.kwonlyargs:
|
||||
self.scope_manager.add_to_scope(arg.arg)
|
||||
arg.annotation = None
|
||||
|
||||
for arg in node.args.posonlyargs:
|
||||
self.scope_manager.add_to_scope(arg.arg)
|
||||
arg.annotation = None
|
||||
|
||||
self.generic_visit(node)
|
||||
|
||||
|
||||
@@ -622,18 +622,14 @@ class CompileCallable:
|
||||
func,
|
||||
)
|
||||
|
||||
# If it's a wrapped function created by jit decorator, get the original function
|
||||
if hasattr(func, "__wrapped__"):
|
||||
# If it's a wrapped function created by decorators, get the original function
|
||||
while hasattr(func, "__wrapped__"):
|
||||
func = func.__wrapped__
|
||||
|
||||
# Lazy initialization of DSL object if has not been initialized
|
||||
# Use local import to avoid circular import
|
||||
from .dsl import BaseDSL
|
||||
|
||||
BaseDSL._lazy_initialize_dsl(func)
|
||||
|
||||
if not hasattr(func, "_dsl_object"):
|
||||
raise DSLRuntimeError("Function is not decorated with jit decorator.")
|
||||
raise DSLRuntimeError(
|
||||
f"Function {func} is not decorated with jit decorator."
|
||||
)
|
||||
|
||||
# process compile options, extract the options and remove them from the kwargs
|
||||
options = kwargs.pop("options", None)
|
||||
@@ -645,8 +641,4 @@ class CompileCallable:
|
||||
else:
|
||||
compile_options = self._compile_options
|
||||
func._dsl_object.compile_options = compile_options
|
||||
fcn_ptr = func._dsl_object._preprocess_and_execute(func)
|
||||
|
||||
if hasattr(func, "_decorator_frame"):
|
||||
kwargs["_decorator_frame"] = func._decorator_frame
|
||||
return func._dsl_object._func(fcn_ptr, *args, **kwargs)
|
||||
return func._dsl_object._func(func, *args, **kwargs)
|
||||
|
||||
@@ -31,9 +31,10 @@ import weakref
|
||||
from functools import lru_cache, wraps
|
||||
from collections import namedtuple, OrderedDict
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Callable, List
|
||||
from typing import Any, Callable, List, ClassVar
|
||||
from types import SimpleNamespace
|
||||
import warnings
|
||||
import threading
|
||||
|
||||
from . import typing as t
|
||||
from .env_manager import EnvironmentVarManager
|
||||
@@ -228,49 +229,92 @@ def new_from_mlir_values(obj, values):
|
||||
assert len(values) == 0, f"{obj} expects 0 values, but got {values}"
|
||||
return obj
|
||||
|
||||
|
||||
class DSLCallable:
|
||||
@dataclass(frozen=True)
|
||||
class DSLLocation:
|
||||
"""
|
||||
Wrapper class for a callable object used within the DSL.
|
||||
|
||||
DSLCallable is designed to wrap a function and provide additional
|
||||
introspection utilities such as retrieving the argument specification
|
||||
and signature. It ensures that the wrapped function can only be called
|
||||
once, after which the reference to the function is cleared to prevent
|
||||
further invocations. This is useful in scenarios where a function should
|
||||
only be executed a single time within the DSL's execution model.
|
||||
Represents Python source location information for MLIR DSL code.
|
||||
|
||||
Attributes:
|
||||
func (callable): The function to be wrapped and managed.
|
||||
filename (str): Name of the Python source file.
|
||||
lineno (int): Line number in the source file.
|
||||
col_offset (int): Column offset in the source line.
|
||||
function_name (str): Name of the function in which the location occurs.
|
||||
|
||||
Methods:
|
||||
__call__(*args, **kwargs): Calls the wrapped function and clears it.
|
||||
This is used primarily to annotate or trace locations in generated MLIR IR
|
||||
back to the original Python code for better diagnostic and debugging.
|
||||
"""
|
||||
|
||||
def __init__(self, func):
|
||||
self.func = func
|
||||
self.name = func.__name__
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
ret = self.__func__(*args, **kwargs)
|
||||
self.func = None
|
||||
return ret
|
||||
|
||||
@property
|
||||
def __func__(self):
|
||||
assert self.func is not None, "DSLCallable is already called"
|
||||
return self.func
|
||||
|
||||
@property
|
||||
def __signature__(self):
|
||||
return inspect.signature(self.__func__)
|
||||
|
||||
@property
|
||||
def __name__(self):
|
||||
return self.name
|
||||
filename: str
|
||||
lineno: int
|
||||
col_offset: int
|
||||
function_name: str
|
||||
|
||||
|
||||
class BaseDSL:
|
||||
@dataclass
|
||||
class PreprocessSessionData:
|
||||
"""
|
||||
Holds metadata and transformed AST related to a DSL preprocessing session.
|
||||
|
||||
Attributes:
|
||||
decorator_globals (dict): The global variables from the decorator's environment,
|
||||
captured for possible AST or code evaluation during preprocessing.
|
||||
"""
|
||||
decorator_globals: dict
|
||||
|
||||
|
||||
class DSLSingletonMeta(type):
|
||||
"""
|
||||
Metaclass implementing the Singleton pattern for DSL classes.
|
||||
|
||||
The DSLSingletonMeta ensures that only one instance of a derived DSL class exists at any time.
|
||||
When a class is called, it checks if an instance already exists in the `_instances` dictionary.
|
||||
- If requesting `BaseDSL` itself, it asserts that a concrete subclass has been initialized,
|
||||
and returns the first available singleton instance among subclasses.
|
||||
- If requesting a concrete subclass, it creates a new instance if none exists, or returns
|
||||
the already created instance.
|
||||
|
||||
This metaclass is useful for maintaining global state and configuration across the DSL system,
|
||||
ensuring that all parts of the application operate on the same DSL instance.
|
||||
|
||||
Attributes:
|
||||
_instances (dict): Maps DSL classes to their singleton instances.
|
||||
|
||||
Example:
|
||||
class MyDSL(BaseDSL): ...
|
||||
dsl1 = MyDSL()
|
||||
dsl2 = MyDSL()
|
||||
assert dsl1 is dsl2 # Singleton property
|
||||
"""
|
||||
|
||||
_instances: ClassVar[dict] = {}
|
||||
_lock: ClassVar[threading.Lock] = threading.Lock()
|
||||
|
||||
def __call__(cls, *args, **kwargs):
|
||||
with cls._lock:
|
||||
log().info(f"DSLSingletonMeta __call__ for {cls}")
|
||||
if cls is BaseDSL:
|
||||
# If one is querying a BaseDSL which is abstract, returns an arbitrary instance of a concrete subclass should be fine.
|
||||
# Here we just return the first instance of a concrete subclass.
|
||||
assert cls._instances, (
|
||||
"Need to initialize a concrete subclass of BaseDSL first"
|
||||
)
|
||||
return next(iter(cls._instances.values()))
|
||||
elif cls not in cls._instances:
|
||||
instance = super().__call__(*args, **kwargs)
|
||||
cls._instances[cls] = instance
|
||||
log().info(f"Active DSL singleton instances: {cls._instances}")
|
||||
return cls._instances[cls]
|
||||
|
||||
def clear_instances(cls):
|
||||
log().info(
|
||||
f"Clearing DSL singleton instances for {cls}, current instances: {cls._instances}"
|
||||
)
|
||||
if cls in cls._instances:
|
||||
del cls._instances[cls]
|
||||
log().info(f"DSL singleton instances after clearing: {cls._instances}")
|
||||
|
||||
|
||||
class BaseDSL(metaclass=DSLSingletonMeta):
|
||||
gpu_module = None
|
||||
_env_class = EnvironmentVarManager
|
||||
|
||||
@@ -310,7 +354,8 @@ class BaseDSL:
|
||||
self.name = name
|
||||
self.compiler_provider = compiler_provider
|
||||
self.pass_sm_arch_name = pass_sm_arch_name
|
||||
self.frame = None
|
||||
self.preprocess_session_data = None
|
||||
self.decorator_location = None
|
||||
self.no_cache = False
|
||||
self.device_compilation_only = device_compilation_only
|
||||
self.num_kernels = 0
|
||||
@@ -379,7 +424,6 @@ class BaseDSL:
|
||||
warnings.warn(message, UserWarning)
|
||||
|
||||
@classmethod
|
||||
@lru_cache(maxsize=1)
|
||||
def _get_dsl(cls):
|
||||
# Instantiate the DSL Class once
|
||||
main_dsl = cls()
|
||||
@@ -414,38 +458,22 @@ class BaseDSL:
|
||||
return fcn_ptr
|
||||
|
||||
@staticmethod
|
||||
def _preprocess_and_execute(func):
|
||||
def _preprocess_and_replace_code(func):
|
||||
"""
|
||||
Run ast transformation and return the materialized function pointer
|
||||
"""
|
||||
|
||||
# Lazy initialization of DSL object if has not been initialized
|
||||
if not hasattr(func, "_dsl_object"):
|
||||
func._dsl_object = func._dsl_cls._get_dsl()
|
||||
delattr(func, "_dsl_cls")
|
||||
|
||||
if not func._dsl_object.enable_preprocessor:
|
||||
if hasattr(func, "_decorator_frame"):
|
||||
delattr(func, "_decorator_frame")
|
||||
if hasattr(func, "_transformed_ast"):
|
||||
delattr(func, "_transformed_ast")
|
||||
return func
|
||||
|
||||
if hasattr(func, "_transformed_ast"):
|
||||
if hasattr(func, "_preprocess_session_data"):
|
||||
# If the function ptr is already materialized, use the existing one
|
||||
func._dsl_object.frame = func._decorator_frame
|
||||
if func._transformed_ast is None:
|
||||
func._transformed_ast = func._dsl_object.run_preprocessor(func)
|
||||
if func._transformed_ast is None:
|
||||
del func._transformed_ast
|
||||
func._dsl_object.frame = None
|
||||
return func
|
||||
|
||||
fcn_ptr = func._dsl_object.get_function_ptr(func)
|
||||
# If the function is decorated, de-decorate it
|
||||
fcn_ptr = BaseDSL._get_original_function(fcn_ptr, func.__name__)
|
||||
func._dsl_object.frame = None
|
||||
return DSLCallable(fcn_ptr)
|
||||
func._dsl_object.preprocess_session_data = func._preprocess_session_data
|
||||
func._dsl_object.decorator_location = func._decorator_location
|
||||
transformed_ast = func._dsl_object.run_preprocessor(func)
|
||||
fcn_ptr = func._dsl_object.get_function_ptr(func, transformed_ast)
|
||||
func.__code__ = (
|
||||
fcn_ptr.__code__
|
||||
if not isinstance(fcn_ptr, staticmethod)
|
||||
else fcn_ptr.__func__.__code__
|
||||
)
|
||||
return func
|
||||
|
||||
@staticmethod
|
||||
@@ -457,20 +485,27 @@ class BaseDSL:
|
||||
|
||||
def jit_runner_decorator(func):
|
||||
# Run preprocessor that alters AST
|
||||
func._dsl_cls = cls
|
||||
if BaseDSL._can_preprocess(**dkwargs):
|
||||
func._dsl_object = cls._get_dsl()
|
||||
func._decorator_location = BaseDSL.get_location_from_frame(frame)
|
||||
if (
|
||||
func._dsl_object.enable_preprocessor
|
||||
and func._dsl_object._can_preprocess(**dkwargs)
|
||||
):
|
||||
# For an annotated function, add some DSL attributes
|
||||
# When materializing the AST, we need decorator's frame
|
||||
func._decorator_frame = frame
|
||||
# No transformed ast at this point
|
||||
func._transformed_ast = None
|
||||
func._preprocess_session_data = PreprocessSessionData(
|
||||
decorator_globals=frame.f_globals,
|
||||
)
|
||||
BaseDSL._preprocess_and_replace_code(func)
|
||||
|
||||
@wraps(func)
|
||||
def jit_wrapper(*args, **kwargs):
|
||||
func_ptr = BaseDSL._preprocess_and_execute(func)
|
||||
return getattr(func._dsl_object, executor_name)(
|
||||
func_ptr, *args, **kwargs
|
||||
)
|
||||
return getattr(func._dsl_object, executor_name)(func, *args, **kwargs)
|
||||
|
||||
def set_name_prefix(name: str):
|
||||
jit_wrapper._name_prefix = name
|
||||
|
||||
jit_wrapper.set_name_prefix = set_name_prefix
|
||||
|
||||
return jit_wrapper
|
||||
|
||||
@@ -479,15 +514,6 @@ class BaseDSL:
|
||||
else:
|
||||
return jit_runner_decorator
|
||||
|
||||
@staticmethod
|
||||
def _lazy_initialize_dsl(func):
|
||||
"""
|
||||
Lazy initialization of DSL object if has not been initialized
|
||||
"""
|
||||
if hasattr(func, "_dsl_cls"):
|
||||
func._dsl_object = func._dsl_cls._get_dsl()
|
||||
delattr(func, "_dsl_cls")
|
||||
|
||||
@classmethod
|
||||
def jit(cls, *dargs, **dkwargs):
|
||||
"""
|
||||
@@ -516,6 +542,7 @@ class BaseDSL:
|
||||
"""
|
||||
Build the module op that contains the kernels.
|
||||
"""
|
||||
log().info(f"[abstract] Building GPU module for {self.name}")
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
@@ -688,9 +715,11 @@ class BaseDSL:
|
||||
dictionary is used to execute the python code.
|
||||
"""
|
||||
all_globals = {}
|
||||
if self.frame:
|
||||
all_globals.update(self.frame.f_globals)
|
||||
all_globals.update(self.frame.f_locals)
|
||||
if (
|
||||
self.preprocess_session_data
|
||||
and self.preprocess_session_data.decorator_globals
|
||||
):
|
||||
all_globals.update(self.preprocess_session_data.decorator_globals)
|
||||
return all_globals
|
||||
|
||||
@abstractmethod
|
||||
@@ -955,25 +984,40 @@ class BaseDSL:
|
||||
else:
|
||||
ir._GlobalDebug.set_types(f"diagnostic-{args.diagnostic}")
|
||||
|
||||
def get_location(self, frame=None):
|
||||
"""
|
||||
Get python location information and generate MLIR location
|
||||
"""
|
||||
frame = self.frame if frame is None else frame
|
||||
frame = inspect.currentframe().f_back if frame is None else frame
|
||||
@staticmethod
|
||||
def get_location_from_frame(frame):
|
||||
frameInfo = inspect.getframeinfo(frame)
|
||||
|
||||
file_loc = ir.Location.file(
|
||||
frame.f_code.co_filename,
|
||||
frame.f_lineno,
|
||||
frameInfo.positions.col_offset if hasattr(frameInfo, "positions") else 0,
|
||||
)
|
||||
loc = ir.Location.name(
|
||||
(
|
||||
return DSLLocation(
|
||||
filename=frameInfo.filename,
|
||||
lineno=frameInfo.lineno,
|
||||
col_offset=(
|
||||
frameInfo.positions.col_offset if hasattr(frameInfo, "positions") else 0
|
||||
),
|
||||
function_name=(
|
||||
"".join([c.strip() for c in frameInfo.code_context])
|
||||
if frameInfo.code_context
|
||||
else frameInfo.function
|
||||
),
|
||||
)
|
||||
|
||||
def get_ir_location(self, location: DSLLocation = None):
|
||||
"""
|
||||
Get python location information and generate MLIR location
|
||||
"""
|
||||
if location is None:
|
||||
if self.decorator_location:
|
||||
location = self.decorator_location
|
||||
|
||||
if location is None:
|
||||
return ir.Location.unknown()
|
||||
|
||||
file_loc = ir.Location.file(
|
||||
location.filename,
|
||||
location.lineno,
|
||||
location.col_offset,
|
||||
)
|
||||
loc = ir.Location.name(
|
||||
(location.function_name),
|
||||
childLoc=file_loc,
|
||||
)
|
||||
return loc
|
||||
@@ -1140,10 +1184,10 @@ class BaseDSL:
|
||||
gpu_module_attrs,
|
||||
args,
|
||||
args_spec,
|
||||
frame=None,
|
||||
location=None,
|
||||
):
|
||||
def build_ir_module():
|
||||
loc = self.get_location(frame)
|
||||
loc = self.get_ir_location(location)
|
||||
module = ir.Module.create(loc=loc)
|
||||
unit_attr = ir.UnitAttr.get()
|
||||
module.operation.attributes["gpu.container_module"] = unit_attr
|
||||
@@ -1308,6 +1352,10 @@ class BaseDSL:
|
||||
self.num_kernels = 0
|
||||
# reset the compile options after the compilation is done.
|
||||
self.compile_options = CompileOptions()
|
||||
# reset preprocess session data after the compilation is done.
|
||||
self.preprocess_session_data = None
|
||||
# reset decorator location after the compilation is done.
|
||||
self.decorator_location = None
|
||||
|
||||
def extract_dynamic_args(self, funcBody, args, kwargs, args_spec):
|
||||
"""This function is used to extract the original dynamic arguments for AOT C header generation.
|
||||
@@ -1348,11 +1396,10 @@ class BaseDSL:
|
||||
pipeline,
|
||||
no_cache,
|
||||
compile_only,
|
||||
loc=None,
|
||||
frame=None,
|
||||
location=None,
|
||||
):
|
||||
"""Generate MLIR module and compile iself.T_provider."""
|
||||
with ir.Context(), self.get_location(frame):
|
||||
with ir.Context(), self.get_ir_location(location):
|
||||
try:
|
||||
# Convert input arguments to MLIR arguments
|
||||
exe_args, func_types, adapted_args = self.generate_mlir_function_types(
|
||||
@@ -1374,7 +1421,7 @@ class BaseDSL:
|
||||
gpu_module_attrs,
|
||||
args,
|
||||
args_spec,
|
||||
frame=frame,
|
||||
location=location,
|
||||
)
|
||||
|
||||
# dryrun is used to only generate IR
|
||||
@@ -1437,11 +1484,14 @@ class BaseDSL:
|
||||
return transformed_ast
|
||||
return None
|
||||
|
||||
def get_function_ptr(self, original_function):
|
||||
def get_function_ptr(self, original_function, transformed_ast):
|
||||
file_name = inspect.getsourcefile(original_function)
|
||||
code_object = compile(
|
||||
original_function._transformed_ast, filename=file_name, mode="exec"
|
||||
transformed_ast,
|
||||
filename=file_name,
|
||||
mode="exec",
|
||||
)
|
||||
|
||||
return self.preprocessor.exec(
|
||||
original_function.__name__,
|
||||
original_function,
|
||||
@@ -1523,7 +1573,7 @@ class BaseDSL:
|
||||
|
||||
pipeline = kwargs.pop("pipeline", None)
|
||||
gpu_module_attrs = kwargs.pop("gpu_module_attrs", {})
|
||||
decorator_frame = kwargs.pop("_decorator_frame", None)
|
||||
self.decorator_location = getattr(funcBody, "_decorator_location", None)
|
||||
|
||||
# Disable cache
|
||||
no_cache = kwargs.pop("no_cache", False)
|
||||
@@ -1556,7 +1606,7 @@ class BaseDSL:
|
||||
function_name = self.mangle_name(function_name, canonicalized_args, args_spec)
|
||||
self.compile_options.apply_envar_settings(self.envar, function_name)
|
||||
if not self.compile_options.generate_line_info:
|
||||
decorator_frame = None
|
||||
self.decorator_location = None
|
||||
|
||||
# Generate MLIR Context and start generating IR
|
||||
log().debug(f"Generating MLIR for function '{function_name}'")
|
||||
@@ -1570,7 +1620,7 @@ class BaseDSL:
|
||||
pipeline,
|
||||
no_cache,
|
||||
compile_only,
|
||||
frame=decorator_frame,
|
||||
location=self.decorator_location,
|
||||
)
|
||||
return result
|
||||
|
||||
@@ -1679,8 +1729,7 @@ class BaseDSL:
|
||||
"""
|
||||
ret = None
|
||||
|
||||
with ir.Context(), self.get_location():
|
||||
loc = self.get_location()
|
||||
with ir.Context(), self.get_ir_location() as loc:
|
||||
module = ir.Module.create(loc=loc)
|
||||
unit_attr = ir.UnitAttr.get()
|
||||
module.operation.attributes["gpu.container_module"] = unit_attr
|
||||
@@ -1819,7 +1868,7 @@ class BaseDSL:
|
||||
)
|
||||
)
|
||||
|
||||
loc = self.get_location()
|
||||
loc = self.get_ir_location()
|
||||
with self._enter_gpu_module():
|
||||
log().debug("Generating device kernel")
|
||||
if self.device_compilation_only:
|
||||
|
||||
@@ -138,6 +138,7 @@ class MLIRBuilder(MLIRTypeBuilder):
|
||||
super().__init__()
|
||||
self.module: Optional[ir.Module] = None
|
||||
self.const_str_table: dict[str, ir.Value] = {}
|
||||
self.const_func_ptr_table: dict[str, ir.Value] = {}
|
||||
self.get_element_extra_kwargs: dict[str, Any] = {}
|
||||
|
||||
# create constants
|
||||
@@ -368,6 +369,64 @@ class MLIRBuilder(MLIRTypeBuilder):
|
||||
self.const_str_table[content] = symbol
|
||||
return symbol
|
||||
|
||||
def get_or_load_global_func_ptr_from_text(
|
||||
self,
|
||||
current_block: ir.Block,
|
||||
function_name: str,
|
||||
) -> ir.Value:
|
||||
"""Get or create a function pointer global in .text section and load it.
|
||||
|
||||
This creates a constant global function pointer in the .text section
|
||||
(for AArch64 ADRP range compatibility) and performs a volatile load
|
||||
to prevent optimization.
|
||||
|
||||
This forces the function pointer to be local to the code, bypassing GOT entry
|
||||
ADRP lookup issues on AArch64 when GOT and .text section are more than 4GB
|
||||
apart which can happen when ASLR is applied.
|
||||
"""
|
||||
# Check if we've already created this global
|
||||
if function_name not in self.const_func_ptr_table:
|
||||
symbol = f"__func_ptr_{function_name}"
|
||||
|
||||
module_body = self.module.body
|
||||
with ir.InsertionPoint(module_body):
|
||||
# 1. Create the global constant
|
||||
# We use 'private' linkage so it doesn't conflict across modules
|
||||
global_ptr = llvm.GlobalOp(
|
||||
self.ptr_type,
|
||||
symbol,
|
||||
ir.Attribute.parse("#llvm.linkage<private>"),
|
||||
# Initialization via block below
|
||||
)
|
||||
|
||||
# 2. Set the necessary attributes for JIT safety and AArch64 range
|
||||
# We use 'constant' to mark it as immutable
|
||||
# We use 'section = ".text"' to force it into the code block
|
||||
global_ptr.attributes["constant"] = ir.UnitAttr.get()
|
||||
global_ptr.attributes["section"] = ir.StringAttr.get(".text")
|
||||
|
||||
# 3. Add a constructor block to the GlobalOp to initialize it
|
||||
# with the address of the target function
|
||||
initializer_block = global_ptr.initializer.blocks.append()
|
||||
with ir.InsertionPoint(initializer_block):
|
||||
# Get the address of the external function
|
||||
func_addr = llvm.AddressOfOp(self.ptr_type, function_name).res
|
||||
# Return the address as the initial value of the global
|
||||
llvm.return_(arg=func_addr)
|
||||
|
||||
self.const_func_ptr_table[function_name] = symbol
|
||||
else:
|
||||
symbol = self.const_func_ptr_table[function_name]
|
||||
|
||||
# Load it with volatile semantics in the current block
|
||||
with ir.InsertionPoint(current_block):
|
||||
symbol_addr = self.address_of(symbol, self.ptr_type)
|
||||
# Perform a volatile load to prevent optimization
|
||||
load_op = llvm.load(self.ptr_type, symbol_addr)
|
||||
# Set volatile attribute to prevent optimization
|
||||
load_op.owner.attributes["volatile_"] = ir.UnitAttr.get()
|
||||
return load_op
|
||||
|
||||
# function
|
||||
def function(
|
||||
self,
|
||||
|
||||
@@ -210,7 +210,7 @@ EnableTVMFFI = _dsl.EnableTVMFFI
|
||||
# attach the TVM FFI ABI interface postprocessor to the DSL
|
||||
from . import _tvm_ffi_args_spec_converter
|
||||
|
||||
_tvm_ffi_args_spec_converter.attach_args_spec_converter()
|
||||
_tvm_ffi_args_spec_converter.attach_args_spec_converter(_dsl.CuTeDSL._get_dsl())
|
||||
|
||||
# Explicitly export all symbols for documentation generation
|
||||
__all__ = [
|
||||
|
||||
@@ -395,8 +395,6 @@ def _tvm_ffi_args_spec_converter(
|
||||
return params, kwargs_wrapper_spec
|
||||
|
||||
|
||||
def attach_args_spec_converter():
|
||||
"""Attach TVM FFI ABI interface postprocessor to the DSL."""
|
||||
from .. import cutlass_dsl as _dsl
|
||||
|
||||
_dsl.CuTeDSL._get_dsl()._tvm_ffi_args_spec_converter = _tvm_ffi_args_spec_converter
|
||||
def attach_args_spec_converter(dsl):
|
||||
"""Attach TVM FFI ABI interface postprocessor to the DSL instance."""
|
||||
dsl._tvm_ffi_args_spec_converter = _tvm_ffi_args_spec_converter
|
||||
|
||||
@@ -10,7 +10,7 @@
|
||||
# is strictly prohibited.
|
||||
|
||||
from cutlass.base_dsl.arch import Arch
|
||||
from cutlass.cutlass_dsl import CuTeDSL, T, dsl_user_op
|
||||
from cutlass.cutlass_dsl import BaseDSL, T, dsl_user_op
|
||||
|
||||
import cutlass._mlir.dialects.cute_nvgpu as _cute_nvgpu_ir
|
||||
from cutlass._mlir.dialects import nvvm, scf
|
||||
@@ -69,7 +69,7 @@ def elect_one(*, loc=None, ip=None) -> IfOpRegion:
|
||||
# Only one thread in the warp executes the code in this context
|
||||
pass
|
||||
"""
|
||||
CuTeDSL._get_dsl().check_arch(lambda arch: arch >= Arch.sm_90)
|
||||
BaseDSL._get_dsl().check_arch(lambda arch: arch >= Arch.sm_90)
|
||||
is_thread_leader = nvvm.elect_sync(T.bool())
|
||||
if_op = scf.IfOp(is_thread_leader, loc=loc, ip=ip)
|
||||
return IfOpRegion(if_op.then_block, loc=loc, ip=ip)
|
||||
|
||||
@@ -11,7 +11,7 @@
|
||||
from typing import Optional
|
||||
|
||||
from cutlass.base_dsl.arch import Arch
|
||||
from cutlass.cutlass_dsl import CuTeDSL, T, if_generate, dsl_user_op
|
||||
from cutlass.cutlass_dsl import BaseDSL, T, if_generate, dsl_user_op
|
||||
|
||||
from cutlass._mlir.dialects import nvvm
|
||||
|
||||
@@ -44,7 +44,7 @@ def mbarrier_init_fence(*, loc=None, ip=None) -> None:
|
||||
"""
|
||||
A fence operation that applies to the mbarrier initializations.
|
||||
"""
|
||||
CuTeDSL._get_dsl().check_arch(lambda arch: arch >= Arch.sm_90)
|
||||
BaseDSL._get_dsl().check_arch(lambda arch: arch >= Arch.sm_90)
|
||||
nvvm.fence_mbarrier_init(loc=loc, ip=ip)
|
||||
|
||||
|
||||
@@ -63,7 +63,7 @@ def mbarrier_arrive_and_expect_tx(
|
||||
the mbarrier is converted to a remote address in the peer CTA's
|
||||
SMEM.
|
||||
"""
|
||||
CuTeDSL._get_dsl().check_arch(lambda arch: arch >= Arch.sm_90)
|
||||
BaseDSL._get_dsl().check_arch(lambda arch: arch >= Arch.sm_90)
|
||||
|
||||
mbar_llvm_ptr = mbar_ptr.llvm_ptr
|
||||
if peer_cta_rank_in_cluster is not None:
|
||||
@@ -103,7 +103,7 @@ def mbarrier_expect_tx(
|
||||
the mbarrier is converted to a remote address in the peer CTA's
|
||||
SMEM.
|
||||
"""
|
||||
CuTeDSL._get_dsl().check_arch(lambda arch: arch >= Arch.sm_90)
|
||||
BaseDSL._get_dsl().check_arch(lambda arch: arch >= Arch.sm_90)
|
||||
|
||||
mbar_llvm_ptr = mbar_ptr.llvm_ptr
|
||||
if peer_cta_rank_in_cluster is not None:
|
||||
@@ -138,7 +138,7 @@ def mbarrier_wait(mbar_ptr: Pointer, phase: Int, *, loc=None, ip=None) -> None:
|
||||
:param phase: The phase to wait for (either 0 or 1)
|
||||
:type phase: Int
|
||||
"""
|
||||
CuTeDSL._get_dsl().check_arch(lambda arch: arch >= Arch.sm_90)
|
||||
BaseDSL._get_dsl().check_arch(lambda arch: arch >= Arch.sm_90)
|
||||
|
||||
timeout_ns = 10000000
|
||||
# This NVVM Op is a spin-loop wrapping the mbarrier.try_wait.parity.shared.b64 PTX
|
||||
@@ -164,7 +164,7 @@ def mbarrier_try_wait(mbar_ptr: Pointer, phase: Int, *, loc=None, ip=None) -> Bo
|
||||
:return: A boolean value indicating whether the wait operation was successful
|
||||
:rtype: Boolean
|
||||
"""
|
||||
CuTeDSL._get_dsl().check_arch(lambda arch: arch >= Arch.sm_90)
|
||||
BaseDSL._get_dsl().check_arch(lambda arch: arch >= Arch.sm_90)
|
||||
|
||||
return Boolean(
|
||||
nvvm.mbarrier_wait_parity(
|
||||
@@ -193,7 +193,7 @@ def mbarrier_conditional_try_wait(
|
||||
:return: A boolean value indicating whether the wait operation was successful
|
||||
:rtype: Boolean
|
||||
"""
|
||||
CuTeDSL._get_dsl().check_arch(lambda arch: arch >= Arch.sm_90)
|
||||
BaseDSL._get_dsl().check_arch(lambda arch: arch >= Arch.sm_90)
|
||||
return if_generate(
|
||||
cond,
|
||||
lambda: mbarrier_try_wait(mbar_ptr, phase, loc=loc, ip=ip),
|
||||
@@ -225,7 +225,7 @@ def mbarrier_arrive(
|
||||
"""
|
||||
mbar_llvm_ptr = mbar_ptr.llvm_ptr
|
||||
if peer_cta_rank_in_cluster is not None:
|
||||
CuTeDSL._get_dsl().check_arch(lambda arch: arch >= Arch.sm_90)
|
||||
BaseDSL._get_dsl().check_arch(lambda arch: arch >= Arch.sm_90)
|
||||
|
||||
mbar_llvm_ptr = nvvm.mapa_shared_cluster(
|
||||
mbar_llvm_ptr.type,
|
||||
@@ -259,7 +259,7 @@ def cp_async_mbarrier_arrive_noinc(mbar_ptr: Pointer, *, loc=None, ip=None) -> N
|
||||
:param mbar_ptr: A pointer to the mbarrier in SMEM
|
||||
:type mbar_ptr: Pointer
|
||||
"""
|
||||
CuTeDSL._get_dsl().check_arch(lambda arch: arch >= Arch.sm_90)
|
||||
BaseDSL._get_dsl().check_arch(lambda arch: arch >= Arch.sm_90)
|
||||
|
||||
mbar_llvm_ptr = mbar_ptr.llvm_ptr
|
||||
nvvm.cp_async_mbarrier_arrive_shared(
|
||||
|
||||
@@ -12,7 +12,8 @@
|
||||
|
||||
from cutlass.base_dsl.arch import Arch
|
||||
from cutlass.base_dsl.common import DSLRuntimeError
|
||||
from cutlass.cutlass_dsl import CuTeDSL, dsl_user_op
|
||||
from cutlass.cutlass_dsl import BaseDSL, dsl_user_op
|
||||
|
||||
from cutlass._mlir import ir
|
||||
from cutlass._mlir.dialects import builtin, arith, llvm, vector
|
||||
|
||||
@@ -53,7 +54,7 @@ def cvt_i8_bf16_intrinsic(vec_i8, length, *, loc=None, ip=None):
|
||||
:return: The output 1D vector of bfloat16 with the same length as the input vector.
|
||||
:rtype: 1D vector of bfloat16
|
||||
"""
|
||||
arch = CuTeDSL._get_dsl().get_arch_enum()
|
||||
arch = BaseDSL._get_dsl().get_arch_enum()
|
||||
if not arch in cvt_i8_bf16_intrinsic.supported_archs:
|
||||
raise DSLRuntimeError(f"cvt_i8_bf16_intrinsic is not supported on {arch}")
|
||||
src_pos = 0
|
||||
@@ -130,7 +131,7 @@ def cvt_i4_bf16_intrinsic(vec_i4, length, *, loc=None, ip=None):
|
||||
:return: The output 1D vector of bfloat16 with the same length as the input vector.
|
||||
:rtype: 1D vector of bfloat16
|
||||
"""
|
||||
arch = CuTeDSL._get_dsl().get_arch_enum()
|
||||
arch = BaseDSL._get_dsl().get_arch_enum()
|
||||
if not arch in cvt_i4_bf16_intrinsic.supported_archs:
|
||||
raise DSLRuntimeError(f"cvt_i4_bf16_intrinsic is not supported on {arch}")
|
||||
src_pos = 0
|
||||
|
||||
@@ -1305,6 +1305,46 @@ def exp_packed_f32x2(
|
||||
return exp2(b[0], loc=loc, ip=ip), exp2(b[1], loc=loc, ip=ip)
|
||||
|
||||
|
||||
@dsl_user_op
|
||||
def griddepcontrol_wait(*, loc=None, ip=None) -> None:
|
||||
"""
|
||||
This instruction is used to wait for the previous kernel's grid ending
|
||||
(all blocks of the previous kernel have finished and memflushed), i.e.,
|
||||
the instruction after this instruction will not be issued until the previous
|
||||
grid has finished.
|
||||
"""
|
||||
llvm.inline_asm(
|
||||
res=None,
|
||||
operands_=[],
|
||||
asm_string="griddepcontrol.wait;",
|
||||
constraints="",
|
||||
has_side_effects=True,
|
||||
asm_dialect=llvm.AsmDialect.AD_ATT,
|
||||
loc=loc,
|
||||
ip=ip,
|
||||
)
|
||||
|
||||
|
||||
@dsl_user_op
|
||||
def griddepcontrol_launch_dependents(*, loc=None, ip=None) -> None:
|
||||
"""
|
||||
Issuing the launch_dependents instruction hints a dependent kernel to launch earlier.
|
||||
launch_dependents doesn't impact the functionality but the performance:
|
||||
Launching a dependent kernel too early can compete with current kernels,
|
||||
while launching too late can lead to a long latency.
|
||||
"""
|
||||
llvm.inline_asm(
|
||||
res=None,
|
||||
operands_=[],
|
||||
asm_string="griddepcontrol.launch_dependents;",
|
||||
constraints="",
|
||||
has_side_effects=True,
|
||||
asm_dialect=llvm.AsmDialect.AD_ATT,
|
||||
loc=loc,
|
||||
ip=ip,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@dsl_user_op
|
||||
def cvt_f4e2m1_f16(src, *, loc=None, ip=None):
|
||||
|
||||
@@ -15,7 +15,7 @@ from typing import Optional, Type
|
||||
|
||||
from cutlass import cute
|
||||
from cutlass.base_dsl.arch import Arch
|
||||
from cutlass.cutlass_dsl import CuTeDSL
|
||||
from cutlass.cutlass_dsl import BaseDSL
|
||||
|
||||
import cutlass._mlir.dialects.cute_nvgpu as _cute_nvgpu_ir
|
||||
from cutlass._mlir import ir
|
||||
@@ -146,7 +146,7 @@ class CopyBulkTensorTileG2SOp(TmaCopyOp):
|
||||
self, "expects the 'cta_group' parameter to be a CtaGroup instance"
|
||||
)
|
||||
# Arch verification
|
||||
arch: Arch = CuTeDSL._get_dsl().get_arch_enum()
|
||||
arch: Arch = BaseDSL._get_dsl().get_arch_enum()
|
||||
if not arch >= Arch.sm_90:
|
||||
raise OpError(
|
||||
self,
|
||||
@@ -263,7 +263,7 @@ class CopyBulkTensorTileG2SMulticastOp(TmaCopyOp):
|
||||
self, "expects the 'cta_group' parameter to be a CtaGroup instance"
|
||||
)
|
||||
# Arch verification
|
||||
arch = CuTeDSL._get_dsl().get_arch_enum()
|
||||
arch = BaseDSL._get_dsl().get_arch_enum()
|
||||
if not arch >= Arch.sm_90:
|
||||
raise OpError(
|
||||
self,
|
||||
@@ -386,7 +386,7 @@ class CopyBulkTensorTileS2GOp(TmaCopyOp):
|
||||
|
||||
def __post_init__(self):
|
||||
# Arch verification
|
||||
arch = CuTeDSL._get_dsl().get_arch_enum()
|
||||
arch = BaseDSL._get_dsl().get_arch_enum()
|
||||
if not arch >= Arch.sm_90:
|
||||
raise OpError(
|
||||
self,
|
||||
@@ -561,7 +561,7 @@ class CopyBulkG2SOp(CopyOp):
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
# Arch verification
|
||||
arch: Arch = CuTeDSL._get_dsl().get_arch_enum()
|
||||
arch: Arch = BaseDSL._get_dsl().get_arch_enum()
|
||||
if not arch >= Arch.sm_90:
|
||||
raise OpError(
|
||||
self,
|
||||
@@ -646,7 +646,7 @@ class CopyBulkG2SMulticastOp(CopyOp):
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
# Arch verification
|
||||
arch: Arch = CuTeDSL._get_dsl().get_arch_enum()
|
||||
arch: Arch = BaseDSL._get_dsl().get_arch_enum()
|
||||
if not arch >= Arch.sm_90:
|
||||
raise OpError(
|
||||
self,
|
||||
@@ -740,7 +740,7 @@ class CopyBulkS2GOp(CopyOp):
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
# Arch verification
|
||||
arch: Arch = CuTeDSL._get_dsl().get_arch_enum()
|
||||
arch: Arch = BaseDSL._get_dsl().get_arch_enum()
|
||||
if not arch >= Arch.sm_90:
|
||||
raise OpError(
|
||||
self,
|
||||
|
||||
@@ -15,7 +15,7 @@ from typing import Type
|
||||
|
||||
from cutlass import cute
|
||||
from cutlass.base_dsl.arch import Arch
|
||||
from cutlass.cutlass_dsl import CuTeDSL
|
||||
from cutlass.cutlass_dsl import BaseDSL
|
||||
|
||||
import cutlass._mlir.dialects.cute_nvgpu as _cute_nvgpu_ir
|
||||
from cutlass._mlir import ir
|
||||
@@ -113,7 +113,7 @@ class _LdBase(CopyOp):
|
||||
:raises OpError: If pack parameter is not a Pack instance
|
||||
"""
|
||||
# Arch verification
|
||||
arch = CuTeDSL._get_dsl().get_arch_enum()
|
||||
arch = BaseDSL._get_dsl().get_arch_enum()
|
||||
if arch not in self.admissible_archs:
|
||||
raise OpError(
|
||||
self,
|
||||
@@ -416,7 +416,7 @@ class _StBase(CopyOp):
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
# Arch verification
|
||||
arch = CuTeDSL._get_dsl().get_arch_enum()
|
||||
arch = BaseDSL._get_dsl().get_arch_enum()
|
||||
if arch not in self.admissible_archs:
|
||||
raise OpError(
|
||||
self,
|
||||
@@ -625,7 +625,7 @@ class _S2TCopyBase(CopyOp):
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
# Arch verification
|
||||
arch = CuTeDSL._get_dsl().get_arch_enum()
|
||||
arch = BaseDSL._get_dsl().get_arch_enum()
|
||||
if not arch.is_family_of(Arch.sm_100f):
|
||||
raise OpError(
|
||||
self,
|
||||
|
||||
@@ -15,7 +15,7 @@ from typing import Type, Any
|
||||
|
||||
from cutlass import cute
|
||||
from cutlass.base_dsl.arch import Arch
|
||||
from cutlass.cutlass_dsl import CuTeDSL, T
|
||||
from cutlass.cutlass_dsl import BaseDSL, T
|
||||
|
||||
import cutlass._mlir.dialects.cute as _cute_ir
|
||||
import cutlass._mlir.dialects.cute_nvgpu as _cute_nvgpu_ir
|
||||
@@ -162,7 +162,7 @@ class MmaOp(Tcgen05MmaOp):
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
# Verify arch
|
||||
arch = CuTeDSL._get_dsl().get_arch_enum()
|
||||
arch = BaseDSL._get_dsl().get_arch_enum()
|
||||
if arch not in self.admissible_archs:
|
||||
raise OpError(
|
||||
self,
|
||||
@@ -314,7 +314,7 @@ class BlockScaledMmaOp(Tcgen05MmaOp):
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
# Verify arch
|
||||
arch = CuTeDSL._get_dsl().get_arch_enum()
|
||||
arch = BaseDSL._get_dsl().get_arch_enum()
|
||||
if arch not in self.admissible_archs:
|
||||
raise OpError(
|
||||
self,
|
||||
@@ -471,7 +471,7 @@ class SparseMmaOp(Tcgen05MmaOp):
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
# Verify arch
|
||||
arch = CuTeDSL._get_dsl().get_arch_enum()
|
||||
arch = BaseDSL._get_dsl().get_arch_enum()
|
||||
if arch not in self.admissible_archs:
|
||||
raise OpError(
|
||||
self,
|
||||
|
||||
@@ -15,7 +15,7 @@ from typing import Type, Any
|
||||
|
||||
from cutlass import cute
|
||||
from cutlass.base_dsl.arch import Arch
|
||||
from cutlass.cutlass_dsl import CuTeDSL, T
|
||||
from cutlass.cutlass_dsl import BaseDSL, T
|
||||
|
||||
import cutlass._mlir.dialects.cute as _cute_ir
|
||||
import cutlass._mlir.dialects.cute_nvgpu as _cute_nvgpu_ir
|
||||
@@ -130,7 +130,7 @@ class MmaOp(WarpGroupMmaOp):
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
# Verify arch
|
||||
arch = CuTeDSL._get_dsl().get_arch_enum()
|
||||
arch = BaseDSL._get_dsl().get_arch_enum()
|
||||
if not arch == Arch.sm_90a:
|
||||
raise OpError(
|
||||
self,
|
||||
|
||||
@@ -925,7 +925,17 @@ def load_module(file_path: str, *, enable_tvm_ffi: bool = True):
|
||||
if enable_tvm_ffi:
|
||||
import tvm_ffi
|
||||
|
||||
return tvm_ffi.load_module(file_path)
|
||||
try:
|
||||
# keep_module_alive=False means the module will be unloaded
|
||||
# after the returned module goes out of scope, this is useful
|
||||
# for frequent loading and unloading of modules. The only requirement
|
||||
# is that the module do not return object that have deleter in the module
|
||||
# and the returned object lives longer than the module.
|
||||
# DSL functions to not have such issue so it is desirable to set this to False.
|
||||
return tvm_ffi.load_module(file_path, keep_module_alive=False)
|
||||
except TypeError:
|
||||
# compatible with tvm-ffi < 0.1.6
|
||||
return tvm_ffi.load_module(file_path)
|
||||
else:
|
||||
raise DSLRuntimeError(
|
||||
"Unimplemented, please load the module with enable_tvm_ffi=True."
|
||||
|
||||
@@ -20,7 +20,7 @@ from cutlass.cutlass_dsl import (
|
||||
T,
|
||||
cutlass_arith,
|
||||
_binary_op_type_promote,
|
||||
CuTeDSL,
|
||||
BaseDSL,
|
||||
)
|
||||
from cutlass._mlir import ir
|
||||
import cutlass._mlir.dialects.cute as _cute_ir
|
||||
@@ -1776,7 +1776,7 @@ class TensorSSA(cutlass_arith.ArithValue):
|
||||
fast_cvt_func = cvt_i8_bf16_intrinsic
|
||||
elif src_dtype == Int4 and dtype == BFloat16:
|
||||
fast_cvt_func = cvt_i4_bf16_intrinsic
|
||||
arch = CuTeDSL._get_dsl().get_arch_enum()
|
||||
arch = BaseDSL._get_dsl().get_arch_enum()
|
||||
if fast_cvt_func is not None and arch in fast_cvt_func.supported_archs:
|
||||
res_vect = fast_cvt_func(src, size(self.shape), loc=loc, ip=ip)
|
||||
else:
|
||||
|
||||
@@ -407,7 +407,7 @@ def benchmark(
|
||||
To use CUDA graphs, the callable must be a compiled @cute.jit annotated function.
|
||||
When using CUDA graphs, the kernel must be launched in a non-default stream.
|
||||
|
||||
:param callable: The function to benchmark
|
||||
:param callable: The function to benchmark. For jit function, it must be compiled functions.
|
||||
:type callable: Callable
|
||||
:param warmup_iterations: Number of warmup iterations, defaults to 10
|
||||
:type warmup_iterations: int, optional
|
||||
@@ -475,15 +475,6 @@ def benchmark(
|
||||
elapsed_time = float("nan")
|
||||
|
||||
if use_cuda_graphs:
|
||||
# Check if the callable is a JitCompiledFunction or JitExecutor
|
||||
# These are functions that can be called to launch kernels
|
||||
compiled_types = (
|
||||
cutlass.base_dsl.jit_executor.JitCompiledFunction,
|
||||
cutlass.base_dsl.jit_executor.JitExecutor,
|
||||
)
|
||||
if not isinstance(callable, compiled_types):
|
||||
raise TypeError("Function must be precompiled to be used with CUDA Graphs")
|
||||
|
||||
# Check if the stream is a non-default stream
|
||||
if int(stream) == int(cuda_driver.CUstream_flags.CU_STREAM_DEFAULT):
|
||||
raise ValueError(
|
||||
|
||||
@@ -247,7 +247,10 @@ class CutlassBaseDSL(BaseDSL):
|
||||
return False
|
||||
|
||||
def _build_gpu_module(self, attrs, loc=None):
|
||||
log().info(f"self : {self}")
|
||||
log().info(f"Building GPU module for {self.name}")
|
||||
self.gpu_module = gpu.GPUModuleOp(ir.StringAttr.get("kernels"), loc=loc)
|
||||
log().info(f"GPU module: {self.gpu_module}")
|
||||
with ir.InsertionPoint(self.gpu_module.bodyRegion.blocks.append(*[])):
|
||||
pass
|
||||
|
||||
@@ -275,6 +278,9 @@ class CutlassBaseDSL(BaseDSL):
|
||||
return pipeline
|
||||
|
||||
def _enter_gpu_module(self):
|
||||
log().info(f"self: {self}")
|
||||
log().info(f"Entering GPU module for {self.name}")
|
||||
log().info(f"GPU module: {self.gpu_module}")
|
||||
return ir.InsertionPoint(self.gpu_module.bodyRegion.blocks[0])
|
||||
|
||||
def _generate_kernel_attrs(self, config: BaseDSL.LaunchConfig) -> dict:
|
||||
|
||||
@@ -126,16 +126,22 @@ class TVMFFICuteCallProvider(DynamicParamPackCallProvider):
|
||||
)
|
||||
context.module.body.append(parsed_op)
|
||||
|
||||
|
||||
with ir.InsertionPoint(current_block):
|
||||
cuda_global_state_ptr = self.address_of(
|
||||
self.cuda_global_state_symbol, self.ptr_type
|
||||
)
|
||||
cuda_init_ptr = self.address_of("cuda_init", self.ptr_type)
|
||||
cuda_load_to_device_ptr = self.address_of("cuda_load_to_device", self.ptr_type)
|
||||
set_error_ptr = self.address_of(
|
||||
"TVMFFIErrorSetRaisedFromCStr", self.ptr_type
|
||||
)
|
||||
|
||||
cuda_init_ptr = context.builder.get_or_load_global_func_ptr_from_text(
|
||||
current_block, "cuda_init"
|
||||
)
|
||||
cuda_load_to_device_ptr = context.builder.get_or_load_global_func_ptr_from_text(
|
||||
current_block, "cuda_load_to_device"
|
||||
)
|
||||
set_error_ptr = context.builder.get_or_load_global_func_ptr_from_text(
|
||||
current_block, "TVMFFIErrorSetRaisedFromCStr"
|
||||
)
|
||||
|
||||
with ir.InsertionPoint(current_block):
|
||||
# Call the callback function with the loaded ptr value
|
||||
init_result = llvm.call(
|
||||
result=self.i32_type, # function returns i32
|
||||
@@ -495,6 +501,13 @@ class TVMFFIJitCompiledFunctionBase(CudaDialectJitCompiledFunction):
|
||||
"""Create the tvm_ffi.Function from the current execution engine.
|
||||
"""
|
||||
if self.engine is not None:
|
||||
# trigger eager compile of init callbacks
|
||||
cuda_init = self.engine.raw_lookup("cuda_init")
|
||||
cuda_load_to_device = self.engine.raw_lookup("cuda_load_to_device")
|
||||
if cuda_init is None:
|
||||
raise DSLRuntimeError("cuda_init not found")
|
||||
if cuda_load_to_device is None:
|
||||
raise DSLRuntimeError("cuda_load_to_device not found")
|
||||
tvm_ffi_function_ptr = self.engine.raw_lookup(
|
||||
"__tvm_ffi_" + self.function_name
|
||||
)
|
||||
|
||||
@@ -261,7 +261,7 @@ def make_smem_layout_a(
|
||||
a_smem_layout_staged = cute.tile_to_shape(
|
||||
a_smem_layout_atom,
|
||||
cute.append(a_smem_shape, num_stages),
|
||||
order=(0, 1, 2) if is_k_major else (0, 1, 2),
|
||||
order=(0, 1, 2) if is_k_major else (1, 0, 2),
|
||||
loc=loc,
|
||||
ip=ip,
|
||||
)
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
# Use `pip install -r requirements.txt` with the present file to install a
|
||||
# wheel consistent with the present state of the github repository
|
||||
nvidia-cutlass-dsl==4.3.3
|
||||
nvidia-cutlass-dsl==4.3.4
|
||||
|
||||
Reference in New Issue
Block a user