v4.4 tag release update. (#3032)

This commit is contained in:
Junkai-Wu
2026-02-14 12:27:58 +08:00
committed by GitHub
parent 01687cfba1
commit d4bbf728ca
140 changed files with 41624 additions and 3691 deletions

View File

@@ -179,12 +179,12 @@ class Executor:
def ifexp_execute(
self,
pred,
generator_targets: tuple,
block_args: tuple,
then_block: Callable,
else_block: Callable,
):
assert self._ifexp_dynamic, "Functions must be set before execution."
return self._ifexp_dynamic(pred, generator_targets, then_block, else_block)
return self._ifexp_dynamic(pred, block_args, then_block, else_block)
# =============================================================================
@@ -309,16 +309,14 @@ def if_executor(
def ifExp_executor(
*,
pred,
generator_targets: tuple,
block_args: tuple,
then_block: Callable,
else_block: Callable,
):
if not executor._is_dynamic_expression(pred):
return (
then_block(*generator_targets) if pred else else_block(*generator_targets)
)
return then_block(*block_args) if pred else else_block(*block_args)
else:
return executor.ifexp_execute(pred, generator_targets, then_block, else_block)
return executor.ifexp_execute(pred, block_args, then_block, else_block)
# =============================================================================

View File

@@ -45,6 +45,7 @@ from typing import List, Set, Dict, Any, Callable, Optional
from types import ModuleType
from collections import OrderedDict
from copy import deepcopy
from itertools import chain
from .common import *
from .utils.logger import log
@@ -302,6 +303,7 @@ class SessionData:
import_top_module: bool = False
region_stack: list[Region] = field(default_factory=list)
generator_targets: list[str] = field(default_factory=list)
lambda_args: list[str] = field(default_factory=list)
@contextlib.contextmanager
def set_current_class_name(self, class_name: str):
@@ -2063,6 +2065,19 @@ class DSLPreprocessor(ast.NodeTransformer):
return self._visit_Comprehension(node, key_value_visitor)
def visit_Lambda(self, node):
current_lambda_args = len(self.session_data.lambda_args)
for arg in node.args.args:
self.session_data.lambda_args.append(arg.arg)
node.body = self.visit(node.body)
self.session_data.lambda_args = self.session_data.lambda_args[
:current_lambda_args
]
return node
def visit_ListComp(self, node):
return self._visit_Comprehension(
node, lambda n: setattr(n, "elt", self.visit(n.elt))
@@ -2113,7 +2128,10 @@ class DSLPreprocessor(ast.NodeTransformer):
posonlyargs=[],
args=[
ast.arg(arg=target, annotation=None)
for target in self.session_data.generator_targets
for target in chain(
self.session_data.generator_targets,
self.session_data.lambda_args,
)
],
kwonlyargs=[],
kw_defaults=[],
@@ -2129,7 +2147,10 @@ class DSLPreprocessor(ast.NodeTransformer):
posonlyargs=[],
args=[
ast.arg(arg=target, annotation=None)
for target in self.session_data.generator_targets
for target in chain(
self.session_data.generator_targets,
self.session_data.lambda_args,
)
],
kwonlyargs=[],
kw_defaults=[],
@@ -2151,11 +2172,14 @@ class DSLPreprocessor(ast.NodeTransformer):
keywords=[
ast.keyword(arg="pred", value=self.visit(node.test)),
ast.keyword(
arg="generator_targets",
arg="block_args",
value=ast.Tuple(
elts=[
ast.Name(id=name, ctx=ast.Load())
for name in self.session_data.generator_targets
for name in chain(
self.session_data.generator_targets,
self.session_data.lambda_args,
)
],
ctx=ast.Load(),
),

View File

@@ -50,7 +50,11 @@ from .jit_executor import JitCompiledFunction, JitFunctionArtifacts
from .utils.timer import timer
from .utils.logger import log
from .utils.stacktrace import filter_exception, walk_to_top_module, filter_stackframe
from .runtime.jit_arg_adapters import is_argument_constexpr, JitArgAdapterRegistry
from .runtime.jit_arg_adapters import (
is_argument_constexpr,
is_arg_spec_constexpr,
JitArgAdapterRegistry,
)
from .ast_preprocessor import DSLPreprocessor
from .common import *
@@ -1310,8 +1314,7 @@ class BaseDSL(metaclass=DSLSingletonMeta):
dynamic_args = []
dynamic_kwargs = OrderedDict()
for i, arg in enumerate(args):
if not is_argument_constexpr(
arg,
if not is_arg_spec_constexpr(
args_spec.annotations.get(args_spec.args[i], None),
args_spec.args[i],
i,
@@ -1319,7 +1322,7 @@ class BaseDSL(metaclass=DSLSingletonMeta):
):
dynamic_args.append(arg)
for i, (k, v) in enumerate(kwargs.items()):
if not is_argument_constexpr(v, args_spec.kwonlyargs[i], k, i, funcBody):
if not is_arg_spec_constexpr(args_spec.kwonlyargs[i], k, i, funcBody):
dynamic_kwargs[k] = v
return dynamic_args, dynamic_kwargs

View File

@@ -371,63 +371,6 @@ class MLIRBuilder(MLIRTypeBuilder):
self.const_str_table[content] = symbol
return symbol
def get_or_load_global_func_ptr_from_text(
self,
current_block: ir.Block,
function_name: str,
) -> ir.Value:
"""Get or create a function pointer global in .text section and load it.
This creates a constant global function pointer in the .text section
(for AArch64 ADRP range compatibility) and performs a volatile load
to prevent optimization.
This forces the function pointer to be local to the code, bypassing GOT entry
ADRP lookup issues on AArch64 when GOT and .text section are more than 4GB
apart which can happen when ASLR is applied.
"""
# Check if we've already created this global
if function_name not in self.const_func_ptr_table:
symbol = f"__func_ptr_{function_name}"
module_body = self.module.body
with ir.InsertionPoint(module_body):
# 1. Create the global constant
# We use 'private' linkage so it doesn't conflict across modules
global_ptr = llvm.GlobalOp(
self.ptr_type,
symbol,
ir.Attribute.parse("#llvm.linkage<private>"),
# Initialization via block below
)
# 2. Set the necessary attributes for JIT safety and AArch64 range
# We use 'constant' to mark it as immutable
# We use 'section = ".text"' to force it into the code block
global_ptr.attributes["constant"] = ir.UnitAttr.get()
global_ptr.attributes["section"] = ir.StringAttr.get(".text")
# 3. Add a constructor block to the GlobalOp to initialize it
# with the address of the target function
initializer_block = global_ptr.initializer.blocks.append()
with ir.InsertionPoint(initializer_block):
# Get the address of the external function
func_addr = llvm.AddressOfOp(self.ptr_type, function_name).res
# Return the address as the initial value of the global
llvm.return_(arg=func_addr)
self.const_func_ptr_table[function_name] = symbol
else:
symbol = self.const_func_ptr_table[function_name]
# Load it with volatile semantics in the current block
with ir.InsertionPoint(current_block):
symbol_addr = self.address_of(symbol, self.ptr_type)
# Perform a volatile load to prevent optimization
load_op = llvm.load(self.ptr_type, symbol_addr)
# Set volatile attribute to prevent optimization
load_op.owner.attributes["volatile_"] = ir.UnitAttr.get()
return load_op
# function
def function(

View File

@@ -29,8 +29,14 @@ from .core import (
append_ones,
group_modes,
)
from .atom import MmaAtom, CopyAtom, make_atom
from .atom import (
MmaAtom,
CopyAtom,
make_atom,
_normalize_variadic_tensor_operand,
copy_atom_call,
)
from .nvgpu.common import CacheEvictionPriority
def _normalize_gemm_operand_list(
x: Union["Tensor", List["Tensor"], Tuple["Tensor", ...]], name: str
@@ -156,7 +162,7 @@ def basic_copy(src: Tensor, dst: Tensor, *, loc=None, ip=None) -> None:
src.element_type.mlir_type, src.element_type.width
)
simt_copy = make_atom(simt_copy_ty, loc=loc, ip=ip)
return _cute_ir.copy(simt_copy, src.value, dst.value, loc=loc, ip=ip)
return _cute_ir.copy(simt_copy, [src.value], [dst.value], loc=loc, ip=ip)
s = size(dst, loc=loc, ip=ip)
# Always generate an scf.for Op when one of the tensors is dynamic
@@ -210,7 +216,14 @@ def _basic_copy_if_static(
@dsl_user_op
def autovec_copy(src: Tensor, dst: Tensor, *, loc=None, ip=None) -> None:
def autovec_copy(
src: Tensor,
dst: Tensor,
*,
l1c_evict_priority: CacheEvictionPriority = CacheEvictionPriority.EVICT_NORMAL,
loc=None,
ip=None,
) -> None:
"""
Auto-vectorization SIMT copy policy.
@@ -263,11 +276,15 @@ def autovec_copy(src: Tensor, dst: Tensor, *, loc=None, ip=None) -> None:
# Dispatch to copy with atom
simt_type = _cute_nvgpu_ir.CopyAtomSIMTSyncCopyType.get(
src.element_type.mlir_type, num_bits_per_copy
src.element_type.mlir_type,
num_bits_per_copy,
0,
0,
l1c_evict_priority._to_ir(),
)
simt_copy = make_atom(simt_type, loc=loc, ip=ip)
return _cute_ir.copy(
simt_copy, tiled_src.value, tiled_dst.value, loc=loc, ip=ip
simt_copy, [tiled_src.value], [tiled_dst.value], loc=loc, ip=ip
)
# Failed to vectorize, use a basic copy
@@ -331,8 +348,8 @@ def _parse_auto_multicast_args(
@dsl_user_op
def copy(
atom: CopyAtom,
src: Tensor,
dst: Tensor,
src: Union[Tensor, List[Tensor], Tuple[Tensor, ...]],
dst: Union[Tensor, List[Tensor], Tuple[Tensor, ...]],
*,
pred: Optional[Tensor] = None,
loc=None,
@@ -343,10 +360,10 @@ def copy(
:param atom: Copy atom specifying the transfer operation
:type atom: CopyAtom
:param src: Source tensor with layout profile ``(V, Rest...)``
:type src: Tensor
:param dst: Destination tensor with layout profile ``(V, Rest...)``
:type dst: Tensor
:param src: Source tensor or list of tensors with layout profile ``(V, Rest...)``
:type src: Union[Tensor, List[Tensor], Tuple[Tensor, ...]]
:param dst: Destination tensor or list of tensors with layout profile ``(V, Rest...)``
:type dst: Union[Tensor, List[Tensor], Tuple[Tensor, ...]]
:param pred: Optional predication tensor for conditional transfers, defaults to None
:type pred: Optional[Tensor], optional
:param loc: Source location information, defaults to None
@@ -374,6 +391,12 @@ def copy(
Source and destination tensors must be partitioned in accordance with the Copy Atom specifications.
Post-partitioning, both tensors will exhibit a ``(V, Rest...)`` layout profile.
The operands `src` and `dst` are variadic, each containing a variable number of tensors:
- For regular copy, `src` and `dst` contain single source and destination tensors respectively.
- For copy with auxiliary operands, `src` and `dst` contain the primary tensors followed by
their respective auxiliary tensors.
**Precondition:** The size of mode 1 must be equal for both source and destination tensors:
``size(src, mode=[1]) == size(dst, mode=[1])``
@@ -399,41 +422,54 @@ def copy(
for future releases.
"""
if isinstance(src.type, _cute_ir.MemRefType) and isinstance(
dst.type, _cute_ir.MemRefType
# Normalize src/dst to lists for variadic IR operands
src_list = _normalize_variadic_tensor_operand(src, "src")
dst_list = _normalize_variadic_tensor_operand(dst, "dst")
# Validate primary tensors (first element)
src_primary = src_list[0]
dst_primary = dst_list[0]
if isinstance(src_primary.type, _cute_ir.MemRefType) and isinstance(
dst_primary.type, _cute_ir.MemRefType
):
if src.element_type.width != dst.element_type.width:
if src_primary.element_type.width != dst_primary.element_type.width:
raise TypeError(
"`copy` currently only supports equal source and destination "
"element type bit width"
)
if rank(src) != rank(dst):
if rank(src_primary) != rank(dst_primary):
raise ValueError(
"Expected source and destination tensors to have the same rank, "
f"but got {rank(src)} and {rank(dst)}"
f"but got {rank(src_primary)} and {rank(dst_primary)}"
)
# Canonicalize to at least rank-2 tensors
src = group_modes(append_ones(src, up_to_rank=2), 1)
dst = group_modes(append_ones(dst, up_to_rank=2), 1)
# Canonicalize all tensors to at least rank-2
src_list = [group_modes(append_ones(t, up_to_rank=2), 1) for t in src_list]
dst_list = [group_modes(append_ones(t, up_to_rank=2), 1) for t in dst_list]
if pred is not None:
pred = group_modes(append_ones(pred, up_to_rank=2), 1)
if is_static(src.shape[1]) and is_static(dst.shape[1]):
if size(src, mode=[1]) != size(dst, mode=[1]):
# Recompute primary references after canonicalization
src_primary = src_list[0]
dst_primary = dst_list[0]
if is_static(src_primary.shape[1]) and is_static(dst_primary.shape[1]):
if size(src_primary, mode=[1]) != size(dst_primary, mode=[1]):
raise ValueError(
"Expected source and destination tensors to have the same size in mode-1, "
f"but got {size(src, mode=[1])} and {size(dst, mode=[1])}"
f"but got {size(src_primary, mode=[1])} and {size(dst_primary, mode=[1])}"
)
multicast_attr_pairs = _parse_auto_multicast_args(kwargs)
value = atom._unpack(loc=loc, ip=ip, **kwargs)
if isinstance(pred, Tensor):
pred = pred.value
pred_value = pred.value if isinstance(pred, Tensor) else pred
op = _cute_ir.copy(value, src.value, dst.value, pred=pred, loc=loc, ip=ip)
src_vals = [t.value for t in src_list]
dst_vals = [t.value for t in dst_list]
op = _cute_ir.copy(value, src_vals, dst_vals, pred=pred_value, loc=loc, ip=ip)
for name, attr in multicast_attr_pairs:
op.attributes[name] = attr

View File

@@ -13,7 +13,7 @@ from functools import partial
from typing import Any, Optional, Tuple, Union, Callable, Literal
from typing_extensions import deprecated
from cutlass.cutlass_dsl import T, dsl_user_op
from cutlass.cutlass_dsl import T, dsl_user_op, target_version
import cutlass.cutlass_dsl as cutlass_dsl
@@ -2557,3 +2557,42 @@ def cvt_f4e2m1x8_to_f16x8(src_vec8, *, loc=None, ip=None):
vec_f16x8_type = ir.VectorType.get([8], Float16.mlir_type, loc=loc)
vec_f16x8 = llvm.bitcast(vec_f16x8_type, vec_f32x4, loc=loc, ip=ip)
return vec_f16x8
@dsl_user_op
def mapa(ptr, cta_rank_in_cluster=0, *, loc=None, ip=None):
"""
Map a pointer to distributed shared memory across cluster.
Portable wrapper that uses the appropriate NVVM API based on CUDA version:
- CUDA 13.1+: Uses nvvm.mapa with dsmem address space
- CUDA 12.9: Uses nvvm.mapa_shared_cluster
Args:
ptr: Pointer to shared memory (llvm_ptr attribute will be used)
cta_rank_in_cluster: CTA rank within the cluster (default 0)
Returns:
Mapped LLVM pointer to shared memory
"""
if target_version(min_version="13.1"):
dsmem_ptr_ty = llvm.PointerType.get(7) # dsmem
smem_ptr_ty = llvm.PointerType.get(3) # smem
llvm_ptr = nvvm.mapa(
dsmem_ptr_ty,
ptr.llvm_ptr,
Int32(cta_rank_in_cluster).ir_value(loc=loc, ip=ip),
loc=loc,
ip=ip,
)
return llvm.addrspacecast(smem_ptr_ty, llvm_ptr, loc=loc, ip=ip)
else:
llvm_ptr = ptr.llvm_ptr
return nvvm.mapa_shared_cluster(
llvm_ptr.type,
llvm_ptr,
Int32(cta_rank_in_cluster).ir_value(loc=loc, ip=ip),
loc=loc,
ip=ip,
)

View File

@@ -10,7 +10,7 @@
# is strictly prohibited.
from abc import ABC, ABCMeta, abstractmethod
from typing import Type, Union, Optional, Any, overload
from typing import Type, Union, Optional, Any, List, Tuple, overload
from .typing import Shape, Layout, Tile, Tensor, Numeric, Int32
from .core import (
@@ -1113,25 +1113,66 @@ def make_tiled_copy_C_atom(atom: CopyAtom, mma: TiledMma, *, loc=None, ip=None):
return _make_tiled_copy(atom, layout_tv, tiler_mn, loc=loc, ip=ip)
def _normalize_variadic_tensor_operand(
x: Union["Tensor", List["Tensor"], Tuple["Tensor", ...]], name: str
) -> List["Tensor"]:
"""Normalize a Tensor or sequence of Tensors to a list of Tensors.
Helper function for operations with variadic operands.
"""
if isinstance(x, Tensor):
return [x]
if isinstance(x, (list, tuple)):
if len(x) == 0:
raise ValueError(f"`{name}` must contain at least one Tensor")
if not all(isinstance(t, Tensor) for t in x):
raise TypeError(f"All elements of `{name}` must be Tensor")
return list(x) # type: ignore
raise TypeError(f"`{name}` must be a Tensor or a sequence of Tensors")
@dsl_user_op
def copy_atom_call(
atom: CopyAtom,
src: Tensor,
dst: Tensor,
src: Union[Tensor, List[Tensor], Tuple[Tensor, ...]],
dst: Union[Tensor, List[Tensor], Tuple[Tensor, ...]],
*,
pred: Optional[Tensor] = None,
loc=None,
ip=None,
**kwargs,
) -> None:
"""Executes a single copy atom operation between two tensors.
"""
Execute a single copy atom operation.
The copy_atom_call operation executes a copy atom with the given operands.
Source and destination tensors have layout profile ``(V)``.
The ``V-mode`` represents either:
- A singular mode directly consumable by the provided Copy Atom
- A composite mode requiring recursive decomposition, structured as ``(V, Rest...)``,
For src/dst layout like ``(V, Rest...)``, the layout profile of ``pred`` must match ``(Rest...)``.
- Certain Atoms may require additional operation-specific keyword arguments.
- Current implementation limits ``V-mode`` rank to 2 or less. Support for higher ranks is planned
for future releases.
Both ``src`` and ``dst`` operands are variadic, containing a variable number of tensors:
- For regular copy, ``src`` and ``dst`` each contain a single tensor.
- For copy with auxiliary operands, they contain the main tensor followed by
auxiliary tensors. For example:
:param atom: Copy atom specifying the transfer operation
:type atom: CopyAtom
:param src: Source tensor with layout profile ``(V)``
:type src: Tensor
:param dst: Destination tensor with layout profile ``(V)``
:type dst: Tensor
:param src: Source tensor(s) with layout profile ``(V)``. Can be a single Tensor
or a list/tuple of Tensors for operations with auxiliary source operands.
:type src: Union[Tensor, List[Tensor], Tuple[Tensor, ...]]
:param dst: Destination tensor(s) with layout profile ``(V)``. Can be a single Tensor
or a list/tuple of Tensors for operations with auxiliary destination operands.
:type dst: Union[Tensor, List[Tensor], Tuple[Tensor, ...]]
:param pred: Optional predication tensor for conditional transfers, defaults to None
:type pred: Optional[Tensor], optional
:param loc: Source location information, defaults to None
@@ -1144,54 +1185,43 @@ def copy_atom_call(
:return: None
:rtype: None
The copy_atom_call operation executes a single copy atom with the given operands.
Source and destination tensors with layout profile like ``(V)``.
The ``V-mode`` represents either:
- A singular mode directly consumable by the provided Copy Atom
- A composite mode requiring recursive decomposition, structured as ``(V, Rest...)``,
For src/dst layout like ``(V, Rest...)``, the layout profile of ``pred`` must match ``(Rest...)``.
**Examples**:
.. code-block:: python
# Basic copy atom operation
# Regular copy atom operation
cute.copy_atom_call(copy_atom, src, dst)
# Predicated copy atom operation
cute.copy_atom_call(copy_atom, src, dst, pred=pred)
.. note::
- Certain Atoms may require additional operation-specific keyword arguments.
- Current implementation limits ``V-mode`` rank to 2 or less. Support for higher ranks is planned
for future releases.
"""
if isinstance(src.type, _cute_ir.MemRefType) and isinstance(
dst.type, _cute_ir.MemRefType
# Normalize src/dst to lists for variadic IR operands, while keeping old API working.
src_list = _normalize_variadic_tensor_operand(src, "src")
dst_list = _normalize_variadic_tensor_operand(dst, "dst")
# Validate first src/dst for element type width check
if isinstance(src_list[0].type, _cute_ir.MemRefType) and isinstance(
dst_list[0].type, _cute_ir.MemRefType
):
if src.element_type.width != dst.element_type.width:
if src_list[0].element_type.width != dst_list[0].element_type.width:
raise TypeError(
"`copy_atom_call` currently only supports equal source and destination "
"element type bit width"
)
if rank(src, mode=[0]) > 2 or rank(dst, mode=[0]) > 2:
if rank(src_list[0], mode=[0]) > 2 or rank(dst_list[0], mode=[0]) > 2:
raise NotImplementedError(
"V-mode (mode-0) with rank > 2 is not supported yet, "
f"but got rank(src, mode=[0]) = {rank(src, mode=[0])} and rank(dst, mode=[0]) = {rank(dst, mode=[0])}"
f"but got rank(src, mode=[0]) = {rank(src_list[0], mode=[0])} and rank(dst, mode=[0]) = {rank(dst_list[0], mode=[0])}"
)
value = atom._unpack(loc=loc, ip=ip, **kwargs)
if isinstance(pred, Tensor):
pred = pred.value
return _cute_ir.copy_atom_call(
value, src.value, dst.value, pred=pred, loc=loc, ip=ip
)
src_vals = [t.value for t in src_list]
dst_vals = [t.value for t in dst_list]
return _cute_ir.copy_atom_call(value, src_vals, dst_vals, pred=pred, loc=loc, ip=ip)
@dsl_user_op

View File

@@ -51,4 +51,7 @@
- `get_cta_v_map_ab` — Compute CTA-V map for A/B operands
- `get_cta_v_map_c` — Compute CTA-V map for C operand
- `make_tmem_layout_acc` — Derive TMEM accumulator buffer layout from a tiled MMA
- `make_tmem_layout_a` — Derive TMEM A-operand buffer layout from a tiled MMA
- `make_t2r_rmem_layout` — Derive per-thread RMEM buffer layout for the T2R epilogue copy

View File

@@ -70,10 +70,93 @@ def get_cta_v_map_c(
:param gmem_tensor: Global-memory tensor being stored/loaded by TMA.
:type gmem_tensor: cute.Tensor
:param epi_tile: Epilogue tile layout describing the CTAs output tile shape.
:param epi_tile: Epilogue tile layout describing the CTA's output tile shape.
:type epi_tile: cute.Layout
:returns: A layout suitable to pass as `cta_v_map=...` to `tma_store` / `tma_load`.
:rtype: cute.Layout
"""
ident = cute.core.make_identity_layout(gmem_tensor.shape, loc=loc, ip=ip)
return cute.core.composition(ident, epi_tile, loc=loc, ip=ip)
def make_tmem_layout_acc(
tiled_mma,
mnk_tiler,
acc_stage,
*,
loc=None,
ip=None,
):
"""Return TMEM accumulator buffer layout for a tiled MMA.
This is a small helper around ``tiled_mma.make_fragment_C(...).layout`` to
keep example code fragment-free at the call site.
:param tiled_mma: The MMA tiler (``cute.TiledMma``).
:type tiled_mma: cute.TiledMma
:param mnk_tiler: Full MNK tiler; only the MN components are used for C.
:type mnk_tiler: tuple
:param acc_stage: Accumulator pipeline stages.
:param loc: Optional location for DSL ops.
:param ip: Optional insertion point for DSL ops.
:return: Layout for the accumulator TMEM buffer.
:rtype: cute.Layout
"""
acc_shape = tiled_mma.partition_shape_C(mnk_tiler[:2], loc=loc, ip=ip)
acc_shape_staged = cute.append(acc_shape, acc_stage, loc=loc, ip=ip)
return tiled_mma.make_fragment_C(acc_shape_staged, loc=loc, ip=ip).layout
def make_tmem_layout_a(
tiled_mma,
mk_tiler,
stage,
*,
loc=None,
ip=None,
):
"""Return TMEM A operand buffer layout for a tiled MMA.
:param tiled_mma: The MMA tiler (``cute.TiledMma``).
:type tiled_mma: cute.TiledMma
:param mk_tiler: MK tiler used to shape the A operand.
:type mk_tiler: tuple
:param stage: Pipeline stages for the A operand buffer.
:param loc: Optional location for DSL ops.
:param ip: Optional insertion point for DSL ops.
:return: Layout for the A operand TMEM buffer.
:rtype: cute.Layout
"""
a_shape = tiled_mma.partition_shape_A(mk_tiler, loc=loc, ip=ip)
a_shape_staged = cute.append(a_shape, stage, loc=loc, ip=ip)
return tiled_mma.make_fragment_A(a_shape_staged, loc=loc, ip=ip).layout
def make_t2r_rmem_layout(
tiled_copy_t2r,
gC_mnl_epi,
tidx,
*,
loc=None,
ip=None,
):
"""Return RMEM buffer layout for the T2R epilogue destination.
Computes the per-thread RMEM buffer layout produced by a TMEM->RMEM copy
for a single epilogue iteration.
:param tiled_copy_t2r: The TMEM->RMEM tiled copy op (``cute.TiledCopy``).
:type tiled_copy_t2r: cute.TiledCopy
:param gC_mnl_epi: Global C tensor partitioned by epilogue tile.
:type gC_mnl_epi: cute.Tensor
:param tidx: Thread index for the copy slice.
:param loc: Optional location for DSL ops.
:param ip: Optional insertion point for DSL ops.
:return: Layout for the RMEM buffer.
:rtype: cute.Layout
"""
thr_copy_t2r = tiled_copy_t2r.get_slice(tidx)
tTR_gC = thr_copy_t2r.partition_D(gC_mnl_epi, loc=loc, ip=ip)
return cute.make_fragment_like(
tTR_gC[(None, None, None, 0, 0)].layout, loc=loc, ip=ip
)

View File

@@ -19,6 +19,7 @@ __all__ = [
# copy.py
#
"Repetition",
"TmemLoadRedOp",
"Pack",
"Unpack",
"Ld16x64bOp",

View File

@@ -26,6 +26,22 @@ from ...typing import Numeric
from .mma import CtaGroup
class TmemLoadRedOp(enum.Enum):
"""
An enumeration for the possible reduce operations for TMEM load operations.
"""
MAX = _cute_nvgpu_ir.TmemLoadRedOp.max
MAXABS = _cute_nvgpu_ir.TmemLoadRedOp.maxabs
MIN = _cute_nvgpu_ir.TmemLoadRedOp.min
MINABS = _cute_nvgpu_ir.TmemLoadRedOp.minabs
def __str__(self) -> str:
return f"{self.__class__.__name__}.{self.name}"
def __repr__(self) -> str:
return f"<{self.__class__.__name__}.{self.name}>"
class Repetition(enum.Enum):
"""
An enumeration for the number of repetitions of a given TMEM copy within the instruction.
@@ -390,6 +406,97 @@ class Ld32x32bTrait(Trait):
pass
@dataclass(frozen=True)
class LdRed16x32bx2Op(_LdBase):
"""
16x32bx2 TMEM load Reduce Operation.
See the `PTX documentation <https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-instructions-tcgen05-ld>`__.
This Operation corresponds to the ``.red`` and ``.16x32bx2`` qualifiers.
"""
redOp: TmemLoadRedOp = TmemLoadRedOp.MAX
nan: bool = False
half_split_off: int = 0
def _make_trait(
self, copy_internal_type: Type[Numeric], *, loc=None, ip=None, **kwargs
) -> "LdRed16x32bx2Trait":
"""
Create a trait object for the 16x32bx2 TMEM load Reduce operation.
:param copy_internal_type: The data type for the copy operation
:type copy_internal_type: Type[Numeric]
:param loc: MLIR location information for debugging, defaults to None
:type loc: optional
:param ip: MLIR insertion point for code generation, defaults to None
:type ip: optional
:param kwargs: Additional keyword arguments
:type kwargs: dict
:return: A trait object for this load operation
:rtype: LdRed16x32bx2Trait
"""
ty = _cute_nvgpu_ir.CopyAtomSM10xTmemLoadRedType.get(
copy_internal_type.mlir_type,
16,
32,
self.repeat.value,
self.redOp.value,
ir.UnitAttr.get() if self.nan else None,
ir.IntegerAttr.get(ir.IntegerType.get_signless(32), self.half_split_off),
)
return LdRed16x32bx2Trait(make_atom(ty, loc=loc, ip=ip))
class LdRed16x32bx2Trait(Trait):
pass
@dataclass(frozen=True)
class LdRed32x32bOp(_LdBase):
"""
32x32b TMEM load Reduce Operation.
See the `PTX documentation <https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-instructions-tcgen05-ld>`__.
This Operation corresponds to the ``red`` and ``.32x32`` qualifiers.
"""
redOp: TmemLoadRedOp = TmemLoadRedOp.MAX
nan: bool = False
def _make_trait(
self, copy_internal_type: Type[Numeric], *, loc=None, ip=None, **kwargs
) -> "LdRed32x32bTrait":
"""
Create a trait object for the 32x32b TMEM load Reduce operation.
:param copy_internal_type: The data type for the copy operation
:type copy_internal_type: Type[Numeric]
:param loc: MLIR location information for debugging, defaults to None
:type loc: optional
:param ip: MLIR insertion point for code generation, defaults to None
:type ip: optional
:param kwargs: Additional keyword arguments
:type kwargs: dict
:return: A trait object for this load operation
:rtype: LdRed32x32bTrait
"""
ty = _cute_nvgpu_ir.CopyAtomSM10xTmemLoadRedType.get(
copy_internal_type.mlir_type,
32,
32,
self.repeat.value,
self.redOp.value,
ir.UnitAttr.get() if self.nan else None,
None,
)
return LdRed32x32bTrait(make_atom(ty, loc=loc, ip=ip))
class LdRed32x32bTrait(Trait):
pass
@dataclass(frozen=True)
class _StBase(CopyOp):
"""

View File

@@ -103,15 +103,20 @@ class LdMatrix8x16x8bOp(BaseOp):
self,
"expects the 'num_matrices' Op parameter to be one of [1,2,4]",
)
if self.unpack_bits not in [4, 6]:
raise OpError(self, "Op unpack bits must be 4 or 6")
if self.unpack_bits not in [None, 4, 6]:
raise OpError(self, "Op unpack bits must be 4 or 6 or None")
def _make_trait(
self, copy_internal_type: Type[Numeric], *, loc=None, ip=None, **kwargs
) -> "LdMatrix8x16x8bTrait":
mode = _pack_shape((8, 16), loc=loc, ip=ip)
sz_pattern = _cute_nvgpu_ir.LdsmSzPattern.u4x16p64to8
if self.unpack_bits == 6:
# LdMatrix8x16x8b without unpacking doesn't exist
# but is equivalent to LdMatrix8x8x16b
mode_n = 8 if self.unpack_bits is None else 16
mode = _pack_shape((8, mode_n), loc=loc, ip=ip)
sz_pattern = _cute_nvgpu_ir.LdsmSzPattern.u16
if self.unpack_bits == 4:
sz_pattern = _cute_nvgpu_ir.LdsmSzPattern.u4x16p64to8
elif self.unpack_bits == 6:
sz_pattern = _cute_nvgpu_ir.LdsmSzPattern.u6x16p32to8
ty = _cute_nvgpu_ir.CopyAtomLdsmType.get(
copy_internal_type.mlir_type,

View File

@@ -1,4 +1,4 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
#
# Use of this software is governed by the terms and conditions of the
@@ -141,6 +141,11 @@ class _Tensor(Tensor):
# If tensor is already a DLPack object, use it directly
if hasattr(tensor, "__dlpack_device__") and not hasattr(tensor, "__dlpack__"):
self._dlpack_data = tensor.__dlpack_device__()
elif enable_tvm_ffi:
import tvm_ffi
self._tvm_ffi_tensor = tvm_ffi.from_dlpack(tensor)
self._dlpack_data = self._tvm_ffi_tensor.__dlpack__()
else:
try:
# we expect no stream sync. Because torch has different default behavior
@@ -149,11 +154,6 @@ class _Tensor(Tensor):
self._dlpack_data = tensor.__dlpack__(stream=-1)
except Exception:
self._dlpack_data = tensor.__dlpack__()
if enable_tvm_ffi:
import tvm_ffi
self._tvm_ffi_tensor = tvm_ffi.from_dlpack(tensor)
self._dlpack_data = self._tvm_ffi_tensor.__dlpack__()
self._dltensor_wrapper = None
self._assumed_align = assumed_align

View File

@@ -647,7 +647,7 @@ def _while_execute_dynamic(
def _ifexp_execute_dynamic(
pred: "ir.Value",
generator_targets: tuple,
block_args: tuple,
then_block: Callable,
else_block: Callable,
):
@@ -663,8 +663,8 @@ def _ifexp_execute_dynamic(
----------
pred : ir.Value
The predicate value (a boolean IR value) that determines which branch is executed.
generator_targets : tuple
The generator targets that are passed to the then and else blocks.
block_args : tuple
The block arguments that are passed to the then and else blocks.
then_block : Callable
A Python function that executes the 'then' branch and returns the result(s). This will be
executed if `pred` evaluates to True.
@@ -698,13 +698,13 @@ def _ifexp_execute_dynamic(
with ir.InsertionPoint(execution_region.region.blocks[0]):
# Call the then block and unpack its results to IR values and tree structure
then_results = ScfGenerator._normalize_region_result_to_list(
then_block(*generator_targets)
then_block(*block_args)
)
ir_values, then_tree = cutlass_dsl.unpack_to_irvalue(then_results, "ifexp", 0)
# Call the else block and unpack its results to IR values and tree structure
else_results = ScfGenerator._normalize_region_result_to_list(
else_block(*generator_targets)
else_block(*block_args)
)
_, else_tree = cutlass_dsl.unpack_to_irvalue(else_results, "ifexp", 0)
@@ -739,11 +739,11 @@ def _ifexp_execute_dynamic(
# SCF region builder for then block
def then_builder(*args):
# Just call the then_block as no arguments are passed to it
return then_block(*generator_targets)
return then_block(*block_args)
# SCF region builder for else block
def else_builder(*args):
return else_block(*generator_targets)
return else_block(*block_args)
# Prepare the list of region builders for the SCF IfOp: first for "then", then for "else"
region_builders = [then_builder, else_builder]

View File

@@ -129,16 +129,13 @@ class TVMFFICuteCallProvider(DynamicParamPackCallProvider):
cuda_global_state_ptr = self.address_of(
self.cuda_global_state_symbol, self.ptr_type
)
cuda_init_ptr = context.builder.get_or_load_global_func_ptr_from_text(
current_block, "cuda_init"
)
cuda_load_to_device_ptr = context.builder.get_or_load_global_func_ptr_from_text(
current_block, "cuda_load_to_device"
)
set_error_ptr = context.builder.get_or_load_global_func_ptr_from_text(
current_block, "TVMFFIErrorSetRaisedFromCStr"
)
cuda_init_ptr = self.address_of("cuda_init", self.ptr_type)
cuda_load_to_device_ptr = self.address_of(
"cuda_load_to_device", self.ptr_type
)
set_error_ptr = self.address_of(
"TVMFFIErrorSetRaisedFromCStr", self.ptr_type
)
with ir.InsertionPoint(current_block):
# Call the callback function with the loaded ptr value

View File

@@ -129,7 +129,7 @@ class ClcDynamicPersistentTileScheduler:
:param num_tiles_executed: Counter for executed tiles.
:type num_tiles_executed: Int32
:param clc_response_ptr: Pointer of the clc rsponse.
:type clc_response_ptr: Tuple[Integer, Integer, Integer, Integer]
:type clc_response_ptr: cute.Pointer
:param block_idx: The block index.
:type block_idx: Tuple[Integer, Integer, Integer]
"""
@@ -238,7 +238,7 @@ class ClcDynamicPersistentTileScheduler:
@dsl_user_op
def work_tile_info_from_clc_response(
self, result_addr: Int32, *, loc=None, ip=None
self, result_addr: cute.Pointer, *, loc=None, ip=None
) -> WorkTileInfo:
"""
Simulates parsing CLC response data in Python.

View File

@@ -293,6 +293,7 @@ def epilogue(
acc_pipeline: pipeline.PipelineAsync,
tCcC_base: cute.Tensor = None,
mC_mnl: cute.Tensor = None,
overlapping_accum: Constexpr = False,
) -> pipeline.PipelineState:
"""
Epilogue function that stores accumulator results directly to global memory.
@@ -310,12 +311,18 @@ def epilogue(
:type epi_tile: cute.Tile
:param epilogue_op: Optional elementwise operation to apply
:type epilogue_op: Constexpr
:param alignment_bytes: Alignment bytes for global memory store
:type alignment_bytes: int
:param mma_tile_coord_mnl: MMA tile coordinates (M, N, L)
:type mma_tile_coord_mnl: Tuple[Int32, Int32, Int32]
:param acc_consumer_state: Accumulator consumer pipeline state
:type acc_consumer_state: pipeline.PipelineState
:param acc_pipeline: Accumulator pipeline for async operations
:type acc_pipeline: pipeline.PipelineAsync
:param tCcC_base: Identity/coordinate tensor C
:type tCcC_base: cute.Tensor
:param mC_mnl: Global memory tensor C (full tensor for predicate computation)
:type mC_mnl: cute.Tensor
:param overlapping_accum: Whether to use overlapping accumulator
:type overlapping_accum: Constexpr
"""
# Layout transformation for tCgC_base
@@ -399,9 +406,16 @@ def epilogue(
]
tTR_cC = cute.group_modes(tTR_cC, 3, cute.rank(tTR_cC))
# Get accumulator stage index
if const_expr(overlapping_accum):
acc_stage_index = acc_consumer_state.phase
reverse_subtile = acc_stage_index == 0
else:
acc_stage_index = acc_consumer_state.index
# Set tensor memory buffer for current tile
# (T2R, T2R_M, T2R_N, EPI_M, EPI_M)
tTR_tAcc = tTR_tAcc_base[(None, None, None, None, None, acc_consumer_state.index)]
tTR_tAcc = tTR_tAcc_base[(None, None, None, None, None, acc_stage_index)]
#
# Wait for accumulator buffer full
@@ -415,21 +429,38 @@ def epilogue(
#
subtile_cnt = cute.size(tTR_tAcc.shape, mode=[3])
for subtile_idx in range(subtile_cnt):
# Compute the actual subtile index
real_subtile_idx = subtile_idx
if const_expr(overlapping_accum):
if reverse_subtile:
real_subtile_idx = subtile_cnt - 1 - subtile_idx
#
# Get the destination and coordinate slices for this subtile
#
tTR_gC_subtile = tTR_gC[(None, None, None, subtile_idx)]
tTR_gC_subtile = tTR_gC[(None, None, None, real_subtile_idx)]
#
# Load accumulator from tensor memory buffer to register
#
tTR_tAcc_mn = tTR_tAcc[(None, None, None, subtile_idx)]
tTR_tAcc_mn = tTR_tAcc[(None, None, None, real_subtile_idx)]
cute.copy(tiled_copy_t2r, tTR_tAcc_mn, tTR_rAcc)
#
# Async arrive accumulator buffer empty
# Release early for perf
if subtile_idx == subtile_cnt - 1:
with cute.arch.elect_one():
acc_pipeline.consumer_release(acc_consumer_state)
acc_consumer_state.advance()
#
if const_expr(overlapping_accum):
# Early release when overlapping: release after processing the
# overlapping region (SF columns) so they can be reused
if subtile_idx == gemm_kernel.iter_acc_early_release_in_epilogue:
cute.arch.fence_view_async_tmem_load()
with cute.arch.elect_one():
acc_pipeline.consumer_release(acc_consumer_state)
acc_consumer_state.advance()
else:
# Release early for perf at the last subtile
if subtile_idx == subtile_cnt - 1:
with cute.arch.elect_one():
acc_pipeline.consumer_release(acc_consumer_state)
acc_consumer_state.advance()
#
# Convert to C type
@@ -440,7 +471,7 @@ def epilogue(
if const_expr(use_predication):
# compute predicate
tTR_cC_subtile = tTR_cC[(None, None, None, subtile_idx)]
tTR_cC_subtile = tTR_cC[(None, None, None, real_subtile_idx)]
pred_C_shape = (1, *tTR_cC_subtile.shape[1:])
pred_C = cute.make_rmem_tensor(pred_C_shape, Boolean)
for m_idx in range(tTR_cC_subtile.shape[1]):

View File

@@ -41,7 +41,23 @@ class HardwareInfo:
self.driver_version = self._checkCudaErrors(driver.cuDriverGetVersion())
# Getting the max active clusters for a given cluster size
def get_max_active_clusters(self, cluster_size: int) -> int:
def get_max_active_clusters(
self, cluster_size: int, stream: driver.CUstream = None
) -> int:
"""
Get the maximum number of active clusters for a given cluster size.
When a stream from a green context is provided, the occupancy calculation
will reflect the reduced SM partition of the green context.
:param cluster_size: Number of blocks per cluster (must be between 1 and 32)
:type cluster_size: int
:param stream: Optional CUDA stream handle. If provided (especially from a green context),
the occupancy calculation reflects the stream's SM partition.
:type stream: driver.CUstream, optional
:return: Maximum number of active clusters
:rtype: int
"""
if self._cuda_driver_version_lt(11, 8):
raise RuntimeError(
"CUDA Driver version < 11.8, cannot get _max_active_clusters"
@@ -94,6 +110,13 @@ class HardwareInfo:
launch_config.blockDimY = 1
launch_config.blockDimZ = 1
launch_config.sharedMemBytes = max_dynamic_shared_memory
# IMPORTANT: Set the stream for green context support
# When hStream is set, cuOccupancyMaxActiveClusters will use the context
# associated with that stream, which includes the green context's SM partition
if stream is not None:
launch_config.hStream = stream
launch_config.numAttrs = 1
# max possible cluster size is 32
cluster_dims_attr = driver.CUlaunchAttribute()

View File

@@ -1,4 +1,4 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
#
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual

View File

@@ -0,0 +1,3 @@
# Use `pip install -r requirements-cu13.txt` with the present file to install a
# wheel consistent with the present state of the github repository
nvidia-cutlass-dsl[cu13]==4.4.0

View File

@@ -1,3 +1,3 @@
# Use `pip install -r requirements.txt` with the present file to install a
# wheel consistent with the present state of the github repository
nvidia-cutlass-dsl==4.4.0.dev1
nvidia-cutlass-dsl==4.4.0

85
python/CuTeDSL/setup.sh Executable file
View File

@@ -0,0 +1,85 @@
#!/bin/bash
#################################################################################################
#
# Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
set -e
# Get the directory where this script is located
SCRIPT_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )"
# Default to requirements.txt
REQUIREMENTS_FILE="requirements.txt"
# Parse command line arguments
if [ $# -gt 0 ]; then
case "$1" in
--cu12)
REQUIREMENTS_FILE="requirements.txt"
echo "Installing CUDA 12 requirements..."
;;
--cu13)
REQUIREMENTS_FILE="requirements-cu13.txt"
echo "Installing CUDA 13 requirements..."
;;
--help|-h)
echo "Usage: $0 [--cu12|--cu13]"
echo " --cu12 Install requirements for CUDA 12 (default)"
echo " --cu13 Install requirements for CUDA 13"
exit 0
;;
*)
echo "Error: Unknown argument '$1'"
echo "Usage: $0 [--cu12|--cu13]"
exit 1
;;
esac
else
echo "Installing default requirements (CUDA 12)..."
fi
# Check if requirements file exists
REQUIREMENTS_PATH="${SCRIPT_DIR}/${REQUIREMENTS_FILE}"
if [ ! -f "$REQUIREMENTS_PATH" ]; then
echo "Error: Requirements file not found: $REQUIREMENTS_PATH"
exit 1
fi
# Uninstall previous version of the CUTLASS DSL
echo "Trying to uninstall previous version of the CUTLASS DSL..."
pip uninstall nvidia-cutlass-dsl nvidia-cutlass-dsl-libs-base nvidia-cutlass-dsl-libs-cu13 -y
# Install requirements
echo "Installing from: $REQUIREMENTS_FILE"
pip install -r "$REQUIREMENTS_PATH"
echo "Installation complete!"

View File

@@ -54,6 +54,7 @@ class MatmulHeuristics:
def __init__(self, gpu = None):
import nvMatmulHeuristics
import inspect
self.mmh_lib = nvMatmulHeuristics
self.gpu = gpu
@@ -62,13 +63,63 @@ class MatmulHeuristics:
else:
nvmmhInterfaceEx = self.mmh_lib.NvMatmulHeuristicsInterfaceEx
self.lh = nvmmhInterfaceEx(
# nvidia-matmul-heuristics 0.1.0.28 changed the API:
# - Constructor: removed 'load_discovery_implicitly' and 'gpu' params
# - GPU: now set via createHardwareDescriptor() + setHardwarePredefinedGpu()
# - setBackendValueProperty renamed to setBackendPropertyValue (simpler signature)
# - getEx: added hardware_descriptor parameter
init_params = set(inspect.signature(self.mmh_lib.NvMatmulHeuristicsInterfaceEx.__init__).parameters.keys())
self._legacy_api = 'load_discovery_implicitly' in init_params
init_kwargs = dict(
backend=self.mmh_lib.NvMatmulHeuristicsTarget["CUTLASS3"],
flags=self.mmh_lib.NvMatmulHeuristicsFlags.PERF_MODEL_BASED_AUTO_TUNING,
load_discovery_implicitly=True,
gpu=self.mmh_lib.NvMatmulHeuristicsNvidiaGpu[self.gpu] if self.gpu else None
)
if self._legacy_api:
# <= 0.1.0.27
init_kwargs['gpu'] = self.mmh_lib.NvMatmulHeuristicsNvidiaGpu[self.gpu] if self.gpu else None
init_kwargs['load_discovery_implicitly'] = True
self.lh = nvmmhInterfaceEx(**init_kwargs)
# >= 0.1.0.28: gpu is set via hardware descriptor after construction,
# and passed to getEx() calls
self.hw_desc = None
if not self._legacy_api and self.gpu:
self.hw_desc = self.lh.createHardwareDescriptor()
if self.hw_desc is None:
raise RuntimeError("Failed to create hardware descriptor for GPU: " + self.gpu)
self.lh.setHardwarePredefinedGpu(self.hw_desc, self.mmh_lib.NvMatmulHeuristicsNvidiaGpu[self.gpu])
self.backend = self.lh.createBackend(self.mmh_lib.NvMatmulHeuristicsTarget["CUTLASS3"])
if not self._legacy_api:
lh = self.lh
original_del = type(lh).__del__
def _safe_del(self_lh):
try:
original_del(self_lh)
except Exception:
pass
type(lh).__del__ = _safe_del
def __del__(self):
"""Clean up resources in correct order before the library's __del__ runs."""
try:
if hasattr(self, 'backend') and self.backend:
self.lh.destroyBackend(self.backend)
self.backend = None
if hasattr(self, 'hw_desc') and self.hw_desc:
self.lh.destroyHardwareDescriptor(self.hw_desc)
self.hw_desc = None
# Null out the handle so the library's __del__ skips nvMatmulHeuristicsDestroy
if hasattr(self, 'lh') and self.lh and hasattr(self.lh, 'handle'):
self.lh.handle = None
except Exception:
pass
def _layout_from_cutlass(self, layouts):
assert(len(layouts)==3)
@@ -98,41 +149,45 @@ class MatmulHeuristics:
else:
return a_c + dtype_to_cublas[dtype_b] + dtype_to_cublas[dtype_c] + dtype_to_cublas[dtype_compute] + dtype_to_cublas[dtype_d]
def _set_backend_property(self, property, value):
"""Compat wrapper: setBackendValueProperty (<=0.1.0.27) vs setBackendPropertyValue (>=0.1.0.28)"""
if self._legacy_api:
c_val = ctypes.c_int(value)
self.lh.setBackendValueProperty(
self.backend, property,
ctypes.byref(c_val), ctypes.sizeof(c_val)
)
else:
self.lh.setBackendPropertyValue(self.backend, property, value)
def set_cta_div_n(self, div_n):
cta_n_div_requirement = ctypes.c_int(div_n)
self.lh.setBackendValueProperty(
self.backend,
self.mmh_lib.NvMatmulHeuristicsBackendProperty.CTA_TILE_N_DIV_REQUIREMENT,
ctypes.byref(cta_n_div_requirement),
ctypes.sizeof(cta_n_div_requirement)
)
self._set_backend_property(
self.mmh_lib.NvMatmulHeuristicsBackendProperty.CTA_TILE_N_DIV_REQUIREMENT, div_n)
def set_cta_div_m(self, div_m):
cta_m_div_requirement = ctypes.c_int(div_m)
self.lh.setBackendValueProperty(
self.backend,
self.mmh_lib.NvMatmulHeuristicsBackendProperty.CTA_TILE_M_DIV_REQUIREMENT,
ctypes.byref(cta_m_div_requirement),
ctypes.sizeof(cta_m_div_requirement)
)
self._set_backend_property(
self.mmh_lib.NvMatmulHeuristicsBackendProperty.CTA_TILE_M_DIV_REQUIREMENT, div_m)
def get_configs(self, m, n, k, batch_count, dtypes, layouts, align_a, align_b, voidC=False, use_fast_acc=True, count=1):
if use_fast_acc:
disable_fast_acc_for_fp8 = ctypes.c_int(0)
else:
disable_fast_acc_for_fp8 = ctypes.c_int(1)
self.lh.setBackendValueProperty(
self.backend,
self._set_backend_property(
self.mmh_lib.NvMatmulHeuristicsBackendProperty.DISABLE_FAST_ACC_FOR_FP8,
ctypes.byref(disable_fast_acc_for_fp8),
ctypes.sizeof(disable_fast_acc_for_fp8)
0 if use_fast_acc else 1
)
precision = self._precision_from_cutlass_dtypes(dtypes)
layout = self._layout_from_cutlass(layouts)
matmul_problem = self.lh.makeNvMatmulHeuristicsProblem(m, n, k, layout, batch_count)
configs = self.lh.getEx(matmul_problem, count, self.backend, precision=precision)
if self._legacy_api:
matmul_problem = self.lh.makeNvMatmulHeuristicsProblem(m, n, k, layout, batch_count)
else:
# >= 0.1.0.28: takes (m,n,k) as a tuple
matmul_problem = self.lh.makeNvMatmulHeuristicsProblem((m, n, k), layout, batch_count)
getEx_kwargs = dict(precision=precision)
if not self._legacy_api:
# >= 0.1.0.28: pass hardware descriptor to getEx
getEx_kwargs['hardware_descriptor'] = self.hw_desc
configs = self.lh.getEx(matmul_problem, count, self.backend, **getEx_kwargs)
ret = []
for c in configs: