mirror of
https://github.com/NVIDIA/cutlass.git
synced 2026-05-20 12:59:01 +00:00
v4.5.1 update. (#3238)
This commit is contained in:
@@ -720,7 +720,9 @@ class DSLPreprocessor(ast.NodeTransformer):
|
||||
offset = len(all_args) - len(func_ast.args.defaults)
|
||||
for i, default_node in enumerate(func_ast.args.defaults):
|
||||
ast_defaults[all_args[offset + i].arg] = default_node
|
||||
for kwarg, kw_default in zip(func_ast.args.kwonlyargs, func_ast.args.kw_defaults):
|
||||
for kwarg, kw_default in zip(
|
||||
func_ast.args.kwonlyargs, func_ast.args.kw_defaults
|
||||
):
|
||||
if kw_default is not None:
|
||||
ast_defaults[kwarg.arg] = kw_default
|
||||
for param_name, default_val in params_with_defaults.items():
|
||||
|
||||
@@ -1865,7 +1865,7 @@ class BaseDSL(metaclass=DSLSingletonMeta):
|
||||
sources = set(x.value for x in link_libraries_attributes)
|
||||
link_libraries = (
|
||||
link_libraries
|
||||
+ ("," if len(link_libraries) > 0 else "")
|
||||
+ ("," if link_libraries and len(sources) > 0 else "")
|
||||
+ ",".join(sources)
|
||||
)
|
||||
self.compile_options.options[LinkLibraries] = LinkLibraries(
|
||||
|
||||
@@ -88,6 +88,11 @@ def _get_gpu_arch_info(major: int, minor: int) -> tuple[str, str, list[str]]:
|
||||
"sm_120a",
|
||||
["sm_120a"],
|
||||
), # RTX PRO 6000 / RTX 50 Series
|
||||
(12, 1): (
|
||||
"Blackwell",
|
||||
"sm_121a",
|
||||
["sm_121a"],
|
||||
), # DGX Spark
|
||||
}
|
||||
return gpu_arch_map.get(
|
||||
(major, minor), ("Unknown", f"sm_{major}{minor}", [f"sm_{major}{minor}"])
|
||||
|
||||
@@ -3330,7 +3330,12 @@ def filter_zeros(
|
||||
if not isinstance(input, (Layout, Tensor)):
|
||||
raise TypeError(f"Expected layout or tensor as input, but got {type(input)=}")
|
||||
if isinstance(input, Tensor):
|
||||
input = input.value
|
||||
return _op_wrapper(
|
||||
partial(_cute_ir.filter_zeros, target_profile=target_profile),
|
||||
input,
|
||||
loc=loc,
|
||||
ip=ip,
|
||||
)
|
||||
return _cute_ir.filter_zeros(input, target_profile=target_profile, loc=loc, ip=ip)
|
||||
|
||||
|
||||
@@ -3388,7 +3393,7 @@ def filter(
|
||||
input.inner, input.offset, filter(input.outer, loc=loc, ip=ip)
|
||||
)
|
||||
elif isinstance(input, _Tensor):
|
||||
return _cute_ir.filter(input.value, loc=loc, ip=ip)
|
||||
return _op_wrapper(_cute_ir.filter, input, loc=loc, ip=ip)
|
||||
else:
|
||||
return _cute_ir.filter(input, loc=loc, ip=ip)
|
||||
|
||||
@@ -5020,10 +5025,9 @@ def local_partition(
|
||||
raise NotImplementedError(
|
||||
f"Index value should be 32-bit or smaller integer type, but got {index_val.type}"
|
||||
)
|
||||
return _cute_ir.local_partition(
|
||||
input=target.value,
|
||||
tiler=dice(tiler, proj),
|
||||
index=index_val,
|
||||
return _op_wrapper(
|
||||
partial(_cute_ir.local_partition, tiler=dice(tiler, proj), index=index_val),
|
||||
target,
|
||||
loc=loc,
|
||||
ip=ip,
|
||||
)
|
||||
@@ -5114,11 +5118,9 @@ def local_tile(
|
||||
proj_val = _pack_coord(proj, loc=loc, ip=ip)
|
||||
proj = proj_val.type.attribute
|
||||
|
||||
return _cute_ir.local_tile(
|
||||
input=input.value,
|
||||
tile=tiler_val,
|
||||
coord=coord_val,
|
||||
proj=proj,
|
||||
return _op_wrapper(
|
||||
partial(_cute_ir.local_tile, tile=tiler_val, coord=coord_val, proj=proj),
|
||||
input,
|
||||
loc=loc,
|
||||
ip=ip,
|
||||
)
|
||||
|
||||
@@ -21,6 +21,9 @@ __all__ = [
|
||||
"MmaFP8Op",
|
||||
"MmaMXF4Op",
|
||||
"MmaMXF4NVF4Op",
|
||||
"MmaMXF8Op",
|
||||
"MmaMXF8F6F4Op",
|
||||
"MXF8F6F4_SUPPORTED_PAIRS",
|
||||
# copy.py
|
||||
"LdMatrix8x8x16bOp",
|
||||
"LdMatrix16x8x8bOp",
|
||||
|
||||
@@ -224,7 +224,9 @@ class MmaSM120BlockScaledOp(MmaOp):
|
||||
|
||||
admissible_archs = [
|
||||
Arch.sm_120a,
|
||||
Arch.sm_120f,
|
||||
Arch.sm_121a,
|
||||
Arch.sm_121f,
|
||||
]
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
@@ -239,29 +241,44 @@ class MmaSM120BlockScaledOp(MmaOp):
|
||||
"CUTE_DSL_ARCH set to sm_120a or sm_121a",
|
||||
suggestion="Ensure env CUTE_DSL_ARCH matches your GPU architecture",
|
||||
)
|
||||
if self.ab_dtype != Float4E2M1FN:
|
||||
# (ab_dtype, shape_mnk) consistency: FP4 uses (16,8,64); FP8 uses (16,8,32).
|
||||
if self.ab_dtype == Float4E2M1FN:
|
||||
if self.shape_mnk != (16, 8, 64):
|
||||
raise OpError(
|
||||
self,
|
||||
"expects the 'shape_mnk' Op parameter to be (16,8,64) for Float4E2M1FN",
|
||||
)
|
||||
elif self.ab_dtype in (Float8E4M3FN, Float8E5M2):
|
||||
if self.shape_mnk != (16, 8, 32):
|
||||
raise OpError(
|
||||
self,
|
||||
"expects the 'shape_mnk' Op parameter to be (16,8,32) for Float8E4M3FN/Float8E5M2",
|
||||
)
|
||||
else:
|
||||
raise OpError(
|
||||
self,
|
||||
"expects the 'ab_dtype' Op parameter to be Float4E2M1FN",
|
||||
"expects the 'ab_dtype' Op parameter to be Float4E2M1FN, Float8E4M3FN, or Float8E5M2",
|
||||
)
|
||||
if self.acc_dtype != Float32:
|
||||
raise OpError(
|
||||
self,
|
||||
"expects the 'acc_dtype' Op parameter to be Float32",
|
||||
)
|
||||
if self.shape_mnk != (16, 8, 64):
|
||||
raise OpError(
|
||||
self,
|
||||
"expects the 'shape_mnk' Op parameter to be (16,8,64)",
|
||||
)
|
||||
|
||||
if self.sf_vec_size == 16:
|
||||
# vec_size=16 is only valid for FP4 (NVFP4) with E4M3 scale.
|
||||
if self.ab_dtype != Float4E2M1FN:
|
||||
raise OpError(
|
||||
self,
|
||||
"expects the 'sf_vec_size' Op parameter to be 32 for Float8E4M3FN/Float8E5M2",
|
||||
)
|
||||
if self.sf_type != Float8E4M3FN:
|
||||
raise OpError(
|
||||
self,
|
||||
"expects the 'sf_type' Op parameter to be Float8E4M3FN",
|
||||
)
|
||||
elif self.sf_vec_size == 32:
|
||||
# vec_size=32 path uses UE8M0 scale for both FP4 (MXF4) and FP8 (MXF8).
|
||||
if self.sf_type != Float8E8M0FNU:
|
||||
raise OpError(
|
||||
self,
|
||||
@@ -275,7 +292,7 @@ class MmaSM120BlockScaledOp(MmaOp):
|
||||
|
||||
def __str__(self) -> str:
|
||||
return (
|
||||
"warp-level MXF4/MXF4NVF4 MMA Operation"
|
||||
"warp-level MXF4/MXF4NVF4/MXF8 MMA Operation"
|
||||
+ f"\n A/B data type = {self.ab_dtype}"
|
||||
+ f"\n Accumulator data type = {self.acc_dtype}"
|
||||
+ f"\n Instruction shape MNK = {self.shape_mnk}"
|
||||
@@ -474,3 +491,214 @@ class MmaMXF4NVF4Op(MmaSM120BlockScaledOp):
|
||||
|
||||
class MmaMXF4NVF4Trait(MmaBlockScaledTrait):
|
||||
pass
|
||||
|
||||
|
||||
#
|
||||
# MXF8 MMA
|
||||
#
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class MmaMXF8Op(MmaSM120BlockScaledOp):
|
||||
"""
|
||||
MXF8 warp-level MMA Operation.
|
||||
|
||||
See the `PTX documentation <https://docs.nvidia.com/cuda/parallel-thread-execution/#warp-level-matrix-instructions-mma>`__.
|
||||
This Operation covers the instructions using the ``.e4m3`` / ``.e5m2`` qualifiers for the input operands.
|
||||
.kind = {.kind::mxf8};
|
||||
.scale_vec_size = {.scale_vec::1X};
|
||||
.stype = {.ue8m0};
|
||||
"""
|
||||
|
||||
descriptive_name = "warp-level MXF8 MMA Operation"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
ab_dtype: Type[Numeric],
|
||||
acc_dtype: Type[Numeric],
|
||||
sf_type: Type[Numeric],
|
||||
) -> None:
|
||||
super().__init__(
|
||||
ab_dtype,
|
||||
acc_dtype,
|
||||
(16, 8, 32),
|
||||
sf_type,
|
||||
32,
|
||||
)
|
||||
|
||||
def _make_trait(
|
||||
self,
|
||||
*,
|
||||
loc: Optional[ir.Location] = None,
|
||||
ip: Optional[ir.InsertionPoint] = None,
|
||||
**kwargs: Any,
|
||||
) -> "MmaMXF8Trait":
|
||||
shape_mnk = _pack_shape(self.shape_mnk, loc=loc, ip=ip)
|
||||
ty = _cute_nvgpu_ir.MmaAtomSM120BlockScaledType.get(
|
||||
shape_mnk.type.attribute,
|
||||
32,
|
||||
False,
|
||||
self.ab_dtype.mlir_type,
|
||||
self.ab_dtype.mlir_type,
|
||||
self.acc_dtype.mlir_type,
|
||||
self.sf_type.mlir_type,
|
||||
)
|
||||
return MmaMXF8Trait(make_atom(ty, loc=loc, ip=ip))
|
||||
|
||||
|
||||
class MmaMXF8Trait(MmaBlockScaledTrait):
|
||||
pass
|
||||
|
||||
|
||||
#
|
||||
# MXF8F6F4 mixed-precision MMA (independent A/B dtypes)
|
||||
#
|
||||
|
||||
|
||||
MXF8F6F4_SUPPORTED_PAIRS = frozenset(
|
||||
{
|
||||
(Float4E2M1FN, Float8E4M3FN),
|
||||
(Float4E2M1FN, Float8E5M2),
|
||||
(Float8E4M3FN, Float4E2M1FN),
|
||||
(Float8E5M2, Float4E2M1FN),
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class MmaMXF8F6F4Op(MmaOp):
|
||||
"""
|
||||
SM120 MXF8F6F4 mixed-precision warp-level block-scaled MMA Operation.
|
||||
|
||||
Covers the PTX instructions using independent ``.<a_type>.<b_type>``
|
||||
qualifiers (one of e2m1.e4m3, e2m1.e5m2, e4m3.e2m1, e5m2.e2m1):
|
||||
|
||||
.kind = {.kind::mxf8f6f4};
|
||||
.scale_vec_size = {.scale_vec::1X};
|
||||
.stype = {.ue8m0};
|
||||
|
||||
A and B operand dtypes are independent. Same-dtype FP4/FP4 and FP8/FP8
|
||||
paths remain on ``MmaMXF4Op`` / ``MmaMXF4NVF4Op`` / ``MmaMXF8Op``
|
||||
respectively. Same-width mixed-FP8 (E4M3 + E5M2) and FP6 mixed pairs
|
||||
are not supported.
|
||||
"""
|
||||
|
||||
a_dtype: Type[Numeric]
|
||||
b_dtype: Type[Numeric]
|
||||
acc_dtype: Type[Numeric]
|
||||
sf_type: Type[Numeric]
|
||||
|
||||
descriptive_name = "warp-level MXF8F6F4 mixed-precision MMA Operation"
|
||||
|
||||
shape_mnk = (16, 8, 32)
|
||||
sf_vec_size = 32
|
||||
use_sf_layout_TV = False
|
||||
|
||||
admissible_archs = [
|
||||
Arch.sm_120a,
|
||||
Arch.sm_121a,
|
||||
]
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
# Verify arch
|
||||
arch = BaseDSL._get_dsl().get_arch_enum()
|
||||
if arch not in self.admissible_archs:
|
||||
raise OpError(
|
||||
self,
|
||||
f"expects arch to be one of {self.admissible_archs}, but got {arch}",
|
||||
suggestion="Ensure env CUTE_DSL_ARCH matches your GPU architecture",
|
||||
)
|
||||
if self.acc_dtype != Float32:
|
||||
raise OpError(
|
||||
self,
|
||||
"expects the 'acc_dtype' Op parameter to be Float32",
|
||||
)
|
||||
if self.sf_type != Float8E8M0FNU:
|
||||
raise OpError(
|
||||
self,
|
||||
"expects the 'sf_type' Op parameter to be Float8E8M0FNU",
|
||||
)
|
||||
# Reject same-dtype pairs explicitly (route to dedicated ops).
|
||||
if self.a_dtype == self.b_dtype:
|
||||
if self.a_dtype == Float4E2M1FN:
|
||||
raise OpError(
|
||||
self,
|
||||
"same-dtype Float4E2M1FN/Float4E2M1FN is not supported by MmaMXF8F6F4Op; "
|
||||
"use MmaMXF4Op (sf_vec_size=32) or MmaMXF4NVF4Op (sf_vec_size=16) instead",
|
||||
)
|
||||
if self.a_dtype in (Float8E4M3FN, Float8E5M2):
|
||||
raise OpError(
|
||||
self,
|
||||
"same-dtype FP8/FP8 is not supported by MmaMXF8F6F4Op; "
|
||||
"use MmaMXF8Op instead",
|
||||
)
|
||||
# Reject same-width mixed-FP8 (E4M3 + E5M2) explicitly.
|
||||
fp8_dtypes = (Float8E4M3FN, Float8E5M2)
|
||||
if self.a_dtype in fp8_dtypes and self.b_dtype in fp8_dtypes:
|
||||
raise OpError(
|
||||
self,
|
||||
"same-width mixed-FP8 (Float8E4M3FN + Float8E5M2) is not supported; "
|
||||
"supported MXF8F6F4 pairs are (Float4E2M1FN x Float8E4M3FN/Float8E5M2) "
|
||||
"and the reverse",
|
||||
)
|
||||
# Final allow-list check (catches FP6 and any other unsupported dtype).
|
||||
if (self.a_dtype, self.b_dtype) not in MXF8F6F4_SUPPORTED_PAIRS:
|
||||
raise OpError(
|
||||
self,
|
||||
f"unsupported (a_dtype, b_dtype) = ({self.a_dtype}, {self.b_dtype}) "
|
||||
f"for MmaMXF8F6F4Op; supported pairs are "
|
||||
f"{sorted(repr(p) for p in MXF8F6F4_SUPPORTED_PAIRS)}. "
|
||||
f"FP6 mixed pairs are not supported.",
|
||||
)
|
||||
|
||||
def __str__(self) -> str:
|
||||
return (
|
||||
"warp-level MXF8F6F4 mixed-precision MMA Operation"
|
||||
+ f"\n A data type = {self.a_dtype}"
|
||||
+ f"\n B data type = {self.b_dtype}"
|
||||
+ f"\n Accumulator data type = {self.acc_dtype}"
|
||||
+ f"\n Instruction shape MNK = {self.shape_mnk}"
|
||||
+ f"\n Vector size = {self.sf_vec_size}"
|
||||
+ f"\n SF data type = {self.sf_type}"
|
||||
)
|
||||
|
||||
def _verify_fragment_A(
|
||||
self,
|
||||
input: Tensor,
|
||||
*,
|
||||
loc: Optional[ir.Location] = None,
|
||||
ip: Optional[ir.InsertionPoint] = None,
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
def _verify_fragment_B(
|
||||
self,
|
||||
input: Tensor,
|
||||
*,
|
||||
loc: Optional[ir.Location] = None,
|
||||
ip: Optional[ir.InsertionPoint] = None,
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
def _make_trait(
|
||||
self,
|
||||
*,
|
||||
loc: Optional[ir.Location] = None,
|
||||
ip: Optional[ir.InsertionPoint] = None,
|
||||
**kwargs: Any,
|
||||
) -> "MmaMXF8F6F4Trait":
|
||||
shape_mnk = _pack_shape(self.shape_mnk, loc=loc, ip=ip)
|
||||
ty = _cute_nvgpu_ir.MmaAtomSM120BlockScaledType.get(
|
||||
shape_mnk.type.attribute,
|
||||
self.sf_vec_size,
|
||||
self.use_sf_layout_TV,
|
||||
self.a_dtype.mlir_type,
|
||||
self.b_dtype.mlir_type,
|
||||
self.acc_dtype.mlir_type,
|
||||
self.sf_type.mlir_type,
|
||||
)
|
||||
return MmaMXF8F6F4Trait(make_atom(ty, loc=loc, ip=ip))
|
||||
|
||||
|
||||
class MmaMXF8F6F4Trait(MmaBlockScaledTrait):
|
||||
pass
|
||||
|
||||
@@ -21,7 +21,11 @@ from jax._src.interpreters import batching
|
||||
|
||||
|
||||
from .compile import get_or_compile_kernel, build_function_spec
|
||||
from .types import cutlass_to_jax_layout_order, default_tensor_spec, TensorSpec
|
||||
from .types import (
|
||||
cutlass_to_jax_layout_order,
|
||||
default_tensor_spec,
|
||||
TensorSpec,
|
||||
)
|
||||
from .ffi import get_cutlass_call_ffi_name, is_ffi_registered, register_ffi
|
||||
|
||||
|
||||
@@ -77,8 +81,10 @@ def cutlass_call(
|
||||
objects with ``.shape`` and ``.dtype`` attributes) describing each
|
||||
output buffer.
|
||||
input_spec: A :class:`TensorSpec` or list thereof providing
|
||||
layout/mode/divisibility hints for input tensors. ``None`` infers
|
||||
defaults from each array.
|
||||
layout/mode/divisibility hints for input tensors. ``None`` infers
|
||||
defaults from each array. A ``TensorSpec`` with ``layout=None`` uses
|
||||
and constrains row-major physical layout; use ``mode`` to remap
|
||||
physical dimensions to the kernel's logical modes.
|
||||
output_spec: Same as *input_spec* but applied to output tensors.
|
||||
input_output_aliases: ``{input_index: output_index}`` mapping that
|
||||
allows an input buffer to alias an output, avoiding an extra copy.
|
||||
@@ -308,7 +314,9 @@ def cutlass_call_inner_p_impl(
|
||||
|
||||
call_name = get_cutlass_call_ffi_name(allow_cuda_graph)
|
||||
|
||||
# Convert layout from CuTeDSL to JAX order as ffi_call expects this.
|
||||
# Convert explicit layout constraints from CuTeDSL to JAX order. ``None`` is
|
||||
# passed through intentionally: jax.ffi.ffi_call treats it as default
|
||||
# row-major layout.
|
||||
input_layouts = [cutlass_to_jax_layout_order(s.layout) for s in input_spec_flat]
|
||||
output_layouts = [cutlass_to_jax_layout_order(s.layout) for s in output_spec_flat]
|
||||
|
||||
|
||||
@@ -15,12 +15,18 @@ import jax.numpy as jnp
|
||||
|
||||
import cutlass.cute as cute
|
||||
from cutlass.cutlass_dsl import dsl_user_op
|
||||
from typing import Optional
|
||||
from typing import Optional, Sequence
|
||||
from cutlass._mlir import ir
|
||||
|
||||
|
||||
def reorder_modes(src: str, target: str) -> tuple[int, ...]:
|
||||
"""Computes the mode given a source and target order."""
|
||||
def reorder_modes(src: Sequence[str], target: Sequence[str]) -> tuple[int, ...]:
|
||||
"""Compute a ``TensorSpec.mode`` from physical input order to kernel order.
|
||||
|
||||
``src`` names the JAX array's physical dimension order. ``target`` names the
|
||||
logical mode order that the CuTe kernel expects. The returned tuple can be
|
||||
passed as ``TensorSpec(mode=...)`` while leaving ``layout`` at its default
|
||||
row-major value when the JAX buffer is physically row-major.
|
||||
"""
|
||||
src = tuple(src)
|
||||
target = tuple(target)
|
||||
src_map = {}
|
||||
@@ -29,52 +35,64 @@ def reorder_modes(src: str, target: str) -> tuple[int, ...]:
|
||||
return tuple([src_map[d] for d in target])
|
||||
|
||||
|
||||
def gemm_a_major(d: str):
|
||||
"""Returns order for A tensor major mode."""
|
||||
def gemm_a_major(d: str) -> str:
|
||||
"""Return the physical JAX dimension order for an A tensor major mode.
|
||||
|
||||
The returned string is not the kernel's canonical logical order. Use
|
||||
:func:`gemm_a_mode` to map this physical order to kernel logical ``mkl``.
|
||||
"""
|
||||
return {"k": "lmk", "m": "lkm"}[d]
|
||||
|
||||
|
||||
def gemm_a_mode(d: str) -> tuple[int, ...]:
|
||||
"""Returns mode for A tensor major mode."""
|
||||
"""Return ``TensorSpec.mode`` for A, mapping physical order to logical ``mkl``."""
|
||||
return reorder_modes(gemm_a_major(d), "mkl")
|
||||
|
||||
|
||||
def gemm_b_major(d: str):
|
||||
"""Returns order for B tensor major mode."""
|
||||
def gemm_b_major(d: str) -> str:
|
||||
"""Return the physical JAX dimension order for a B tensor major mode.
|
||||
|
||||
The returned string is not the kernel's canonical logical order. Use
|
||||
:func:`gemm_b_mode` to map this physical order to kernel logical ``nkl``.
|
||||
"""
|
||||
return {"k": "lnk", "n": "lkn"}[d]
|
||||
|
||||
|
||||
def gemm_b_mode(d: str) -> tuple[int, ...]:
|
||||
"""Returns mode for B tensor major mode."""
|
||||
"""Return ``TensorSpec.mode`` for B, mapping physical order to logical ``nkl``."""
|
||||
return reorder_modes(gemm_b_major(d), "nkl")
|
||||
|
||||
|
||||
def gemm_c_major(d: str):
|
||||
"""Returns order for C tensor major mode."""
|
||||
def gemm_c_major(d: str) -> str:
|
||||
"""Return the physical JAX dimension order for a C/D tensor major mode.
|
||||
|
||||
The returned string is not the kernel's canonical logical order. Use
|
||||
:func:`gemm_c_mode` to map this physical order to kernel logical ``mnl``.
|
||||
"""
|
||||
return {"n": "lmn", "m": "lnm"}[d]
|
||||
|
||||
|
||||
def gemm_c_mode(d: str) -> tuple[int, ...]:
|
||||
"""Returns mode for C tensor major mode."""
|
||||
"""Return ``TensorSpec.mode`` for C/D, mapping physical order to logical ``mnl``."""
|
||||
return reorder_modes(gemm_c_major(d), "mnl")
|
||||
|
||||
|
||||
def gemm_a_shape(l, m, k, major) -> tuple[int, ...]:
|
||||
"""Returns shape for A tensor given major mode."""
|
||||
def gemm_a_shape(l: int, m: int, k: int, major: str) -> tuple[int, ...]:
|
||||
"""Return the physical row-major JAX shape for A with the requested major mode."""
|
||||
assert major in ("k", "m")
|
||||
shape = (l, m, k) if major == "k" else (l, k, m)
|
||||
return shape
|
||||
|
||||
|
||||
def gemm_b_shape(l, n, k, major) -> tuple[int, ...]:
|
||||
"""Returns shape for B tensor given major mode."""
|
||||
def gemm_b_shape(l: int, n: int, k: int, major: str) -> tuple[int, ...]:
|
||||
"""Return the physical row-major JAX shape for B with the requested major mode."""
|
||||
assert major in ("k", "n")
|
||||
shape = (l, n, k) if major == "k" else (l, k, n)
|
||||
return shape
|
||||
|
||||
|
||||
def gemm_c_shape(l, m, n, major) -> tuple[int, ...]:
|
||||
"""Returns shape for C tensor given major mode."""
|
||||
def gemm_c_shape(l: int, m: int, n: int, major: str) -> tuple[int, ...]:
|
||||
"""Return the physical row-major JAX shape for C/D with the requested major mode."""
|
||||
assert major in ("m", "n")
|
||||
shape = (l, m, n) if major == "n" else (l, n, m)
|
||||
return shape
|
||||
|
||||
@@ -9,7 +9,7 @@
|
||||
# and related documentation outside the scope permitted by the EULA
|
||||
# is strictly prohibited.
|
||||
|
||||
from typing import Optional, Sequence
|
||||
from typing import Any, Optional, Sequence
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
|
||||
@@ -18,6 +18,7 @@ import jax.numpy as jnp
|
||||
|
||||
import cutlass
|
||||
import cutlass.cute as cute
|
||||
from cutlass.cute.core import IntValue
|
||||
from cutlass.cute.runtime import from_dlpack as _from_dlpack
|
||||
from cutlass.cute import AddressSpace
|
||||
from cutlass._mlir import ir
|
||||
@@ -58,35 +59,69 @@ DEFAULT_CUTLASS_DEVICE_BUFFER_ALIGNMENT = 256
|
||||
class TensorSpec:
|
||||
"""Specifies the layout and metadata for a JAX array passed to a CuTe kernel.
|
||||
|
||||
TensorSpec controls how a JAX array's dimensions are mapped to a cute.Tensor
|
||||
during jit lowering, including stride ordering, mode permutation, and whether
|
||||
shapes/strides are compiled as static constants.
|
||||
TensorSpec controls how a JAX array's input dimensions are mapped to a
|
||||
``cute.Tensor`` during jit lowering, including compact stride ordering,
|
||||
mode permutation, and whether shapes/strides are compiled as static
|
||||
constants. The JAX bridge models tensors as compact layouts: runtime
|
||||
strides are derived from runtime shapes using ``layout`` order rather than
|
||||
loaded from a strided view descriptor.
|
||||
|
||||
A useful way to choose a spec is to separate physical storage from logical
|
||||
kernel modes:
|
||||
|
||||
1. First choose the public JAX array shape and its compact physical memory
|
||||
order. If the buffer is a standard row-major JAX array, leave
|
||||
``layout=None``. ``cutlass_call`` will constrain the FFI operand/result
|
||||
to row-major physical layout, matching the CuTe tensor strides that are
|
||||
built from the default.
|
||||
2. Then use ``mode`` only when the kernel should see those input dimensions
|
||||
in a different logical order. ``mode`` is applied after the compact
|
||||
layout is built; it is not a request for JAX/XLA to transpose data.
|
||||
|
||||
For example, a row-major JAX buffer shaped ``(expert_count, N, K)`` can be
|
||||
presented to a kernel expecting logical ``(N, K, expert_count)`` with
|
||||
``TensorSpec(mode=(1, 2, 0))``. No explicit ``layout`` is needed because the
|
||||
physical buffer is still ordinary row-major, and the FFI call will be
|
||||
constrained accordingly. Use ``layout`` only when the compact physical
|
||||
stride order itself differs from the default row-major order, such as a
|
||||
column-major compact buffer.
|
||||
|
||||
Attributes:
|
||||
layout: A minor-to-major stride ordering in CuTeDSL convention. ``layout[i]``
|
||||
gives the stride rank of dimension ``i``, where rank 0 means the smallest
|
||||
(innermost) stride. For example, row-major order for a 3-D tensor is
|
||||
``(2, 1, 0)``. If ``None``, row-major is assumed. Use
|
||||
:func:`jax_to_cutlass_layout_order` to convert from JAX's major-to-minor
|
||||
convention.
|
||||
mode: A permutation that maps the stride-ordered dimensions to the mode
|
||||
positions of the resulting ``cute.Layout``. For example, ``mode=(2, 0, 1)``
|
||||
reorders an ``(M, K, L)`` layout into ``(K, L, M)`` mode order inside the
|
||||
kernel. If ``None``, modes match the natural dimension order ``(0, 1, ..., N-1)``.
|
||||
gives the compact physical stride rank of input dimension ``i``,
|
||||
where rank 0 means the smallest (innermost) stride. For example,
|
||||
row-major order for a 3-D tensor is ``(2, 1, 0)``. If ``None``,
|
||||
row-major is assumed. Use :func:`jax_to_cutlass_layout_order` to
|
||||
convert from JAX's major-to-minor convention. ``layout`` does not
|
||||
change which logical mode a dimension represents; combine it with
|
||||
``mode`` when physical order and kernel-logical order differ.
|
||||
mode: A permutation applied after the compact layout is constructed. It
|
||||
selects input dimensions into the mode positions seen by the kernel.
|
||||
For example, ``mode=(2, 0, 1)`` presents an input shaped
|
||||
``(M, K, L)`` to the kernel as logical ``(L, M, K)``. If ``None``,
|
||||
modes match the natural input-dimension order ``(0, 1, ..., N-1)``.
|
||||
``mode`` changes the tensor layout object seen by CuTe code but
|
||||
does not materialize a transpose or change the underlying buffer.
|
||||
static: If ``True``, shapes and strides are compiled as static ``constexpr``
|
||||
values, which may enable additional compiler optimisations. Kernels that
|
||||
do not support static shapes will raise a compile error. Must be ``False``
|
||||
when any dimension is symbolic (e.g. under ``jax.export``).
|
||||
ptr_assumed_align: Assumed byte alignment of the tensor's data pointer.
|
||||
Overrides the default of 256 bytes. Rarely needs to change.
|
||||
divisibility: Optional per-mode divisibility hints. If a single int is passed
|
||||
divisibility will be applied to the leading (stride=1) dimension only.
|
||||
divisibility: Optional divisibility hints for input dimensions, in the
|
||||
same order as the JAX array shape and before any ``mode`` reordering.
|
||||
Positive hints constrain dynamic shape values and are propagated
|
||||
through compact stride construction: a stride inherits the product
|
||||
of the divisibilities for dimensions with lower stride rank.
|
||||
Positive explicit hints take precedence over inferred concrete
|
||||
extents. If a single int is passed, it is applied to the leading
|
||||
compact dimension only, where ``layout[i] == 0``.
|
||||
"""
|
||||
|
||||
# Minor-to-major stride ordering in CuTeDSL convention (layout[i] = stride rank
|
||||
# of dimension i, 0 = innermost). Defaults to row-major if None.
|
||||
layout: tuple[int, ...] | None = field(metadata=dict(static=True), default=None)
|
||||
# Permutation from stride-ordered dimensions to cute.Layout mode positions.
|
||||
# Permutation from input dimensions to cute.Layout mode positions.
|
||||
# Defaults to identity (0, 1, ..., N-1) if None.
|
||||
mode: tuple[int, ...] | None = field(metadata=dict(static=True), default=None)
|
||||
# If True, shapes and strides are embedded as compile-time constants.
|
||||
@@ -96,7 +131,7 @@ class TensorSpec:
|
||||
ptr_assumed_align: int = field(
|
||||
metadata=dict(static=True), default=DEFAULT_CUTLASS_DEVICE_BUFFER_ALIGNMENT
|
||||
)
|
||||
# Per-mode divisibility hints.
|
||||
# Per-input-dimension divisibility hints, before mode reordering.
|
||||
divisibility: tuple[int | None, ...] | int | None = field(
|
||||
metadata=dict(static=True), default=None
|
||||
)
|
||||
@@ -128,9 +163,10 @@ def row_major_layout(shaped):
|
||||
def default_tensor_mode(shaped):
|
||||
"""Returns the identity mode permutation for an N-dimensional tensor.
|
||||
|
||||
The mode permutation maps stride-ordered dimensions to ``cute.Layout`` mode
|
||||
positions. The default identity ``(0, 1, ..., N-1)`` leaves the mode order
|
||||
unchanged relative to the dimension order.
|
||||
The mode permutation maps JAX input dimensions to ``cute.Layout`` mode
|
||||
positions after the compact layout has been constructed. The default
|
||||
identity ``(0, 1, ..., N-1)`` leaves the mode order unchanged relative to
|
||||
the JAX shape order.
|
||||
|
||||
Args:
|
||||
shaped: An object with a ``.shape`` attribute, or a shape tuple/sequence.
|
||||
@@ -151,12 +187,22 @@ def default_tensor_spec(shaped) -> TensorSpec:
|
||||
TensorSpec(layout=(N-1, ..., 1, 0), mode=(0, 1, ..., N-1), divisibility=(D0, D1, ... DN-1))
|
||||
|
||||
This is appropriate for standard row-major (C-contiguous) JAX arrays that
|
||||
do not require dimension reordering inside the kernel.
|
||||
do not require dimension reordering inside the kernel. The resulting JAX
|
||||
CuTe tensor is treated as compact: strides are derived from shapes using the
|
||||
row-major layout order.
|
||||
|
||||
Divisibility hints are inferred only for concrete integer dimensions.
|
||||
Symbolic dimensions always produce ``None`` for their slot; pass an
|
||||
explicit ``TensorSpec`` with ``divisibility`` set if you need alignment
|
||||
hints for symbolic shapes.
|
||||
If the JAX buffer is row-major but the kernel expects a different logical
|
||||
mode order, use an explicit :class:`TensorSpec` with ``mode`` set and leave
|
||||
``layout`` unset. ``cutlass_call`` still constrains the FFI buffer to
|
||||
row-major layout in this case. For example, ``TensorSpec(mode=(1, 2, 0))``
|
||||
maps a physical ``(L, M, K)`` row-major input to a logical ``(M, K, L)``
|
||||
tensor.
|
||||
|
||||
Divisibility hints are inferred only for concrete integer input dimensions.
|
||||
Symbolic dimensions always produce ``None`` for their slot; pass an explicit
|
||||
``TensorSpec`` with ``divisibility`` set if you need alignment hints for
|
||||
symbolic shapes or want a weaker explicit constraint than the concrete
|
||||
extent.
|
||||
|
||||
Args:
|
||||
shaped: An object with a ``.shape`` attribute, or a shape tuple/sequence.
|
||||
@@ -179,11 +225,12 @@ def default_tensor_spec(shaped) -> TensorSpec:
|
||||
def _expand_divisibility(
|
||||
divisibility, order: tuple[int, ...], ndim: int
|
||||
) -> tuple[int | None, ...] | None:
|
||||
"""Expand a divisibility spec to a full per-dimension tuple.
|
||||
"""Expand a divisibility spec to a full per-input-dimension tuple.
|
||||
|
||||
A bare ``int`` is placed at the leading-dimension slot (where
|
||||
``order[i] == 0``, i.e. stride == 1) and ``None`` everywhere else.
|
||||
A tuple is returned unchanged. ``None`` returns ``None``.
|
||||
A tuple is already in JAX input-dimension order and is returned unchanged.
|
||||
``None`` returns ``None``.
|
||||
"""
|
||||
if divisibility is None or isinstance(divisibility, tuple):
|
||||
return divisibility
|
||||
@@ -268,7 +315,20 @@ def from_dlpack(array, assumed_align: int = DEFAULT_CUTLASS_DEVICE_BUFFER_ALIGNM
|
||||
return _from_dlpack(array, assumed_align=assumed_align)
|
||||
|
||||
|
||||
def _validate_permutation(name: str, perm, shape):
|
||||
def _assume_divisible_int(
|
||||
value: Any,
|
||||
divby: int,
|
||||
*,
|
||||
loc: ir.Location | None = None,
|
||||
ip: ir.InsertionPoint | None = None,
|
||||
) -> Any:
|
||||
"""Attach a divisibility assumption to an integer value without narrowing it."""
|
||||
if divby <= 1:
|
||||
return value
|
||||
return cute.assume(IntValue(value, loc=loc, ip=ip), divby=divby, loc=loc, ip=ip)
|
||||
|
||||
|
||||
def _validate_permutation(name: str, perm: Sequence[int], shape: Sequence[Any]) -> None:
|
||||
if len(perm) != len(shape):
|
||||
raise ValueError(f"{name} must be same length as shape", perm, shape)
|
||||
for s in perm:
|
||||
@@ -292,9 +352,11 @@ class JaxArray:
|
||||
can be concrete or symbolic in the case of jax.export.
|
||||
3. mem_space: The memory space of the tensor. Defaults to gmem.
|
||||
4. assumed_align: The alignment of the tensor. Defaults to XLA alignment.
|
||||
5. order: Specifies the order of the shape to determine strides.
|
||||
6. mode: Specifies how to map ordered elements to the modes od a cute.Layout.
|
||||
5. order: Specifies the compact physical stride order of the shape.
|
||||
6. mode: Specifies how to map input dimensions to the logical modes seen by
|
||||
the kernel after the compact layout is constructed.
|
||||
7. static: If True, tensor shapes and strides are compiled statically.
|
||||
8. divisibility: Optional divisibility hints in input-dimension order.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -381,6 +443,21 @@ class JaxArrayValue(JaxArray):
|
||||
ip: Optional[ir.InsertionPoint] = None,
|
||||
):
|
||||
i32 = ir.IntegerType.get_signless(32)
|
||||
|
||||
# Track the divisibility available for each input dimension. Explicit
|
||||
# hints win; otherwise concrete dimensions contribute their known extent.
|
||||
dim_divisibility = None
|
||||
if self.divisibility is not None:
|
||||
dim_divisibility = []
|
||||
for div_spec, static_s in zip(self.divisibility, self.shape):
|
||||
if div_spec is not None and div_spec > 0:
|
||||
dim_divisibility.append(div_spec)
|
||||
elif isinstance(static_s, int):
|
||||
dim_divisibility.append(static_s)
|
||||
else:
|
||||
dim_divisibility.append(1)
|
||||
dim_divisibility = tuple(dim_divisibility)
|
||||
|
||||
pairs = sorted(zip(shape, order), key=lambda x: x[1])
|
||||
|
||||
# Compute strides for each element in order.
|
||||
@@ -395,28 +472,29 @@ class JaxArrayValue(JaxArray):
|
||||
for i in range(len(shape)):
|
||||
strides_ordered.append(strides[order[i]])
|
||||
|
||||
if dim_divisibility is not None:
|
||||
# A compact stride is the product of all dimensions with a lower
|
||||
# stride order, so it inherits the product of their divisibility.
|
||||
stride_divisibility = []
|
||||
for dim_order in order:
|
||||
divby = 1
|
||||
for other_dim, other_order in enumerate(order):
|
||||
if other_order < dim_order:
|
||||
divby *= dim_divisibility[other_dim]
|
||||
stride_divisibility.append(divby)
|
||||
|
||||
strides_ordered = [
|
||||
_assume_divisible_int(s, divby, loc=loc, ip=ip)
|
||||
for s, divby in zip(strides_ordered, stride_divisibility)
|
||||
]
|
||||
|
||||
# Shapes are expected to be int32 so truncate to that before creating layout
|
||||
shape_i32 = tuple(arith.trunci(i32, s) for s in shape)
|
||||
|
||||
# Apply per-mode divisibility assumptions so the compiler can exploit alignment.
|
||||
if self.divisibility is not None:
|
||||
assumed = []
|
||||
for s32, div_spec, static_s in zip(
|
||||
shape_i32, self.divisibility, self.shape
|
||||
):
|
||||
if isinstance(static_s, int):
|
||||
# Pure static shape is known even though a dynamic shape is
|
||||
# used. We can assume the exact shape here. We keep the shape
|
||||
# as a dynamic value to avoid breaking code that may expect
|
||||
# a dynamic value.
|
||||
assumed.append(cute.assume(s32, divby=static_s))
|
||||
elif div_spec is not None:
|
||||
# Using a dynamic value so apply the div_spec if its provided.
|
||||
assumed.append(cute.assume(s32, divby=div_spec))
|
||||
else:
|
||||
# No divisibility specification for this shape
|
||||
assumed.append(s32)
|
||||
shape_i32 = tuple(assumed)
|
||||
if dim_divisibility is not None:
|
||||
shape_i32 = tuple(
|
||||
_assume_divisible_int(s, divby, loc=loc, ip=ip)
|
||||
for s, divby in zip(shape_i32, dim_divisibility)
|
||||
)
|
||||
|
||||
return cute.make_layout(shape_i32, stride=tuple(strides_ordered))
|
||||
|
||||
|
||||
@@ -84,6 +84,8 @@ from .tmem_allocator import (
|
||||
|
||||
from .layout import LayoutEnum
|
||||
|
||||
from .block import block_copy
|
||||
|
||||
from .mixed_input_helpers import (
|
||||
TransformMode,
|
||||
scale_tma_partition,
|
||||
@@ -176,6 +178,7 @@ __all__ = [
|
||||
"sm90",
|
||||
"sm100",
|
||||
"gemm",
|
||||
"block_copy",
|
||||
"ClcDynamicPersistentTileSchedulerParams",
|
||||
"ClcDynamicPersistentTileScheduler",
|
||||
"print_latex",
|
||||
|
||||
@@ -612,7 +612,7 @@ def get_tmem_load_op(
|
||||
def get_smem_layout_atom_ab(
|
||||
major_mode: OperandMajorMode,
|
||||
element_type: Type[Numeric],
|
||||
smem_shape_mn_k: Tuple[int, int],
|
||||
smem_shape_mn_k: cute.Tile,
|
||||
*,
|
||||
loc: Optional[ir.Location] = None,
|
||||
ip: Optional[ir.InsertionPoint] = None,
|
||||
@@ -625,13 +625,16 @@ def get_smem_layout_atom_ab(
|
||||
:param element_type: The element type for the SMEM tensor.
|
||||
:type element_type: Type[Numeric]
|
||||
:param smem_shape_mn_k: The shape of the SMEM tensor.
|
||||
:type smem_shape_mn_k: Tuple[int, int]
|
||||
:type smem_shape_mn_k: cute.Tile
|
||||
:return: The SMEM layout atom kind
|
||||
:rtype: cutlass.cute.nvgpu.tcgen05.SmemLayoutAtomKind
|
||||
"""
|
||||
is_k_major = major_mode == OperandMajorMode.K
|
||||
major_mode_size = smem_shape_mn_k[1] if is_k_major else smem_shape_mn_k[0]
|
||||
|
||||
major_mode_size = (
|
||||
cute.size(smem_shape_mn_k, mode=[1])
|
||||
if is_k_major
|
||||
else cute.size(smem_shape_mn_k, mode=[0])
|
||||
)
|
||||
assert major_mode_size % 8 == 0
|
||||
sw128_num_contiguous_bits = 1024
|
||||
sw64_num_contiguous_bits = 512
|
||||
@@ -711,6 +714,7 @@ def make_smem_layout(
|
||||
cute.append(smem_tile_shape, num_stages),
|
||||
order=(0, 1, 2) if is_k_major else (1, 0, 2),
|
||||
)
|
||||
|
||||
return cute.coalesce(smem_layout, target_profile=(1, 1, 1), loc=loc, ip=ip)
|
||||
|
||||
|
||||
@@ -1956,12 +1960,35 @@ def thrfrg_SFA(
|
||||
"""Thread-fragment scale factor A tensor for SM120 block-scaled MMA.
|
||||
|
||||
Implements the ThrFrg partitioning for scale factor A according to the
|
||||
corresponding C++ code.
|
||||
corresponding C++ code in cutlass/include/cute/atom/mma_traits_sm120.hpp:
|
||||
SFALayout for SM120 MXF4 16x8x64 uses K=64, SM120 MXF8F6F4 16x8x32 uses
|
||||
K=32; the stride pattern ``((_8,_0,_1), _16)`` is shared.
|
||||
"""
|
||||
assert cute.rank(sfa_tensor) >= 2
|
||||
|
||||
atom_shape_mnk = tiled_mma.shape_mnk
|
||||
atom_sfa_layout = cute.make_layout(shape=((2, 2, 8), 64), stride=((8, 0, 1), 16))
|
||||
# K-dim of the warp-MMA atom: FP4 -> 64, FP8 -> 32 (per mma_traits_sm120.hpp).
|
||||
# For FP8 (atom_K=32) where mma_nsf=1, wrap K in a 2-tuple ``(atom_K, 1)``
|
||||
# so the layout's K mode keeps its 2D structure and the resulting fragment
|
||||
# has the same rank as the FP4 path. For FP4 (atom_K=64) the original 1D
|
||||
# layout already produces a 2D K decomposition through SMEM-layout
|
||||
# composition, so we keep the original shape.
|
||||
atom_K = atom_shape_mnk[2]
|
||||
if atom_K == 32:
|
||||
atom_sfa_layout = cute.make_layout(
|
||||
shape=((2, 2, 8), (atom_K, 1)),
|
||||
stride=((8, 0, 1), (16, 0)),
|
||||
)
|
||||
elif atom_K == 64:
|
||||
atom_sfa_layout = cute.make_layout(
|
||||
shape=((2, 2, 8), atom_K),
|
||||
stride=((8, 0, 1), 16),
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"thrfrg_SFA: unsupported atom_K={atom_K}; SM120 block-scaled atoms "
|
||||
f"use atom_K=32 (mxf8/mxf8f6f4) or atom_K=64 (mxf4/mxf4nvf4)"
|
||||
)
|
||||
permutation_mnk = tiled_mma.permutation_mnk
|
||||
thr_layout_vmnk = tiled_mma.thr_layout_vmnk
|
||||
|
||||
@@ -2000,12 +2027,32 @@ def thrfrg_SFB(
|
||||
"""Thread-fragment scale factor B tensor for SM120 block-scaled MMA.
|
||||
|
||||
Implements the ThrFrg partitioning for scale factor B according to the
|
||||
corresponding C++ code.
|
||||
corresponding C++ code in cutlass/include/cute/atom/mma_traits_sm120.hpp:
|
||||
SFBLayout for SM120 MXF4 16x8x64 uses K=64, SM120 MXF8F6F4 16x8x32 uses
|
||||
K=32; the stride pattern ``((_0,_1), _8)`` is shared.
|
||||
"""
|
||||
assert cute.rank(sfb_tensor) >= 2
|
||||
|
||||
atom_shape_mnk = tiled_mma.shape_mnk
|
||||
atom_sfb_layout = cute.make_layout(shape=((4, 8), 64), stride=((0, 1), 8))
|
||||
# K-dim of the warp-MMA atom: FP4 -> 64, FP8 -> 32 (per mma_traits_sm120.hpp).
|
||||
# See :func:`thrfrg_SFA` for the rationale behind the FP8-only
|
||||
# ``(atom_K, 1)`` wrapping.
|
||||
atom_K = atom_shape_mnk[2]
|
||||
if atom_K == 32:
|
||||
atom_sfb_layout = cute.make_layout(
|
||||
shape=((4, 8), (atom_K, 1)),
|
||||
stride=((0, 1), (8, 0)),
|
||||
)
|
||||
elif atom_K == 64:
|
||||
atom_sfb_layout = cute.make_layout(
|
||||
shape=((4, 8), atom_K),
|
||||
stride=((0, 1), 8),
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"thrfrg_SFB: unsupported atom_K={atom_K}; SM120 block-scaled atoms "
|
||||
f"use atom_K=32 (mxf8/mxf8f6f4) or atom_K=64 (mxf4/mxf4nvf4)"
|
||||
)
|
||||
permutation_mnk = tiled_mma.permutation_mnk
|
||||
thr_layout_vmnk = tiled_mma.thr_layout_vmnk
|
||||
|
||||
|
||||
248
python/CuTeDSL/cutlass/utils/block.py
Normal file
248
python/CuTeDSL/cutlass/utils/block.py
Normal file
@@ -0,0 +1,248 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2025 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
|
||||
#
|
||||
# Use of this software is governed by the terms and conditions of the
|
||||
# NVIDIA End User License Agreement (EULA), available at:
|
||||
# https://docs.nvidia.com/cutlass/latest/media/docs/pythonDSL/license.html
|
||||
#
|
||||
# Any use, reproduction, disclosure, or distribution of this software
|
||||
# and related documentation outside the scope permitted by the EULA
|
||||
# is strictly prohibited.
|
||||
|
||||
|
||||
from cutlass.cutlass_dsl import dsl_user_op, CuTeDSL
|
||||
|
||||
from cutlass.cute.typing import Tensor
|
||||
from cutlass.cute.core import make_layout, filter_zeros
|
||||
from cutlass.cute.atom import TiledCopy
|
||||
from cutlass.cute.algorithm import copy
|
||||
from cutlass.cute.nvgpu import tcgen05
|
||||
from cutlass.cute.nvgpu.cpasync.copy import (
|
||||
TmaCopyOp,
|
||||
CopyBulkTensorTileG2SOp,
|
||||
CopyBulkTensorTileG2SMulticastOp,
|
||||
)
|
||||
from cutlass.cute.nvgpu.cpasync.helpers import tma_partition
|
||||
from cutlass.cute.nvgpu.tcgen05.copy import _S2TCopyBase
|
||||
from typing import Any, Optional
|
||||
from cutlass._mlir import ir
|
||||
|
||||
|
||||
def _check_required_args(
|
||||
required_args: list[str], kwargs: dict, condition: bool = True
|
||||
) -> None:
|
||||
if not condition:
|
||||
return
|
||||
for arg in required_args:
|
||||
if arg not in kwargs:
|
||||
raise ValueError(f"Argument {arg} is required.")
|
||||
|
||||
|
||||
def _tma_copy_impl(
|
||||
tiled_copy: TiledCopy,
|
||||
src: Tensor,
|
||||
dst: Tensor,
|
||||
*,
|
||||
loc: Optional[ir.Location] = None,
|
||||
ip: Optional[ir.InsertionPoint] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Internal implementation for TMA-based block-level copy."""
|
||||
#
|
||||
# Handle tma_multicast argument
|
||||
#
|
||||
if "tma_multicast" in kwargs:
|
||||
if not isinstance(
|
||||
tiled_copy.op,
|
||||
(
|
||||
CopyBulkTensorTileG2SOp,
|
||||
),
|
||||
):
|
||||
raise ValueError(
|
||||
"block_copy with tma_multicast expects a non-multicast G2S TMA copy atom "
|
||||
"(CopyBulkTensorTileG2SOp) for compiler-driven multicast"
|
||||
)
|
||||
# Mark as coming from block API
|
||||
kwargs["tma_multicast"]["from_block_api"] = True
|
||||
|
||||
#
|
||||
# Check if required arguments are provided
|
||||
#
|
||||
is_bar_ptr_required = isinstance(
|
||||
tiled_copy.op,
|
||||
(
|
||||
CopyBulkTensorTileG2SOp,
|
||||
CopyBulkTensorTileG2SMulticastOp,
|
||||
),
|
||||
)
|
||||
_check_required_args(["tma_bar_ptr"], kwargs, is_bar_ptr_required)
|
||||
|
||||
#
|
||||
# TMA bulk tensor copies: partition via tma_partition
|
||||
#
|
||||
is_g2s = isinstance(
|
||||
tiled_copy.op,
|
||||
(
|
||||
CopyBulkTensorTileG2SOp,
|
||||
),
|
||||
)
|
||||
stensor = dst if is_g2s else src
|
||||
gtensor = src if is_g2s else dst
|
||||
cta_coord = 0
|
||||
cta_layout = make_layout(1, loc=loc, ip=ip)
|
||||
s_ptn, g_ptn = tma_partition(
|
||||
tiled_copy, cta_coord, cta_layout, stensor, gtensor, loc=loc, ip=ip
|
||||
)
|
||||
|
||||
s_ptn = filter_zeros(s_ptn)
|
||||
g_ptn = filter_zeros(g_ptn)
|
||||
|
||||
src_arg = g_ptn if is_g2s else s_ptn
|
||||
dst_arg = s_ptn if is_g2s else g_ptn
|
||||
return copy(tiled_copy, src_arg, dst_arg, loc=loc, ip=ip, **kwargs)
|
||||
|
||||
|
||||
def _utccp_copy_impl(
|
||||
tiled_copy: TiledCopy,
|
||||
src: Tensor,
|
||||
dst: Tensor,
|
||||
*,
|
||||
loc: Optional[ir.Location] = None,
|
||||
ip: Optional[ir.InsertionPoint] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Internal implementation for S2T (SMEM to TMEM) copy operations.
|
||||
|
||||
This function abstracts the S2T copy pattern which involves:
|
||||
1. Filtering zeros from src (smem) and dst (tmem) tensors
|
||||
2. Creating a tiled copy using make_s2t_copy
|
||||
3. Partitioning source and destination
|
||||
4. Getting the SMEM descriptor tensor
|
||||
5. Executing the copy
|
||||
|
||||
:param tiled_copy: The tiled copy for S2T operations.
|
||||
:type tiled_copy: TiledCopy
|
||||
:param src: The source tensor in shared memory.
|
||||
:type src: Tensor
|
||||
:param dst: The destination tensor in TMEM.
|
||||
:type dst: Tensor
|
||||
"""
|
||||
# Filter zeros from src (smem) and dst (tmem) tensors
|
||||
src_compact = filter_zeros(src)
|
||||
dst_compact = filter_zeros(dst)
|
||||
|
||||
# S2T has a single thread slice; election handled automatically in lowering
|
||||
thr_copy = tiled_copy.get_slice(0)
|
||||
|
||||
# Partition source and destination
|
||||
src_partitioned = thr_copy.partition_S(src_compact, loc=loc, ip=ip)
|
||||
dst_partitioned = thr_copy.partition_D(dst_compact, loc=loc, ip=ip)
|
||||
|
||||
# Get SMEM descriptor tensor for the source
|
||||
smem_desc_tensor = tcgen05.get_s2t_smem_desc_tensor(
|
||||
tiled_copy, src_partitioned, loc=loc, ip=ip
|
||||
)
|
||||
|
||||
# Execute the copy
|
||||
return copy(tiled_copy, smem_desc_tensor, dst_partitioned, loc=loc, ip=ip, **kwargs)
|
||||
|
||||
|
||||
@dsl_user_op
|
||||
@CuTeDSL.jit
|
||||
def block_copy(
|
||||
tiled_copy: TiledCopy,
|
||||
src: Tensor,
|
||||
dst: Tensor,
|
||||
*,
|
||||
loc: Optional[ir.Location] = None,
|
||||
ip: Optional[ir.InsertionPoint] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Performs a block-level copy operation.
|
||||
|
||||
This function adds an abstraction layer over the `cute.copy` usage model by
|
||||
allowing operands with layouts shaped like tiles to be passed directly. This
|
||||
removes the need to manually partition. The API is designed to support multiple
|
||||
copy kinds; currently TMA-based copies and S2T (SMEM to TMEM) copies are supported.
|
||||
|
||||
**TMA copy requirements**:
|
||||
|
||||
When using TMA-based tiled copies, the ``src`` and ``dst`` tensors must have
|
||||
their first mode representing the TMATile, i.e. tensors shaped as ``(TMATile, Rest...)``.
|
||||
For a rank-2 tensor with logical layout (e.g., ``(TILE_M, TILE_N)``), call
|
||||
``group_modes(tensor, 0, 2)`` before passing it to this function.
|
||||
|
||||
**TMA multicast support**:
|
||||
|
||||
For TMA-based copies that enable compiler-driven multicast in a 2D cluster, pass the
|
||||
``tma_multicast`` argument as a dict with the following keys:
|
||||
|
||||
- ``cluster_shape``: a tuple of 2 integers ``(cluster_m, cluster_n)``
|
||||
representing the **2D cluster shape**.
|
||||
- ``multicast_dim``: either ``"M"`` or ``"N"`` indicating which
|
||||
cluster dimension the multicast happens along.
|
||||
- ``use_2cta_mma_inst`` (optional): a ``bool`` indicating whether to
|
||||
use 2CTA MMA instructions when the loaded data is consumed by MMA.
|
||||
Defaults to ``False`` when omitted.
|
||||
|
||||
**S2T (SMEM to TMEM) copy**:
|
||||
|
||||
When using S2T copy operations (e.g., ``tcgen05.Cp4x32x128bOp``), the function
|
||||
automatically handles the filtering, partitioning, and SMEM descriptor creation.
|
||||
Pass a copy atom created with ``cute.make_copy_atom(tcgen05.Cp*Op(...), dtype)``
|
||||
along with source (SMEM) and destination (TMEM) tensors.
|
||||
|
||||
Examples:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
# 1) TMA load without compiler-driven multicast
|
||||
# Note: group_modes is called to make the first mode TMATile
|
||||
block_copy(tma_atom_a, group_modes(tCgA_, 0, 2), group_modes(tCsA_, 0, 2),
|
||||
tma_bar_ptr=tma_bar_ptr)
|
||||
|
||||
# 2) TMA load with compiler-driven multicast along M in a (4,2) cluster
|
||||
block_copy(
|
||||
tma_atom_a,
|
||||
group_modes(tCgA_, 0, 2),
|
||||
group_modes(tCsA_, 0, 2),
|
||||
tma_multicast={
|
||||
"cluster_shape": (4, 2),
|
||||
"multicast_dim": "M",
|
||||
"use_2cta_mma_inst": True,
|
||||
},
|
||||
tma_bar_ptr=tma_bar_ptr,
|
||||
)
|
||||
|
||||
# 3) TMA store
|
||||
# Note that `tma_bar_ptr` and CTA params (`cta_coord` and `cta_layout`)
|
||||
# are not needed for TMA store
|
||||
block_copy(tma_atom_c, group_modes(tCsC_, 0, 2), group_modes(tCgC_, 0, 2))
|
||||
|
||||
# 4) S2T copy (SMEM to TMEM)
|
||||
copy_atom_s2t = cute.make_copy_atom(
|
||||
tcgen05.Cp4x32x128bOp(tcgen05.CtaGroup.ONE), sf_dtype
|
||||
)
|
||||
block_copy(copy_atom_s2t, tCsSF, tCtSF)
|
||||
|
||||
:param tiled_copy: The tiled_copy or copy_atom of the current copy operation.
|
||||
:type tiled_copy: TiledCopy
|
||||
:param src: The source tensor.
|
||||
:type src: Tensor
|
||||
:param dst: The destination tensor.
|
||||
:type dst: Tensor
|
||||
:param tma_multicast: Optional dict for TMA multicast configuration with keys
|
||||
``cluster_shape``, ``multicast_dim``, and optionally
|
||||
``use_2cta_mma_inst``.
|
||||
:type tma_multicast: dict, optional
|
||||
"""
|
||||
import cutlass # local import to avoid circular import at module load time
|
||||
|
||||
if cutlass.const_expr(isinstance(tiled_copy.op, TmaCopyOp)):
|
||||
return _tma_copy_impl(tiled_copy, src, dst, loc=loc, ip=ip, **kwargs)
|
||||
elif cutlass.const_expr(isinstance(tiled_copy.op, _S2TCopyBase)):
|
||||
return _utccp_copy_impl(tiled_copy, src, dst, loc=loc, ip=ip, **kwargs)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"Copy op {type(tiled_copy.op).__name__} is not supported yet."
|
||||
)
|
||||
@@ -1,3 +1,3 @@
|
||||
# Use `pip install -r requirements-cu13.txt` with the present file to install a
|
||||
# wheel consistent with the present state of the github repository
|
||||
nvidia-cutlass-dsl[cu13]==4.4.2
|
||||
nvidia-cutlass-dsl[cu13]==4.5.1
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
# Use `pip install -r requirements.txt` with the present file to install a
|
||||
# wheel consistent with the present state of the github repository
|
||||
nvidia-cutlass-dsl==4.4.2
|
||||
nvidia-cutlass-dsl==4.5.1
|
||||
|
||||
@@ -133,7 +133,7 @@ def get_option_registry():
|
||||
this._option_registry = OptionRegistry(device_cc())
|
||||
return this._option_registry
|
||||
|
||||
this.__version__ = '4.5.0'
|
||||
this.__version__ = '4.5.1'
|
||||
|
||||
from cutlass_cppgen.backend import create_memory_pool
|
||||
from cutlass_cppgen.emit.pytorch import pytorch
|
||||
|
||||
@@ -51,7 +51,7 @@ setup_pycute.perform_setup()
|
||||
|
||||
setup(
|
||||
name='cutlass_cppgen',
|
||||
version='4.5.0',
|
||||
version='4.5.1',
|
||||
description='CUTLASS Pythonic Interface',
|
||||
package_dir={'': '.'},
|
||||
packages=[
|
||||
|
||||
@@ -36,7 +36,7 @@ from setuptools import setup
|
||||
def perform_setup():
|
||||
setup(
|
||||
name='cutlass_library',
|
||||
version='4.5.0',
|
||||
version='4.5.1',
|
||||
description='CUTLASS library generation scripts',
|
||||
packages=['cutlass_library']
|
||||
)
|
||||
|
||||
@@ -36,7 +36,7 @@ from setuptools import setup
|
||||
def perform_setup():
|
||||
setup(
|
||||
name='pycute',
|
||||
version='4.5.0',
|
||||
version='4.5.1',
|
||||
description='Python implementation of CuTe',
|
||||
packages=['pycute'],
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user