v4.5.1 update. (#3238)

This commit is contained in:
Junkai-Wu
2026-05-19 10:33:27 +08:00
committed by GitHub
parent e406c186f5
commit 2e602843e7
42 changed files with 6487 additions and 336 deletions

View File

@@ -720,7 +720,9 @@ class DSLPreprocessor(ast.NodeTransformer):
offset = len(all_args) - len(func_ast.args.defaults)
for i, default_node in enumerate(func_ast.args.defaults):
ast_defaults[all_args[offset + i].arg] = default_node
for kwarg, kw_default in zip(func_ast.args.kwonlyargs, func_ast.args.kw_defaults):
for kwarg, kw_default in zip(
func_ast.args.kwonlyargs, func_ast.args.kw_defaults
):
if kw_default is not None:
ast_defaults[kwarg.arg] = kw_default
for param_name, default_val in params_with_defaults.items():

View File

@@ -1865,7 +1865,7 @@ class BaseDSL(metaclass=DSLSingletonMeta):
sources = set(x.value for x in link_libraries_attributes)
link_libraries = (
link_libraries
+ ("," if len(link_libraries) > 0 else "")
+ ("," if link_libraries and len(sources) > 0 else "")
+ ",".join(sources)
)
self.compile_options.options[LinkLibraries] = LinkLibraries(

View File

@@ -88,6 +88,11 @@ def _get_gpu_arch_info(major: int, minor: int) -> tuple[str, str, list[str]]:
"sm_120a",
["sm_120a"],
), # RTX PRO 6000 / RTX 50 Series
(12, 1): (
"Blackwell",
"sm_121a",
["sm_121a"],
), # DGX Spark
}
return gpu_arch_map.get(
(major, minor), ("Unknown", f"sm_{major}{minor}", [f"sm_{major}{minor}"])

View File

@@ -3330,7 +3330,12 @@ def filter_zeros(
if not isinstance(input, (Layout, Tensor)):
raise TypeError(f"Expected layout or tensor as input, but got {type(input)=}")
if isinstance(input, Tensor):
input = input.value
return _op_wrapper(
partial(_cute_ir.filter_zeros, target_profile=target_profile),
input,
loc=loc,
ip=ip,
)
return _cute_ir.filter_zeros(input, target_profile=target_profile, loc=loc, ip=ip)
@@ -3388,7 +3393,7 @@ def filter(
input.inner, input.offset, filter(input.outer, loc=loc, ip=ip)
)
elif isinstance(input, _Tensor):
return _cute_ir.filter(input.value, loc=loc, ip=ip)
return _op_wrapper(_cute_ir.filter, input, loc=loc, ip=ip)
else:
return _cute_ir.filter(input, loc=loc, ip=ip)
@@ -5020,10 +5025,9 @@ def local_partition(
raise NotImplementedError(
f"Index value should be 32-bit or smaller integer type, but got {index_val.type}"
)
return _cute_ir.local_partition(
input=target.value,
tiler=dice(tiler, proj),
index=index_val,
return _op_wrapper(
partial(_cute_ir.local_partition, tiler=dice(tiler, proj), index=index_val),
target,
loc=loc,
ip=ip,
)
@@ -5114,11 +5118,9 @@ def local_tile(
proj_val = _pack_coord(proj, loc=loc, ip=ip)
proj = proj_val.type.attribute
return _cute_ir.local_tile(
input=input.value,
tile=tiler_val,
coord=coord_val,
proj=proj,
return _op_wrapper(
partial(_cute_ir.local_tile, tile=tiler_val, coord=coord_val, proj=proj),
input,
loc=loc,
ip=ip,
)

View File

@@ -21,6 +21,9 @@ __all__ = [
"MmaFP8Op",
"MmaMXF4Op",
"MmaMXF4NVF4Op",
"MmaMXF8Op",
"MmaMXF8F6F4Op",
"MXF8F6F4_SUPPORTED_PAIRS",
# copy.py
"LdMatrix8x8x16bOp",
"LdMatrix16x8x8bOp",

View File

@@ -224,7 +224,9 @@ class MmaSM120BlockScaledOp(MmaOp):
admissible_archs = [
Arch.sm_120a,
Arch.sm_120f,
Arch.sm_121a,
Arch.sm_121f,
]
def __post_init__(self) -> None:
@@ -239,29 +241,44 @@ class MmaSM120BlockScaledOp(MmaOp):
"CUTE_DSL_ARCH set to sm_120a or sm_121a",
suggestion="Ensure env CUTE_DSL_ARCH matches your GPU architecture",
)
if self.ab_dtype != Float4E2M1FN:
# (ab_dtype, shape_mnk) consistency: FP4 uses (16,8,64); FP8 uses (16,8,32).
if self.ab_dtype == Float4E2M1FN:
if self.shape_mnk != (16, 8, 64):
raise OpError(
self,
"expects the 'shape_mnk' Op parameter to be (16,8,64) for Float4E2M1FN",
)
elif self.ab_dtype in (Float8E4M3FN, Float8E5M2):
if self.shape_mnk != (16, 8, 32):
raise OpError(
self,
"expects the 'shape_mnk' Op parameter to be (16,8,32) for Float8E4M3FN/Float8E5M2",
)
else:
raise OpError(
self,
"expects the 'ab_dtype' Op parameter to be Float4E2M1FN",
"expects the 'ab_dtype' Op parameter to be Float4E2M1FN, Float8E4M3FN, or Float8E5M2",
)
if self.acc_dtype != Float32:
raise OpError(
self,
"expects the 'acc_dtype' Op parameter to be Float32",
)
if self.shape_mnk != (16, 8, 64):
raise OpError(
self,
"expects the 'shape_mnk' Op parameter to be (16,8,64)",
)
if self.sf_vec_size == 16:
# vec_size=16 is only valid for FP4 (NVFP4) with E4M3 scale.
if self.ab_dtype != Float4E2M1FN:
raise OpError(
self,
"expects the 'sf_vec_size' Op parameter to be 32 for Float8E4M3FN/Float8E5M2",
)
if self.sf_type != Float8E4M3FN:
raise OpError(
self,
"expects the 'sf_type' Op parameter to be Float8E4M3FN",
)
elif self.sf_vec_size == 32:
# vec_size=32 path uses UE8M0 scale for both FP4 (MXF4) and FP8 (MXF8).
if self.sf_type != Float8E8M0FNU:
raise OpError(
self,
@@ -275,7 +292,7 @@ class MmaSM120BlockScaledOp(MmaOp):
def __str__(self) -> str:
return (
"warp-level MXF4/MXF4NVF4 MMA Operation"
"warp-level MXF4/MXF4NVF4/MXF8 MMA Operation"
+ f"\n A/B data type = {self.ab_dtype}"
+ f"\n Accumulator data type = {self.acc_dtype}"
+ f"\n Instruction shape MNK = {self.shape_mnk}"
@@ -474,3 +491,214 @@ class MmaMXF4NVF4Op(MmaSM120BlockScaledOp):
class MmaMXF4NVF4Trait(MmaBlockScaledTrait):
pass
#
# MXF8 MMA
#
@dataclass(frozen=True)
class MmaMXF8Op(MmaSM120BlockScaledOp):
"""
MXF8 warp-level MMA Operation.
See the `PTX documentation <https://docs.nvidia.com/cuda/parallel-thread-execution/#warp-level-matrix-instructions-mma>`__.
This Operation covers the instructions using the ``.e4m3`` / ``.e5m2`` qualifiers for the input operands.
.kind = {.kind::mxf8};
.scale_vec_size = {.scale_vec::1X};
.stype = {.ue8m0};
"""
descriptive_name = "warp-level MXF8 MMA Operation"
def __init__(
self,
ab_dtype: Type[Numeric],
acc_dtype: Type[Numeric],
sf_type: Type[Numeric],
) -> None:
super().__init__(
ab_dtype,
acc_dtype,
(16, 8, 32),
sf_type,
32,
)
def _make_trait(
self,
*,
loc: Optional[ir.Location] = None,
ip: Optional[ir.InsertionPoint] = None,
**kwargs: Any,
) -> "MmaMXF8Trait":
shape_mnk = _pack_shape(self.shape_mnk, loc=loc, ip=ip)
ty = _cute_nvgpu_ir.MmaAtomSM120BlockScaledType.get(
shape_mnk.type.attribute,
32,
False,
self.ab_dtype.mlir_type,
self.ab_dtype.mlir_type,
self.acc_dtype.mlir_type,
self.sf_type.mlir_type,
)
return MmaMXF8Trait(make_atom(ty, loc=loc, ip=ip))
class MmaMXF8Trait(MmaBlockScaledTrait):
pass
#
# MXF8F6F4 mixed-precision MMA (independent A/B dtypes)
#
MXF8F6F4_SUPPORTED_PAIRS = frozenset(
{
(Float4E2M1FN, Float8E4M3FN),
(Float4E2M1FN, Float8E5M2),
(Float8E4M3FN, Float4E2M1FN),
(Float8E5M2, Float4E2M1FN),
}
)
@dataclass(frozen=True)
class MmaMXF8F6F4Op(MmaOp):
"""
SM120 MXF8F6F4 mixed-precision warp-level block-scaled MMA Operation.
Covers the PTX instructions using independent ``.<a_type>.<b_type>``
qualifiers (one of e2m1.e4m3, e2m1.e5m2, e4m3.e2m1, e5m2.e2m1):
.kind = {.kind::mxf8f6f4};
.scale_vec_size = {.scale_vec::1X};
.stype = {.ue8m0};
A and B operand dtypes are independent. Same-dtype FP4/FP4 and FP8/FP8
paths remain on ``MmaMXF4Op`` / ``MmaMXF4NVF4Op`` / ``MmaMXF8Op``
respectively. Same-width mixed-FP8 (E4M3 + E5M2) and FP6 mixed pairs
are not supported.
"""
a_dtype: Type[Numeric]
b_dtype: Type[Numeric]
acc_dtype: Type[Numeric]
sf_type: Type[Numeric]
descriptive_name = "warp-level MXF8F6F4 mixed-precision MMA Operation"
shape_mnk = (16, 8, 32)
sf_vec_size = 32
use_sf_layout_TV = False
admissible_archs = [
Arch.sm_120a,
Arch.sm_121a,
]
def __post_init__(self) -> None:
# Verify arch
arch = BaseDSL._get_dsl().get_arch_enum()
if arch not in self.admissible_archs:
raise OpError(
self,
f"expects arch to be one of {self.admissible_archs}, but got {arch}",
suggestion="Ensure env CUTE_DSL_ARCH matches your GPU architecture",
)
if self.acc_dtype != Float32:
raise OpError(
self,
"expects the 'acc_dtype' Op parameter to be Float32",
)
if self.sf_type != Float8E8M0FNU:
raise OpError(
self,
"expects the 'sf_type' Op parameter to be Float8E8M0FNU",
)
# Reject same-dtype pairs explicitly (route to dedicated ops).
if self.a_dtype == self.b_dtype:
if self.a_dtype == Float4E2M1FN:
raise OpError(
self,
"same-dtype Float4E2M1FN/Float4E2M1FN is not supported by MmaMXF8F6F4Op; "
"use MmaMXF4Op (sf_vec_size=32) or MmaMXF4NVF4Op (sf_vec_size=16) instead",
)
if self.a_dtype in (Float8E4M3FN, Float8E5M2):
raise OpError(
self,
"same-dtype FP8/FP8 is not supported by MmaMXF8F6F4Op; "
"use MmaMXF8Op instead",
)
# Reject same-width mixed-FP8 (E4M3 + E5M2) explicitly.
fp8_dtypes = (Float8E4M3FN, Float8E5M2)
if self.a_dtype in fp8_dtypes and self.b_dtype in fp8_dtypes:
raise OpError(
self,
"same-width mixed-FP8 (Float8E4M3FN + Float8E5M2) is not supported; "
"supported MXF8F6F4 pairs are (Float4E2M1FN x Float8E4M3FN/Float8E5M2) "
"and the reverse",
)
# Final allow-list check (catches FP6 and any other unsupported dtype).
if (self.a_dtype, self.b_dtype) not in MXF8F6F4_SUPPORTED_PAIRS:
raise OpError(
self,
f"unsupported (a_dtype, b_dtype) = ({self.a_dtype}, {self.b_dtype}) "
f"for MmaMXF8F6F4Op; supported pairs are "
f"{sorted(repr(p) for p in MXF8F6F4_SUPPORTED_PAIRS)}. "
f"FP6 mixed pairs are not supported.",
)
def __str__(self) -> str:
return (
"warp-level MXF8F6F4 mixed-precision MMA Operation"
+ f"\n A data type = {self.a_dtype}"
+ f"\n B data type = {self.b_dtype}"
+ f"\n Accumulator data type = {self.acc_dtype}"
+ f"\n Instruction shape MNK = {self.shape_mnk}"
+ f"\n Vector size = {self.sf_vec_size}"
+ f"\n SF data type = {self.sf_type}"
)
def _verify_fragment_A(
self,
input: Tensor,
*,
loc: Optional[ir.Location] = None,
ip: Optional[ir.InsertionPoint] = None,
) -> None:
pass
def _verify_fragment_B(
self,
input: Tensor,
*,
loc: Optional[ir.Location] = None,
ip: Optional[ir.InsertionPoint] = None,
) -> None:
pass
def _make_trait(
self,
*,
loc: Optional[ir.Location] = None,
ip: Optional[ir.InsertionPoint] = None,
**kwargs: Any,
) -> "MmaMXF8F6F4Trait":
shape_mnk = _pack_shape(self.shape_mnk, loc=loc, ip=ip)
ty = _cute_nvgpu_ir.MmaAtomSM120BlockScaledType.get(
shape_mnk.type.attribute,
self.sf_vec_size,
self.use_sf_layout_TV,
self.a_dtype.mlir_type,
self.b_dtype.mlir_type,
self.acc_dtype.mlir_type,
self.sf_type.mlir_type,
)
return MmaMXF8F6F4Trait(make_atom(ty, loc=loc, ip=ip))
class MmaMXF8F6F4Trait(MmaBlockScaledTrait):
pass

View File

@@ -21,7 +21,11 @@ from jax._src.interpreters import batching
from .compile import get_or_compile_kernel, build_function_spec
from .types import cutlass_to_jax_layout_order, default_tensor_spec, TensorSpec
from .types import (
cutlass_to_jax_layout_order,
default_tensor_spec,
TensorSpec,
)
from .ffi import get_cutlass_call_ffi_name, is_ffi_registered, register_ffi
@@ -77,8 +81,10 @@ def cutlass_call(
objects with ``.shape`` and ``.dtype`` attributes) describing each
output buffer.
input_spec: A :class:`TensorSpec` or list thereof providing
layout/mode/divisibility hints for input tensors. ``None`` infers
defaults from each array.
layout/mode/divisibility hints for input tensors. ``None`` infers
defaults from each array. A ``TensorSpec`` with ``layout=None`` uses
and constrains row-major physical layout; use ``mode`` to remap
physical dimensions to the kernel's logical modes.
output_spec: Same as *input_spec* but applied to output tensors.
input_output_aliases: ``{input_index: output_index}`` mapping that
allows an input buffer to alias an output, avoiding an extra copy.
@@ -308,7 +314,9 @@ def cutlass_call_inner_p_impl(
call_name = get_cutlass_call_ffi_name(allow_cuda_graph)
# Convert layout from CuTeDSL to JAX order as ffi_call expects this.
# Convert explicit layout constraints from CuTeDSL to JAX order. ``None`` is
# passed through intentionally: jax.ffi.ffi_call treats it as default
# row-major layout.
input_layouts = [cutlass_to_jax_layout_order(s.layout) for s in input_spec_flat]
output_layouts = [cutlass_to_jax_layout_order(s.layout) for s in output_spec_flat]

View File

@@ -15,12 +15,18 @@ import jax.numpy as jnp
import cutlass.cute as cute
from cutlass.cutlass_dsl import dsl_user_op
from typing import Optional
from typing import Optional, Sequence
from cutlass._mlir import ir
def reorder_modes(src: str, target: str) -> tuple[int, ...]:
"""Computes the mode given a source and target order."""
def reorder_modes(src: Sequence[str], target: Sequence[str]) -> tuple[int, ...]:
"""Compute a ``TensorSpec.mode`` from physical input order to kernel order.
``src`` names the JAX array's physical dimension order. ``target`` names the
logical mode order that the CuTe kernel expects. The returned tuple can be
passed as ``TensorSpec(mode=...)`` while leaving ``layout`` at its default
row-major value when the JAX buffer is physically row-major.
"""
src = tuple(src)
target = tuple(target)
src_map = {}
@@ -29,52 +35,64 @@ def reorder_modes(src: str, target: str) -> tuple[int, ...]:
return tuple([src_map[d] for d in target])
def gemm_a_major(d: str):
"""Returns order for A tensor major mode."""
def gemm_a_major(d: str) -> str:
"""Return the physical JAX dimension order for an A tensor major mode.
The returned string is not the kernel's canonical logical order. Use
:func:`gemm_a_mode` to map this physical order to kernel logical ``mkl``.
"""
return {"k": "lmk", "m": "lkm"}[d]
def gemm_a_mode(d: str) -> tuple[int, ...]:
"""Returns mode for A tensor major mode."""
"""Return ``TensorSpec.mode`` for A, mapping physical order to logical ``mkl``."""
return reorder_modes(gemm_a_major(d), "mkl")
def gemm_b_major(d: str):
"""Returns order for B tensor major mode."""
def gemm_b_major(d: str) -> str:
"""Return the physical JAX dimension order for a B tensor major mode.
The returned string is not the kernel's canonical logical order. Use
:func:`gemm_b_mode` to map this physical order to kernel logical ``nkl``.
"""
return {"k": "lnk", "n": "lkn"}[d]
def gemm_b_mode(d: str) -> tuple[int, ...]:
"""Returns mode for B tensor major mode."""
"""Return ``TensorSpec.mode`` for B, mapping physical order to logical ``nkl``."""
return reorder_modes(gemm_b_major(d), "nkl")
def gemm_c_major(d: str):
"""Returns order for C tensor major mode."""
def gemm_c_major(d: str) -> str:
"""Return the physical JAX dimension order for a C/D tensor major mode.
The returned string is not the kernel's canonical logical order. Use
:func:`gemm_c_mode` to map this physical order to kernel logical ``mnl``.
"""
return {"n": "lmn", "m": "lnm"}[d]
def gemm_c_mode(d: str) -> tuple[int, ...]:
"""Returns mode for C tensor major mode."""
"""Return ``TensorSpec.mode`` for C/D, mapping physical order to logical ``mnl``."""
return reorder_modes(gemm_c_major(d), "mnl")
def gemm_a_shape(l, m, k, major) -> tuple[int, ...]:
"""Returns shape for A tensor given major mode."""
def gemm_a_shape(l: int, m: int, k: int, major: str) -> tuple[int, ...]:
"""Return the physical row-major JAX shape for A with the requested major mode."""
assert major in ("k", "m")
shape = (l, m, k) if major == "k" else (l, k, m)
return shape
def gemm_b_shape(l, n, k, major) -> tuple[int, ...]:
"""Returns shape for B tensor given major mode."""
def gemm_b_shape(l: int, n: int, k: int, major: str) -> tuple[int, ...]:
"""Return the physical row-major JAX shape for B with the requested major mode."""
assert major in ("k", "n")
shape = (l, n, k) if major == "k" else (l, k, n)
return shape
def gemm_c_shape(l, m, n, major) -> tuple[int, ...]:
"""Returns shape for C tensor given major mode."""
def gemm_c_shape(l: int, m: int, n: int, major: str) -> tuple[int, ...]:
"""Return the physical row-major JAX shape for C/D with the requested major mode."""
assert major in ("m", "n")
shape = (l, m, n) if major == "n" else (l, n, m)
return shape

View File

@@ -9,7 +9,7 @@
# and related documentation outside the scope permitted by the EULA
# is strictly prohibited.
from typing import Optional, Sequence
from typing import Any, Optional, Sequence
from dataclasses import dataclass, field
@@ -18,6 +18,7 @@ import jax.numpy as jnp
import cutlass
import cutlass.cute as cute
from cutlass.cute.core import IntValue
from cutlass.cute.runtime import from_dlpack as _from_dlpack
from cutlass.cute import AddressSpace
from cutlass._mlir import ir
@@ -58,35 +59,69 @@ DEFAULT_CUTLASS_DEVICE_BUFFER_ALIGNMENT = 256
class TensorSpec:
"""Specifies the layout and metadata for a JAX array passed to a CuTe kernel.
TensorSpec controls how a JAX array's dimensions are mapped to a cute.Tensor
during jit lowering, including stride ordering, mode permutation, and whether
shapes/strides are compiled as static constants.
TensorSpec controls how a JAX array's input dimensions are mapped to a
``cute.Tensor`` during jit lowering, including compact stride ordering,
mode permutation, and whether shapes/strides are compiled as static
constants. The JAX bridge models tensors as compact layouts: runtime
strides are derived from runtime shapes using ``layout`` order rather than
loaded from a strided view descriptor.
A useful way to choose a spec is to separate physical storage from logical
kernel modes:
1. First choose the public JAX array shape and its compact physical memory
order. If the buffer is a standard row-major JAX array, leave
``layout=None``. ``cutlass_call`` will constrain the FFI operand/result
to row-major physical layout, matching the CuTe tensor strides that are
built from the default.
2. Then use ``mode`` only when the kernel should see those input dimensions
in a different logical order. ``mode`` is applied after the compact
layout is built; it is not a request for JAX/XLA to transpose data.
For example, a row-major JAX buffer shaped ``(expert_count, N, K)`` can be
presented to a kernel expecting logical ``(N, K, expert_count)`` with
``TensorSpec(mode=(1, 2, 0))``. No explicit ``layout`` is needed because the
physical buffer is still ordinary row-major, and the FFI call will be
constrained accordingly. Use ``layout`` only when the compact physical
stride order itself differs from the default row-major order, such as a
column-major compact buffer.
Attributes:
layout: A minor-to-major stride ordering in CuTeDSL convention. ``layout[i]``
gives the stride rank of dimension ``i``, where rank 0 means the smallest
(innermost) stride. For example, row-major order for a 3-D tensor is
``(2, 1, 0)``. If ``None``, row-major is assumed. Use
:func:`jax_to_cutlass_layout_order` to convert from JAX's major-to-minor
convention.
mode: A permutation that maps the stride-ordered dimensions to the mode
positions of the resulting ``cute.Layout``. For example, ``mode=(2, 0, 1)``
reorders an ``(M, K, L)`` layout into ``(K, L, M)`` mode order inside the
kernel. If ``None``, modes match the natural dimension order ``(0, 1, ..., N-1)``.
gives the compact physical stride rank of input dimension ``i``,
where rank 0 means the smallest (innermost) stride. For example,
row-major order for a 3-D tensor is ``(2, 1, 0)``. If ``None``,
row-major is assumed. Use :func:`jax_to_cutlass_layout_order` to
convert from JAX's major-to-minor convention. ``layout`` does not
change which logical mode a dimension represents; combine it with
``mode`` when physical order and kernel-logical order differ.
mode: A permutation applied after the compact layout is constructed. It
selects input dimensions into the mode positions seen by the kernel.
For example, ``mode=(2, 0, 1)`` presents an input shaped
``(M, K, L)`` to the kernel as logical ``(L, M, K)``. If ``None``,
modes match the natural input-dimension order ``(0, 1, ..., N-1)``.
``mode`` changes the tensor layout object seen by CuTe code but
does not materialize a transpose or change the underlying buffer.
static: If ``True``, shapes and strides are compiled as static ``constexpr``
values, which may enable additional compiler optimisations. Kernels that
do not support static shapes will raise a compile error. Must be ``False``
when any dimension is symbolic (e.g. under ``jax.export``).
ptr_assumed_align: Assumed byte alignment of the tensor's data pointer.
Overrides the default of 256 bytes. Rarely needs to change.
divisibility: Optional per-mode divisibility hints. If a single int is passed
divisibility will be applied to the leading (stride=1) dimension only.
divisibility: Optional divisibility hints for input dimensions, in the
same order as the JAX array shape and before any ``mode`` reordering.
Positive hints constrain dynamic shape values and are propagated
through compact stride construction: a stride inherits the product
of the divisibilities for dimensions with lower stride rank.
Positive explicit hints take precedence over inferred concrete
extents. If a single int is passed, it is applied to the leading
compact dimension only, where ``layout[i] == 0``.
"""
# Minor-to-major stride ordering in CuTeDSL convention (layout[i] = stride rank
# of dimension i, 0 = innermost). Defaults to row-major if None.
layout: tuple[int, ...] | None = field(metadata=dict(static=True), default=None)
# Permutation from stride-ordered dimensions to cute.Layout mode positions.
# Permutation from input dimensions to cute.Layout mode positions.
# Defaults to identity (0, 1, ..., N-1) if None.
mode: tuple[int, ...] | None = field(metadata=dict(static=True), default=None)
# If True, shapes and strides are embedded as compile-time constants.
@@ -96,7 +131,7 @@ class TensorSpec:
ptr_assumed_align: int = field(
metadata=dict(static=True), default=DEFAULT_CUTLASS_DEVICE_BUFFER_ALIGNMENT
)
# Per-mode divisibility hints.
# Per-input-dimension divisibility hints, before mode reordering.
divisibility: tuple[int | None, ...] | int | None = field(
metadata=dict(static=True), default=None
)
@@ -128,9 +163,10 @@ def row_major_layout(shaped):
def default_tensor_mode(shaped):
"""Returns the identity mode permutation for an N-dimensional tensor.
The mode permutation maps stride-ordered dimensions to ``cute.Layout`` mode
positions. The default identity ``(0, 1, ..., N-1)`` leaves the mode order
unchanged relative to the dimension order.
The mode permutation maps JAX input dimensions to ``cute.Layout`` mode
positions after the compact layout has been constructed. The default
identity ``(0, 1, ..., N-1)`` leaves the mode order unchanged relative to
the JAX shape order.
Args:
shaped: An object with a ``.shape`` attribute, or a shape tuple/sequence.
@@ -151,12 +187,22 @@ def default_tensor_spec(shaped) -> TensorSpec:
TensorSpec(layout=(N-1, ..., 1, 0), mode=(0, 1, ..., N-1), divisibility=(D0, D1, ... DN-1))
This is appropriate for standard row-major (C-contiguous) JAX arrays that
do not require dimension reordering inside the kernel.
do not require dimension reordering inside the kernel. The resulting JAX
CuTe tensor is treated as compact: strides are derived from shapes using the
row-major layout order.
Divisibility hints are inferred only for concrete integer dimensions.
Symbolic dimensions always produce ``None`` for their slot; pass an
explicit ``TensorSpec`` with ``divisibility`` set if you need alignment
hints for symbolic shapes.
If the JAX buffer is row-major but the kernel expects a different logical
mode order, use an explicit :class:`TensorSpec` with ``mode`` set and leave
``layout`` unset. ``cutlass_call`` still constrains the FFI buffer to
row-major layout in this case. For example, ``TensorSpec(mode=(1, 2, 0))``
maps a physical ``(L, M, K)`` row-major input to a logical ``(M, K, L)``
tensor.
Divisibility hints are inferred only for concrete integer input dimensions.
Symbolic dimensions always produce ``None`` for their slot; pass an explicit
``TensorSpec`` with ``divisibility`` set if you need alignment hints for
symbolic shapes or want a weaker explicit constraint than the concrete
extent.
Args:
shaped: An object with a ``.shape`` attribute, or a shape tuple/sequence.
@@ -179,11 +225,12 @@ def default_tensor_spec(shaped) -> TensorSpec:
def _expand_divisibility(
divisibility, order: tuple[int, ...], ndim: int
) -> tuple[int | None, ...] | None:
"""Expand a divisibility spec to a full per-dimension tuple.
"""Expand a divisibility spec to a full per-input-dimension tuple.
A bare ``int`` is placed at the leading-dimension slot (where
``order[i] == 0``, i.e. stride == 1) and ``None`` everywhere else.
A tuple is returned unchanged. ``None`` returns ``None``.
A tuple is already in JAX input-dimension order and is returned unchanged.
``None`` returns ``None``.
"""
if divisibility is None or isinstance(divisibility, tuple):
return divisibility
@@ -268,7 +315,20 @@ def from_dlpack(array, assumed_align: int = DEFAULT_CUTLASS_DEVICE_BUFFER_ALIGNM
return _from_dlpack(array, assumed_align=assumed_align)
def _validate_permutation(name: str, perm, shape):
def _assume_divisible_int(
value: Any,
divby: int,
*,
loc: ir.Location | None = None,
ip: ir.InsertionPoint | None = None,
) -> Any:
"""Attach a divisibility assumption to an integer value without narrowing it."""
if divby <= 1:
return value
return cute.assume(IntValue(value, loc=loc, ip=ip), divby=divby, loc=loc, ip=ip)
def _validate_permutation(name: str, perm: Sequence[int], shape: Sequence[Any]) -> None:
if len(perm) != len(shape):
raise ValueError(f"{name} must be same length as shape", perm, shape)
for s in perm:
@@ -292,9 +352,11 @@ class JaxArray:
can be concrete or symbolic in the case of jax.export.
3. mem_space: The memory space of the tensor. Defaults to gmem.
4. assumed_align: The alignment of the tensor. Defaults to XLA alignment.
5. order: Specifies the order of the shape to determine strides.
6. mode: Specifies how to map ordered elements to the modes od a cute.Layout.
5. order: Specifies the compact physical stride order of the shape.
6. mode: Specifies how to map input dimensions to the logical modes seen by
the kernel after the compact layout is constructed.
7. static: If True, tensor shapes and strides are compiled statically.
8. divisibility: Optional divisibility hints in input-dimension order.
"""
def __init__(
@@ -381,6 +443,21 @@ class JaxArrayValue(JaxArray):
ip: Optional[ir.InsertionPoint] = None,
):
i32 = ir.IntegerType.get_signless(32)
# Track the divisibility available for each input dimension. Explicit
# hints win; otherwise concrete dimensions contribute their known extent.
dim_divisibility = None
if self.divisibility is not None:
dim_divisibility = []
for div_spec, static_s in zip(self.divisibility, self.shape):
if div_spec is not None and div_spec > 0:
dim_divisibility.append(div_spec)
elif isinstance(static_s, int):
dim_divisibility.append(static_s)
else:
dim_divisibility.append(1)
dim_divisibility = tuple(dim_divisibility)
pairs = sorted(zip(shape, order), key=lambda x: x[1])
# Compute strides for each element in order.
@@ -395,28 +472,29 @@ class JaxArrayValue(JaxArray):
for i in range(len(shape)):
strides_ordered.append(strides[order[i]])
if dim_divisibility is not None:
# A compact stride is the product of all dimensions with a lower
# stride order, so it inherits the product of their divisibility.
stride_divisibility = []
for dim_order in order:
divby = 1
for other_dim, other_order in enumerate(order):
if other_order < dim_order:
divby *= dim_divisibility[other_dim]
stride_divisibility.append(divby)
strides_ordered = [
_assume_divisible_int(s, divby, loc=loc, ip=ip)
for s, divby in zip(strides_ordered, stride_divisibility)
]
# Shapes are expected to be int32 so truncate to that before creating layout
shape_i32 = tuple(arith.trunci(i32, s) for s in shape)
# Apply per-mode divisibility assumptions so the compiler can exploit alignment.
if self.divisibility is not None:
assumed = []
for s32, div_spec, static_s in zip(
shape_i32, self.divisibility, self.shape
):
if isinstance(static_s, int):
# Pure static shape is known even though a dynamic shape is
# used. We can assume the exact shape here. We keep the shape
# as a dynamic value to avoid breaking code that may expect
# a dynamic value.
assumed.append(cute.assume(s32, divby=static_s))
elif div_spec is not None:
# Using a dynamic value so apply the div_spec if its provided.
assumed.append(cute.assume(s32, divby=div_spec))
else:
# No divisibility specification for this shape
assumed.append(s32)
shape_i32 = tuple(assumed)
if dim_divisibility is not None:
shape_i32 = tuple(
_assume_divisible_int(s, divby, loc=loc, ip=ip)
for s, divby in zip(shape_i32, dim_divisibility)
)
return cute.make_layout(shape_i32, stride=tuple(strides_ordered))

View File

@@ -84,6 +84,8 @@ from .tmem_allocator import (
from .layout import LayoutEnum
from .block import block_copy
from .mixed_input_helpers import (
TransformMode,
scale_tma_partition,
@@ -176,6 +178,7 @@ __all__ = [
"sm90",
"sm100",
"gemm",
"block_copy",
"ClcDynamicPersistentTileSchedulerParams",
"ClcDynamicPersistentTileScheduler",
"print_latex",

View File

@@ -612,7 +612,7 @@ def get_tmem_load_op(
def get_smem_layout_atom_ab(
major_mode: OperandMajorMode,
element_type: Type[Numeric],
smem_shape_mn_k: Tuple[int, int],
smem_shape_mn_k: cute.Tile,
*,
loc: Optional[ir.Location] = None,
ip: Optional[ir.InsertionPoint] = None,
@@ -625,13 +625,16 @@ def get_smem_layout_atom_ab(
:param element_type: The element type for the SMEM tensor.
:type element_type: Type[Numeric]
:param smem_shape_mn_k: The shape of the SMEM tensor.
:type smem_shape_mn_k: Tuple[int, int]
:type smem_shape_mn_k: cute.Tile
:return: The SMEM layout atom kind
:rtype: cutlass.cute.nvgpu.tcgen05.SmemLayoutAtomKind
"""
is_k_major = major_mode == OperandMajorMode.K
major_mode_size = smem_shape_mn_k[1] if is_k_major else smem_shape_mn_k[0]
major_mode_size = (
cute.size(smem_shape_mn_k, mode=[1])
if is_k_major
else cute.size(smem_shape_mn_k, mode=[0])
)
assert major_mode_size % 8 == 0
sw128_num_contiguous_bits = 1024
sw64_num_contiguous_bits = 512
@@ -711,6 +714,7 @@ def make_smem_layout(
cute.append(smem_tile_shape, num_stages),
order=(0, 1, 2) if is_k_major else (1, 0, 2),
)
return cute.coalesce(smem_layout, target_profile=(1, 1, 1), loc=loc, ip=ip)
@@ -1956,12 +1960,35 @@ def thrfrg_SFA(
"""Thread-fragment scale factor A tensor for SM120 block-scaled MMA.
Implements the ThrFrg partitioning for scale factor A according to the
corresponding C++ code.
corresponding C++ code in cutlass/include/cute/atom/mma_traits_sm120.hpp:
SFALayout for SM120 MXF4 16x8x64 uses K=64, SM120 MXF8F6F4 16x8x32 uses
K=32; the stride pattern ``((_8,_0,_1), _16)`` is shared.
"""
assert cute.rank(sfa_tensor) >= 2
atom_shape_mnk = tiled_mma.shape_mnk
atom_sfa_layout = cute.make_layout(shape=((2, 2, 8), 64), stride=((8, 0, 1), 16))
# K-dim of the warp-MMA atom: FP4 -> 64, FP8 -> 32 (per mma_traits_sm120.hpp).
# For FP8 (atom_K=32) where mma_nsf=1, wrap K in a 2-tuple ``(atom_K, 1)``
# so the layout's K mode keeps its 2D structure and the resulting fragment
# has the same rank as the FP4 path. For FP4 (atom_K=64) the original 1D
# layout already produces a 2D K decomposition through SMEM-layout
# composition, so we keep the original shape.
atom_K = atom_shape_mnk[2]
if atom_K == 32:
atom_sfa_layout = cute.make_layout(
shape=((2, 2, 8), (atom_K, 1)),
stride=((8, 0, 1), (16, 0)),
)
elif atom_K == 64:
atom_sfa_layout = cute.make_layout(
shape=((2, 2, 8), atom_K),
stride=((8, 0, 1), 16),
)
else:
raise ValueError(
f"thrfrg_SFA: unsupported atom_K={atom_K}; SM120 block-scaled atoms "
f"use atom_K=32 (mxf8/mxf8f6f4) or atom_K=64 (mxf4/mxf4nvf4)"
)
permutation_mnk = tiled_mma.permutation_mnk
thr_layout_vmnk = tiled_mma.thr_layout_vmnk
@@ -2000,12 +2027,32 @@ def thrfrg_SFB(
"""Thread-fragment scale factor B tensor for SM120 block-scaled MMA.
Implements the ThrFrg partitioning for scale factor B according to the
corresponding C++ code.
corresponding C++ code in cutlass/include/cute/atom/mma_traits_sm120.hpp:
SFBLayout for SM120 MXF4 16x8x64 uses K=64, SM120 MXF8F6F4 16x8x32 uses
K=32; the stride pattern ``((_0,_1), _8)`` is shared.
"""
assert cute.rank(sfb_tensor) >= 2
atom_shape_mnk = tiled_mma.shape_mnk
atom_sfb_layout = cute.make_layout(shape=((4, 8), 64), stride=((0, 1), 8))
# K-dim of the warp-MMA atom: FP4 -> 64, FP8 -> 32 (per mma_traits_sm120.hpp).
# See :func:`thrfrg_SFA` for the rationale behind the FP8-only
# ``(atom_K, 1)`` wrapping.
atom_K = atom_shape_mnk[2]
if atom_K == 32:
atom_sfb_layout = cute.make_layout(
shape=((4, 8), (atom_K, 1)),
stride=((0, 1), (8, 0)),
)
elif atom_K == 64:
atom_sfb_layout = cute.make_layout(
shape=((4, 8), atom_K),
stride=((0, 1), 8),
)
else:
raise ValueError(
f"thrfrg_SFB: unsupported atom_K={atom_K}; SM120 block-scaled atoms "
f"use atom_K=32 (mxf8/mxf8f6f4) or atom_K=64 (mxf4/mxf4nvf4)"
)
permutation_mnk = tiled_mma.permutation_mnk
thr_layout_vmnk = tiled_mma.thr_layout_vmnk

View File

@@ -0,0 +1,248 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
#
# Use of this software is governed by the terms and conditions of the
# NVIDIA End User License Agreement (EULA), available at:
# https://docs.nvidia.com/cutlass/latest/media/docs/pythonDSL/license.html
#
# Any use, reproduction, disclosure, or distribution of this software
# and related documentation outside the scope permitted by the EULA
# is strictly prohibited.
from cutlass.cutlass_dsl import dsl_user_op, CuTeDSL
from cutlass.cute.typing import Tensor
from cutlass.cute.core import make_layout, filter_zeros
from cutlass.cute.atom import TiledCopy
from cutlass.cute.algorithm import copy
from cutlass.cute.nvgpu import tcgen05
from cutlass.cute.nvgpu.cpasync.copy import (
TmaCopyOp,
CopyBulkTensorTileG2SOp,
CopyBulkTensorTileG2SMulticastOp,
)
from cutlass.cute.nvgpu.cpasync.helpers import tma_partition
from cutlass.cute.nvgpu.tcgen05.copy import _S2TCopyBase
from typing import Any, Optional
from cutlass._mlir import ir
def _check_required_args(
required_args: list[str], kwargs: dict, condition: bool = True
) -> None:
if not condition:
return
for arg in required_args:
if arg not in kwargs:
raise ValueError(f"Argument {arg} is required.")
def _tma_copy_impl(
tiled_copy: TiledCopy,
src: Tensor,
dst: Tensor,
*,
loc: Optional[ir.Location] = None,
ip: Optional[ir.InsertionPoint] = None,
**kwargs: Any,
) -> None:
"""Internal implementation for TMA-based block-level copy."""
#
# Handle tma_multicast argument
#
if "tma_multicast" in kwargs:
if not isinstance(
tiled_copy.op,
(
CopyBulkTensorTileG2SOp,
),
):
raise ValueError(
"block_copy with tma_multicast expects a non-multicast G2S TMA copy atom "
"(CopyBulkTensorTileG2SOp) for compiler-driven multicast"
)
# Mark as coming from block API
kwargs["tma_multicast"]["from_block_api"] = True
#
# Check if required arguments are provided
#
is_bar_ptr_required = isinstance(
tiled_copy.op,
(
CopyBulkTensorTileG2SOp,
CopyBulkTensorTileG2SMulticastOp,
),
)
_check_required_args(["tma_bar_ptr"], kwargs, is_bar_ptr_required)
#
# TMA bulk tensor copies: partition via tma_partition
#
is_g2s = isinstance(
tiled_copy.op,
(
CopyBulkTensorTileG2SOp,
),
)
stensor = dst if is_g2s else src
gtensor = src if is_g2s else dst
cta_coord = 0
cta_layout = make_layout(1, loc=loc, ip=ip)
s_ptn, g_ptn = tma_partition(
tiled_copy, cta_coord, cta_layout, stensor, gtensor, loc=loc, ip=ip
)
s_ptn = filter_zeros(s_ptn)
g_ptn = filter_zeros(g_ptn)
src_arg = g_ptn if is_g2s else s_ptn
dst_arg = s_ptn if is_g2s else g_ptn
return copy(tiled_copy, src_arg, dst_arg, loc=loc, ip=ip, **kwargs)
def _utccp_copy_impl(
tiled_copy: TiledCopy,
src: Tensor,
dst: Tensor,
*,
loc: Optional[ir.Location] = None,
ip: Optional[ir.InsertionPoint] = None,
**kwargs: Any,
) -> None:
"""Internal implementation for S2T (SMEM to TMEM) copy operations.
This function abstracts the S2T copy pattern which involves:
1. Filtering zeros from src (smem) and dst (tmem) tensors
2. Creating a tiled copy using make_s2t_copy
3. Partitioning source and destination
4. Getting the SMEM descriptor tensor
5. Executing the copy
:param tiled_copy: The tiled copy for S2T operations.
:type tiled_copy: TiledCopy
:param src: The source tensor in shared memory.
:type src: Tensor
:param dst: The destination tensor in TMEM.
:type dst: Tensor
"""
# Filter zeros from src (smem) and dst (tmem) tensors
src_compact = filter_zeros(src)
dst_compact = filter_zeros(dst)
# S2T has a single thread slice; election handled automatically in lowering
thr_copy = tiled_copy.get_slice(0)
# Partition source and destination
src_partitioned = thr_copy.partition_S(src_compact, loc=loc, ip=ip)
dst_partitioned = thr_copy.partition_D(dst_compact, loc=loc, ip=ip)
# Get SMEM descriptor tensor for the source
smem_desc_tensor = tcgen05.get_s2t_smem_desc_tensor(
tiled_copy, src_partitioned, loc=loc, ip=ip
)
# Execute the copy
return copy(tiled_copy, smem_desc_tensor, dst_partitioned, loc=loc, ip=ip, **kwargs)
@dsl_user_op
@CuTeDSL.jit
def block_copy(
tiled_copy: TiledCopy,
src: Tensor,
dst: Tensor,
*,
loc: Optional[ir.Location] = None,
ip: Optional[ir.InsertionPoint] = None,
**kwargs: Any,
) -> None:
"""Performs a block-level copy operation.
This function adds an abstraction layer over the `cute.copy` usage model by
allowing operands with layouts shaped like tiles to be passed directly. This
removes the need to manually partition. The API is designed to support multiple
copy kinds; currently TMA-based copies and S2T (SMEM to TMEM) copies are supported.
**TMA copy requirements**:
When using TMA-based tiled copies, the ``src`` and ``dst`` tensors must have
their first mode representing the TMATile, i.e. tensors shaped as ``(TMATile, Rest...)``.
For a rank-2 tensor with logical layout (e.g., ``(TILE_M, TILE_N)``), call
``group_modes(tensor, 0, 2)`` before passing it to this function.
**TMA multicast support**:
For TMA-based copies that enable compiler-driven multicast in a 2D cluster, pass the
``tma_multicast`` argument as a dict with the following keys:
- ``cluster_shape``: a tuple of 2 integers ``(cluster_m, cluster_n)``
representing the **2D cluster shape**.
- ``multicast_dim``: either ``"M"`` or ``"N"`` indicating which
cluster dimension the multicast happens along.
- ``use_2cta_mma_inst`` (optional): a ``bool`` indicating whether to
use 2CTA MMA instructions when the loaded data is consumed by MMA.
Defaults to ``False`` when omitted.
**S2T (SMEM to TMEM) copy**:
When using S2T copy operations (e.g., ``tcgen05.Cp4x32x128bOp``), the function
automatically handles the filtering, partitioning, and SMEM descriptor creation.
Pass a copy atom created with ``cute.make_copy_atom(tcgen05.Cp*Op(...), dtype)``
along with source (SMEM) and destination (TMEM) tensors.
Examples:
.. code-block:: python
# 1) TMA load without compiler-driven multicast
# Note: group_modes is called to make the first mode TMATile
block_copy(tma_atom_a, group_modes(tCgA_, 0, 2), group_modes(tCsA_, 0, 2),
tma_bar_ptr=tma_bar_ptr)
# 2) TMA load with compiler-driven multicast along M in a (4,2) cluster
block_copy(
tma_atom_a,
group_modes(tCgA_, 0, 2),
group_modes(tCsA_, 0, 2),
tma_multicast={
"cluster_shape": (4, 2),
"multicast_dim": "M",
"use_2cta_mma_inst": True,
},
tma_bar_ptr=tma_bar_ptr,
)
# 3) TMA store
# Note that `tma_bar_ptr` and CTA params (`cta_coord` and `cta_layout`)
# are not needed for TMA store
block_copy(tma_atom_c, group_modes(tCsC_, 0, 2), group_modes(tCgC_, 0, 2))
# 4) S2T copy (SMEM to TMEM)
copy_atom_s2t = cute.make_copy_atom(
tcgen05.Cp4x32x128bOp(tcgen05.CtaGroup.ONE), sf_dtype
)
block_copy(copy_atom_s2t, tCsSF, tCtSF)
:param tiled_copy: The tiled_copy or copy_atom of the current copy operation.
:type tiled_copy: TiledCopy
:param src: The source tensor.
:type src: Tensor
:param dst: The destination tensor.
:type dst: Tensor
:param tma_multicast: Optional dict for TMA multicast configuration with keys
``cluster_shape``, ``multicast_dim``, and optionally
``use_2cta_mma_inst``.
:type tma_multicast: dict, optional
"""
import cutlass # local import to avoid circular import at module load time
if cutlass.const_expr(isinstance(tiled_copy.op, TmaCopyOp)):
return _tma_copy_impl(tiled_copy, src, dst, loc=loc, ip=ip, **kwargs)
elif cutlass.const_expr(isinstance(tiled_copy.op, _S2TCopyBase)):
return _utccp_copy_impl(tiled_copy, src, dst, loc=loc, ip=ip, **kwargs)
else:
raise NotImplementedError(
f"Copy op {type(tiled_copy.op).__name__} is not supported yet."
)

View File

@@ -1,3 +1,3 @@
# Use `pip install -r requirements-cu13.txt` with the present file to install a
# wheel consistent with the present state of the github repository
nvidia-cutlass-dsl[cu13]==4.4.2
nvidia-cutlass-dsl[cu13]==4.5.1

View File

@@ -1,3 +1,3 @@
# Use `pip install -r requirements.txt` with the present file to install a
# wheel consistent with the present state of the github repository
nvidia-cutlass-dsl==4.4.2
nvidia-cutlass-dsl==4.5.1

View File

@@ -133,7 +133,7 @@ def get_option_registry():
this._option_registry = OptionRegistry(device_cc())
return this._option_registry
this.__version__ = '4.5.0'
this.__version__ = '4.5.1'
from cutlass_cppgen.backend import create_memory_pool
from cutlass_cppgen.emit.pytorch import pytorch

View File

@@ -51,7 +51,7 @@ setup_pycute.perform_setup()
setup(
name='cutlass_cppgen',
version='4.5.0',
version='4.5.1',
description='CUTLASS Pythonic Interface',
package_dir={'': '.'},
packages=[

View File

@@ -36,7 +36,7 @@ from setuptools import setup
def perform_setup():
setup(
name='cutlass_library',
version='4.5.0',
version='4.5.1',
description='CUTLASS library generation scripts',
packages=['cutlass_library']
)

View File

@@ -36,7 +36,7 @@ from setuptools import setup
def perform_setup():
setup(
name='pycute',
version='4.5.0',
version='4.5.1',
description='Python implementation of CuTe',
packages=['pycute'],
)