mirror of
https://github.com/NVIDIA/cutlass.git
synced 2026-04-19 14:28:59 +00:00
v4.4 tag release update. (#3032)
This commit is contained in:
@@ -179,12 +179,12 @@ class Executor:
|
||||
def ifexp_execute(
|
||||
self,
|
||||
pred,
|
||||
generator_targets: tuple,
|
||||
block_args: tuple,
|
||||
then_block: Callable,
|
||||
else_block: Callable,
|
||||
):
|
||||
assert self._ifexp_dynamic, "Functions must be set before execution."
|
||||
return self._ifexp_dynamic(pred, generator_targets, then_block, else_block)
|
||||
return self._ifexp_dynamic(pred, block_args, then_block, else_block)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
@@ -309,16 +309,14 @@ def if_executor(
|
||||
def ifExp_executor(
|
||||
*,
|
||||
pred,
|
||||
generator_targets: tuple,
|
||||
block_args: tuple,
|
||||
then_block: Callable,
|
||||
else_block: Callable,
|
||||
):
|
||||
if not executor._is_dynamic_expression(pred):
|
||||
return (
|
||||
then_block(*generator_targets) if pred else else_block(*generator_targets)
|
||||
)
|
||||
return then_block(*block_args) if pred else else_block(*block_args)
|
||||
else:
|
||||
return executor.ifexp_execute(pred, generator_targets, then_block, else_block)
|
||||
return executor.ifexp_execute(pred, block_args, then_block, else_block)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
|
||||
@@ -45,6 +45,7 @@ from typing import List, Set, Dict, Any, Callable, Optional
|
||||
from types import ModuleType
|
||||
from collections import OrderedDict
|
||||
from copy import deepcopy
|
||||
from itertools import chain
|
||||
|
||||
from .common import *
|
||||
from .utils.logger import log
|
||||
@@ -302,6 +303,7 @@ class SessionData:
|
||||
import_top_module: bool = False
|
||||
region_stack: list[Region] = field(default_factory=list)
|
||||
generator_targets: list[str] = field(default_factory=list)
|
||||
lambda_args: list[str] = field(default_factory=list)
|
||||
|
||||
@contextlib.contextmanager
|
||||
def set_current_class_name(self, class_name: str):
|
||||
@@ -2063,6 +2065,19 @@ class DSLPreprocessor(ast.NodeTransformer):
|
||||
|
||||
return self._visit_Comprehension(node, key_value_visitor)
|
||||
|
||||
def visit_Lambda(self, node):
|
||||
current_lambda_args = len(self.session_data.lambda_args)
|
||||
for arg in node.args.args:
|
||||
self.session_data.lambda_args.append(arg.arg)
|
||||
|
||||
node.body = self.visit(node.body)
|
||||
|
||||
self.session_data.lambda_args = self.session_data.lambda_args[
|
||||
:current_lambda_args
|
||||
]
|
||||
|
||||
return node
|
||||
|
||||
def visit_ListComp(self, node):
|
||||
return self._visit_Comprehension(
|
||||
node, lambda n: setattr(n, "elt", self.visit(n.elt))
|
||||
@@ -2113,7 +2128,10 @@ class DSLPreprocessor(ast.NodeTransformer):
|
||||
posonlyargs=[],
|
||||
args=[
|
||||
ast.arg(arg=target, annotation=None)
|
||||
for target in self.session_data.generator_targets
|
||||
for target in chain(
|
||||
self.session_data.generator_targets,
|
||||
self.session_data.lambda_args,
|
||||
)
|
||||
],
|
||||
kwonlyargs=[],
|
||||
kw_defaults=[],
|
||||
@@ -2129,7 +2147,10 @@ class DSLPreprocessor(ast.NodeTransformer):
|
||||
posonlyargs=[],
|
||||
args=[
|
||||
ast.arg(arg=target, annotation=None)
|
||||
for target in self.session_data.generator_targets
|
||||
for target in chain(
|
||||
self.session_data.generator_targets,
|
||||
self.session_data.lambda_args,
|
||||
)
|
||||
],
|
||||
kwonlyargs=[],
|
||||
kw_defaults=[],
|
||||
@@ -2151,11 +2172,14 @@ class DSLPreprocessor(ast.NodeTransformer):
|
||||
keywords=[
|
||||
ast.keyword(arg="pred", value=self.visit(node.test)),
|
||||
ast.keyword(
|
||||
arg="generator_targets",
|
||||
arg="block_args",
|
||||
value=ast.Tuple(
|
||||
elts=[
|
||||
ast.Name(id=name, ctx=ast.Load())
|
||||
for name in self.session_data.generator_targets
|
||||
for name in chain(
|
||||
self.session_data.generator_targets,
|
||||
self.session_data.lambda_args,
|
||||
)
|
||||
],
|
||||
ctx=ast.Load(),
|
||||
),
|
||||
|
||||
@@ -50,7 +50,11 @@ from .jit_executor import JitCompiledFunction, JitFunctionArtifacts
|
||||
from .utils.timer import timer
|
||||
from .utils.logger import log
|
||||
from .utils.stacktrace import filter_exception, walk_to_top_module, filter_stackframe
|
||||
from .runtime.jit_arg_adapters import is_argument_constexpr, JitArgAdapterRegistry
|
||||
from .runtime.jit_arg_adapters import (
|
||||
is_argument_constexpr,
|
||||
is_arg_spec_constexpr,
|
||||
JitArgAdapterRegistry,
|
||||
)
|
||||
|
||||
from .ast_preprocessor import DSLPreprocessor
|
||||
from .common import *
|
||||
@@ -1310,8 +1314,7 @@ class BaseDSL(metaclass=DSLSingletonMeta):
|
||||
dynamic_args = []
|
||||
dynamic_kwargs = OrderedDict()
|
||||
for i, arg in enumerate(args):
|
||||
if not is_argument_constexpr(
|
||||
arg,
|
||||
if not is_arg_spec_constexpr(
|
||||
args_spec.annotations.get(args_spec.args[i], None),
|
||||
args_spec.args[i],
|
||||
i,
|
||||
@@ -1319,7 +1322,7 @@ class BaseDSL(metaclass=DSLSingletonMeta):
|
||||
):
|
||||
dynamic_args.append(arg)
|
||||
for i, (k, v) in enumerate(kwargs.items()):
|
||||
if not is_argument_constexpr(v, args_spec.kwonlyargs[i], k, i, funcBody):
|
||||
if not is_arg_spec_constexpr(args_spec.kwonlyargs[i], k, i, funcBody):
|
||||
dynamic_kwargs[k] = v
|
||||
return dynamic_args, dynamic_kwargs
|
||||
|
||||
|
||||
@@ -371,63 +371,6 @@ class MLIRBuilder(MLIRTypeBuilder):
|
||||
self.const_str_table[content] = symbol
|
||||
return symbol
|
||||
|
||||
def get_or_load_global_func_ptr_from_text(
|
||||
self,
|
||||
current_block: ir.Block,
|
||||
function_name: str,
|
||||
) -> ir.Value:
|
||||
"""Get or create a function pointer global in .text section and load it.
|
||||
|
||||
This creates a constant global function pointer in the .text section
|
||||
(for AArch64 ADRP range compatibility) and performs a volatile load
|
||||
to prevent optimization.
|
||||
|
||||
This forces the function pointer to be local to the code, bypassing GOT entry
|
||||
ADRP lookup issues on AArch64 when GOT and .text section are more than 4GB
|
||||
apart which can happen when ASLR is applied.
|
||||
"""
|
||||
# Check if we've already created this global
|
||||
if function_name not in self.const_func_ptr_table:
|
||||
symbol = f"__func_ptr_{function_name}"
|
||||
|
||||
module_body = self.module.body
|
||||
with ir.InsertionPoint(module_body):
|
||||
# 1. Create the global constant
|
||||
# We use 'private' linkage so it doesn't conflict across modules
|
||||
global_ptr = llvm.GlobalOp(
|
||||
self.ptr_type,
|
||||
symbol,
|
||||
ir.Attribute.parse("#llvm.linkage<private>"),
|
||||
# Initialization via block below
|
||||
)
|
||||
|
||||
# 2. Set the necessary attributes for JIT safety and AArch64 range
|
||||
# We use 'constant' to mark it as immutable
|
||||
# We use 'section = ".text"' to force it into the code block
|
||||
global_ptr.attributes["constant"] = ir.UnitAttr.get()
|
||||
global_ptr.attributes["section"] = ir.StringAttr.get(".text")
|
||||
|
||||
# 3. Add a constructor block to the GlobalOp to initialize it
|
||||
# with the address of the target function
|
||||
initializer_block = global_ptr.initializer.blocks.append()
|
||||
with ir.InsertionPoint(initializer_block):
|
||||
# Get the address of the external function
|
||||
func_addr = llvm.AddressOfOp(self.ptr_type, function_name).res
|
||||
# Return the address as the initial value of the global
|
||||
llvm.return_(arg=func_addr)
|
||||
|
||||
self.const_func_ptr_table[function_name] = symbol
|
||||
else:
|
||||
symbol = self.const_func_ptr_table[function_name]
|
||||
|
||||
# Load it with volatile semantics in the current block
|
||||
with ir.InsertionPoint(current_block):
|
||||
symbol_addr = self.address_of(symbol, self.ptr_type)
|
||||
# Perform a volatile load to prevent optimization
|
||||
load_op = llvm.load(self.ptr_type, symbol_addr)
|
||||
# Set volatile attribute to prevent optimization
|
||||
load_op.owner.attributes["volatile_"] = ir.UnitAttr.get()
|
||||
return load_op
|
||||
|
||||
# function
|
||||
def function(
|
||||
|
||||
@@ -29,8 +29,14 @@ from .core import (
|
||||
append_ones,
|
||||
group_modes,
|
||||
)
|
||||
from .atom import MmaAtom, CopyAtom, make_atom
|
||||
|
||||
from .atom import (
|
||||
MmaAtom,
|
||||
CopyAtom,
|
||||
make_atom,
|
||||
_normalize_variadic_tensor_operand,
|
||||
copy_atom_call,
|
||||
)
|
||||
from .nvgpu.common import CacheEvictionPriority
|
||||
|
||||
def _normalize_gemm_operand_list(
|
||||
x: Union["Tensor", List["Tensor"], Tuple["Tensor", ...]], name: str
|
||||
@@ -156,7 +162,7 @@ def basic_copy(src: Tensor, dst: Tensor, *, loc=None, ip=None) -> None:
|
||||
src.element_type.mlir_type, src.element_type.width
|
||||
)
|
||||
simt_copy = make_atom(simt_copy_ty, loc=loc, ip=ip)
|
||||
return _cute_ir.copy(simt_copy, src.value, dst.value, loc=loc, ip=ip)
|
||||
return _cute_ir.copy(simt_copy, [src.value], [dst.value], loc=loc, ip=ip)
|
||||
|
||||
s = size(dst, loc=loc, ip=ip)
|
||||
# Always generate an scf.for Op when one of the tensors is dynamic
|
||||
@@ -210,7 +216,14 @@ def _basic_copy_if_static(
|
||||
|
||||
|
||||
@dsl_user_op
|
||||
def autovec_copy(src: Tensor, dst: Tensor, *, loc=None, ip=None) -> None:
|
||||
def autovec_copy(
|
||||
src: Tensor,
|
||||
dst: Tensor,
|
||||
*,
|
||||
l1c_evict_priority: CacheEvictionPriority = CacheEvictionPriority.EVICT_NORMAL,
|
||||
loc=None,
|
||||
ip=None,
|
||||
) -> None:
|
||||
"""
|
||||
Auto-vectorization SIMT copy policy.
|
||||
|
||||
@@ -263,11 +276,15 @@ def autovec_copy(src: Tensor, dst: Tensor, *, loc=None, ip=None) -> None:
|
||||
|
||||
# Dispatch to copy with atom
|
||||
simt_type = _cute_nvgpu_ir.CopyAtomSIMTSyncCopyType.get(
|
||||
src.element_type.mlir_type, num_bits_per_copy
|
||||
src.element_type.mlir_type,
|
||||
num_bits_per_copy,
|
||||
0,
|
||||
0,
|
||||
l1c_evict_priority._to_ir(),
|
||||
)
|
||||
simt_copy = make_atom(simt_type, loc=loc, ip=ip)
|
||||
return _cute_ir.copy(
|
||||
simt_copy, tiled_src.value, tiled_dst.value, loc=loc, ip=ip
|
||||
simt_copy, [tiled_src.value], [tiled_dst.value], loc=loc, ip=ip
|
||||
)
|
||||
|
||||
# Failed to vectorize, use a basic copy
|
||||
@@ -331,8 +348,8 @@ def _parse_auto_multicast_args(
|
||||
@dsl_user_op
|
||||
def copy(
|
||||
atom: CopyAtom,
|
||||
src: Tensor,
|
||||
dst: Tensor,
|
||||
src: Union[Tensor, List[Tensor], Tuple[Tensor, ...]],
|
||||
dst: Union[Tensor, List[Tensor], Tuple[Tensor, ...]],
|
||||
*,
|
||||
pred: Optional[Tensor] = None,
|
||||
loc=None,
|
||||
@@ -343,10 +360,10 @@ def copy(
|
||||
|
||||
:param atom: Copy atom specifying the transfer operation
|
||||
:type atom: CopyAtom
|
||||
:param src: Source tensor with layout profile ``(V, Rest...)``
|
||||
:type src: Tensor
|
||||
:param dst: Destination tensor with layout profile ``(V, Rest...)``
|
||||
:type dst: Tensor
|
||||
:param src: Source tensor or list of tensors with layout profile ``(V, Rest...)``
|
||||
:type src: Union[Tensor, List[Tensor], Tuple[Tensor, ...]]
|
||||
:param dst: Destination tensor or list of tensors with layout profile ``(V, Rest...)``
|
||||
:type dst: Union[Tensor, List[Tensor], Tuple[Tensor, ...]]
|
||||
:param pred: Optional predication tensor for conditional transfers, defaults to None
|
||||
:type pred: Optional[Tensor], optional
|
||||
:param loc: Source location information, defaults to None
|
||||
@@ -374,6 +391,12 @@ def copy(
|
||||
Source and destination tensors must be partitioned in accordance with the Copy Atom specifications.
|
||||
Post-partitioning, both tensors will exhibit a ``(V, Rest...)`` layout profile.
|
||||
|
||||
The operands `src` and `dst` are variadic, each containing a variable number of tensors:
|
||||
|
||||
- For regular copy, `src` and `dst` contain single source and destination tensors respectively.
|
||||
- For copy with auxiliary operands, `src` and `dst` contain the primary tensors followed by
|
||||
their respective auxiliary tensors.
|
||||
|
||||
**Precondition:** The size of mode 1 must be equal for both source and destination tensors:
|
||||
``size(src, mode=[1]) == size(dst, mode=[1])``
|
||||
|
||||
@@ -399,41 +422,54 @@ def copy(
|
||||
for future releases.
|
||||
|
||||
"""
|
||||
if isinstance(src.type, _cute_ir.MemRefType) and isinstance(
|
||||
dst.type, _cute_ir.MemRefType
|
||||
# Normalize src/dst to lists for variadic IR operands
|
||||
src_list = _normalize_variadic_tensor_operand(src, "src")
|
||||
dst_list = _normalize_variadic_tensor_operand(dst, "dst")
|
||||
|
||||
# Validate primary tensors (first element)
|
||||
src_primary = src_list[0]
|
||||
dst_primary = dst_list[0]
|
||||
|
||||
if isinstance(src_primary.type, _cute_ir.MemRefType) and isinstance(
|
||||
dst_primary.type, _cute_ir.MemRefType
|
||||
):
|
||||
if src.element_type.width != dst.element_type.width:
|
||||
if src_primary.element_type.width != dst_primary.element_type.width:
|
||||
raise TypeError(
|
||||
"`copy` currently only supports equal source and destination "
|
||||
"element type bit width"
|
||||
)
|
||||
|
||||
if rank(src) != rank(dst):
|
||||
if rank(src_primary) != rank(dst_primary):
|
||||
raise ValueError(
|
||||
"Expected source and destination tensors to have the same rank, "
|
||||
f"but got {rank(src)} and {rank(dst)}"
|
||||
f"but got {rank(src_primary)} and {rank(dst_primary)}"
|
||||
)
|
||||
|
||||
# Canonicalize to at least rank-2 tensors
|
||||
src = group_modes(append_ones(src, up_to_rank=2), 1)
|
||||
dst = group_modes(append_ones(dst, up_to_rank=2), 1)
|
||||
# Canonicalize all tensors to at least rank-2
|
||||
src_list = [group_modes(append_ones(t, up_to_rank=2), 1) for t in src_list]
|
||||
dst_list = [group_modes(append_ones(t, up_to_rank=2), 1) for t in dst_list]
|
||||
if pred is not None:
|
||||
pred = group_modes(append_ones(pred, up_to_rank=2), 1)
|
||||
|
||||
if is_static(src.shape[1]) and is_static(dst.shape[1]):
|
||||
if size(src, mode=[1]) != size(dst, mode=[1]):
|
||||
# Recompute primary references after canonicalization
|
||||
src_primary = src_list[0]
|
||||
dst_primary = dst_list[0]
|
||||
|
||||
if is_static(src_primary.shape[1]) and is_static(dst_primary.shape[1]):
|
||||
if size(src_primary, mode=[1]) != size(dst_primary, mode=[1]):
|
||||
raise ValueError(
|
||||
"Expected source and destination tensors to have the same size in mode-1, "
|
||||
f"but got {size(src, mode=[1])} and {size(dst, mode=[1])}"
|
||||
f"but got {size(src_primary, mode=[1])} and {size(dst_primary, mode=[1])}"
|
||||
)
|
||||
|
||||
multicast_attr_pairs = _parse_auto_multicast_args(kwargs)
|
||||
|
||||
value = atom._unpack(loc=loc, ip=ip, **kwargs)
|
||||
if isinstance(pred, Tensor):
|
||||
pred = pred.value
|
||||
pred_value = pred.value if isinstance(pred, Tensor) else pred
|
||||
|
||||
op = _cute_ir.copy(value, src.value, dst.value, pred=pred, loc=loc, ip=ip)
|
||||
src_vals = [t.value for t in src_list]
|
||||
dst_vals = [t.value for t in dst_list]
|
||||
op = _cute_ir.copy(value, src_vals, dst_vals, pred=pred_value, loc=loc, ip=ip)
|
||||
|
||||
for name, attr in multicast_attr_pairs:
|
||||
op.attributes[name] = attr
|
||||
|
||||
@@ -13,7 +13,7 @@ from functools import partial
|
||||
from typing import Any, Optional, Tuple, Union, Callable, Literal
|
||||
from typing_extensions import deprecated
|
||||
|
||||
from cutlass.cutlass_dsl import T, dsl_user_op
|
||||
from cutlass.cutlass_dsl import T, dsl_user_op, target_version
|
||||
|
||||
import cutlass.cutlass_dsl as cutlass_dsl
|
||||
|
||||
@@ -2557,3 +2557,42 @@ def cvt_f4e2m1x8_to_f16x8(src_vec8, *, loc=None, ip=None):
|
||||
vec_f16x8_type = ir.VectorType.get([8], Float16.mlir_type, loc=loc)
|
||||
vec_f16x8 = llvm.bitcast(vec_f16x8_type, vec_f32x4, loc=loc, ip=ip)
|
||||
return vec_f16x8
|
||||
|
||||
|
||||
@dsl_user_op
|
||||
def mapa(ptr, cta_rank_in_cluster=0, *, loc=None, ip=None):
|
||||
"""
|
||||
Map a pointer to distributed shared memory across cluster.
|
||||
|
||||
Portable wrapper that uses the appropriate NVVM API based on CUDA version:
|
||||
- CUDA 13.1+: Uses nvvm.mapa with dsmem address space
|
||||
- CUDA 12.9: Uses nvvm.mapa_shared_cluster
|
||||
|
||||
Args:
|
||||
ptr: Pointer to shared memory (llvm_ptr attribute will be used)
|
||||
cta_rank_in_cluster: CTA rank within the cluster (default 0)
|
||||
|
||||
Returns:
|
||||
Mapped LLVM pointer to shared memory
|
||||
"""
|
||||
if target_version(min_version="13.1"):
|
||||
dsmem_ptr_ty = llvm.PointerType.get(7) # dsmem
|
||||
smem_ptr_ty = llvm.PointerType.get(3) # smem
|
||||
|
||||
llvm_ptr = nvvm.mapa(
|
||||
dsmem_ptr_ty,
|
||||
ptr.llvm_ptr,
|
||||
Int32(cta_rank_in_cluster).ir_value(loc=loc, ip=ip),
|
||||
loc=loc,
|
||||
ip=ip,
|
||||
)
|
||||
return llvm.addrspacecast(smem_ptr_ty, llvm_ptr, loc=loc, ip=ip)
|
||||
else:
|
||||
llvm_ptr = ptr.llvm_ptr
|
||||
return nvvm.mapa_shared_cluster(
|
||||
llvm_ptr.type,
|
||||
llvm_ptr,
|
||||
Int32(cta_rank_in_cluster).ir_value(loc=loc, ip=ip),
|
||||
loc=loc,
|
||||
ip=ip,
|
||||
)
|
||||
|
||||
@@ -10,7 +10,7 @@
|
||||
# is strictly prohibited.
|
||||
|
||||
from abc import ABC, ABCMeta, abstractmethod
|
||||
from typing import Type, Union, Optional, Any, overload
|
||||
from typing import Type, Union, Optional, Any, List, Tuple, overload
|
||||
|
||||
from .typing import Shape, Layout, Tile, Tensor, Numeric, Int32
|
||||
from .core import (
|
||||
@@ -1113,25 +1113,66 @@ def make_tiled_copy_C_atom(atom: CopyAtom, mma: TiledMma, *, loc=None, ip=None):
|
||||
return _make_tiled_copy(atom, layout_tv, tiler_mn, loc=loc, ip=ip)
|
||||
|
||||
|
||||
def _normalize_variadic_tensor_operand(
|
||||
x: Union["Tensor", List["Tensor"], Tuple["Tensor", ...]], name: str
|
||||
) -> List["Tensor"]:
|
||||
"""Normalize a Tensor or sequence of Tensors to a list of Tensors.
|
||||
|
||||
Helper function for operations with variadic operands.
|
||||
"""
|
||||
if isinstance(x, Tensor):
|
||||
return [x]
|
||||
if isinstance(x, (list, tuple)):
|
||||
if len(x) == 0:
|
||||
raise ValueError(f"`{name}` must contain at least one Tensor")
|
||||
if not all(isinstance(t, Tensor) for t in x):
|
||||
raise TypeError(f"All elements of `{name}` must be Tensor")
|
||||
return list(x) # type: ignore
|
||||
raise TypeError(f"`{name}` must be a Tensor or a sequence of Tensors")
|
||||
|
||||
|
||||
@dsl_user_op
|
||||
def copy_atom_call(
|
||||
atom: CopyAtom,
|
||||
src: Tensor,
|
||||
dst: Tensor,
|
||||
src: Union[Tensor, List[Tensor], Tuple[Tensor, ...]],
|
||||
dst: Union[Tensor, List[Tensor], Tuple[Tensor, ...]],
|
||||
*,
|
||||
pred: Optional[Tensor] = None,
|
||||
loc=None,
|
||||
ip=None,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
"""Executes a single copy atom operation between two tensors.
|
||||
"""
|
||||
Execute a single copy atom operation.
|
||||
|
||||
The copy_atom_call operation executes a copy atom with the given operands.
|
||||
Source and destination tensors have layout profile ``(V)``.
|
||||
|
||||
The ``V-mode`` represents either:
|
||||
|
||||
- A singular mode directly consumable by the provided Copy Atom
|
||||
- A composite mode requiring recursive decomposition, structured as ``(V, Rest...)``,
|
||||
|
||||
For src/dst layout like ``(V, Rest...)``, the layout profile of ``pred`` must match ``(Rest...)``.
|
||||
|
||||
- Certain Atoms may require additional operation-specific keyword arguments.
|
||||
- Current implementation limits ``V-mode`` rank to 2 or less. Support for higher ranks is planned
|
||||
for future releases.
|
||||
|
||||
Both ``src`` and ``dst`` operands are variadic, containing a variable number of tensors:
|
||||
|
||||
- For regular copy, ``src`` and ``dst`` each contain a single tensor.
|
||||
- For copy with auxiliary operands, they contain the main tensor followed by
|
||||
auxiliary tensors. For example:
|
||||
|
||||
:param atom: Copy atom specifying the transfer operation
|
||||
:type atom: CopyAtom
|
||||
:param src: Source tensor with layout profile ``(V)``
|
||||
:type src: Tensor
|
||||
:param dst: Destination tensor with layout profile ``(V)``
|
||||
:type dst: Tensor
|
||||
:param src: Source tensor(s) with layout profile ``(V)``. Can be a single Tensor
|
||||
or a list/tuple of Tensors for operations with auxiliary source operands.
|
||||
:type src: Union[Tensor, List[Tensor], Tuple[Tensor, ...]]
|
||||
:param dst: Destination tensor(s) with layout profile ``(V)``. Can be a single Tensor
|
||||
or a list/tuple of Tensors for operations with auxiliary destination operands.
|
||||
:type dst: Union[Tensor, List[Tensor], Tuple[Tensor, ...]]
|
||||
:param pred: Optional predication tensor for conditional transfers, defaults to None
|
||||
:type pred: Optional[Tensor], optional
|
||||
:param loc: Source location information, defaults to None
|
||||
@@ -1144,54 +1185,43 @@ def copy_atom_call(
|
||||
:return: None
|
||||
:rtype: None
|
||||
|
||||
The copy_atom_call operation executes a single copy atom with the given operands.
|
||||
Source and destination tensors with layout profile like ``(V)``.
|
||||
|
||||
The ``V-mode`` represents either:
|
||||
|
||||
- A singular mode directly consumable by the provided Copy Atom
|
||||
- A composite mode requiring recursive decomposition, structured as ``(V, Rest...)``,
|
||||
|
||||
For src/dst layout like ``(V, Rest...)``, the layout profile of ``pred`` must match ``(Rest...)``.
|
||||
|
||||
**Examples**:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
# Basic copy atom operation
|
||||
# Regular copy atom operation
|
||||
cute.copy_atom_call(copy_atom, src, dst)
|
||||
|
||||
# Predicated copy atom operation
|
||||
cute.copy_atom_call(copy_atom, src, dst, pred=pred)
|
||||
|
||||
.. note::
|
||||
|
||||
- Certain Atoms may require additional operation-specific keyword arguments.
|
||||
- Current implementation limits ``V-mode`` rank to 2 or less. Support for higher ranks is planned
|
||||
for future releases.
|
||||
|
||||
"""
|
||||
if isinstance(src.type, _cute_ir.MemRefType) and isinstance(
|
||||
dst.type, _cute_ir.MemRefType
|
||||
# Normalize src/dst to lists for variadic IR operands, while keeping old API working.
|
||||
src_list = _normalize_variadic_tensor_operand(src, "src")
|
||||
dst_list = _normalize_variadic_tensor_operand(dst, "dst")
|
||||
|
||||
# Validate first src/dst for element type width check
|
||||
if isinstance(src_list[0].type, _cute_ir.MemRefType) and isinstance(
|
||||
dst_list[0].type, _cute_ir.MemRefType
|
||||
):
|
||||
if src.element_type.width != dst.element_type.width:
|
||||
if src_list[0].element_type.width != dst_list[0].element_type.width:
|
||||
raise TypeError(
|
||||
"`copy_atom_call` currently only supports equal source and destination "
|
||||
"element type bit width"
|
||||
)
|
||||
|
||||
if rank(src, mode=[0]) > 2 or rank(dst, mode=[0]) > 2:
|
||||
if rank(src_list[0], mode=[0]) > 2 or rank(dst_list[0], mode=[0]) > 2:
|
||||
raise NotImplementedError(
|
||||
"V-mode (mode-0) with rank > 2 is not supported yet, "
|
||||
f"but got rank(src, mode=[0]) = {rank(src, mode=[0])} and rank(dst, mode=[0]) = {rank(dst, mode=[0])}"
|
||||
f"but got rank(src, mode=[0]) = {rank(src_list[0], mode=[0])} and rank(dst, mode=[0]) = {rank(dst_list[0], mode=[0])}"
|
||||
)
|
||||
|
||||
value = atom._unpack(loc=loc, ip=ip, **kwargs)
|
||||
if isinstance(pred, Tensor):
|
||||
pred = pred.value
|
||||
return _cute_ir.copy_atom_call(
|
||||
value, src.value, dst.value, pred=pred, loc=loc, ip=ip
|
||||
)
|
||||
src_vals = [t.value for t in src_list]
|
||||
dst_vals = [t.value for t in dst_list]
|
||||
return _cute_ir.copy_atom_call(value, src_vals, dst_vals, pred=pred, loc=loc, ip=ip)
|
||||
|
||||
|
||||
@dsl_user_op
|
||||
|
||||
@@ -51,4 +51,7 @@
|
||||
|
||||
- `get_cta_v_map_ab` — Compute CTA-V map for A/B operands
|
||||
- `get_cta_v_map_c` — Compute CTA-V map for C operand
|
||||
- `make_tmem_layout_acc` — Derive TMEM accumulator buffer layout from a tiled MMA
|
||||
- `make_tmem_layout_a` — Derive TMEM A-operand buffer layout from a tiled MMA
|
||||
- `make_t2r_rmem_layout` — Derive per-thread RMEM buffer layout for the T2R epilogue copy
|
||||
|
||||
|
||||
@@ -70,10 +70,93 @@ def get_cta_v_map_c(
|
||||
|
||||
:param gmem_tensor: Global-memory tensor being stored/loaded by TMA.
|
||||
:type gmem_tensor: cute.Tensor
|
||||
:param epi_tile: Epilogue tile layout describing the CTA’s output tile shape.
|
||||
:param epi_tile: Epilogue tile layout describing the CTA's output tile shape.
|
||||
:type epi_tile: cute.Layout
|
||||
:returns: A layout suitable to pass as `cta_v_map=...` to `tma_store` / `tma_load`.
|
||||
:rtype: cute.Layout
|
||||
"""
|
||||
ident = cute.core.make_identity_layout(gmem_tensor.shape, loc=loc, ip=ip)
|
||||
return cute.core.composition(ident, epi_tile, loc=loc, ip=ip)
|
||||
|
||||
|
||||
def make_tmem_layout_acc(
|
||||
tiled_mma,
|
||||
mnk_tiler,
|
||||
acc_stage,
|
||||
*,
|
||||
loc=None,
|
||||
ip=None,
|
||||
):
|
||||
"""Return TMEM accumulator buffer layout for a tiled MMA.
|
||||
|
||||
This is a small helper around ``tiled_mma.make_fragment_C(...).layout`` to
|
||||
keep example code fragment-free at the call site.
|
||||
|
||||
:param tiled_mma: The MMA tiler (``cute.TiledMma``).
|
||||
:type tiled_mma: cute.TiledMma
|
||||
:param mnk_tiler: Full MNK tiler; only the MN components are used for C.
|
||||
:type mnk_tiler: tuple
|
||||
:param acc_stage: Accumulator pipeline stages.
|
||||
:param loc: Optional location for DSL ops.
|
||||
:param ip: Optional insertion point for DSL ops.
|
||||
:return: Layout for the accumulator TMEM buffer.
|
||||
:rtype: cute.Layout
|
||||
"""
|
||||
acc_shape = tiled_mma.partition_shape_C(mnk_tiler[:2], loc=loc, ip=ip)
|
||||
acc_shape_staged = cute.append(acc_shape, acc_stage, loc=loc, ip=ip)
|
||||
return tiled_mma.make_fragment_C(acc_shape_staged, loc=loc, ip=ip).layout
|
||||
|
||||
|
||||
def make_tmem_layout_a(
|
||||
tiled_mma,
|
||||
mk_tiler,
|
||||
stage,
|
||||
*,
|
||||
loc=None,
|
||||
ip=None,
|
||||
):
|
||||
"""Return TMEM A operand buffer layout for a tiled MMA.
|
||||
|
||||
:param tiled_mma: The MMA tiler (``cute.TiledMma``).
|
||||
:type tiled_mma: cute.TiledMma
|
||||
:param mk_tiler: MK tiler used to shape the A operand.
|
||||
:type mk_tiler: tuple
|
||||
:param stage: Pipeline stages for the A operand buffer.
|
||||
:param loc: Optional location for DSL ops.
|
||||
:param ip: Optional insertion point for DSL ops.
|
||||
:return: Layout for the A operand TMEM buffer.
|
||||
:rtype: cute.Layout
|
||||
"""
|
||||
a_shape = tiled_mma.partition_shape_A(mk_tiler, loc=loc, ip=ip)
|
||||
a_shape_staged = cute.append(a_shape, stage, loc=loc, ip=ip)
|
||||
return tiled_mma.make_fragment_A(a_shape_staged, loc=loc, ip=ip).layout
|
||||
|
||||
|
||||
def make_t2r_rmem_layout(
|
||||
tiled_copy_t2r,
|
||||
gC_mnl_epi,
|
||||
tidx,
|
||||
*,
|
||||
loc=None,
|
||||
ip=None,
|
||||
):
|
||||
"""Return RMEM buffer layout for the T2R epilogue destination.
|
||||
|
||||
Computes the per-thread RMEM buffer layout produced by a TMEM->RMEM copy
|
||||
for a single epilogue iteration.
|
||||
|
||||
:param tiled_copy_t2r: The TMEM->RMEM tiled copy op (``cute.TiledCopy``).
|
||||
:type tiled_copy_t2r: cute.TiledCopy
|
||||
:param gC_mnl_epi: Global C tensor partitioned by epilogue tile.
|
||||
:type gC_mnl_epi: cute.Tensor
|
||||
:param tidx: Thread index for the copy slice.
|
||||
:param loc: Optional location for DSL ops.
|
||||
:param ip: Optional insertion point for DSL ops.
|
||||
:return: Layout for the RMEM buffer.
|
||||
:rtype: cute.Layout
|
||||
"""
|
||||
thr_copy_t2r = tiled_copy_t2r.get_slice(tidx)
|
||||
tTR_gC = thr_copy_t2r.partition_D(gC_mnl_epi, loc=loc, ip=ip)
|
||||
return cute.make_fragment_like(
|
||||
tTR_gC[(None, None, None, 0, 0)].layout, loc=loc, ip=ip
|
||||
)
|
||||
|
||||
@@ -19,6 +19,7 @@ __all__ = [
|
||||
# copy.py
|
||||
#
|
||||
"Repetition",
|
||||
"TmemLoadRedOp",
|
||||
"Pack",
|
||||
"Unpack",
|
||||
"Ld16x64bOp",
|
||||
|
||||
@@ -26,6 +26,22 @@ from ...typing import Numeric
|
||||
from .mma import CtaGroup
|
||||
|
||||
|
||||
class TmemLoadRedOp(enum.Enum):
|
||||
"""
|
||||
An enumeration for the possible reduce operations for TMEM load operations.
|
||||
"""
|
||||
|
||||
MAX = _cute_nvgpu_ir.TmemLoadRedOp.max
|
||||
MAXABS = _cute_nvgpu_ir.TmemLoadRedOp.maxabs
|
||||
MIN = _cute_nvgpu_ir.TmemLoadRedOp.min
|
||||
MINABS = _cute_nvgpu_ir.TmemLoadRedOp.minabs
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"{self.__class__.__name__}.{self.name}"
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<{self.__class__.__name__}.{self.name}>"
|
||||
|
||||
class Repetition(enum.Enum):
|
||||
"""
|
||||
An enumeration for the number of repetitions of a given TMEM copy within the instruction.
|
||||
@@ -390,6 +406,97 @@ class Ld32x32bTrait(Trait):
|
||||
pass
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class LdRed16x32bx2Op(_LdBase):
|
||||
"""
|
||||
16x32bx2 TMEM load Reduce Operation.
|
||||
|
||||
See the `PTX documentation <https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-instructions-tcgen05-ld>`__.
|
||||
This Operation corresponds to the ``.red`` and ``.16x32bx2`` qualifiers.
|
||||
"""
|
||||
|
||||
redOp: TmemLoadRedOp = TmemLoadRedOp.MAX
|
||||
nan: bool = False
|
||||
half_split_off: int = 0
|
||||
|
||||
def _make_trait(
|
||||
self, copy_internal_type: Type[Numeric], *, loc=None, ip=None, **kwargs
|
||||
) -> "LdRed16x32bx2Trait":
|
||||
"""
|
||||
Create a trait object for the 16x32bx2 TMEM load Reduce operation.
|
||||
|
||||
:param copy_internal_type: The data type for the copy operation
|
||||
:type copy_internal_type: Type[Numeric]
|
||||
:param loc: MLIR location information for debugging, defaults to None
|
||||
:type loc: optional
|
||||
:param ip: MLIR insertion point for code generation, defaults to None
|
||||
:type ip: optional
|
||||
:param kwargs: Additional keyword arguments
|
||||
:type kwargs: dict
|
||||
:return: A trait object for this load operation
|
||||
:rtype: LdRed16x32bx2Trait
|
||||
"""
|
||||
ty = _cute_nvgpu_ir.CopyAtomSM10xTmemLoadRedType.get(
|
||||
copy_internal_type.mlir_type,
|
||||
16,
|
||||
32,
|
||||
self.repeat.value,
|
||||
self.redOp.value,
|
||||
ir.UnitAttr.get() if self.nan else None,
|
||||
ir.IntegerAttr.get(ir.IntegerType.get_signless(32), self.half_split_off),
|
||||
)
|
||||
return LdRed16x32bx2Trait(make_atom(ty, loc=loc, ip=ip))
|
||||
|
||||
|
||||
class LdRed16x32bx2Trait(Trait):
|
||||
pass
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class LdRed32x32bOp(_LdBase):
|
||||
"""
|
||||
32x32b TMEM load Reduce Operation.
|
||||
|
||||
See the `PTX documentation <https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-instructions-tcgen05-ld>`__.
|
||||
This Operation corresponds to the ``red`` and ``.32x32`` qualifiers.
|
||||
"""
|
||||
|
||||
redOp: TmemLoadRedOp = TmemLoadRedOp.MAX
|
||||
nan: bool = False
|
||||
|
||||
def _make_trait(
|
||||
self, copy_internal_type: Type[Numeric], *, loc=None, ip=None, **kwargs
|
||||
) -> "LdRed32x32bTrait":
|
||||
"""
|
||||
Create a trait object for the 32x32b TMEM load Reduce operation.
|
||||
|
||||
:param copy_internal_type: The data type for the copy operation
|
||||
:type copy_internal_type: Type[Numeric]
|
||||
:param loc: MLIR location information for debugging, defaults to None
|
||||
:type loc: optional
|
||||
:param ip: MLIR insertion point for code generation, defaults to None
|
||||
:type ip: optional
|
||||
:param kwargs: Additional keyword arguments
|
||||
:type kwargs: dict
|
||||
:return: A trait object for this load operation
|
||||
:rtype: LdRed32x32bTrait
|
||||
"""
|
||||
ty = _cute_nvgpu_ir.CopyAtomSM10xTmemLoadRedType.get(
|
||||
copy_internal_type.mlir_type,
|
||||
32,
|
||||
32,
|
||||
self.repeat.value,
|
||||
self.redOp.value,
|
||||
ir.UnitAttr.get() if self.nan else None,
|
||||
None,
|
||||
)
|
||||
return LdRed32x32bTrait(make_atom(ty, loc=loc, ip=ip))
|
||||
|
||||
|
||||
class LdRed32x32bTrait(Trait):
|
||||
pass
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class _StBase(CopyOp):
|
||||
"""
|
||||
|
||||
@@ -103,15 +103,20 @@ class LdMatrix8x16x8bOp(BaseOp):
|
||||
self,
|
||||
"expects the 'num_matrices' Op parameter to be one of [1,2,4]",
|
||||
)
|
||||
if self.unpack_bits not in [4, 6]:
|
||||
raise OpError(self, "Op unpack bits must be 4 or 6")
|
||||
if self.unpack_bits not in [None, 4, 6]:
|
||||
raise OpError(self, "Op unpack bits must be 4 or 6 or None")
|
||||
|
||||
def _make_trait(
|
||||
self, copy_internal_type: Type[Numeric], *, loc=None, ip=None, **kwargs
|
||||
) -> "LdMatrix8x16x8bTrait":
|
||||
mode = _pack_shape((8, 16), loc=loc, ip=ip)
|
||||
sz_pattern = _cute_nvgpu_ir.LdsmSzPattern.u4x16p64to8
|
||||
if self.unpack_bits == 6:
|
||||
# LdMatrix8x16x8b without unpacking doesn't exist
|
||||
# but is equivalent to LdMatrix8x8x16b
|
||||
mode_n = 8 if self.unpack_bits is None else 16
|
||||
mode = _pack_shape((8, mode_n), loc=loc, ip=ip)
|
||||
sz_pattern = _cute_nvgpu_ir.LdsmSzPattern.u16
|
||||
if self.unpack_bits == 4:
|
||||
sz_pattern = _cute_nvgpu_ir.LdsmSzPattern.u4x16p64to8
|
||||
elif self.unpack_bits == 6:
|
||||
sz_pattern = _cute_nvgpu_ir.LdsmSzPattern.u6x16p32to8
|
||||
ty = _cute_nvgpu_ir.CopyAtomLdsmType.get(
|
||||
copy_internal_type.mlir_type,
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2025 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
|
||||
#
|
||||
# Use of this software is governed by the terms and conditions of the
|
||||
@@ -141,6 +141,11 @@ class _Tensor(Tensor):
|
||||
# If tensor is already a DLPack object, use it directly
|
||||
if hasattr(tensor, "__dlpack_device__") and not hasattr(tensor, "__dlpack__"):
|
||||
self._dlpack_data = tensor.__dlpack_device__()
|
||||
elif enable_tvm_ffi:
|
||||
import tvm_ffi
|
||||
|
||||
self._tvm_ffi_tensor = tvm_ffi.from_dlpack(tensor)
|
||||
self._dlpack_data = self._tvm_ffi_tensor.__dlpack__()
|
||||
else:
|
||||
try:
|
||||
# we expect no stream sync. Because torch has different default behavior
|
||||
@@ -149,11 +154,6 @@ class _Tensor(Tensor):
|
||||
self._dlpack_data = tensor.__dlpack__(stream=-1)
|
||||
except Exception:
|
||||
self._dlpack_data = tensor.__dlpack__()
|
||||
if enable_tvm_ffi:
|
||||
import tvm_ffi
|
||||
|
||||
self._tvm_ffi_tensor = tvm_ffi.from_dlpack(tensor)
|
||||
self._dlpack_data = self._tvm_ffi_tensor.__dlpack__()
|
||||
|
||||
self._dltensor_wrapper = None
|
||||
self._assumed_align = assumed_align
|
||||
|
||||
@@ -647,7 +647,7 @@ def _while_execute_dynamic(
|
||||
|
||||
def _ifexp_execute_dynamic(
|
||||
pred: "ir.Value",
|
||||
generator_targets: tuple,
|
||||
block_args: tuple,
|
||||
then_block: Callable,
|
||||
else_block: Callable,
|
||||
):
|
||||
@@ -663,8 +663,8 @@ def _ifexp_execute_dynamic(
|
||||
----------
|
||||
pred : ir.Value
|
||||
The predicate value (a boolean IR value) that determines which branch is executed.
|
||||
generator_targets : tuple
|
||||
The generator targets that are passed to the then and else blocks.
|
||||
block_args : tuple
|
||||
The block arguments that are passed to the then and else blocks.
|
||||
then_block : Callable
|
||||
A Python function that executes the 'then' branch and returns the result(s). This will be
|
||||
executed if `pred` evaluates to True.
|
||||
@@ -698,13 +698,13 @@ def _ifexp_execute_dynamic(
|
||||
with ir.InsertionPoint(execution_region.region.blocks[0]):
|
||||
# Call the then block and unpack its results to IR values and tree structure
|
||||
then_results = ScfGenerator._normalize_region_result_to_list(
|
||||
then_block(*generator_targets)
|
||||
then_block(*block_args)
|
||||
)
|
||||
ir_values, then_tree = cutlass_dsl.unpack_to_irvalue(then_results, "ifexp", 0)
|
||||
|
||||
# Call the else block and unpack its results to IR values and tree structure
|
||||
else_results = ScfGenerator._normalize_region_result_to_list(
|
||||
else_block(*generator_targets)
|
||||
else_block(*block_args)
|
||||
)
|
||||
_, else_tree = cutlass_dsl.unpack_to_irvalue(else_results, "ifexp", 0)
|
||||
|
||||
@@ -739,11 +739,11 @@ def _ifexp_execute_dynamic(
|
||||
# SCF region builder for then block
|
||||
def then_builder(*args):
|
||||
# Just call the then_block as no arguments are passed to it
|
||||
return then_block(*generator_targets)
|
||||
return then_block(*block_args)
|
||||
|
||||
# SCF region builder for else block
|
||||
def else_builder(*args):
|
||||
return else_block(*generator_targets)
|
||||
return else_block(*block_args)
|
||||
|
||||
# Prepare the list of region builders for the SCF IfOp: first for "then", then for "else"
|
||||
region_builders = [then_builder, else_builder]
|
||||
|
||||
@@ -129,16 +129,13 @@ class TVMFFICuteCallProvider(DynamicParamPackCallProvider):
|
||||
cuda_global_state_ptr = self.address_of(
|
||||
self.cuda_global_state_symbol, self.ptr_type
|
||||
)
|
||||
|
||||
cuda_init_ptr = context.builder.get_or_load_global_func_ptr_from_text(
|
||||
current_block, "cuda_init"
|
||||
)
|
||||
cuda_load_to_device_ptr = context.builder.get_or_load_global_func_ptr_from_text(
|
||||
current_block, "cuda_load_to_device"
|
||||
)
|
||||
set_error_ptr = context.builder.get_or_load_global_func_ptr_from_text(
|
||||
current_block, "TVMFFIErrorSetRaisedFromCStr"
|
||||
)
|
||||
cuda_init_ptr = self.address_of("cuda_init", self.ptr_type)
|
||||
cuda_load_to_device_ptr = self.address_of(
|
||||
"cuda_load_to_device", self.ptr_type
|
||||
)
|
||||
set_error_ptr = self.address_of(
|
||||
"TVMFFIErrorSetRaisedFromCStr", self.ptr_type
|
||||
)
|
||||
|
||||
with ir.InsertionPoint(current_block):
|
||||
# Call the callback function with the loaded ptr value
|
||||
|
||||
@@ -129,7 +129,7 @@ class ClcDynamicPersistentTileScheduler:
|
||||
:param num_tiles_executed: Counter for executed tiles.
|
||||
:type num_tiles_executed: Int32
|
||||
:param clc_response_ptr: Pointer of the clc rsponse.
|
||||
:type clc_response_ptr: Tuple[Integer, Integer, Integer, Integer]
|
||||
:type clc_response_ptr: cute.Pointer
|
||||
:param block_idx: The block index.
|
||||
:type block_idx: Tuple[Integer, Integer, Integer]
|
||||
"""
|
||||
@@ -238,7 +238,7 @@ class ClcDynamicPersistentTileScheduler:
|
||||
|
||||
@dsl_user_op
|
||||
def work_tile_info_from_clc_response(
|
||||
self, result_addr: Int32, *, loc=None, ip=None
|
||||
self, result_addr: cute.Pointer, *, loc=None, ip=None
|
||||
) -> WorkTileInfo:
|
||||
"""
|
||||
Simulates parsing CLC response data in Python.
|
||||
|
||||
@@ -293,6 +293,7 @@ def epilogue(
|
||||
acc_pipeline: pipeline.PipelineAsync,
|
||||
tCcC_base: cute.Tensor = None,
|
||||
mC_mnl: cute.Tensor = None,
|
||||
overlapping_accum: Constexpr = False,
|
||||
) -> pipeline.PipelineState:
|
||||
"""
|
||||
Epilogue function that stores accumulator results directly to global memory.
|
||||
@@ -310,12 +311,18 @@ def epilogue(
|
||||
:type epi_tile: cute.Tile
|
||||
:param epilogue_op: Optional elementwise operation to apply
|
||||
:type epilogue_op: Constexpr
|
||||
:param alignment_bytes: Alignment bytes for global memory store
|
||||
:type alignment_bytes: int
|
||||
:param mma_tile_coord_mnl: MMA tile coordinates (M, N, L)
|
||||
:type mma_tile_coord_mnl: Tuple[Int32, Int32, Int32]
|
||||
:param acc_consumer_state: Accumulator consumer pipeline state
|
||||
:type acc_consumer_state: pipeline.PipelineState
|
||||
:param acc_pipeline: Accumulator pipeline for async operations
|
||||
:type acc_pipeline: pipeline.PipelineAsync
|
||||
:param tCcC_base: Identity/coordinate tensor C
|
||||
:type tCcC_base: cute.Tensor
|
||||
:param mC_mnl: Global memory tensor C (full tensor for predicate computation)
|
||||
:type mC_mnl: cute.Tensor
|
||||
:param overlapping_accum: Whether to use overlapping accumulator
|
||||
:type overlapping_accum: Constexpr
|
||||
"""
|
||||
|
||||
# Layout transformation for tCgC_base
|
||||
@@ -399,9 +406,16 @@ def epilogue(
|
||||
]
|
||||
tTR_cC = cute.group_modes(tTR_cC, 3, cute.rank(tTR_cC))
|
||||
|
||||
# Get accumulator stage index
|
||||
if const_expr(overlapping_accum):
|
||||
acc_stage_index = acc_consumer_state.phase
|
||||
reverse_subtile = acc_stage_index == 0
|
||||
else:
|
||||
acc_stage_index = acc_consumer_state.index
|
||||
|
||||
# Set tensor memory buffer for current tile
|
||||
# (T2R, T2R_M, T2R_N, EPI_M, EPI_M)
|
||||
tTR_tAcc = tTR_tAcc_base[(None, None, None, None, None, acc_consumer_state.index)]
|
||||
tTR_tAcc = tTR_tAcc_base[(None, None, None, None, None, acc_stage_index)]
|
||||
|
||||
#
|
||||
# Wait for accumulator buffer full
|
||||
@@ -415,21 +429,38 @@ def epilogue(
|
||||
#
|
||||
subtile_cnt = cute.size(tTR_tAcc.shape, mode=[3])
|
||||
for subtile_idx in range(subtile_cnt):
|
||||
# Compute the actual subtile index
|
||||
real_subtile_idx = subtile_idx
|
||||
if const_expr(overlapping_accum):
|
||||
if reverse_subtile:
|
||||
real_subtile_idx = subtile_cnt - 1 - subtile_idx
|
||||
#
|
||||
# Get the destination and coordinate slices for this subtile
|
||||
#
|
||||
tTR_gC_subtile = tTR_gC[(None, None, None, subtile_idx)]
|
||||
tTR_gC_subtile = tTR_gC[(None, None, None, real_subtile_idx)]
|
||||
#
|
||||
# Load accumulator from tensor memory buffer to register
|
||||
#
|
||||
tTR_tAcc_mn = tTR_tAcc[(None, None, None, subtile_idx)]
|
||||
tTR_tAcc_mn = tTR_tAcc[(None, None, None, real_subtile_idx)]
|
||||
cute.copy(tiled_copy_t2r, tTR_tAcc_mn, tTR_rAcc)
|
||||
|
||||
#
|
||||
# Async arrive accumulator buffer empty
|
||||
# Release early for perf
|
||||
if subtile_idx == subtile_cnt - 1:
|
||||
with cute.arch.elect_one():
|
||||
acc_pipeline.consumer_release(acc_consumer_state)
|
||||
acc_consumer_state.advance()
|
||||
#
|
||||
if const_expr(overlapping_accum):
|
||||
# Early release when overlapping: release after processing the
|
||||
# overlapping region (SF columns) so they can be reused
|
||||
if subtile_idx == gemm_kernel.iter_acc_early_release_in_epilogue:
|
||||
cute.arch.fence_view_async_tmem_load()
|
||||
with cute.arch.elect_one():
|
||||
acc_pipeline.consumer_release(acc_consumer_state)
|
||||
acc_consumer_state.advance()
|
||||
else:
|
||||
# Release early for perf at the last subtile
|
||||
if subtile_idx == subtile_cnt - 1:
|
||||
with cute.arch.elect_one():
|
||||
acc_pipeline.consumer_release(acc_consumer_state)
|
||||
acc_consumer_state.advance()
|
||||
|
||||
#
|
||||
# Convert to C type
|
||||
@@ -440,7 +471,7 @@ def epilogue(
|
||||
|
||||
if const_expr(use_predication):
|
||||
# compute predicate
|
||||
tTR_cC_subtile = tTR_cC[(None, None, None, subtile_idx)]
|
||||
tTR_cC_subtile = tTR_cC[(None, None, None, real_subtile_idx)]
|
||||
pred_C_shape = (1, *tTR_cC_subtile.shape[1:])
|
||||
pred_C = cute.make_rmem_tensor(pred_C_shape, Boolean)
|
||||
for m_idx in range(tTR_cC_subtile.shape[1]):
|
||||
|
||||
@@ -41,7 +41,23 @@ class HardwareInfo:
|
||||
self.driver_version = self._checkCudaErrors(driver.cuDriverGetVersion())
|
||||
|
||||
# Getting the max active clusters for a given cluster size
|
||||
def get_max_active_clusters(self, cluster_size: int) -> int:
|
||||
def get_max_active_clusters(
|
||||
self, cluster_size: int, stream: driver.CUstream = None
|
||||
) -> int:
|
||||
"""
|
||||
Get the maximum number of active clusters for a given cluster size.
|
||||
|
||||
When a stream from a green context is provided, the occupancy calculation
|
||||
will reflect the reduced SM partition of the green context.
|
||||
|
||||
:param cluster_size: Number of blocks per cluster (must be between 1 and 32)
|
||||
:type cluster_size: int
|
||||
:param stream: Optional CUDA stream handle. If provided (especially from a green context),
|
||||
the occupancy calculation reflects the stream's SM partition.
|
||||
:type stream: driver.CUstream, optional
|
||||
:return: Maximum number of active clusters
|
||||
:rtype: int
|
||||
"""
|
||||
if self._cuda_driver_version_lt(11, 8):
|
||||
raise RuntimeError(
|
||||
"CUDA Driver version < 11.8, cannot get _max_active_clusters"
|
||||
@@ -94,6 +110,13 @@ class HardwareInfo:
|
||||
launch_config.blockDimY = 1
|
||||
launch_config.blockDimZ = 1
|
||||
launch_config.sharedMemBytes = max_dynamic_shared_memory
|
||||
|
||||
# IMPORTANT: Set the stream for green context support
|
||||
# When hStream is set, cuOccupancyMaxActiveClusters will use the context
|
||||
# associated with that stream, which includes the green context's SM partition
|
||||
if stream is not None:
|
||||
launch_config.hStream = stream
|
||||
|
||||
launch_config.numAttrs = 1
|
||||
# max possible cluster size is 32
|
||||
cluster_dims_attr = driver.CUlaunchAttribute()
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2025 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
|
||||
#
|
||||
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
|
||||
|
||||
3
python/CuTeDSL/requirements-cu13.txt
Normal file
3
python/CuTeDSL/requirements-cu13.txt
Normal file
@@ -0,0 +1,3 @@
|
||||
# Use `pip install -r requirements-cu13.txt` with the present file to install a
|
||||
# wheel consistent with the present state of the github repository
|
||||
nvidia-cutlass-dsl[cu13]==4.4.0
|
||||
@@ -1,3 +1,3 @@
|
||||
# Use `pip install -r requirements.txt` with the present file to install a
|
||||
# wheel consistent with the present state of the github repository
|
||||
nvidia-cutlass-dsl==4.4.0.dev1
|
||||
nvidia-cutlass-dsl==4.4.0
|
||||
|
||||
85
python/CuTeDSL/setup.sh
Executable file
85
python/CuTeDSL/setup.sh
Executable file
@@ -0,0 +1,85 @@
|
||||
#!/bin/bash
|
||||
|
||||
#################################################################################################
|
||||
#
|
||||
# Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
#################################################################################################
|
||||
|
||||
set -e
|
||||
|
||||
# Get the directory where this script is located
|
||||
SCRIPT_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )"
|
||||
|
||||
# Default to requirements.txt
|
||||
REQUIREMENTS_FILE="requirements.txt"
|
||||
|
||||
# Parse command line arguments
|
||||
if [ $# -gt 0 ]; then
|
||||
case "$1" in
|
||||
--cu12)
|
||||
REQUIREMENTS_FILE="requirements.txt"
|
||||
echo "Installing CUDA 12 requirements..."
|
||||
;;
|
||||
--cu13)
|
||||
REQUIREMENTS_FILE="requirements-cu13.txt"
|
||||
echo "Installing CUDA 13 requirements..."
|
||||
;;
|
||||
--help|-h)
|
||||
echo "Usage: $0 [--cu12|--cu13]"
|
||||
echo " --cu12 Install requirements for CUDA 12 (default)"
|
||||
echo " --cu13 Install requirements for CUDA 13"
|
||||
exit 0
|
||||
;;
|
||||
*)
|
||||
echo "Error: Unknown argument '$1'"
|
||||
echo "Usage: $0 [--cu12|--cu13]"
|
||||
exit 1
|
||||
;;
|
||||
esac
|
||||
else
|
||||
echo "Installing default requirements (CUDA 12)..."
|
||||
fi
|
||||
|
||||
# Check if requirements file exists
|
||||
REQUIREMENTS_PATH="${SCRIPT_DIR}/${REQUIREMENTS_FILE}"
|
||||
if [ ! -f "$REQUIREMENTS_PATH" ]; then
|
||||
echo "Error: Requirements file not found: $REQUIREMENTS_PATH"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Uninstall previous version of the CUTLASS DSL
|
||||
echo "Trying to uninstall previous version of the CUTLASS DSL..."
|
||||
pip uninstall nvidia-cutlass-dsl nvidia-cutlass-dsl-libs-base nvidia-cutlass-dsl-libs-cu13 -y
|
||||
|
||||
# Install requirements
|
||||
echo "Installing from: $REQUIREMENTS_FILE"
|
||||
pip install -r "$REQUIREMENTS_PATH"
|
||||
|
||||
echo "Installation complete!"
|
||||
@@ -54,6 +54,7 @@ class MatmulHeuristics:
|
||||
|
||||
def __init__(self, gpu = None):
|
||||
import nvMatmulHeuristics
|
||||
import inspect
|
||||
self.mmh_lib = nvMatmulHeuristics
|
||||
self.gpu = gpu
|
||||
|
||||
@@ -62,13 +63,63 @@ class MatmulHeuristics:
|
||||
else:
|
||||
nvmmhInterfaceEx = self.mmh_lib.NvMatmulHeuristicsInterfaceEx
|
||||
|
||||
self.lh = nvmmhInterfaceEx(
|
||||
# nvidia-matmul-heuristics 0.1.0.28 changed the API:
|
||||
# - Constructor: removed 'load_discovery_implicitly' and 'gpu' params
|
||||
# - GPU: now set via createHardwareDescriptor() + setHardwarePredefinedGpu()
|
||||
# - setBackendValueProperty renamed to setBackendPropertyValue (simpler signature)
|
||||
# - getEx: added hardware_descriptor parameter
|
||||
init_params = set(inspect.signature(self.mmh_lib.NvMatmulHeuristicsInterfaceEx.__init__).parameters.keys())
|
||||
self._legacy_api = 'load_discovery_implicitly' in init_params
|
||||
|
||||
init_kwargs = dict(
|
||||
backend=self.mmh_lib.NvMatmulHeuristicsTarget["CUTLASS3"],
|
||||
flags=self.mmh_lib.NvMatmulHeuristicsFlags.PERF_MODEL_BASED_AUTO_TUNING,
|
||||
load_discovery_implicitly=True,
|
||||
gpu=self.mmh_lib.NvMatmulHeuristicsNvidiaGpu[self.gpu] if self.gpu else None
|
||||
)
|
||||
|
||||
if self._legacy_api:
|
||||
# <= 0.1.0.27
|
||||
init_kwargs['gpu'] = self.mmh_lib.NvMatmulHeuristicsNvidiaGpu[self.gpu] if self.gpu else None
|
||||
init_kwargs['load_discovery_implicitly'] = True
|
||||
|
||||
self.lh = nvmmhInterfaceEx(**init_kwargs)
|
||||
|
||||
# >= 0.1.0.28: gpu is set via hardware descriptor after construction,
|
||||
# and passed to getEx() calls
|
||||
self.hw_desc = None
|
||||
if not self._legacy_api and self.gpu:
|
||||
self.hw_desc = self.lh.createHardwareDescriptor()
|
||||
if self.hw_desc is None:
|
||||
raise RuntimeError("Failed to create hardware descriptor for GPU: " + self.gpu)
|
||||
self.lh.setHardwarePredefinedGpu(self.hw_desc, self.mmh_lib.NvMatmulHeuristicsNvidiaGpu[self.gpu])
|
||||
|
||||
self.backend = self.lh.createBackend(self.mmh_lib.NvMatmulHeuristicsTarget["CUTLASS3"])
|
||||
|
||||
if not self._legacy_api:
|
||||
lh = self.lh
|
||||
original_del = type(lh).__del__
|
||||
|
||||
def _safe_del(self_lh):
|
||||
try:
|
||||
original_del(self_lh)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
type(lh).__del__ = _safe_del
|
||||
|
||||
def __del__(self):
|
||||
"""Clean up resources in correct order before the library's __del__ runs."""
|
||||
try:
|
||||
if hasattr(self, 'backend') and self.backend:
|
||||
self.lh.destroyBackend(self.backend)
|
||||
self.backend = None
|
||||
if hasattr(self, 'hw_desc') and self.hw_desc:
|
||||
self.lh.destroyHardwareDescriptor(self.hw_desc)
|
||||
self.hw_desc = None
|
||||
# Null out the handle so the library's __del__ skips nvMatmulHeuristicsDestroy
|
||||
if hasattr(self, 'lh') and self.lh and hasattr(self.lh, 'handle'):
|
||||
self.lh.handle = None
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def _layout_from_cutlass(self, layouts):
|
||||
assert(len(layouts)==3)
|
||||
@@ -98,41 +149,45 @@ class MatmulHeuristics:
|
||||
else:
|
||||
return a_c + dtype_to_cublas[dtype_b] + dtype_to_cublas[dtype_c] + dtype_to_cublas[dtype_compute] + dtype_to_cublas[dtype_d]
|
||||
|
||||
def _set_backend_property(self, property, value):
|
||||
"""Compat wrapper: setBackendValueProperty (<=0.1.0.27) vs setBackendPropertyValue (>=0.1.0.28)"""
|
||||
if self._legacy_api:
|
||||
c_val = ctypes.c_int(value)
|
||||
self.lh.setBackendValueProperty(
|
||||
self.backend, property,
|
||||
ctypes.byref(c_val), ctypes.sizeof(c_val)
|
||||
)
|
||||
else:
|
||||
self.lh.setBackendPropertyValue(self.backend, property, value)
|
||||
|
||||
def set_cta_div_n(self, div_n):
|
||||
cta_n_div_requirement = ctypes.c_int(div_n)
|
||||
self.lh.setBackendValueProperty(
|
||||
self.backend,
|
||||
self.mmh_lib.NvMatmulHeuristicsBackendProperty.CTA_TILE_N_DIV_REQUIREMENT,
|
||||
ctypes.byref(cta_n_div_requirement),
|
||||
ctypes.sizeof(cta_n_div_requirement)
|
||||
)
|
||||
self._set_backend_property(
|
||||
self.mmh_lib.NvMatmulHeuristicsBackendProperty.CTA_TILE_N_DIV_REQUIREMENT, div_n)
|
||||
|
||||
def set_cta_div_m(self, div_m):
|
||||
cta_m_div_requirement = ctypes.c_int(div_m)
|
||||
self.lh.setBackendValueProperty(
|
||||
self.backend,
|
||||
self.mmh_lib.NvMatmulHeuristicsBackendProperty.CTA_TILE_M_DIV_REQUIREMENT,
|
||||
ctypes.byref(cta_m_div_requirement),
|
||||
ctypes.sizeof(cta_m_div_requirement)
|
||||
)
|
||||
self._set_backend_property(
|
||||
self.mmh_lib.NvMatmulHeuristicsBackendProperty.CTA_TILE_M_DIV_REQUIREMENT, div_m)
|
||||
|
||||
def get_configs(self, m, n, k, batch_count, dtypes, layouts, align_a, align_b, voidC=False, use_fast_acc=True, count=1):
|
||||
if use_fast_acc:
|
||||
disable_fast_acc_for_fp8 = ctypes.c_int(0)
|
||||
else:
|
||||
disable_fast_acc_for_fp8 = ctypes.c_int(1)
|
||||
self.lh.setBackendValueProperty(
|
||||
self.backend,
|
||||
self._set_backend_property(
|
||||
self.mmh_lib.NvMatmulHeuristicsBackendProperty.DISABLE_FAST_ACC_FOR_FP8,
|
||||
ctypes.byref(disable_fast_acc_for_fp8),
|
||||
ctypes.sizeof(disable_fast_acc_for_fp8)
|
||||
0 if use_fast_acc else 1
|
||||
)
|
||||
|
||||
precision = self._precision_from_cutlass_dtypes(dtypes)
|
||||
layout = self._layout_from_cutlass(layouts)
|
||||
|
||||
matmul_problem = self.lh.makeNvMatmulHeuristicsProblem(m, n, k, layout, batch_count)
|
||||
configs = self.lh.getEx(matmul_problem, count, self.backend, precision=precision)
|
||||
if self._legacy_api:
|
||||
matmul_problem = self.lh.makeNvMatmulHeuristicsProblem(m, n, k, layout, batch_count)
|
||||
else:
|
||||
# >= 0.1.0.28: takes (m,n,k) as a tuple
|
||||
matmul_problem = self.lh.makeNvMatmulHeuristicsProblem((m, n, k), layout, batch_count)
|
||||
|
||||
getEx_kwargs = dict(precision=precision)
|
||||
if not self._legacy_api:
|
||||
# >= 0.1.0.28: pass hardware descriptor to getEx
|
||||
getEx_kwargs['hardware_descriptor'] = self.hw_desc
|
||||
configs = self.lh.getEx(matmul_problem, count, self.backend, **getEx_kwargs)
|
||||
|
||||
ret = []
|
||||
for c in configs:
|
||||
|
||||
Reference in New Issue
Block a user