2026-01-12 updates

This commit is contained in:
jkosaian
2026-01-12 18:51:25 -08:00
parent 7c09485e25
commit 87cab7bae2
27 changed files with 7619 additions and 234 deletions

View File

@@ -69,8 +69,8 @@ See CUTLASS's [Compatibility section](https://github.com/NVIDIA/cutlass?tab=read
### Current support
* Dense GEMM: `out = A @ B`
- Compute capabilities: 100, 103 (WIP to expand to more)
- Input precisions (A and B must be of same type): F16, BF16, TF32, INT8
- Compute capabilities: 100, 103
- Input precisions (A and B must be of same type): F16, BF16, INT8
- Output precisions: F32, F16, BF16, INT32
- Epilogue operations:
- Auxiliary load of tensors equal in rank and shape of `out`
@@ -79,10 +79,20 @@ See CUTLASS's [Compatibility section](https://github.com/NVIDIA/cutlass?tab=read
- Tensor-tensor elementwise or tensor-scalar addition, multiplication, subtraction, division
- Elementwise tensor exponent, relu, sigmoid, tanh
- Note: Partial support exists on CC 80/89/90 (limited dtypes/tilings coverage)
* Block-scaled dense GEMM:
- Compute capabilities: 100, 103
- Input precisions: MXFP8 (F8E4M3, F8E5M2)
- Output precisions: F32, F16, BF16, F8E4M3, F8E5M2
- Scale factor precisions: F8E8M0
* Grouped GEMM
- Contiguous offset 2d-3d grouped GEMM: `(TotalM, K) @ (G, K, N) -> (TotalM, N)`
- Compute capability: 100
- Input precisions: F16, BF16, F8E4M3, F8E5M2
- Output precisions: F32, F16, BF16
* Planned additions
* Block-scaled GEMMs
* Grouped GEMMs
* Block-scaled GEMMs (additional precisions)
* Grouped GEMMs (additional variants and precisions)
* Additional epilogue operations
* Reductions
* Row/column broadcasts

View File

@@ -28,9 +28,11 @@
from __future__ import annotations
from abc import ABC
import copy
from collections import OrderedDict
from dataclasses import dataclass, field, fields
from typing import TYPE_CHECKING, get_type_hints
from typing import TYPE_CHECKING, Any, get_type_hints
if TYPE_CHECKING:
from collections.abc import Callable
@@ -45,10 +47,10 @@ import cutlass.cute as cute
from cutlass_api.fusion import EmptyTensor, trace, trace_in_out
from cutlass_api.fusion.library import LayoutType
from cutlass_api.library import ScaleMode, ScaleSwizzleMode
from cutlass_api.typing import NumericLike, TensorLike
from cutlass_api.utils import (
TensorWrapper,
add_batch_mode,
is_torch_tensor,
to_cutlass_type,
)
@@ -172,10 +174,6 @@ class EpilogueArguments:
"""Returns the list of names of the input and output parameters of the epilogue"""
return list(self.tensors.keys())
def add_batch_modes(self):
for name in self.tensors:
self.tensors[name] = add_batch_mode(self.tensors[name])
def to_tensor_wrappers(self, permute: list[int] | None = None):
"""Converts the input and output parameters of the epilogue to TensorWrappers"""
for k, v in self.tensors.items():
@@ -215,6 +213,42 @@ class EpilogueArguments:
)
def convert_to_internal_types(caller, metadata: dict[str, Any] = None):
"""
Converts fields of the caller to internal types. Current fields that are converted:
* TensorLike -> TensorWrapper
* NumericLike -> cutlass.Numeric
* Classes that implement _convert_to_internal_types -> their internal types
:param caller: The caller object to convert the fields of
:type caller: Any
:param metadata: Additional metadata to be used for conversion
:type metadata: dict[str, Any] | None
"""
type_hints = get_type_hints(type(caller))
for f in fields(caller):
hint = type_hints.get(f.name)
value = getattr(caller, f.name)
global_metadata = {} if metadata is None else copy.deepcopy(metadata)
global_metadata.update(f.metadata)
if hint is TensorLike:
# Find all fields that are annotated as TensorLike,
# and wrap them in TensorWrapper
setattr(caller, f.name, TensorWrapper(value, **global_metadata))
elif hint is NumericLike:
# Find all fields that are annotated as NumericLike,
# and convert them to cutlass.Numeric
setattr(caller, f.name, to_cutlass_type(value))
elif hasattr(value, "_convert_to_internal_types"):
# If the field is an instance of a class that implements
# _convert_to_internal_types, convert it to internal types
value._convert_to_internal_types(metadata=global_metadata)
setattr(caller, f.name, value)
@dataclass
class RuntimeArguments:
"""
@@ -231,7 +265,6 @@ class RuntimeArguments:
:type performance: Optional[PerformanceControls]
"""
epilogue: EpilogueArguments | None = field(default=None, kw_only=True)
performance: PerformanceControls | None = field(default=None, kw_only=True)
def _validate(self):
@@ -241,24 +274,92 @@ class RuntimeArguments:
"""
def __post_init__(self):
self._validate()
self._convert_to_internal_types()
convert_to_internal_types(self)
def _convert_to_internal_types(self):
"""
Converts all fields to their internal types.
"""
type_hints = get_type_hints(type(self))
for f in fields(self):
hint = type_hints.get(f.name)
value = getattr(self, f.name)
# Find all fields that are annotated as TensorLike,
# and wrap them in TensorWrapper
if hint is TensorLike:
setattr(self, f.name, TensorWrapper(value))
elif hint is NumericLike:
setattr(self, f.name, to_cutlass_type(value))
class KernelOperand(ABC):
"""
Base class for all operands to kernels.
"""
def _convert_to_internal_types(self, metadata: dict[str, Any] = None):
convert_to_internal_types(self, metadata=metadata)
@dataclass
class DenseTensor(KernelOperand):
"""
Representation of a dense tensor operand.
"""
tensor: TensorLike
def __getattr__(self, attr: str) -> Any:
if hasattr(self.tensor, attr):
return getattr(self.tensor, attr)
else:
raise AttributeError(
f"'{self.__class__.__name__}' object has no attribute '{attr}'"
)
def kernel_operand_or_default(operand: TensorLike | KernelOperand) -> KernelOperand:
"""
If operand is already a KernelOperand, return it. Otherwise, wrap it in a DenseTensor.
This is used for convenience interfaces to avoid needing to wrap dense tensors in a DenseTensor
class.
:param operand: The operand to convert
:type operand: TensorLike | KernelOperand
:return: The operand wrapped in a DenseTensor if it is a TensorLike, otherwise the operand itself
:rtype: KernelOperand
"""
if isinstance(operand, TensorLike):
return DenseTensor(operand)
else:
return operand
@dataclass
class ScaledTensor(KernelOperand):
"""
Representation of a scaled tensor operand. This includes:
* A base tensor. This is a KernelOperand subclass.
* A tensor containing scale factors
* Scale mode and swizzle mode
An example of its creation is the following:
```python
A = torch.rand(...)
scale_A = torch.rand(...)
arg = ScaledTensor(A, scale_A, ScaleMode.Blockwise1x32, ScaleSwizzleMode.Swizzle32x4x4)
```
"""
base: KernelOperand
scale: DenseTensor
mode: ScaleMode | tuple[int, ...]
swizzle: ScaleSwizzleMode
def __init__(
self,
base: KernelOperand | TensorLike,
scale: DenseTensor | TensorLike,
mode: ScaleMode | tuple[int, ...],
swizzle: ScaleSwizzleMode,
):
self.base = kernel_operand_or_default(base)
self.scale = kernel_operand_or_default(scale)
self.mode = mode
self.swizzle = swizzle
def __getattr__(self, attr: str) -> Any:
if hasattr(self.base, attr):
return getattr(self.base, attr)
else:
raise AttributeError(
f"'{self.__class__.__name__}' object has no attribute '{attr}'"
)
@dataclass
@@ -273,24 +374,70 @@ class GemmArguments(RuntimeArguments):
N: Number of columns in B and out
:param A: Input tensor A of shape (L, M, K) or (M, K)
:type A: TensorLike
:type A: KernelOperand
:param B: Input tensor B of shape (L, K, N) or (K, N)
:type B: TensorLike
:type B: KernelOperand
:param out: Output tensor C of shape (L, M, N) or (M, N)
:type out: TensorLike
:type out: KernelOperand
:param accumulator_type: Data type of the accumulator
:type accumulator_type: NumericLike
"""
A: TensorLike
B: TensorLike
out: TensorLike
A: KernelOperand
B: KernelOperand
out: KernelOperand
accumulator_type: NumericLike
epilogue: EpilogueArguments | None
def to_operands_metadata(self) -> GemmOperandsMetadata:
from cutlass_api.metadata import GemmOperandsMetadata
def __init__(
self,
A: TensorLike | KernelOperand,
B: TensorLike | KernelOperand,
out: TensorLike | KernelOperand,
accumulator_type: NumericLike,
epilogue: EpilogueArguments | None = None,
):
"""
Construct a GemmArguments object.
return GemmOperandsMetadata.from_args(self)
For convenience, construction for `A`, `B`, and `out` that are operands
to a dense GEMM can be passed in without wrapping them in `DenseTensor`.
```python
GemmArguments(A, B, out, accumulator_type)
# is equivalent to:
GemmArguments(DenseTensor(A), DenseTensor(B), DenseTensor(out), accumulator_type)
```
Other operand types must explicitly wrap tensors in a `KernelOperand` subclass.
For example, a scaled GEMM can be constructed as:
```python
GemmArguments(
ScaledTensor(A, ScaleATensor, scale_mode, scale_swizzle),
ScaledTensor(B, ScaleBTensor, scale_mode, scale_swizzle),
out, # No need to wrap out in a `DenseTensor`
accumulator_type,
)
```
:param A: Input tensor A of shape (L, M, K) or (M, K)
:type A: TensorLike | KernelOperand
:param B: Input tensor B of shape (L, K, N) or (K, N)
:type B: TensorLike | KernelOperand
:param out: Output tensor C of shape (L, M, N) or (M, N)
:type out: TensorLike | KernelOperand
:param accumulator_type: Data type of the accumulator
:type accumulator_type: NumericLike
:param epilogue: Optional custom epilogue fusion to be performed after the GEMM.
:type epilogue: Optional[EpilogueArguments]
"""
self.A = kernel_operand_or_default(A)
self.B = kernel_operand_or_default(B)
self.out = kernel_operand_or_default(out)
self.accumulator_type = accumulator_type
self.epilogue = epilogue
super().__init__()
def _validate(self):
"""
@@ -329,13 +476,79 @@ class GemmArguments(RuntimeArguments):
f"out & A must have the same rank and batch dimension (if any). out shape (L, M, N): {self.out.shape}, A shape (L, M, K): {self.A.shape}"
)
if isinstance(self.epilogue, EpilogueArguments):
def _convert_epilogue(self):
"""Converts the epilogue to an internal representation using internal types."""
if self.epilogue is not None:
L = self.A.shape[0] if len(self.A.shape) == 3 else 1
M, N = self.A.shape[-2], self.B.shape[-1]
accum_shape = (L, M, N)
self.epilogue.trace(accum_shape, self.accumulator_type)
self.epilogue.to_tensor_wrappers()
def __post_init__(self):
self._validate()
self._convert_epilogue()
super().__post_init__()
@dataclass
class GroupedGemmArguments(RuntimeArguments):
"""
Arguments for a grouped GEMM operation. A grouped GEMM performs a series
of independent GEMM operations.
The most basic formulation of a grouped GEMM is one in which one has lists
of A tensors and lists of B tensors, and computes the following:
```python
for i in range(problems_in_group):
out[i] = A[i] @ B[i]
```
Note that each constituent GEMM operation can have different sizes.
Though the abstract formulation of a grouped GEMM treats operands as parts
of a list, in practice, the operands are often packed together contiguously in memory.
In this case, an `offsets` tensor delineates the ending positions of each problem in the group.
Currently-supported variants of a grouped GEMM are:
* Contiguous offset 2D/3D:
* A is a tensor of shape (TotalM, K) or (1, TotalM, K)
* B is a tensor of shape (problems_in_group, K, N)
* out is a tensor of shape (TotalM, N) or (1, TotalM, N)
* offsets is a tensor delineating the ending positions of each problem in the group.
```python
start = 0
for i in range(problems_in_group):
end = offsets[i]
out[start:end, :] = A[start:end, :] @ B[i, :, :]
start = end
```
"""
A: KernelOperand
B: KernelOperand
out: KernelOperand
accumulator_type: NumericLike
offsets: KernelOperand = field(metadata={"alignment_bytes": 4})
epilogue: EpilogueArguments | None
def __init__(
self,
A: TensorLike | KernelOperand,
B: TensorLike | KernelOperand,
out: TensorLike | KernelOperand,
accumulator_type: NumericLike,
offsets: TensorLike | KernelOperand,
epilogue: EpilogueArguments | None = None,
):
self.A = kernel_operand_or_default(A)
self.B = kernel_operand_or_default(B)
self.out = kernel_operand_or_default(out)
self.accumulator_type = accumulator_type
self.offsets = kernel_operand_or_default(offsets)
self.epilogue = epilogue
super().__init__()
@dataclass
class ElementwiseArguments(RuntimeArguments):
@@ -343,21 +556,27 @@ class ElementwiseArguments(RuntimeArguments):
Arguments needed for an elementwise operation.
:param A: The input tensor A.
:type A: TensorLike
:type A: TensorLike | KernelOperand
:param B: The input tensor B.
:type B: TensorLike
:type B: TensorLike | KernelOperand
:param out: The output tensor.
:type out: TensorLike
:type out: TensorLike | KernelOperand
"""
A: TensorLike
B: TensorLike
out: TensorLike
A: KernelOperand
B: KernelOperand
out: KernelOperand
def to_operands_metadata(self) -> ElementwiseOperandsMetadata:
from cutlass_api.metadata import ElementwiseOperandsMetadata
return ElementwiseOperandsMetadata.from_args(self)
def __init__(
self,
A: TensorLike | KernelOperand,
B: TensorLike | KernelOperand,
out: TensorLike | KernelOperand,
):
self.A = kernel_operand_or_default(A)
self.B = kernel_operand_or_default(B)
self.out = kernel_operand_or_default(out)
super().__init__()
def _validate(self):
"""
@@ -371,3 +590,7 @@ class ElementwiseArguments(RuntimeArguments):
raise ValueError(
f"out.shape ({self.out.shape}) must be equal to A.shape ({self.A.shape})"
)
def __post_init__(self):
self._validate()
super().__post_init__()

View File

@@ -0,0 +1,82 @@
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
import enum
from enum import auto as enum_auto
from typing import Self
class ScaleMode(enum.Enum):
"""
Type of scaling used for scaled operations. Each scaling enum corresponds to
a tuple indicating the units the scale covers in the (L, M, K) modes.
For example, Blockwise1x16 corresponds to (1, 1, 16), meaning the scale operates
over 16 elements in the K mode.
"""
Blockwise1x16 = (1, 1, 16)
Blockwise1x32 = (1, 1, 32)
@staticmethod
def compare(mode1: Self | tuple[int, ...], mode2: Self | tuple[int, ...]) -> bool:
if isinstance(mode1, ScaleMode) and isinstance(mode2, ScaleMode):
return mode1.value == mode2.value
# One of the two modes is a tuple. Use tuple comparison and allow
# for different lengths as long as long as the longer tuple contains
# only 1s for the extra leading positions (i.e., allow (1,1,16) == (1, 16))
tuple1 = mode1 if isinstance(mode1, tuple) else mode1.value
tuple2 = mode2 if isinstance(mode2, tuple) else mode2.value
if len(tuple1) == len(tuple2):
return tuple1 == tuple2
elif len(tuple1) < len(tuple2):
padding = (1,) * (len(tuple2) - len(tuple1))
return padding + tuple1 == tuple2
else:
padding = (1,) * (len(tuple1) - len(tuple2))
return padding + tuple2 == tuple1
def __eq__(self, other: Self | tuple[int, ...]) -> bool:
return ScaleMode.compare(self, other)
def __ne__(self, other: Self | tuple[int, ...]) -> bool:
return not ScaleMode.compare(self, other)
def __getitem__(self, index: int) -> int:
return self.value[index]
def __len__(self) -> int:
return len(self.value)
class ScaleSwizzleMode(enum.Enum):
"""Swizzle mode used for scale factors"""
SwizzleNone = enum_auto()
Swizzle32x4x4 = enum_auto()

View File

@@ -90,7 +90,7 @@ class Manifest:
return False
return True
epilogue_args = None if args is None else args.epilogue
epilogue_args = getattr(args, "epilogue", None)
kernels = [
k
# Generate kernels from all providers

View File

@@ -38,13 +38,17 @@ if TYPE_CHECKING:
import cutlass.cute as cute
from cutlass_api.arguments import (
DenseTensor,
ElementwiseArguments,
EpilogueArguments,
GemmArguments,
GroupedGemmArguments,
KernelOperand,
RuntimeArguments,
ScaledTensor,
)
from cutlass_api.library import ScaleMode, ScaleSwizzleMode
from cutlass_api.status import Status
from cutlass_api.utils import TensorWrapper
def _convert_stride(shape: tuple[int, ...], stride: tuple[int, ...]) -> tuple[int, ...]:
@@ -65,6 +69,11 @@ def _convert_stride(shape: tuple[int, ...], stride: tuple[int, ...]) -> tuple[in
new_stride.append(0)
else:
new_stride.append(stride[i])
# If the size of the shape is 1, assign the last mode a stride of 1.
if all(x == 0 for x in new_stride):
new_stride[-1] = 1
return new_stride
@@ -110,7 +119,7 @@ class TensorAttributes:
:param dtype: The data type of the tensor.
:type dtype: cutlass.Numeric
:param stride: The stride of the tensor.
:type stride: tuple[int, ...]
:type stride: tuple[int, ...] | None
:param divisibility: The divisibility requirement on a tensor's stride & shape elements
:type divisibility: int
:param ptr_alignment_bytes: The alignment of the tensor's data pointer, in bytes.
@@ -119,7 +128,7 @@ class TensorAttributes:
"""
dtype: cutlass.Numeric # F32, F16, etc.
stride: tuple[int, ...]
stride: tuple[int, ...] | None
divisibility: int
ptr_alignment_bytes: int
@@ -137,19 +146,26 @@ class TensorAttributes:
(divisibility * dtype.width) // 8
)
def supports(self, operand: TensorWrapper | Self) -> Status:
def supports(self, operand: KernelOperand | Self) -> Status:
"""
Checks whether the provided args satisfy the properties described by
these TensorAttributes.
:param operand: The operand to check support for.
:type operand: TensorWrapper | Self
:type operand: KernelOperand | Self
:return: Whether the provided operand satisfies the properties described by
these TensorAttributes.
:rtype: Status
"""
if isinstance(operand, TensorWrapper):
if not isinstance(operand, KernelOperand) and not isinstance(
operand, TensorAttributes
):
return Status.fail(
f"Expected KernelOperand or TensorAttributes, got {type(operand)}"
)
if isinstance(operand, KernelOperand):
if operand.element_type != self.dtype:
return Status.fail(
f"Expected element type {self.dtype}, got {operand.element_type}"
@@ -162,37 +178,37 @@ class TensorAttributes:
# Normalize the operand stride to have 0's where operand.shape is 1
normalized_operand_stride = _convert_stride(operand.shape, operand.stride)
# If the operand is batched, the expected stride is the same as the self.stride
if len(self.stride) == len(normalized_operand_stride):
expected_stride = self.stride
# We can try if strides match without the batch mode
elif len(self.stride) - 1 == len(normalized_operand_stride):
expected_stride = self.stride[1:]
if self.stride is None:
contiguous_dim = normalized_operand_stride.index(1)
else:
return Status.fail(
f"Expected tensor with strides of rank {len(self.stride)} (batched) or {len(self.stride) - 1} (unbatched), got {len(normalized_operand_stride)} ({normalized_operand_stride})"
)
# If the operand is batched, the expected stride is the same as the self.stride
if len(self.stride) == len(normalized_operand_stride):
expected_stride = self.stride
# We can try if strides match without the batch mode
elif len(self.stride) - 1 == len(normalized_operand_stride):
expected_stride = self.stride[1:]
else:
return Status.fail(
f"Expected tensor with strides of rank {len(self.stride)} (batched) or {len(self.stride) - 1} (unbatched), got {len(normalized_operand_stride)} ({normalized_operand_stride})"
)
# Strides are considered compatible if:
# 1. They have the same rank (checked above)
# 2. The leading mode is the same (or both are all 0)
all_zeros = all(x == 0 for x in expected_stride) and all(
x == 0 for x in normalized_operand_stride
)
# Strides are considered compatible if:
# 1. They have the same rank (checked above)
# 2. The leading mode is the same
# When setting stride from args, any modes of stride 1 and shape 1
# are changed to have stride 0. Thus, there will only be one mode
# with stride 1.
contiguous_dim = expected_stride.index(1)
if not all_zeros and normalized_operand_stride[contiguous_dim] != 1:
return Status.fail(
f"Expected stride[{contiguous_dim}] to be 1, got "
f"{normalized_operand_stride[contiguous_dim]} "
f"(strides: {normalized_operand_stride})"
)
# When setting stride from args, any modes of stride 1 and shape 1
# are changed to have stride 0. Thus, there will only be one mode
# with stride 1.
contiguous_dim = expected_stride.index(1)
if normalized_operand_stride[contiguous_dim] != 1:
return Status.fail(
f"Expected stride[{contiguous_dim}] to be 1, got "
f"{normalized_operand_stride[contiguous_dim]} "
f"(strides: {normalized_operand_stride})"
)
# Check that divisibility constraints are met
if isinstance(operand, TensorWrapper):
if isinstance(operand, KernelOperand):
if not _is_tuple_aligned(
normalized_operand_stride, self.divisibility, contiguous_dim
):
@@ -209,7 +225,7 @@ class TensorAttributes:
# Check data ptr alignment, if available
if (
isinstance(operand, TensorWrapper)
isinstance(operand, KernelOperand)
and operand.data_ptr % self.ptr_alignment_bytes != 0
):
return Status.fail(
@@ -218,23 +234,104 @@ class TensorAttributes:
return Status.success()
@staticmethod
def from_tensor(tensor) -> Self:
"""
Create a TensorAttributes from a tensor.
:param tensor: The tensor to create a TensorAttributes from.
:type tensor: cute.Tensor
@dataclass
class DenseTensorAttributes:
"""
Description of a dense tensor. This includes the data type, stride, and alignment.
"""
:return: The TensorAttributes corresponding to the provided tensor.
:rtype: TensorAttributes
"""
stride = _convert_stride(tensor.shape, tensor.stride)
max_divisibility = _get_max_pow2_divisibility(tensor.shape, stride)
return TensorAttributes(
dtype=tensor.element_type, stride=stride, divisibility=max_divisibility
_tensor: TensorAttributes
def __init__(
self,
dtype: cutlass.Numeric,
stride: tuple[int, ...],
divisibility: int,
ptr_alignment_bytes: int | None = None,
):
self._tensor = TensorAttributes(
dtype, stride, divisibility, ptr_alignment_bytes
)
def supports(self, operand: DenseTensor | Self) -> Status:
"""
Checks whether the provided args satisfy the properties described by
these TensorAttributes.
:param operand: The operand to check support for.
:type operand: DenseTensor | Self
:return: Whether the provided operand satisfies the properties described by
these TensorAttributes.
:rtype: Status
"""
if not isinstance(operand, DenseTensor) and not isinstance(
operand, DenseTensorAttributes
):
return Status.fail(
f"Expected DenseTensor or DenseTensorAttributes, got {type(operand)}"
)
return self._tensor.supports(operand)
def __getattr__(self, attr: str) -> Any:
if hasattr(self._tensor, attr):
return getattr(self._tensor, attr)
else:
raise AttributeError(
f"'{self.__class__.__name__}' object has no attribute '{attr}'"
)
@dataclass
class ScaledTensorAttributes:
"""
Description of a scaled tensor. This includes the base tensor and the
scale tensor + scale mode + scale swizzle combination.
"""
base: TensorAttributes
scale: DenseTensorAttributes
mode: ScaleMode | tuple[int, ...]
swizzle: ScaleSwizzleMode
def supports(self, operand: ScaledTensor | Self) -> Status:
"""
Checks whether the provided args satisfy the properties described by
these ScaledTensorAttributes.
"""
if not isinstance(operand, ScaledTensor) and not isinstance(
operand, ScaledTensorAttributes
):
return Status.fail(
f"Expected ScaledTensor or ScaledTensorAttributes, got {type(operand)}"
)
if not (status := self.base.supports(operand.base)):
return status
if not (status := self.scale.supports(operand.scale)):
return status
if not ScaleMode.compare(self.mode, operand.mode):
return Status.fail(f"Expected scale mode {self.mode}, got {operand.mode}")
if self.swizzle != operand.swizzle:
return Status.fail(
f"Expected scale swizzle mode {self.swizzle}, got {operand.swizzle}"
)
return Status.success()
def __getattr__(self, attr: str) -> Any:
if hasattr(self.base, attr):
return getattr(self.base, attr)
else:
raise AttributeError(
f"'{self.__class__.__name__}' object has no attribute '{attr}'"
)
@dataclass
class OperandsMetadata(ABC):
@@ -326,11 +423,11 @@ class GemmOperandsMetadata(OperandsMetadata):
"""
Metadata for the operands of a GEMM operation.
:param A: Metadata for the input tensor A.
:param A: Metadata for operand `A` of the GEMM.
:type A: TensorAttributes
:param B: Metadata for the input tensor B.
:param B: Metadata for operand `B` of the GEMM.
:type B: TensorAttributes
:param out: Metadata for the output tensor.
:param out: Metadata for operand `out` of the GEMM.
:type out: TensorAttributes
:param accumulator_type: The data type of the accumulator tensor.
:type accumulator_type: cutlass.Numeric
@@ -372,23 +469,65 @@ class GemmOperandsMetadata(OperandsMetadata):
return Status.success()
@staticmethod
def from_args(args: GemmArguments) -> Self:
"""
Create a GemmOperandsMetadata from a GemmArguments.
:param args: The GemmArguments to create a GemmOperandsMetadata from.
:type args: GemmArguments
@dataclass
class GroupedGemmOperandsMetadata(OperandsMetadata):
"""
Metadata for the operands of a grouped GEMM operation.
:return: The GemmOperandsMetadata corresponding to the provided GemmArguments.
:rtype: GemmOperandsMetadata
:param A: Metadata for operand `A` of the grouped GEMM.
:type A: TensorAttributes
:param B: Metadata for operand `B` of the grouped GEMM.
:type B: TensorAttributes
:param out: Metadata for operand `out` of the grouped GEMM.
:type out: TensorAttributes
:param accumulator_type: The data type of the accumulator tensor.
:type accumulator_type: cutlass.Numeric
:param offsets: Metadata for operand `offsets` of the grouped GEMM.
:type offsets: TensorAttributes
"""
A: TensorAttributes
B: TensorAttributes
out: TensorAttributes
offsets: TensorAttributes
accumulator_type: cutlass.Numeric
def supports(self, other: GroupedGemmArguments | Self) -> Status:
"""
return GemmOperandsMetadata(
A=TensorAttributes.from_tensor(args.A),
B=TensorAttributes.from_tensor(args.B),
out=TensorAttributes.from_tensor(args.out),
accumulator_type=args.accumulator_type,
)
Checks whether the provided args satisfy the properties described by
the operands in this metadata.
:param other: The arguments to check support for.
:type other: GroupedGemmArguments | Self
:return: Whether the provided args satisfy the properties described by
the operands in this metadata.
:rtype: Status
"""
if isinstance(other, RuntimeArguments) and not isinstance(
other, GroupedGemmArguments
):
return Status.fail(f"Expected GroupedGemmArguments, got {type(other)}")
if not (status := self.A.supports(other.A)):
return Status.fail(f"Operand `A` is unsupported: {status.error}")
if not (status := self.B.supports(other.B)):
return Status.fail(f"Operand `B` is unsupported: {status.error}")
if not (status := self.out.supports(other.out)):
return Status.fail(f"Operand `out` is unsupported: {status.error}")
if not (status := self.offsets.supports(other.offsets)):
return Status.fail(f"Operand `offsets` is unsupported: {status.error}")
if self.accumulator_type != other.accumulator_type:
return Status.fail(
f"Expected accumulator type {self.accumulator_type}, got {other.accumulator_type}"
)
return Status.success()
@dataclass
@@ -436,23 +575,6 @@ class ElementwiseOperandsMetadata(OperandsMetadata):
return Status.success()
@staticmethod
def from_args(args: ElementwiseArguments) -> Self:
"""
Create a ElementwiseOperandsMetadata from a ElementwiseArguments.
:param args: The ElementwiseArguments to create a ElementwiseOperandsMetadata from.
:type args: ElementwiseArguments
:return: The ElementwiseOperandsMetadata corresponding to the provided ElementwiseArguments.
:rtype: ElementwiseOperandsMetadata
"""
return ElementwiseOperandsMetadata(
A=TensorAttributes.from_tensor(args.A),
B=TensorAttributes.from_tensor(args.B),
out=TensorAttributes.from_tensor(args.out),
)
class EpilogueMetadata:
def __init__(self, epilogue_args: EpilogueArguments):
@@ -535,7 +657,11 @@ class KernelMetadata:
if not (status := supports_or_none(self.design, args.performance, "design")):
return status
if not (status := supports_or_none(self.epilogue, args.epilogue, "epilogue")):
if not (
status := supports_or_none(
self.epilogue, getattr(args, "epilogue", None), "epilogue"
)
):
return status
return Status.success()

View File

@@ -38,9 +38,9 @@ import cutlass.cute as cute
from cutlass_api.arguments import ElementwiseArguments, EpilogueArguments
from cutlass_api.artifact import CompiledArtifact
from cutlass_api.metadata import (
DenseTensorAttributes,
ElementwiseOperandsMetadata,
KernelMetadata,
TensorAttributes,
)
from cutlass_api.providers.cutedsl import CuTeDSLProvider
from cutlass_api.providers.cutedsl.kernel import CuteDslKernel
@@ -65,7 +65,9 @@ class ElementwiseAddKernel(CuteDslKernel):
def compile(self, args: ElementwiseArguments, cc: int = None) -> CompiledArtifact:
stream = cutlass.cute.runtime.make_fake_stream()
compiled_kernel = self.cute_compile(self.impl, args.A, args.B, args.out, stream)
compiled_kernel = self.cute_compile(
self.impl, args.A.tensor, args.B.tensor, args.out.tensor, stream
)
return CompiledArtifact(compiled_kernel, self)
def _run(
@@ -77,7 +79,9 @@ class ElementwiseAddKernel(CuteDslKernel):
) -> None:
stream = to_cuda_stream(stream)
compiled_kernel = compiled_artifact.compiled_obj
self.cute_run(compiled_kernel, args.A, args.B, args.out, stream)
self.cute_run(
compiled_kernel, args.A.tensor, args.B.tensor, args.out.tensor, stream
)
@staticmethod
def generate_kernels(
@@ -111,13 +115,13 @@ class ElementwiseAddKernel(CuteDslKernel):
)
operands = ElementwiseOperandsMetadata(
A=TensorAttributes(
A=DenseTensorAttributes(
dtype=dtype, stride=stride_A, divisibility=divisibility
),
B=TensorAttributes(
B=DenseTensorAttributes(
dtype=dtype, stride=stride_B, divisibility=divisibility
),
out=TensorAttributes(
out=DenseTensorAttributes(
dtype=dtype, stride=stride_out, divisibility=divisibility
),
)

View File

@@ -31,4 +31,6 @@
# ruff: noqa: F401
import cutlass_api.providers.cutedsl.gemm.sm100_static_persistent
import cutlass_api.providers.cutedsl.gemm.sm100_static_persistent_efc
import cutlass_api.providers.cutedsl.gemm.sm100_dense_blockscaled_static_persistent
import cutlass_api.providers.cutedsl.gemm.sm80_tensorop_gemm
import cutlass_api.providers.cutedsl.gemm.sm100_contiguous_offset_2d3d_dense_gemm

View File

@@ -0,0 +1,602 @@
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
from typing import Tuple
from cutlass.cutlass_dsl import (
Boolean,
Integer,
Int32,
min,
extract_mlir_values,
new_from_mlir_values,
dsl_user_op,
const_expr,
)
from cutlass._mlir import ir
import cutlass.cute as cute
class WorkTileInfo:
"""A class to represent information about a work tile.
:ivar tile_idx: The index of the tile.
:type tile_idx: cute.Coord
:ivar is_valid_tile: Whether the tile is valid.
:type is_valid_tile: Boolean
"""
def __init__(self, tile_idx: cute.Coord, is_valid_tile: Boolean):
self._tile_idx = tile_idx
self._is_valid_tile = Boolean(is_valid_tile)
def __extract_mlir_values__(self) -> list[ir.Value]:
values = extract_mlir_values(self.tile_idx)
values.extend(extract_mlir_values(self.is_valid_tile))
return values
def __new_from_mlir_values__(self, values: list[ir.Value]) -> "WorkTileInfo":
assert len(values) == 4
new_tile_idx = new_from_mlir_values(self._tile_idx, values[:-1])
new_is_valid_tile = new_from_mlir_values(self._is_valid_tile, [values[-1]])
return WorkTileInfo(new_tile_idx, new_is_valid_tile)
@property
def is_valid_tile(self) -> Boolean:
"""Check latest tile returned by the scheduler is valid or not. Any scheduling
requests after all tasks completed will return an invalid tile.
:return: The validity of the tile.
:rtype: Boolean
"""
return self._is_valid_tile
@property
def tile_idx(self) -> cute.Coord:
"""
Get the index of the tile.
:return: The index of the tile.
:rtype: cute.Coord
"""
return self._tile_idx
class PersistentTileSchedulerParams:
"""A class to represent parameters for a persistent tile scheduler.
This class is designed to manage and compute the layout of clusters and tiles
in a batched gemm problem.
:ivar cluster_shape_mn: Shape of the cluster in (m, n) dimensions (K dimension cta count must be 1).
:type cluster_shape_mn: tuple
:ivar problem_layout_ncluster_mnl: Layout of the problem in terms of
number of clusters in (m, n, l) dimensions.
:type problem_layout_ncluster_mnl: cute.Layout
"""
@dsl_user_op
def __init__(
self,
problem_shape_ntile_mnl: cute.Shape,
cluster_shape_mnk: cute.Shape,
swizzle_size: int = 1,
raster_along_m: bool = True,
*,
loc=None,
ip=None,
):
"""
Initializes the PersistentTileSchedulerParams with the given parameters.
:param problem_shape_ntile_mnl: The shape of the problem in terms of
number of CTA (Cooperative Thread Array) in (m, n, l) dimensions.
:type problem_shape_ntile_mnl: cute.Shape
:param cluster_shape_mnk: The shape of the cluster in (m, n) dimensions.
:type cluster_shape_mnk: cute.Shape
:param swizzle_size: Swizzling size in the unit of cluster. 1 means no swizzle
:type swizzle_size: int
:param raster_along_m: Rasterization order of clusters. Only used when swizzle_size > 1.
True means along M, false means along N.
:type raster_along_m: bool
:raises ValueError: If cluster_shape_k is not 1.
"""
if cluster_shape_mnk[2] != 1:
raise ValueError(f"unsupported cluster_shape_k {cluster_shape_mnk[2]}")
if swizzle_size < 1:
raise ValueError(f"expect swizzle_size >= 1, but get {swizzle_size}")
self.problem_shape_ntile_mnl = problem_shape_ntile_mnl
# cluster_shape_mnk is kept for reconstruction
self._cluster_shape_mnk = cluster_shape_mnk
self.cluster_shape_mn = cluster_shape_mnk[:2]
self.swizzle_size = swizzle_size
self.raster_along_m = raster_along_m
self._loc = loc
# By default, we follow m major (col-major) raster order, so make a col-major layout
self.problem_layout_ncluster_mnl = cute.make_layout(
cute.ceil_div(
self.problem_shape_ntile_mnl, cluster_shape_mnk[:2], loc=loc, ip=ip
),
loc=loc,
ip=ip,
)
# Apply swizzle if swizzle_size > 1
if swizzle_size > 1:
problem_shape_ncluster_mnl = cute.round_up(
self.problem_layout_ncluster_mnl.shape,
(1, swizzle_size, 1) if raster_along_m else (swizzle_size, 1, 1),
)
if raster_along_m:
self.problem_layout_ncluster_mnl = cute.make_layout(
(
problem_shape_ncluster_mnl[0],
(swizzle_size, problem_shape_ncluster_mnl[1] // swizzle_size),
problem_shape_ncluster_mnl[2],
),
stride=(
swizzle_size,
(1, swizzle_size * problem_shape_ncluster_mnl[0]),
problem_shape_ncluster_mnl[0] * problem_shape_ncluster_mnl[1],
),
loc=loc,
ip=ip,
)
else:
self.problem_layout_ncluster_mnl = cute.make_layout(
(
(swizzle_size, problem_shape_ncluster_mnl[0] // swizzle_size),
problem_shape_ncluster_mnl[1],
problem_shape_ncluster_mnl[2],
),
stride=(
(1, swizzle_size * problem_shape_ncluster_mnl[1]),
swizzle_size,
problem_shape_ncluster_mnl[0] * problem_shape_ncluster_mnl[1],
),
loc=loc,
ip=ip,
)
# Create FastDivmod divisors (only when swizzle_size == 1 for correctness)
# FastDivmod assumes simple col-major layout, incompatible with swizzled layouts
if swizzle_size == 1:
problem_layout_size = cute.size(
self.problem_layout_ncluster_mnl, loc=loc, ip=ip
)
cluster_count_m = self.problem_layout_ncluster_mnl.shape[0]
cluster_count_n = self.problem_layout_ncluster_mnl.shape[1]
# batch_fdd: Used to map linear_idx to work_unit_id (handles persistent scheduling)
self.batch_fdd = cute.fast_divmod_create_divisor(
problem_layout_size, loc=loc, ip=ip
)
if raster_along_m:
cluster_count_major = cluster_count_m
cluster_count_minor = cluster_count_n
else:
cluster_count_major = cluster_count_n
cluster_count_minor = cluster_count_m
# cluster_shape_major_fdd: Used to decode work_unit_id to cluster coordinates
self.cluster_shape_major_fdd = cute.fast_divmod_create_divisor(
cluster_count_major, loc=loc, ip=ip
)
# cluster_shape_minor_fdd: Used for the second level decomposition
self.cluster_shape_minor_fdd = cute.fast_divmod_create_divisor(
cluster_count_minor, loc=loc, ip=ip
)
else:
# FastDivmod not applicable with swizzling, set to None
self.batch_fdd = None
self.cluster_shape_major_fdd = None
self.cluster_shape_minor_fdd = None
def __extract_mlir_values__(self):
values, self._values_pos = [], []
for obj in [
self.problem_shape_ntile_mnl,
self._cluster_shape_mnk,
self.swizzle_size,
self.raster_along_m,
]:
obj_values = extract_mlir_values(obj)
values += obj_values
self._values_pos.append(len(obj_values))
# Add FastDivmod divisors to MLIR values for Host->Device transfer
# Only add non-None values to avoid MLIR type errors
fastdivmod_values = []
fastdivmod_indices = [] # Track which FastDivmod objects are present
for i, (fdd_name, fdd_obj) in enumerate(
[
("batch_fdd", self.batch_fdd),
("cluster_shape_major_fdd", self.cluster_shape_major_fdd),
("cluster_shape_minor_fdd", self.cluster_shape_minor_fdd),
]
):
if fdd_obj is not None:
# Extract MLIR values from FastDivmodDivisor objects
fdd_values = extract_mlir_values(fdd_obj)
fastdivmod_values.extend(fdd_values)
fastdivmod_indices.append(i)
values += fastdivmod_values
self._values_pos.append(
len(fastdivmod_indices)
) # Store count of FastDivmod objects, not values
self._fastdivmod_indices = fastdivmod_indices # Store for reconstruction
return values
def __new_from_mlir_values__(self, values):
obj_list = []
values_copy = list(values) # Make a copy to avoid modifying original
# Reconstruct original objects from MLIR values
for obj, n_items in zip(
[
self.problem_shape_ntile_mnl,
self._cluster_shape_mnk,
self.swizzle_size,
self.raster_along_m,
],
self._values_pos[:-1], # Exclude FastDivmod count
):
obj_list.append(new_from_mlir_values(obj, values_copy[:n_items]))
values_copy = values_copy[n_items:]
# Create new params object by calling __init__ with reconstructed values
# This properly recreates layouts and other derived attributes in the device context
new_params = PersistentTileSchedulerParams(*(tuple(obj_list)), loc=self._loc)
# Restore FastDivmod divisors from remaining values
fdd_names = ["batch_fdd", "cluster_shape_major_fdd", "cluster_shape_minor_fdd"]
if hasattr(self, "_fastdivmod_indices") and len(self._fastdivmod_indices) > 0:
# Override the FastDivmod divisors created by __init__ with reconstructed ones
for j, original_index in enumerate(self._fastdivmod_indices):
fdd_name = fdd_names[original_index]
# Get the original FastDivmodDivisor object
original_fdd = getattr(self, fdd_name)
if original_fdd is not None and j < len(values_copy):
# Each FastDivmodDivisor has 1 MLIR value
reconstructed_fdd = new_from_mlir_values(
original_fdd, [values_copy[j]]
)
setattr(new_params, fdd_name, reconstructed_fdd)
return new_params
@dsl_user_op
def get_grid_shape(
self, max_active_clusters: Int32, *, loc=None, ip=None
) -> Tuple[Integer, Integer, Integer]:
"""
Computes the grid shape based on the maximum active clusters allowed.
:param max_active_clusters: The maximum number of active clusters that
can run in one wave.
:type max_active_clusters: Int32
:return: A tuple containing the grid shape in (m, n, persistent_clusters).
- m: self.cluster_shape_m.
- n: self.cluster_shape_n.
- persistent_clusters: Number of persistent clusters that can run.
"""
# Total ctas in problem size
num_ctas_mnl = tuple(
cute.size(x) * y
for x, y in zip(
self.problem_layout_ncluster_mnl.shape, self.cluster_shape_mn
)
) + (self.problem_layout_ncluster_mnl.shape[2],)
num_ctas_in_problem = cute.size(num_ctas_mnl, loc=loc, ip=ip)
num_ctas_per_cluster = cute.size(self.cluster_shape_mn, loc=loc, ip=ip)
# Total ctas that can run in one wave
num_ctas_per_wave = max_active_clusters * num_ctas_per_cluster
num_persistent_ctas = min(num_ctas_in_problem, num_ctas_per_wave)
num_persistent_clusters = num_persistent_ctas // num_ctas_per_cluster
return (*self.cluster_shape_mn, num_persistent_clusters)
class StaticPersistentTileScheduler:
"""A scheduler for static persistent tile execution in CUTLASS/CuTe kernels.
:ivar params: Tile schedule related params, including cluster shape and problem_layout_ncluster_mnl
:type params: PersistentTileSchedulerParams
:ivar num_persistent_clusters: Number of persistent clusters that can be launched
:type num_persistent_clusters: Int32
:ivar cta_id_in_cluster: ID of the CTA within its cluster
:type cta_id_in_cluster: cute.Coord
:ivar _num_tiles_executed: Counter for executed tiles
:type _num_tiles_executed: Int32
:ivar _current_work_linear_idx: Current cluster index
:type _current_work_linear_idx: Int32
"""
def __init__(
self,
params: PersistentTileSchedulerParams,
num_persistent_clusters: Int32,
current_work_linear_idx: Int32,
cta_id_in_cluster: cute.Coord,
num_tiles_executed: Int32,
):
"""
Initializes the StaticPersistentTileScheduler with the given parameters.
:param params: Tile schedule related params, including cluster shape and problem_layout_ncluster_mnl.
:type params: PersistentTileSchedulerParams
:param num_persistent_clusters: Number of persistent clusters that can be launched.
:type num_persistent_clusters: Int32
:param current_work_linear_idx: Current cluster index.
:type current_work_linear_idx: Int32
:param cta_id_in_cluster: ID of the CTA within its cluster.
:type cta_id_in_cluster: cute.Coord
:param num_tiles_executed: Counter for executed tiles.
:type num_tiles_executed: Int32
"""
self.params = params
self.num_persistent_clusters = num_persistent_clusters
self._current_work_linear_idx = current_work_linear_idx
self.cta_id_in_cluster = cta_id_in_cluster
self._num_tiles_executed = num_tiles_executed
def __extract_mlir_values__(self) -> list[ir.Value]:
values = extract_mlir_values(self.num_persistent_clusters)
values.extend(extract_mlir_values(self._current_work_linear_idx))
values.extend(extract_mlir_values(self.cta_id_in_cluster))
values.extend(extract_mlir_values(self._num_tiles_executed))
# CRITICAL: Also extract FastDivmod divisors from params
values.extend(extract_mlir_values(self.params))
return values
def __new_from_mlir_values__(
self, values: list[ir.Value]
) -> "StaticPersistentTileScheduler":
assert len(values) >= 6
new_num_persistent_clusters = new_from_mlir_values(
self.num_persistent_clusters, [values[0]]
)
new_current_work_linear_idx = new_from_mlir_values(
self._current_work_linear_idx, [values[1]]
)
new_cta_id_in_cluster = new_from_mlir_values(
self.cta_id_in_cluster, values[2:5]
)
new_num_tiles_executed = new_from_mlir_values(
self._num_tiles_executed, [values[5]]
)
# Reconstruct params with FastDivmod divisors
params_values = values[6:] # Remaining values are from params
new_params = new_from_mlir_values(self.params, params_values)
return StaticPersistentTileScheduler(
new_params, # Use reconstructed params with FastDivmod divisors
new_num_persistent_clusters,
new_current_work_linear_idx,
new_cta_id_in_cluster,
new_num_tiles_executed,
)
@staticmethod
@dsl_user_op
def create(
params: PersistentTileSchedulerParams,
block_idx: Tuple[Integer, Integer, Integer],
grid_dim: Tuple[Integer, Integer, Integer],
*,
loc=None,
ip=None,
):
"""Initialize the static persistent tile scheduler.
:param params: Parameters for the persistent
tile scheduler.
:type params: PersistentTileSchedulerParams
:param block_idx: The 3d block index in the format (bidx, bidy, bidz).
:type block_idx: Tuple[Integer, Integer, Integer]
:param grid_dim: The 3d grid dimensions for kernel launch.
:type grid_dim: Tuple[Integer, Integer, Integer]
:return: A StaticPersistentTileScheduler object.
:rtype: StaticPersistentTileScheduler
"""
# Calculate the number of persistent clusters by dividing the total grid size
# by the number of CTAs per cluster
num_persistent_clusters = cute.size(grid_dim, loc=loc, ip=ip) // cute.size(
params.cluster_shape_mn, loc=loc, ip=ip
)
bidx, bidy, bidz = block_idx
# Initialize workload index equals to the cluster index in the grid
current_work_linear_idx = Int32(bidz)
# CTA id in the cluster
cta_id_in_cluster = (
Int32(bidx % params.cluster_shape_mn[0]),
Int32(bidy % params.cluster_shape_mn[1]),
Int32(0),
)
# Initialize number of tiles executed to zero
num_tiles_executed = Int32(0)
return StaticPersistentTileScheduler(
params,
num_persistent_clusters,
current_work_linear_idx,
cta_id_in_cluster,
num_tiles_executed,
)
# called by host
@staticmethod
def get_grid_shape(
params: PersistentTileSchedulerParams,
max_active_clusters: Int32,
*,
loc=None,
ip=None,
) -> Tuple[Integer, Integer, Integer]:
"""Calculates the grid shape to be launched on GPU using problem shape,
threadblock shape, and active cluster size.
:param params: Parameters for grid shape calculation.
:type params: PersistentTileSchedulerParams
:param max_active_clusters: Maximum active clusters allowed.
:type max_active_clusters: Int32
:return: The calculated 3d grid shape.
:rtype: Tuple[Integer, Integer, Integer]
"""
return params.get_grid_shape(max_active_clusters, loc=loc, ip=ip)
# private method
def _get_current_work_for_linear_idx(
self, current_work_linear_idx: Int32, *, loc=None, ip=None
) -> WorkTileInfo:
"""Compute current tile coord given current_work_linear_idx and cta_id_in_cluster.
:param current_work_linear_idx: The linear index of the current work.
:type current_work_linear_idx: Int32
:return: An object containing information about the current tile coordinates
and validity status.
:rtype: WorkTileInfo
"""
is_valid = current_work_linear_idx < cute.size(
self.params.problem_layout_ncluster_mnl, loc=loc, ip=ip
)
# Choose coordinate calculation method based on swizzle configuration
if self.params.swizzle_size == 1:
# Use FastDivmod optimization for non-swizzled layouts
cur_cluster_coord = self._get_cluster_work_idx_with_fastdivmod(
current_work_linear_idx, loc=loc, ip=ip
)
else:
# Use get_flat_coord for swizzled layouts (FastDivmod doesn't support them)
cur_cluster_coord = self.params.problem_layout_ncluster_mnl.get_flat_coord(
current_work_linear_idx, loc=loc, ip=ip
)
# cur_tile_coord is a tuple of i32 values
cur_tile_coord = tuple(
Int32(x) * Int32(z) + Int32(y)
for x, y, z in zip(
cur_cluster_coord,
self.cta_id_in_cluster,
(*self.params.cluster_shape_mn, Int32(1)),
)
)
return WorkTileInfo(cur_tile_coord, is_valid)
def _get_cluster_work_idx_with_fastdivmod(
self, current_work_linear_idx: Int32, *, loc=None, ip=None
) -> Tuple[Int32, Int32, Int32]:
"""
FastDivmod optimized CLUSTER coordinate calculation.
CRITICAL: This should mimic problem_layout_ncluster_mnl.get_hier_coord()
which returns CLUSTER coordinates, not tile coordinates!
:param current_work_linear_idx: Linear index in the work space
:type current_work_linear_idx: Int32
:return: Cluster coordinates (m, n, l) or None if FastDivmod not available
:rtype: Tuple[Int32, Int32, Int32] or None
"""
# Step 1: Handle persistent scheduling - map linear_idx to work_unit_id
work_iteration, work_unit_id = divmod(
current_work_linear_idx, self.params.batch_fdd
)
# Step 2: Decode work_unit_id using FastDivmod objects
# The layout structure is: problem_layout_ncluster_mnl has shape (cluster_count_m, cluster_count_n, batch_count)
# work_unit_id needs to be decomposed into (batch_l, cluster_minor, cluster_major) in little-endian order
# First, get cluster_major using cluster_shape_major_fdd
cluster_minor_batch, cluster_major = divmod(
work_unit_id, self.params.cluster_shape_major_fdd
)
# Then decode cluster_minor_batch to get cluster_minor and batch_l using FastDivmod
batch_l, cluster_minor = divmod(
cluster_minor_batch, self.params.cluster_shape_minor_fdd
)
if self.params.raster_along_m:
cluster_m = cluster_major
cluster_n = cluster_minor
else:
cluster_m = cluster_minor
cluster_n = cluster_major
return (cluster_m, cluster_n, batch_l)
@dsl_user_op
def get_current_work(self, *, loc=None, ip=None) -> WorkTileInfo:
return self._get_current_work_for_linear_idx(
self._current_work_linear_idx, loc=loc, ip=ip
)
@dsl_user_op
def initial_work_tile_info(self, *, loc=None, ip=None) -> WorkTileInfo:
return self.get_current_work(loc=loc, ip=ip)
@dsl_user_op
def advance_to_next_work(self, *, advance_count: int = 1, loc=None, ip=None):
self._current_work_linear_idx += Int32(advance_count) * Int32(
self.num_persistent_clusters
)
self._num_tiles_executed += Int32(1)
@property
def num_tiles_executed(self) -> Int32:
return self._num_tiles_executed

View File

@@ -0,0 +1,347 @@
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
import itertools
from collections.abc import Callable, Generator
import cutlass
import cutlass.utils as utils
import cutlass.cute as cute
from cutlass.cute.runtime import make_ptr
import cutlass_api
from cutlass_api.arguments import (
EpilogueArguments,
GroupedGemmArguments,
)
from cutlass_api.artifact import CompiledArtifact
from cutlass_api.metadata import (
DesignMetadata,
EpilogueMetadata,
GroupedGemmOperandsMetadata,
KernelMetadata,
Sm100DesignMetadata,
)
from cutlass_api.providers.cutedsl import CuTeDSLProvider
from cutlass_api.providers.cutedsl.kernel import CuteDslKernel
from cutlass_api.providers.cutedsl.utils import get_max_active_clusters
from cutlass_api.status import Status
from cutlass_api.utils import (
ceil_div,
round_up,
strides_to_layout_string,
to_cuda_stream,
tuple_to_string,
)
from .implementations.sm100_contiguous_offset_2d3d_dense_gemm_impl import (
ContiguousOffset2D3DGemmDenseKernelImpl,
)
from .utils import (
tensor_rank_2_or_3,
tensor_reduction_mode_matches,
tensor_output_shape_matches,
)
@CuTeDSLProvider.register
class ContiguousOffset2D3DGemmDenseKernel(CuteDslKernel):
def __init__(self, metadata: KernelMetadata):
super().__init__(metadata)
mma_tiler_mn = (metadata.design.tile_shape[0], metadata.design.tile_shape[1])
cluster_shape_mn = (
metadata.design.cluster_shape[0],
metadata.design.cluster_shape[1],
)
self.impl = ContiguousOffset2D3DGemmDenseKernelImpl(
metadata.operands.accumulator_type,
metadata.design.use_2cta_mma,
mma_tiler_mn,
cluster_shape_mn,
)
def _supports(self, args: GroupedGemmArguments) -> Status:
if not (status := tensor_reduction_mode_matches(args)):
return status
if not (status := tensor_output_shape_matches(args)):
return status
if len(args.A.shape) == 3 and args.A.shape[0] != 1:
return Status.fail(
"Operand A must have batch size 1."
)
if len(args.out.shape) == 3 and args.out.shape[0] != 1:
return Status.fail("out must have batch size 1.")
ValidM, N = args.out.shape[-2:]
K = args.A.shape[-1]
group_count = args.B.shape[0] if len(args.B.shape) == 3 else 1
if args.A.shape[-2] != ValidM or args.A.shape[-1] != K:
return Status.fail(f"A must have shape ({ValidM}, {K}). Got {args.A.shape}")
if args.B.shape != (group_count, K, N):
return Status.fail(
f"B must have shape ({group_count}, {K}, {N}). Got {args.B.shape}"
)
if args.out.shape[-2] != ValidM or args.out.shape[-1] != N:
return Status.fail(
f"out must have shape ({ValidM}, {N}). Got {args.out.shape}"
)
#
# Check offsets
#
if len(args.offsets.shape) != 1:
return Status.fail(f"offsets must be a 1D tensor. Got {args.offsets.shape}")
if args.offsets.numel() != group_count:
return Status.fail(
f"offsets must have {group_count} elements when offset mode is End and group count is {group_count}."
)
return Status.success()
def compile(self, args: GroupedGemmArguments, cc: int = None) -> CompiledArtifact:
stream = cutlass.cute.runtime.make_fake_stream()
max_active_clusters = get_max_active_clusters(self.impl.cluster_shape_mn)
compiled_kernel = self.cute_compile(
self.impl,
args.A.tensor,
args.B.tensor,
args.out.tensor,
args.offsets.tensor,
max_active_clusters,
stream,
)
return CompiledArtifact(compiled_kernel, self)
def _run(
self,
args: GroupedGemmArguments,
compiled_artifact: CompiledArtifact,
stream,
workspace=None,
) -> None:
stream = to_cuda_stream(stream)
compiled_gemm = compiled_artifact.compiled_obj
self.cute_run(
compiled_gemm,
args.A.tensor,
args.B.tensor,
args.out.tensor,
args.offsets.tensor,
stream,
)
@staticmethod
def _valid_operands(operands: GroupedGemmOperandsMetadata) -> bool:
if operands.accumulator_type != cutlass.Float32:
return False
if operands.A.stride[-2:].index(1) != 1:
# A must be k-major
return False
if operands.B.stride[-2:].index(1) != 0:
# B must be n-major
return False
if operands.out.stride[-2:].index(1) != 1:
# out must be n-major
return False
return True
@staticmethod
def _valid_design_metadata(design: DesignMetadata) -> bool:
use_2cta_mma = design.use_2cta_mma
mma_tiler_mn = design.tile_shape[:2]
cluster_shape_mn = design.cluster_shape[:2]
impl = ContiguousOffset2D3DGemmDenseKernelImpl
return impl.is_valid_mma_tiler_and_cluster_shape(
use_2cta_mma, mma_tiler_mn, cluster_shape_mn
)
@staticmethod
def _valid_epilogue_metadata(epilogue: EpilogueMetadata | None) -> bool:
return epilogue is None
@staticmethod
def _valid_metadata(metadata: KernelMetadata) -> bool:
if not ContiguousOffset2D3DGemmDenseKernel._valid_operands(metadata.operands):
return False
if not ContiguousOffset2D3DGemmDenseKernel._valid_design_metadata(
metadata.design
):
return False
if not ContiguousOffset2D3DGemmDenseKernel._valid_epilogue_metadata(
metadata.epilogue
):
return False
return True
@staticmethod
def _metadata_operand_combinations() -> Generator[
tuple[GroupedGemmOperandsMetadata, int], None, None
]:
"""
Generator that yields all valid (GroupedGemmOperandsMetadata, sf_vec_size) combinations
based on the validation rules in _valid_operands.
"""
# Supported A/B data types (must be the same)
ab_dtypes = [
cutlass.Float8E5M2,
cutlass.Float8E4M3FN,
cutlass.Float16,
cutlass.BFloat16,
]
out_dtypes = [
cutlass.Float32,
cutlass.Float16,
cutlass.BFloat16,
]
row_major_stride = (0, 0, 1)
col_major_stride = (0, 1, 0)
alignment_bytes = 16
offsets_alignment_bytes = 4
stride_A = row_major_stride
stride_B = col_major_stride
stride_out = row_major_stride
acc_dtype = cutlass.Float32
Impl = ContiguousOffset2D3DGemmDenseKernelImpl
for ab_dtype, out_dtype in itertools.product(ab_dtypes, out_dtypes):
if not Impl.is_valid_dtypes(ab_dtype, acc_dtype, out_dtype):
continue
ab_divisibility = alignment_bytes * 8 // ab_dtype.width
out_divisibility = alignment_bytes * 8 // out_dtype.width
offsets_divisibility = offsets_alignment_bytes * 8 // cutlass.Int32.width
operands = GroupedGemmOperandsMetadata(
A=cutlass_api.metadata.DenseTensorAttributes(
dtype=ab_dtype,
stride=stride_A,
divisibility=ab_divisibility,
),
B=cutlass_api.metadata.DenseTensorAttributes(
dtype=ab_dtype,
stride=stride_B,
divisibility=ab_divisibility,
),
out=cutlass_api.metadata.DenseTensorAttributes(
dtype=out_dtype,
stride=stride_out,
divisibility=out_divisibility,
),
offsets=cutlass_api.metadata.DenseTensorAttributes(
dtype=cutlass.Int32,
stride=(1,),
divisibility=offsets_divisibility,
),
accumulator_type=acc_dtype,
)
yield operands
@staticmethod
def generate_kernels(
metadata_filter: Callable[[KernelMetadata], bool],
epilogue_args: EpilogueArguments = None,
cc: int = None,
) -> list["ContiguousOffset2D3DGemmDenseKernel"]:
"""
Returns a list of all possible configurations of ContiguousOffset2D3DGemmDenseKernel that
adhere to constraints passed in under kwargs.
"""
if cc is not None and cc not in [100, 101, 103]:
return []
design_params = {
"use_2cta_mma": [True, False],
"tile_shape": [
(64, 128, 128),
(128, 128, 128),
(256, 128, 128),
(256, 256, 128),
],
"cluster_shape": [(M, N, 1) for M in [1, 2, 4] for N in [1, 2, 4]],
"use_tma_store": [True],
}
if epilogue_args is not None:
return []
from itertools import product
# Get the list of tunable parameter names and their possible values
param_names = list(design_params.keys())
param_values = [design_params[name] for name in param_names]
kernel_list = []
for (
operands
) in ContiguousOffset2D3DGemmDenseKernel._metadata_operand_combinations():
for values in product(*param_values):
design = Sm100DesignMetadata(**dict(zip(param_names, values)))
kernel_name = "cutedsl.ContiguousOffset2D3DGemmDenseKernel_sm100_{layout}_A{A}_B{B}_out{out}_acc{acc}_{num_cta}cta_cluster{cluster}_tile{tile}{_tma_store}".format(
layout=strides_to_layout_string(
operands.A.stride, operands.B.stride, operands.out.stride
),
A=operands.A.dtype,
B=operands.B.dtype,
out=operands.out.dtype,
acc=operands.accumulator_type,
num_cta=("2" if design.use_2cta_mma else "1"),
cluster=tuple_to_string(design.cluster_shape),
tile=tuple_to_string(design.tile_shape),
_tma_store="_tma_store" if design.use_tma_store else "",
)
metadata = KernelMetadata(
operands=operands,
design=design,
kernel_name=kernel_name,
kernel_class=ContiguousOffset2D3DGemmDenseKernel,
min_cc=100,
epilogue=None,
)
metadata_valid = ContiguousOffset2D3DGemmDenseKernel._valid_metadata(
metadata
)
if metadata_valid and metadata_filter(metadata):
kernel_list.append(ContiguousOffset2D3DGemmDenseKernel(metadata))
return kernel_list

View File

@@ -0,0 +1,418 @@
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
import itertools
from collections.abc import Callable, Generator
import cutlass
import cutlass.utils as utils
import cutlass.cute as cute
from cutlass.cute.nvgpu.tcgen05 import OperandMajorMode
from cutlass.cute.runtime import make_ptr
import cutlass_api
from cutlass_api.arguments import (
EpilogueArguments,
GemmArguments,
)
from cutlass_api.artifact import CompiledArtifact
from cutlass_api.metadata import (
DenseTensorAttributes,
DesignMetadata,
EpilogueMetadata,
GemmOperandsMetadata,
KernelMetadata,
Sm100DesignMetadata,
ScaledTensorAttributes,
)
from cutlass_api.library import ScaleMode, ScaleSwizzleMode
from cutlass_api.providers.cutedsl import CuTeDSLProvider
from cutlass_api.providers.cutedsl.kernel import CuteDslKernel
from cutlass_api.providers.cutedsl.utils import get_max_active_clusters
from cutlass_api.status import Status
from cutlass_api.utils import ceil_div, round_up, strides_to_layout_string, to_cuda_stream, tuple_to_string
from .implementations.sm100_dense_blockscaled_static_persistent_impl import PersistentDenseBlockScaledGemmKernelImpl
@CuTeDSLProvider.register
class PersistentDenseBlockScaledGemmKernel(CuteDslKernel):
def __init__(self, metadata: KernelMetadata):
super().__init__(metadata)
self.sf_vec_size = metadata.operands.A.mode[-1]
mma_tiler_mn = (metadata.design.tile_shape[0], metadata.design.tile_shape[1])
cluster_shape_mn = (
metadata.design.cluster_shape[0],
metadata.design.cluster_shape[1],
)
self.impl = PersistentDenseBlockScaledGemmKernelImpl(
self.sf_vec_size,
mma_tiler_mn,
cluster_shape_mn,
)
def _supports(self, args: GemmArguments) -> Status:
M, N = args.out.shape[-2:]
K = args.A.shape[-1]
L = args.A.shape[0] if len(args.A.shape) == 3 else 1
# To support 32-4-4 swizzling, the scale factor tensor must be padded to a multiple of 4
expected_sf_k = round_up(ceil_div(K, self.sf_vec_size), 4)
# Shapes of scale factor tensors are not enforced. It is expected that the
# data underlying the passed in scale factor tensor is in 32-4-4 swizzled format.
# Thus, we check only that the number of elements in the scale factor tensor is correct.
expected_sfa_elements = L * M * expected_sf_k
expected_sfb_elements = L * N * expected_sf_k
if args.A.scale.numel() != expected_sfa_elements:
return Status.fail(
f"Scale factor A for tensor A of shape {args.A.shape} must have "
f"{expected_sfa_elements} elements. Scale factor A is of shape {args.A.scale.shape} "
f"and has {args.A.scale.numel()} elements."
)
if args.B.scale.numel() != expected_sfb_elements:
return Status.fail(
f"Scale factor B for tensor B of shape {args.B.shape} must have "
f"{expected_sfb_elements} elements. Scale factor B is of shape {args.B.scale.shape} "
f"and has {args.B.scale.numel()} elements."
)
return Status.success()
def _construct_pointers(self, args: GemmArguments, nullptr: bool = False) -> tuple[cute.Pointer, cute.Pointer, cute.Pointer, cute.Pointer, cute.Pointer]:
def ptr(x): return 0 if nullptr else x.data_ptr
gmem = cute.AddressSpace.gmem
a_ptr = make_ptr(args.A.element_type, ptr(args.A), gmem, assumed_align=16)
b_ptr = make_ptr(args.B.element_type, ptr(args.B), gmem, assumed_align=16)
out_ptr = make_ptr(args.out.element_type, ptr(args.out), gmem, assumed_align=16)
sfa_ptr = make_ptr(
args.A.scale.element_type, ptr(args.A.scale), gmem, assumed_align=32
)
sfb_ptr = make_ptr(
args.B.scale.element_type, ptr(args.B.scale), gmem, assumed_align=32
)
return a_ptr, b_ptr, out_ptr, sfa_ptr, sfb_ptr
@staticmethod
def _major_modes(args: GemmArguments | GemmOperandsMetadata) -> tuple[tuple[OperandMajorMode, str], tuple[OperandMajorMode, str], tuple[utils.LayoutEnum, str]]:
# A, B, and out can be of rank 2 or 3. Extract the final two dimensions
# to determine the major mode
if args.A.stride[-2:].index(1) == 1:
a_major_mode = (OperandMajorMode.K, "k")
else:
a_major_mode = (OperandMajorMode.MN, "m")
if args.B.stride[-2:].index(1) == 0:
b_major_mode = (OperandMajorMode.K, "k")
else:
b_major_mode = (OperandMajorMode.MN, "n")
if args.out.stride[-2:].index(1) == 1:
out_layout = (utils.LayoutEnum.ROW_MAJOR, "n")
else:
out_layout = (utils.LayoutEnum.COL_MAJOR, "m")
return a_major_mode, b_major_mode, out_layout
def compile(self, args: GemmArguments, cc: int = None) -> CompiledArtifact:
stream = cutlass.cute.runtime.make_fake_stream()
max_active_clusters = get_max_active_clusters(self.impl.cluster_shape_mn)
a_ptr, b_ptr, out_ptr, sfa_ptr, sfb_ptr = self._construct_pointers(args, nullptr=True)
fake_problem_shape = (cutlass.Int32(0), cutlass.Int32(0), cutlass.Int32(0), cutlass.Int32(0))
epilogue_op = lambda x: x
(a_major_mode, _), (b_major_mode, _), (out_layout, _) = PersistentDenseBlockScaledGemmKernel._major_modes(args)
compiled_kernel = self.cute_compile(
self.impl,
a_ptr,
b_ptr,
sfa_ptr,
sfb_ptr,
out_ptr,
(a_major_mode, b_major_mode, out_layout),
fake_problem_shape,
max_active_clusters,
stream,
epilogue_op,
)
return CompiledArtifact(compiled_kernel, self)
def _run(
self,
args: GemmArguments,
compiled_artifact: CompiledArtifact,
stream,
workspace=None,
) -> None:
stream = to_cuda_stream(stream)
compiled_gemm = compiled_artifact.compiled_obj
m, n = args.out.shape[-2:]
k = args.A.shape[-1]
l = args.A.shape[0] if len(args.A.shape) == 3 else 1
a_ptr, b_ptr, out_ptr, sfa_ptr, sfb_ptr = self._construct_pointers(args)
self.cute_run(compiled_gemm, a_ptr, b_ptr, sfa_ptr, sfb_ptr, out_ptr, (m, n, k, l), stream)
@staticmethod
def _valid_operands(operands: GemmOperandsMetadata, sf_vec_size: int) -> bool:
if not isinstance(operands, GemmOperandsMetadata):
return False
# In current version, A and B tensor must have the same data type
# i.e., Float8E4M3FN for A and Float8E5M2 for B is not supported.
# The same holds for scale A and scale B.
if (
operands.A.dtype != operands.B.dtype
or operands.A.scale.dtype != operands.B.scale.dtype
):
return False
if operands.accumulator_type != cutlass.Float32:
return False
ab_dtype = operands.A.dtype
sf_dtype = operands.A.scale.dtype
out_dtype = operands.out.dtype
impl = PersistentDenseBlockScaledGemmKernelImpl
if not impl.is_valid_dtypes_and_scale_factor_vec_size(ab_dtype, sf_dtype, sf_vec_size, out_dtype):
return False
(_, a_major), (_, b_major), (_, out_major) = PersistentDenseBlockScaledGemmKernel._major_modes(operands)
if not impl.is_valid_layouts(ab_dtype, out_dtype, a_major, b_major, out_major):
return False
return True
@staticmethod
def _valid_design_metadata(design: DesignMetadata) -> bool:
if not isinstance(design, Sm100DesignMetadata):
return False
mma_tiler_mn = design.tile_shape[:2]
cluster_shape_mn = design.cluster_shape[:2]
impl = PersistentDenseBlockScaledGemmKernelImpl
return impl.is_valid_mma_tiler_and_cluster_shape(mma_tiler_mn, cluster_shape_mn)
@staticmethod
def _valid_epilogue_metadata(epilogue: EpilogueMetadata | None) -> bool:
return epilogue is None
@staticmethod
def _valid_metadata(metadata: KernelMetadata) -> bool:
scale_vec = metadata.operands.A.mode
# Make sure scale vector is in the form (1, 1, ..., 1, sf_vec_size)
if len(scale_vec) > 1:
for i in range(0, len(scale_vec) - 1):
if scale_vec[i] != 1:
return False
sf_vec_size = scale_vec[-1]
if not PersistentDenseBlockScaledGemmKernel._valid_operands(metadata.operands, sf_vec_size):
return False
if not PersistentDenseBlockScaledGemmKernel._valid_design_metadata(metadata.design):
return False
if not PersistentDenseBlockScaledGemmKernel._valid_epilogue_metadata(metadata.epilogue):
return False
return True
@staticmethod
def _metadata_operand_combinations() -> Generator[tuple[GemmOperandsMetadata, int], None, None]:
"""
Generator that yields all valid (GemmOperandsMetadata, sf_vec_size) combinations
based on the validation rules in _valid_operands.
"""
# Supported A/B data types (must be the same)
ab_dtypes = [
cutlass.Float8E5M2,
cutlass.Float8E4M3FN,
]
out_dtypes = [
cutlass.Float32,
cutlass.Float16,
cutlass.BFloat16,
cutlass.Float8E4M3FN,
cutlass.Float8E5M2,
]
sf_dtypes = [
cutlass.Float8E8M0FNU,
cutlass.Float8E4M3FN,
]
scale_modes = [ScaleMode.Blockwise1x16, ScaleMode.Blockwise1x32]
row_major_stride = (0, 0, 1)
col_major_stride = (0, 1, 0)
alignment_bytes = 16
def major_str_a(major_tuple: tuple[int, int, int]) -> str:
return "k" if major_tuple == row_major_stride else "m"
def major_str_b(major_tuple: tuple[int, int, int]) -> str:
return "n" if major_tuple == row_major_stride else "k"
def major_str_out(major_tuple: tuple[int, int, int]) -> str:
return "n" if major_tuple == row_major_stride else "m"
Impl = PersistentDenseBlockScaledGemmKernelImpl
for ab_dtype, sf_dtype, scale_mode, out_dtype in itertools.product(ab_dtypes, sf_dtypes, scale_modes, out_dtypes):
sf_vec_size = scale_mode[-1]
if not Impl.is_valid_dtypes_and_scale_factor_vec_size(ab_dtype, sf_dtype, sf_vec_size, out_dtype):
continue
for stride_A, stride_B, stride_out in itertools.product(
[row_major_stride, col_major_stride], repeat=3
):
a_major = major_str_a(stride_A)
b_major = major_str_b(stride_B)
out_major = major_str_out(stride_out)
if not Impl.is_valid_layouts(ab_dtype, out_dtype, a_major, b_major, out_major):
continue
ab_divisibility = alignment_bytes * 8 // ab_dtype.width
out_divisibility = alignment_bytes * 8 // out_dtype.width
sf_divisibility = alignment_bytes * 8 // sf_dtype.width
operands = GemmOperandsMetadata(
A=ScaledTensorAttributes(
base=DenseTensorAttributes(
dtype=ab_dtype,
stride=stride_A,
divisibility=ab_divisibility,
),
scale=DenseTensorAttributes(
dtype=sf_dtype,
stride=None,
divisibility=sf_divisibility,
),
mode=scale_mode,
swizzle=ScaleSwizzleMode.Swizzle32x4x4,
),
B=ScaledTensorAttributes(
base=DenseTensorAttributes(
dtype=ab_dtype,
stride=stride_B,
divisibility=ab_divisibility,
),
scale=DenseTensorAttributes(
dtype=sf_dtype,
stride=None,
divisibility=sf_divisibility,
),
mode=scale_mode,
swizzle=ScaleSwizzleMode.Swizzle32x4x4,
),
out=DenseTensorAttributes(
dtype=out_dtype,
stride=stride_out,
divisibility=out_divisibility,
),
accumulator_type=cutlass.Float32,
)
yield operands
@staticmethod
def generate_kernels(
metadata_filter: Callable[[KernelMetadata], bool],
epilogue_args: EpilogueArguments = None,
cc: int = None,
) -> list["PersistentDenseBlockScaledGemmKernel"]:
"""
Returns a list of all possible configurations of PersistentDenseBlockScaledGemmKernel that
adhere to constraints passed in under kwargs.
"""
if cc is not None and cc not in [100, 101, 103]:
return []
design_params = {
"use_2cta_mma": [True],
"tile_shape": [
(M, N, 256) for M in [128, 256] for N in [64, 128, 192, 256]
],
"cluster_shape": [
(M, N, 1) for M in [2, 4] for N in [1, 2, 4]
],
"use_tma_store": [True],
}
if epilogue_args is not None:
return []
from itertools import product
# Get the list of tunable parameter names and their possible values
param_names = list(design_params.keys())
param_values = [design_params[name] for name in param_names]
kernel_list = []
for operands in PersistentDenseBlockScaledGemmKernel._metadata_operand_combinations():
for values in product(*param_values):
design = Sm100DesignMetadata(**dict(zip(param_names, values)))
kernel_name = "cutedsl.PersistentDenseBlockScaledGemmKernel_sm100_{layout}_A{A}_B{B}_out{out}_SFA{SFA}_SFB{SFB}_acc{acc}_scale{scale_mode}_swizzle{scale_swizzle}_{num_cta}cta_cluster{cluster}_tile{tile}{_tma_store}".format(
layout=strides_to_layout_string(
operands.A.stride, operands.B.stride, operands.out.stride
),
A=operands.A.dtype,
B=operands.B.dtype,
out=operands.out.dtype,
SFA=operands.A.scale.dtype,
SFB=operands.B.scale.dtype,
acc=operands.accumulator_type,
scale_mode=operands.A.mode,
scale_swizzle=operands.A.swizzle,
num_cta=("2" if design.use_2cta_mma else "1"),
cluster=tuple_to_string(design.cluster_shape),
tile=tuple_to_string(design.tile_shape),
_tma_store="_tma_store" if design.use_tma_store else "",
)
metadata = KernelMetadata(
operands=operands,
design=design,
kernel_name=kernel_name,
kernel_class=PersistentDenseBlockScaledGemmKernel,
min_cc=100,
epilogue=None,
)
metadata_valid = PersistentDenseBlockScaledGemmKernel._valid_metadata(metadata)
if metadata_valid and metadata_filter(metadata):
kernel_list.append(PersistentDenseBlockScaledGemmKernel(metadata))
return kernel_list

View File

@@ -37,10 +37,10 @@ from cutlass_api.arguments import (
)
from cutlass_api.artifact import CompiledArtifact
from cutlass_api.metadata import (
DenseTensorAttributes,
GemmOperandsMetadata,
KernelMetadata,
Sm100DesignMetadata,
TensorAttributes,
)
from cutlass_api.providers.cutedsl import CuTeDSLProvider
from cutlass_api.providers.cutedsl.kernel import CuteDslKernel
@@ -104,21 +104,15 @@ class PersistentDenseGemmKernel(CuteDslKernel):
epilogue_op,
)
def _supports(self, args: GemmArguments) -> Status:
if args.epilogue is not None:
return Status.fail("This kernel does not support any epilogue fusion.")
return Status.success()
def compile(self, args: GemmArguments, cc: int = None) -> CompiledArtifact:
stream = cutlass.cute.runtime.make_fake_stream()
max_active_clusters = get_max_active_clusters(self.impl.cluster_shape_mn)
compiled_kernel = self.cute_compile(
self.impl,
args.A,
args.B,
args.out,
args.A.tensor,
args.B.tensor,
args.out.tensor,
max_active_clusters,
stream,
self.impl.epilogue_op,
@@ -134,7 +128,9 @@ class PersistentDenseGemmKernel(CuteDslKernel):
) -> None:
stream = to_cuda_stream(stream)
compiled_gemm = compiled_artifact.compiled_obj
self.cute_run(compiled_gemm, args.A, args.B, args.out, stream)
self.cute_run(
compiled_gemm, args.A.tensor, args.B.tensor, args.out.tensor, stream
)
@staticmethod
def _valid_operands(operands: GemmOperandsMetadata) -> bool:
@@ -279,17 +275,17 @@ class PersistentDenseGemmKernel(CuteDslKernel):
ab_divisibility = alignment_bytes * 8 // ab_dtype.width
out_divisibility = alignment_bytes * 8 // out_dtype.width
# Create TensorAttributes for A, B, and out tensors
a_attrs = TensorAttributes(
a_attrs = DenseTensorAttributes(
dtype=ab_dtype,
stride=stride_A,
divisibility=ab_divisibility,
)
b_attrs = TensorAttributes(
b_attrs = DenseTensorAttributes(
dtype=ab_dtype,
stride=stride_B,
divisibility=ab_divisibility,
)
out_attrs = TensorAttributes(
out_attrs = DenseTensorAttributes(
dtype=out_dtype,
stride=stride_out,
divisibility=out_divisibility,

View File

@@ -37,11 +37,11 @@ from cutlass_api.arguments import (
)
from cutlass_api.artifact import CompiledArtifact
from cutlass_api.metadata import (
DenseTensorAttributes,
EpilogueMetadata,
GemmOperandsMetadata,
KernelMetadata,
Sm100DesignMetadata,
TensorAttributes,
)
from cutlass_api.providers.cutedsl import CuTeDSLProvider
from cutlass_api.providers.cutedsl.evt.converter import EFCConverter
@@ -284,17 +284,17 @@ class PersistentDenseGemmEFCKernel(CuteDslKernel):
ab_divisibility = alignment_bytes * 8 // ab_dtype.width
out_divisibility = alignment_bytes * 8 // out_dtype.width
# Create TensorAttributes for A, B, and out tensors
a_attrs = TensorAttributes(
a_attrs = DenseTensorAttributes(
dtype=ab_dtype,
stride=stride_A,
divisibility=ab_divisibility,
)
b_attrs = TensorAttributes(
b_attrs = DenseTensorAttributes(
dtype=ab_dtype,
stride=stride_B,
divisibility=ab_divisibility,
)
out_attrs = TensorAttributes(
out_attrs = DenseTensorAttributes(
dtype=out_dtype,
stride=stride_out,
divisibility=out_divisibility,
@@ -311,7 +311,6 @@ class PersistentDenseGemmEFCKernel(CuteDslKernel):
yield operands
def _supports(self, args: GemmArguments) -> Status:
if args.epilogue is not None:
fusion_metadata = EpilogueMetadata.from_args(args.epilogue)
if not self._valid_fusion(fusion_metadata):
@@ -339,8 +338,8 @@ class PersistentDenseGemmEFCKernel(CuteDslKernel):
self.impl.efc.compile(epilogue_params)
compiled_gemm = self.cute_compile(
self.impl,
args.A,
args.B,
args.A.tensor,
args.B.tensor,
max_active_clusters,
stream,
self.impl.efc.jit.pack_arguments(*epilogue_params),
@@ -376,7 +375,9 @@ class PersistentDenseGemmEFCKernel(CuteDslKernel):
epilogue_params = [args.out]
compiled_gemm = compiled_artifact.compiled_obj
self.cute_run(compiled_gemm, args.A, args.B, stream, *epilogue_params)
self.cute_run(
compiled_gemm, args.A.tensor, args.B.tensor, stream, *epilogue_params
)
@staticmethod
def generate_kernels(

View File

@@ -37,15 +37,16 @@ from cutlass_api.arguments import (
)
from cutlass_api.artifact import CompiledArtifact
from cutlass_api.metadata import (
DenseTensorAttributes,
GemmOperandsMetadata,
KernelMetadata,
TensorAttributes,
)
from cutlass_api.providers.cutedsl import CuTeDSLProvider
from cutlass_api.providers.cutedsl.kernel import CuteDslKernel
from cutlass_api.status import Status
from cutlass_api.utils import strides_to_layout_string, to_cuda_stream, tuple_to_string
from cutlass_api.metadata import BLASDesignMetadata
from .implementations.sm80_tensorop_gemm_impl import Sm80TensorOpGemmImpl
@@ -77,20 +78,14 @@ class Sm80TensorOpGemmKernel(CuteDslKernel):
metadata.operands.accumulator_type
)
def _supports(self, args: GemmArguments) -> Status:
if args.epilogue is not None:
return Status.fail("This kernel does not support any epilogue fusion.")
return Status.success()
def compile(self, args: GemmArguments, cc: int = None) -> CompiledArtifact:
stream = cutlass.cute.runtime.make_fake_stream()
compiled_kernel = self.cute_compile(
self.impl,
args.A,
args.B,
args.out,
args.A.tensor,
args.B.tensor,
args.out.tensor,
stream,
)
return CompiledArtifact(compiled_kernel, self)
@@ -104,7 +99,9 @@ class Sm80TensorOpGemmKernel(CuteDslKernel):
) -> None:
stream = to_cuda_stream(stream)
compiled_gemm = compiled_artifact.compiled_obj
self.cute_run(compiled_gemm, args.A, args.B, args.out, stream)
self.cute_run(
compiled_gemm, args.A.tensor, args.B.tensor, args.out.tensor, stream
)
@staticmethod
def _valid_operands(operands: GemmOperandsMetadata) -> bool:
@@ -155,17 +152,17 @@ class Sm80TensorOpGemmKernel(CuteDslKernel):
out_divisibility = alignment_bytes * 8 // out_dtype.width
# Create TensorAttributes for A, B, and out tensors
a_attrs = TensorAttributes(
a_attrs = DenseTensorAttributes(
dtype=ab_dtype,
stride=stride_A,
divisibility=ab_divisibility,
)
b_attrs = TensorAttributes(
b_attrs = DenseTensorAttributes(
dtype=ab_dtype,
stride=stride_B,
divisibility=ab_divisibility,
)
out_attrs = TensorAttributes(
out_attrs = DenseTensorAttributes(
dtype=out_dtype,
stride=stride_out,
divisibility=out_divisibility,

View File

@@ -0,0 +1,75 @@
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
from cutlass_api.arguments import RuntimeArguments
from cutlass_api.status import Status
def tensor_rank_2_or_3(args: RuntimeArguments) -> Status:
"""
Checks that the arguments are valid for a tensor of rank 2 or 3.
"""
if len(args.A.shape) < 2 or len(args.A.shape) > 3:
return Status.fail(
f"A must be a tensor of rank 2 or 3 (L=1, M, K), got {args.A.shape}"
)
if len(args.B.shape) < 2 or len(args.B.shape) > 3:
return Status.fail(
f"B must be a tensor of rank 2 or 3 (L=1, K, N), got {len(args.B.shape)}"
)
if len(args.out.shape) < 2 or len(args.out.shape) > 3:
return Status.fail(
f"out must be a tensor of rank 2 or 3 (L=1, M, N), got {len(args.out.shape)}"
)
return Status.success()
def tensor_reduction_mode_matches(args: RuntimeArguments) -> Status:
if args.A.shape[-1] != args.B.shape[-2]:
return Status.fail(
f"A's K dimension ({args.A.shape[-1]}) must be equal to B's "
f"K dimension ({args.B.shape[-2]}). "
f"A shape (L, M, K): {args.A.shape}, B shape (L, K, N): {args.B.shape}"
)
return Status.success()
def tensor_output_shape_matches(args: RuntimeArguments) -> Status:
if args.out.shape[-2] != args.A.shape[-2]:
return Status.fail(
f"out's M dimension ({args.out.shape[-2]}) must be equal to A's "
f"M dimension ({args.A.shape[-2]}). "
f"A shape (L, M, K): {args.A.shape}, out shape (L, M, N): {args.out.shape}"
)
if args.out.shape[-1] != args.B.shape[-1]:
return Status.fail(
f"out's N dimension ({args.out.shape[-1]}) must be equal to B's "
f"N dimension ({args.B.shape[-1]}). "
f"B shape (L, K, N): {args.B.shape}, out shape (L, M, N): {args.out.shape}"
)
return Status.success()

View File

@@ -47,6 +47,36 @@ if TYPE_CHECKING:
import torch
def ceil_div(a: int, b: int) -> int:
"""
Compute the ceiling division of a by b.
:param a: The dividend.
:type a: int
:param b: The divisor.
:type b: int
:return: The ceiling division of a by b.
:rtype: int
"""
return (a + b - 1) // b
def round_up(a: int, b: int) -> int:
"""
Round a up to the nearest multiple of b.
:param a: The value to round up.
:type a: int
:param b: The multiple to round up to.
:type b: int
:return: The value of a rounded up to the nearest multiple of b.
:rtype: int
"""
return ((a + b - 1) // b) * b
def is_numpy_available() -> bool:
"""Check if numpy is available."""
return importlib.util.find_spec("numpy") is not None
@@ -173,6 +203,7 @@ def cutlass_type_from_torch_type(dtype) -> type[cutlass.Numeric]:
torch.float8_e5m2: cutlass.Float8E5M2,
torch.float8_e4m3fn: cutlass.Float8E4M3FN,
torch.float8_e4m3fnuz: cutlass.Float8E4M3B11FNUZ,
torch.float8_e8m0fnu: cutlass.Float8E8M0FNU,
}
try:
@@ -236,52 +267,6 @@ def to_cuda_stream(
raise ValueError(f"Unsupported stream type: {type(stream)}")
def add_batch_mode(
tensor: cute.Tensor | torch.Tensor,
) -> cute.Tensor | torch.Tensor:
"""
Adds a batch mode to the tensor.
If the tensor is a torch.Tensor and has rank 2,
it will be unsqueezed along the first dimension.
:param tensor: The tensor to add batch mode to.
:type tensor: Union[cute.Tensor, "torch.Tensor"]
:return: The tensor with batch mode added.
:rtype: Union[cute.Tensor, "torch.Tensor"]
"""
if is_torch_tensor(tensor):
if tensor.dim() == 2:
return tensor.unsqueeze(0)
elif tensor.dim() < 2 or tensor.dim() > 3:
raise ValueError(f"Expected 2-3 dimensions, got {tensor.dim()}")
return tensor
return tensor
def permute_batch_mode(
tensor: cute.Tensor | torch.Tensor,
) -> cute.Tensor | torch.Tensor:
"""
Permute the batch mode of the tensor.
If the tensor is a torch.Tensor and has rank 3,
it will be permuted along the first dimension.
:param tensor: The tensor to permute.
:type tensor: Union[cute.Tensor, "torch.Tensor"]
:return: The tensor with batch mode permuted.
:rtype: Union[cute.Tensor, "torch.Tensor"]
"""
if is_torch_tensor(tensor):
if tensor.dim() != 3:
raise ValueError(f"Expected 3 dimensions, got {tensor.dim()}")
return tensor.permute([1, 2, 0])
else:
raise ValueError(f"Unsupported type: {type(tensor)}")
def leading_dim(tensor) -> int:
"""
Get the leading dimension of a tensor. This is the first mode with
@@ -378,7 +363,7 @@ class TensorWrapper:
is enabled.
"""
def __init__(self, tensor: Any):
def __init__(self, tensor: Any, alignment_bytes: int = 16):
if isinstance(tensor, cute.Tensor):
# Regardless of whether TVM-FFI is enabled, if the tensor passed in is a cute.Tensor,
# it can be used as the runtime tensor and compile time tensor.
@@ -399,7 +384,9 @@ class TensorWrapper:
stride_order = get_stride_rank(self._stride)
leading_dim_idx = stride_order.index(0)
shape = [cute.SymInt() for _ in range(rank)]
shape[leading_dim_idx] = cute.SymInt(divisibility=16 * 8 // dtype.width)
shape[leading_dim_idx] = cute.SymInt(
divisibility=alignment_bytes * 8 // dtype.width
)
self._shape = tuple(self.runtime_tensor.shape)
self._data_ptr = self.runtime_tensor.data_ptr()
else:
@@ -411,7 +398,7 @@ class TensorWrapper:
dtype,
shape,
stride_order=stride_order,
assumed_align=16, # bytes
assumed_align=alignment_bytes,
)
else:
# TVM-FFI is disabled and the tensor passed in is not a cute.Tensor,
@@ -425,12 +412,12 @@ class TensorWrapper:
self.runtime_tensor = (
from_dlpack(
tensor,
assumed_align=16, # bytes
assumed_align=alignment_bytes,
)
.mark_layout_dynamic(leading_dim(tensor))
.mark_compact_shape_dynamic(
mode=leading_dim(tensor),
divisibility=16 * 8 // dtype.width,
divisibility=alignment_bytes * 8 // dtype.width,
stride_order=stride_order,
)
)
@@ -459,6 +446,12 @@ class TensorWrapper:
def data_ptr(self) -> int:
return self._data_ptr
def numel(self) -> int:
num = self._shape[0]
for i in range(1, len(self._shape)):
num *= self._shape[i]
return num
def strides_to_layout_string(*strides: list[tuple[int, ...]]) -> str:
"""

View File

@@ -5,7 +5,7 @@
"id": "3dd45ef2",
"metadata": {},
"source": [
"# Basic GEMM using CUTLASS Python API"
"# Basic GEMM using CUTLASS API"
]
},
{

View File

@@ -229,7 +229,7 @@
" self, args: GemmArguments, cc: int = None\n",
") -> cutlass_api.artifact.CompiledArtifact:\n",
" stream = cutlass.cute.runtime.make_fake_stream()\n",
" compiled_gemm = self.cute_compile(self.impl, args.A, args.B, args.out, stream)\n",
" compiled_gemm = self.cute_compile(self.impl, args.A.tensor, args.B.tensor, args.out.tensor, stream)\n",
" return cutlass_api.artifact.CompiledArtifact(compiled_gemm, self)"
]
},
@@ -260,7 +260,7 @@
"):\n",
" stream = cutlass_api.utils.to_cuda_stream(stream)\n",
" compiled_gemm = artifact.compiled_obj\n",
" self.cute_run(compiled_gemm, args.A, args.B, args.out, stream)"
" self.cute_run(compiled_gemm, args.A.tensor, args.B.tensor, args.out.tensor, stream)"
]
},
{
@@ -343,13 +343,13 @@
" design_metadata = cutlass_api.metadata.BLASDesignMetadata(tile_shape, (1, 1, 1))\n",
" for stride_A, stride_B, stride_out in stride_combos:\n",
" # Create TensorAttributes for A, B, and out tensors\n",
" a_attrs = cutlass_api.metadata.TensorAttributes(\n",
" a_attrs = cutlass_api.metadata.DenseTensorAttributes(\n",
" cutlass.Float64, stride_A, divisibility\n",
" )\n",
" b_attrs = cutlass_api.metadata.TensorAttributes(\n",
" b_attrs = cutlass_api.metadata.DenseTensorAttributes(\n",
" cutlass.Float64, stride_B, divisibility\n",
" )\n",
" out_attrs = cutlass_api.metadata.TensorAttributes(\n",
" out_attrs = cutlass_api.metadata.DenseTensorAttributes(\n",
" cutlass.Float64, stride_out, divisibility\n",
" )\n",
" layout_str = cutlass_api.utils.strides_to_layout_string(\n",

View File

@@ -0,0 +1,240 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "91d43c2b",
"metadata": {},
"source": [
"# Grouped GEMM with contiguous tensors via the CUTLASS API\n",
"\n",
"Note: this notebook requires a GPU with compute capability 100:"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "f671f602",
"metadata": {},
"outputs": [],
"source": [
"import cutlass_api\n",
"\n",
"if not (status := cutlass_api.utils.is_device_cc_supported({100})):\n",
" print(\n",
" f\"This notebook requires a GPU with compute capability 100.\\n{status.error}\"\n",
" )\n",
" import sys\n",
" sys.exit(0)"
]
},
{
"cell_type": "markdown",
"id": "bc4adf7d",
"metadata": {},
"source": [
"This notebook shows how to use the CUTLASS API to discover, compile, and execute\n",
"kernels supporting contiguous offset grouped GEMMs.\n",
"\n",
"In a \"contiguous offset\" grouped GEMM, `G` different problems are executed\n",
"in which problems differ only in the `M` mode. Their problem sizes are thus\n",
"represented as:\n",
"\n",
"```text\n",
"M0 x N x K\n",
"M1 x N x K\n",
"M2 x N x K\n",
"...\n",
"M(G-1) x N x K\n",
"```\n",
"\n",
"The grouped GEMM is referred to as \"contiguous\" because operands for different\n",
"problems in the group are contained within contiguous tensors.\n",
"\n",
"Rather than having `G` different tensors for each of operands `A` and `B`, tensors\n",
"for different problems in the group are packed together:\n",
"* `A` is of shape `(TotalM, K)`, where `TotalM` is the sum of all `M` modes for problems in the group.\n",
"The `A` operands for each problem in the group are stacked along the `M` mode to form this input. More on this below.\n",
"* `B` is of shape `(G, K, N)`, where `B[i, :, :]` represents the GEMM `B` operand for the `i`th problem in the group.\n",
"\n",
"For example, with `G=3` (three problems in the group), with `M` modes of M0, M1, and M2,\n",
"respectively, the tensor `A` would be laid out as follows:\n",
"\n",
"```text\n",
"\n",
" +----------------------------------+ ^ \n",
" | | | | \n",
" | A0 | M0 | \n",
" | | | | \n",
" |- - - - - - - - - - - -| | \n",
" | | | |\n",
" | | | TotalM \n",
" | A1 | M1 |\n",
" | | | |\n",
" | | | | \n",
" |- - - - - - - - - - - -| | \n",
" | A2 | M2 | \n",
" +----------------------------------+ v \n",
"```\n",
"\n",
"The extents of individual `A` operands packed within the overall contiguous offset `A` tensor\n",
"are provided by an auxiliary `offsets` vector of shape `(G,)`. `offsets[i]` indicates the ending\n",
"M coordinate (exclusive) for the `i`th `A` operand.\n",
"\n",
"Thus, for the example above, `offsets = [M0, M0 + M1, M0 + M1 + M2]`.\n",
"\n",
"The output of the operation is of shape `(TotalM, N)`. The `i`th output occupies `out[start:end, :]`,\n",
"where `start` and `end` are `offsets[i-1]` and `offsets[i]`, respectively (unless `i=0`, in which case\n",
"`start` is 0).\n",
"\n",
"The reference code below shows the computation of this kernel."
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "6185f60a",
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"\n",
"def reference_contiguous_offset_grouped_gemm(A, B, offsets, out_dtype):\n",
" G, K, N = B.shape\n",
" TotalM = A.shape[0]\n",
"\n",
" out = torch.empty((TotalM, N), dtype=out_dtype, device=A.device)\n",
"\n",
" start = 0\n",
" for i in range(G):\n",
" end = offsets[i]\n",
" out[start:end, :] = A[start:end, :] @ B[i, :, :]\n",
" start = end\n",
"\n",
" return out"
]
},
{
"cell_type": "markdown",
"id": "d0bf2f91",
"metadata": {},
"source": [
"## Contiguous offset grouped GEMM in PyTorch"
]
},
{
"cell_type": "markdown",
"id": "4308a6a2",
"metadata": {},
"source": [
"The same operation is performed by `torch`'s `torch._grouped_mm` (torch < 2.10)\n",
"and `torch.nn.functional.grouped_mm` (torch >= 2.10)."
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "043906af",
"metadata": {},
"outputs": [],
"source": [
"TotalM = 8192\n",
"G = 12\n",
"K = 1024\n",
"N = 2048\n",
"\n",
"offsets = torch.arange(TotalM // G, TotalM, TotalM // G, device=\"cuda\", dtype=torch.int32)\n",
"offsets[-1] = TotalM\n",
"\n",
"A = torch.randn(TotalM, K, device=\"cuda\", dtype=torch.bfloat16)\n",
"B = torch.randn(G, N, K, device=\"cuda\", dtype=torch.bfloat16).permute(0, 2, 1)\n",
"\n",
"out_torch = torch._grouped_mm(A, B, offsets, out_dtype=torch.bfloat16)\n",
"reference = reference_contiguous_offset_grouped_gemm(A, B, offsets, out_dtype=torch.bfloat16)\n",
"\n",
"torch.testing.assert_close(out_torch, reference)"
]
},
{
"cell_type": "markdown",
"id": "0d0e9479",
"metadata": {},
"source": [
"## Contiguous offset grouped GEMM in CUTLASS API\n",
"\n",
"CUTLASS API exposes this contiguous offset grouped GEMM via `GroupedGemmArguments`,\n",
"which are constructed similarly to `GemmArguments`, but take in an `offsets`\n",
"tensor as well:"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "ff8d3ef1",
"metadata": {},
"outputs": [],
"source": [
"out = torch.empty((TotalM, N), device=\"cuda\", dtype=torch.bfloat16)\n",
"\n",
"args = cutlass_api.arguments.GroupedGemmArguments(\n",
" A,\n",
" B,\n",
" out,\n",
" accumulator_type=torch.float32,\n",
" offsets=offsets,\n",
")"
]
},
{
"cell_type": "markdown",
"id": "0dc6d1cb",
"metadata": {},
"source": [
"One can then use the same APIs for finding, compiling, and executing a\n",
"kernel supporting this operation"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "80213e1e",
"metadata": {},
"outputs": [],
"source": [
"kernels = cutlass_api.get_kernels(args, cc=100)\n",
"\n",
"assert kernels, \"No kernels found\"\n",
"\n",
"# Select the first kernel found for simplicity\n",
"kernel = kernels[0]\n",
"\n",
"compiled_kernel = kernel.compile(args)\n",
"\n",
"# Execute the kernel\n",
"kernel.run(args, compiled_artifact=compiled_kernel)\n",
"\n",
"torch.testing.assert_close(out, reference)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.5"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

View File

@@ -12,7 +12,7 @@ requires-python = ">=3.12"
dependencies = [
"nvidia-cutlass-dsl",
"apache-tvm-ffi", # Required for optimal performance. To opt-out: uninstall or set cutlass_api.config.GlobalOptions().use_tvm_ffi=False
"apache-tvm-ffi", # Required for best performance. To opt-out: uninstall or set cutlass_api.config.GlobalOptions().use_tvm_ffi=False
]
[project.optional-dependencies]

View File

@@ -32,6 +32,33 @@ import pytest
import cutlass_api
from cutlass_api.config import GlobalOptions
# Global variable to store the test level, accessible at collection time
_test_level = None
def get_test_level():
"""Get the test level set via --level CLI option."""
return _test_level
def pytest_addoption(parser):
parser.addoption(
"--level",
action="store",
default="L0",
type=str,
choices=["L0", "L1", "L2"],
help="Test level to run",
)
def pytest_configure(config):
global _test_level
_test_level = config.getoption("--level", "L0")
#
# Before each test, save the GlobalOptions dict
# After each test, restore the GlobalOptions dict

View File

@@ -0,0 +1,332 @@
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
import functools
import os
import pytest
import random
import torch
import cutlass
import cutlass_api
from cutlass_api.utils import is_device_cc_supported
torch.manual_seed(2025)
random.seed(2025)
def create_offsets(
problems_in_group: int,
expect_m: int,
force_empty_problems: bool,
) -> tuple[int, list[int], torch.Tensor]:
"""
Utility to create offsets for the contiguous offset tensor.
:param problems_in_group: Number of problems to create offsets for.
:type problems_in_group: int
:param expect_m: Expected number of rows in each group.
:type expect_m: int
:param force_empty_problems: Whether to force some problems in the group to be empty.
:type force_empty_problems: bool
:return: Total number of rows, list of group sizes, and offset mapping tensor.
:rtype: tuple[int, list[int], torch.Tensor]
"""
valid_m = 0
problem_m_list = []
for i in range(problems_in_group):
problem_m = random.randint(int(expect_m * 0.7), int(expect_m * 1.3))
valid_m += problem_m
# handle the case that valid_m == 0
if (i == problems_in_group - 1) and (valid_m == 0):
problem_m = 128
valid_m += problem_m
problem_m_list.append(problem_m)
if force_empty_problems:
problems_to_zero = random.sample(
range(problems_in_group), max(problems_in_group - 2, 1)
)
for problem_idx in problems_to_zero:
valid_m -= problem_m_list[problem_idx]
problem_m_list[problem_idx] = 0
end_mapping = torch.cumsum(torch.tensor(problem_m_list, dtype=torch.int32), dim=0)
offsets = torch.tensor(end_mapping, device="cuda", dtype=torch.int32)
return valid_m, problem_m_list, offsets
def args_and_ref_from_metadata(
metadata: cutlass_api.metadata.KernelMetadata,
mnkl: tuple[int, int, int, int],
force_empty_problems: bool,
):
m, n, k, l = mnkl
valid_m, problem_m_list, offsets = create_offsets(l, m, force_empty_problems)
# Create A (1, valid_m, k)
if metadata.operands.A.stride[-2:].index(1) == 1:
A = torch.randint(-1, 2, (valid_m, k), device="cuda").to(
cutlass.torch.dtype(metadata.operands.A.dtype)
)
else:
A = (
torch.randint(-1, 2, (k, valid_m), device="cuda")
.to(cutlass.torch.dtype(metadata.operands.A.dtype))
.T
)
# Create B (l, k, n)
if metadata.operands.B.stride[-2:].index(1) == 1:
B = torch.randint(-1, 2, (l, k, n), device="cuda").to(
cutlass.torch.dtype(metadata.operands.B.dtype)
)
else:
B = (
torch.randint(-1, 2, (l, n, k), device="cuda")
.to(cutlass.torch.dtype(metadata.operands.B.dtype))
.permute(0, 2, 1)
)
# Create out (1, valid_m, n)
if metadata.operands.out.stride[-2:].index(1) == 1:
out = torch.zeros(
(valid_m, n),
device="cuda",
dtype=cutlass.torch.dtype(metadata.operands.out.dtype),
)
else:
out = torch.zeros(
(n, valid_m),
device="cuda",
dtype=cutlass.torch.dtype(metadata.operands.out.dtype),
).T
args = cutlass_api.arguments.GroupedGemmArguments(
A=A,
B=B,
out=out,
accumulator_type=cutlass.torch.dtype(metadata.operands.accumulator_type),
offsets=offsets,
)
#
# Compute reference
#
out_type = cutlass.torch.dtype(metadata.operands.out.dtype)
# Compute reference in F32 because torch does not support GEMMs
# for all FP8 types
if hasattr(torch.nn.functional, "grouped_mm"):
reference = torch.nn.functional.grouped_mm(
A.float(), B.float(), offs=offsets
).to(out_type)
else:
reference = torch._grouped_mm(A.float(), B.float(), offsets).to(out_type)
return args, reference, out
def kernels_for_class(kernel_class):
def metadata_filter(metadata):
return metadata.kernel_class == kernel_class
kernels = cutlass_api.get_kernels(
args=None, metadata_filter=metadata_filter, cc=100
)
assert len(kernels) > 0
# Toggle the number of kernels to return based on the test level
try:
from conftest import get_test_level
test_level = get_test_level()
except ImportError:
test_level = "L0"
if test_level == "L0":
return random.sample(kernels, 10)
else:
return kernels
@pytest.mark.parametrize(
"mnkl, force_empty_problems",
[
((13, 16, 32, 3), False),
((253, 32, 32, 7), False),
((256, 256, 512, 1), False),
((256, 256, 512, 10), False),
((8192, 8192, 8192, 4), False),
((8192, 8192, 8192, 4), True),
((1024, 4096, 4096, 15), False),
((256, 16384, 128, 20), False),
],
)
@pytest.mark.parametrize(
"kernel",
kernels_for_class(
cutlass_api.providers.cutedsl.gemm.sm100_contiguous_offset_2d3d_dense_gemm.ContiguousOffset2D3DGemmDenseKernel
),
)
@pytest.mark.skipif(
not is_device_cc_supported({100})
or (os.getenv("CUTE_DSL_ARCH", "") not in ["", "sm_100a", "sm_100f"]),
reason="Requires compute capability 100 and to be compiled with sm_100a or sm_100f",
)
def test_contiguous_offset_gemm_dense(mnkl, force_empty_problems, kernel):
m, n, k, l = mnkl
args, reference, out = args_and_ref_from_metadata(
kernel.metadata, mnkl, force_empty_problems
)
if not (status := kernel.supports(args)):
pytest.fail(
f"Kernel {kernel.metadata.kernel_name} does not support the given arguments: {status}"
)
kernel.run(args)
torch.testing.assert_close(
out, reference, msg=f"Kernel {kernel.metadata.kernel_name} failed"
)
def kernel_metadata_filter(metadata):
return (
metadata.kernel_class
== cutlass_api.providers.cutedsl.gemm.sm100_contiguous_offset_2d3d_dense_gemm.ContiguousOffset2D3DGemmDenseKernel
)
def test_incorrect_ab_contiguous():
"""
Contiguous offset GEMMs currently require A to be in contiguous offset form.
Test that no kernels are found when only B is in contiguous offset form.
"""
problem_count, m, n, k = 12, 8192, 128, 512
A = torch.empty((problem_count, m, k), device="cuda", dtype=torch.float16)
B = torch.empty((1, n, k), device="cuda", dtype=torch.float16).permute(0, 2, 1)
out = torch.empty((1, m, n), device="cuda", dtype=torch.float32)
offsets = torch.empty((problem_count,), device="cuda", dtype=torch.int32)
args = cutlass_api.arguments.GroupedGemmArguments(
A=A,
B=B,
out=out,
accumulator_type=torch.float32,
offsets=offsets,
)
kernels = cutlass_api.get_kernels(
args, metadata_filter=kernel_metadata_filter, cc=100
)
assert len(kernels) == 0
def test_incorrect_offset_length():
"""
Offset tensors are required to have `problem_count` elements.
Test that no kernels are found when this is violated.
"""
problem_count, m, n, k = 12, 8192, 128, 512
A = torch.empty((1, m, k), device="cuda", dtype=torch.float16)
B = torch.empty((problem_count, n, k), device="cuda", dtype=torch.float16).permute(
0, 2, 1
)
out = torch.empty((1, m, n), device="cuda", dtype=torch.float32)
# Incorrect: should have `problem_count` elements
offsets = torch.empty((problem_count + 1,), device="cuda", dtype=torch.int32)
args = cutlass_api.arguments.GroupedGemmArguments(
A=A,
B=B,
out=out,
accumulator_type=torch.float32,
offsets=offsets,
)
kernels = cutlass_api.get_kernels(
args, metadata_filter=kernel_metadata_filter, cc=100
)
assert len(kernels) == 0
def test_incorrect_offsets_shape():
"""
Offset maps are expected to be rank 1. Test that no kernels are found when this is not the case.
"""
problem_count, m, n, k = 12, 8192, 128, 512
A = torch.empty((1, m, k), device="cuda", dtype=torch.float16)
B = torch.empty((problem_count, n, k), device="cuda", dtype=torch.float16).permute(
0, 2, 1
)
out = torch.empty((1, m, n), device="cuda", dtype=torch.float32)
# Correct representation of offsets. Kernels should be found.
offsets = torch.empty((problem_count,), device="cuda", dtype=torch.int32)
args = cutlass_api.arguments.GroupedGemmArguments(
A=A,
B=B,
out=out,
accumulator_type=torch.float32,
offsets=offsets,
)
kernels = cutlass_api.get_kernels(
args, metadata_filter=kernel_metadata_filter, cc=100
)
assert len(kernels) > 0
# Reformat to be rank 2. No kernels should be found
args = cutlass_api.arguments.GroupedGemmArguments(
A=A,
B=B,
out=out,
accumulator_type=torch.float32,
offsets=offsets.view(3, 4),
)
kernels = cutlass_api.get_kernels(
args, metadata_filter=kernel_metadata_filter, cc=100
)
assert len(kernels) == 0

View File

@@ -0,0 +1,725 @@
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
from collections.abc import Callable
from importlib.util import find_spec
import os
import pytest
import torch
import cutlass
import cutlass_api
from cutlass_api.arguments import EpilogueArguments, ScaledTensor
from cutlass_api.config import GlobalOptions
from cutlass_api.library import ScaleMode, ScaleSwizzleMode
from cutlass_api.metadata import KernelMetadata, Sm100DesignMetadata
from cutlass_api.utils import ceil_div, is_device_cc_supported, round_up
torch.manual_seed(2025)
def prep_k(K: int, scale_size: int):
"""
Prepares K mode for requirements needed for 32-4-4 swizzling
:param K: The value to prepare.
:type K: int
:param scale_size: The scale size.
:type scale_size: int
:return: The K mode rounded up to the requirements needed for 32-4-4 swizzling
:rtype: int
"""
return round_up(ceil_div(K, scale_size), 4)
def reference_scaled_mm(
A: torch.Tensor,
B: torch.Tensor,
scale_A: torch.Tensor,
scale_B: torch.Tensor,
out_dtype: torch.dtype,
transform_sf: Callable[[torch.Tensor], torch.Tensor] = lambda x : x
):
"""
Computes a reference scaled mm operation. Currently, torch._scaled_mm does not support batch mode.
When a batch mode is present, this function iterates through each problem in the batch.
:param A: The A tensor
:type A: torch.Tensor
:param B: The B tensor
:type B: torch.Tensor
:param scale_A: The scale factor tensor for operand A
:type scale_A: torch.Tensor
:param scale_B: The scale factor tensor for operand B
:type scale_B: torch.Tensor
:param out_dtype: The output dtype
:type out_dtype: torch.dtype
:param transform_sf: A function to transform the scale factor tensors to the correct shape for the scaled mm operation
:type transform_sf: Callable
:return: The reference scaled mm operation
"""
if len(A.shape) == 2:
return torch._scaled_mm(
A, B, scale_a=scale_A, scale_b=scale_B, out_dtype=out_dtype
)
else:
# torch._scaled_mm does not support batch mode. Iterate through each problem in the batch
L, M, N = A.shape[0], A.shape[1], B.shape[2]
reference = torch.empty((L, M, N), device=A.device, dtype=out_dtype)
for l in range(L):
reference[l, :, :] = torch._scaled_mm(
A[l, :, :],
B[l, :, :],
scale_a=transform_sf(scale_A[l, :, :]),
scale_b=transform_sf(scale_B[l, :, :]),
out_dtype=out_dtype
)
return reference
@pytest.mark.parametrize(
"M, N, K, L",
[
(256, 512, 1024, 1),
(1024, 640, 512, 2),
(256, 512, 1024 + 496, 1), # Test where K is not divisible by scale_size (32)
],
)
@pytest.mark.parametrize(
"ab_dtype, c_dtype, accumulator_type, scale_dtype",
[
(torch.float8_e4m3fn, torch.float32, torch.float32, torch.float8_e8m0fnu),
(torch.float8_e5m2, torch.float32, torch.float32, torch.float8_e8m0fnu),
],
)
@pytest.mark.skipif(
not is_device_cc_supported({100, 103})
or (os.getenv("CUTE_DSL_ARCH", "") not in ["", "sm_100a", "sm_103a"]),
reason="Requires compute capability 100 and to be compiled with sm_100a or sm_103a",
)
def test_mxfp8_gemm_sm100(
M: int,
N: int,
K: int,
L: int,
ab_dtype: torch.dtype,
c_dtype: torch.dtype,
accumulator_type: torch.dtype,
scale_dtype: torch.dtype,
fixture_enable_tvm_ffi,
):
# Create torch fp8 tensors for A and B
A = torch.randint(-1, 2, (L, M, K), device="cuda").to(ab_dtype)
D = torch.empty((L, M, N), device="cuda", dtype=c_dtype)
# Transpose B because torch._scaled_mm expects B to be column-major
B = torch.randint(-1, 2, (L, N, K), device="cuda").to(ab_dtype).transpose(1, 2)
scale_size = 32
SFA = torch.rand((L, M, prep_k(K, scale_size),), device="cuda").to(scale_dtype)
SFB = torch.rand((L, prep_k(K, scale_size), N), device="cuda").to(scale_dtype)
args = cutlass_api.arguments.GemmArguments(
A=ScaledTensor(
A,
SFA,
ScaleMode.Blockwise1x32,
ScaleSwizzleMode.Swizzle32x4x4,
),
B=ScaledTensor(
B,
SFB,
ScaleMode.Blockwise1x32,
ScaleSwizzleMode.Swizzle32x4x4,
),
out=D,
accumulator_type=accumulator_type,
)
kernels = cutlass_api.get_kernels(args, cc=100)
assert len(kernels) > 0
kernel = kernels[0]
assert kernel.supports(args)
compiled_artifact = kernel.compile(args)
kernel.run(args, compiled_artifact=compiled_artifact, assume_supported_args=True)
# torch._scaled_mm does not support f8e5m2 * f8e5m2 currently.
# Simply skip reference check in that case (but test that a CUTLASS API kernel
# is found and runs)
if ab_dtype != torch.float8_e5m2:
reference = reference_scaled_mm(A, B, SFA, SFB, c_dtype)
torch.testing.assert_close(D, reference)
@pytest.mark.parametrize(
"M, N, K",
[
(256, 512, 1024),
(1024, 640, 512),
],
)
@pytest.mark.parametrize(
"ab_dtype, c_dtype, accumulator_type, scale_dtype",
[
(torch.float8_e4m3fn, torch.float32, torch.float32, torch.float8_e8m0fnu),
],
)
@pytest.mark.skipif(
not is_device_cc_supported({100, 103})
or (os.getenv("CUTE_DSL_ARCH", "") not in ["", "sm_100a", "sm_103a"]),
reason="Requires compute capability 100 and to be compiled with sm_100a or sm_103a",
)
def test_mxfp8_gemm_sm100_2d(
M: int,
N: int,
K: int,
ab_dtype: torch.dtype,
c_dtype: torch.dtype,
accumulator_type: torch.dtype,
scale_dtype: torch.dtype,
fixture_enable_tvm_ffi,
):
"""
Tests valid MXFP8 GEMM cases in which A, B, and out are 2D tensors.
"""
# Create torch fp8 tensors for A and B
A = torch.randint(-1, 2, (M, K), device="cuda").to(ab_dtype)
D = torch.empty((M, N), device="cuda", dtype=c_dtype)
# Transpose B because torch._scaled_mm expects B to be column-major
B = torch.randint(-1, 2, (N, K), device="cuda").to(ab_dtype).transpose(0, 1)
scale_size = 32
SFA = torch.rand((M, prep_k(K, scale_size),), device="cuda").to(scale_dtype)
SFB = torch.rand((prep_k(K, scale_size), N), device="cuda").to(scale_dtype)
args = cutlass_api.arguments.GemmArguments(
A=ScaledTensor(
A,
SFA,
ScaleMode.Blockwise1x32,
ScaleSwizzleMode.Swizzle32x4x4,
),
B=ScaledTensor(
B,
SFB,
ScaleMode.Blockwise1x32,
ScaleSwizzleMode.Swizzle32x4x4,
),
out=D,
accumulator_type=accumulator_type,
)
kernels = cutlass_api.get_kernels(args, cc=100)
assert len(kernels) > 0
kernel = kernels[0]
assert kernel.supports(args)
compiled_artifact = kernel.compile(args)
kernel.run(args, compiled_artifact=compiled_artifact, assume_supported_args=True)
reference = reference_scaled_mm(A, B, SFA, SFB, c_dtype)
torch.testing.assert_close(D, reference)
@pytest.mark.skipif(
not is_device_cc_supported({100, 103})
or (os.getenv("CUTE_DSL_ARCH", "") not in ["", "sm_100a", "sm_103a"]),
reason="Requires compute capability 100 and to be compiled with sm_100a or sm_103a",
)
def test_mxfp8_1d_scale_factors(fixture_enable_tvm_ffi):
"""
Tests for valid MXFP8 GEMM cases in which A, B, and out are 3D tensors, and the scale factors are 1D tensors.
"""
M, N, K, L = 256, 512, 1024, 1
scale_size = 32
# Create torch fp8 tensors for A and B
A = torch.randint(-1, 2, (L, M, K), device="cuda").to(torch.float8_e4m3fn)
B = torch.randint(-1, 2, (L, N, K), device="cuda").to(torch.float8_e4m3fn).transpose(1, 2)
D = torch.empty((L, M, N), device="cuda", dtype=torch.float32)
SFA = torch.rand((L, M, prep_k(K, scale_size),), device="cuda").to(torch.float8_e8m0fnu)
SFB = torch.rand((L, prep_k(K, scale_size), N), device="cuda").to(torch.float8_e8m0fnu)
args = cutlass_api.arguments.GemmArguments(
# Pass in SFA and SFB in flattened form (.view(-1))
A=ScaledTensor(
A,
SFA.view(-1),
ScaleMode.Blockwise1x32,
ScaleSwizzleMode.Swizzle32x4x4,
),
B=ScaledTensor(
B,
SFB.view(-1),
ScaleMode.Blockwise1x32,
ScaleSwizzleMode.Swizzle32x4x4,
),
out=D,
accumulator_type=torch.float32,
)
# Only test kernels for which we know there will be an error if the scale factors are invalid
def metadata_filter(metadata: KernelMetadata):
return metadata.kernel_class == cutlass_api.providers.cutedsl.gemm.sm100_dense_blockscaled_static_persistent.PersistentDenseBlockScaledGemmKernel
kernels = cutlass_api.get_kernels(args, cc=100, metadata_filter=metadata_filter)
assert len(kernels) > 0
kernel = kernels[0]
assert kernel.supports(args)
compiled_artifact = kernel.compile(args)
kernel.run(args, compiled_artifact=compiled_artifact, assume_supported_args=True)
def transform_sf(x: torch.Tensor) -> torch.Tensor:
"""Flattens scale factor tensors after indexing into the batch dimension"""
return x.view(-1)
reference = reference_scaled_mm(A, B, SFA, SFB, torch.float32, transform_sf)
torch.testing.assert_close(D, reference)
@pytest.mark.parametrize(
"discovery_scale_mode, runtime_scale_mode",
[
((1, 1, 32), (1, 1, 32)),
((1, 32), (1, 1, 32)),
(ScaleMode.Blockwise1x32, ScaleMode.Blockwise1x32),
(ScaleMode.Blockwise1x32, (1, 32)),
],
)
@pytest.mark.skipif(
not is_device_cc_supported({100, 103})
or (os.getenv("CUTE_DSL_ARCH", "") not in ["", "sm_100a", "sm_103a"]),
reason="Requires compute capability 100 and to be compiled with sm_100a or sm_103a",
)
def test_mxfp8_tuple_scale_mode(
discovery_scale_mode: tuple[int, ...],
runtime_scale_mode: tuple[int, ...],
fixture_enable_tvm_ffi,
):
"""
Tests for valid MXFP8 GEMM cases in which A, B, and out are 3D tensors, and
the scale mode is specified as a tuple.
"""
M, N, K, L = 256, 512, 1024, 1
scale_size = 32
# Create torch fp8 tensors for A and B
A = torch.randint(-1, 2, (L, M, K), device="cuda").to(torch.float8_e4m3fn)
B = (
torch.randint(-1, 2, (L, N, K), device="cuda")
.to(torch.float8_e4m3fn)
.transpose(1, 2)
)
D = torch.empty((L, M, N), device="cuda", dtype=torch.float32)
SFA = torch.rand(
(
L,
M,
prep_k(K, scale_size),
),
device="cuda",
).to(torch.float8_e8m0fnu)
SFB = torch.rand((L, prep_k(K, scale_size), N), device="cuda").to(
torch.float8_e8m0fnu
)
discovery_args = cutlass_api.arguments.GemmArguments(
A=ScaledTensor(
A,
SFA,
discovery_scale_mode,
ScaleSwizzleMode.Swizzle32x4x4,
),
B=ScaledTensor(
B,
SFB,
discovery_scale_mode,
ScaleSwizzleMode.Swizzle32x4x4,
),
out=D,
accumulator_type=torch.float32,
)
# Only test kernels for which we know there will be an error if the scale factors are invalid
def metadata_filter(metadata: KernelMetadata):
return (
metadata.kernel_class
== cutlass_api.providers.cutedsl.gemm.sm100_dense_blockscaled_static_persistent.PersistentDenseBlockScaledGemmKernel
)
kernels = cutlass_api.get_kernels(
discovery_args, cc=100, metadata_filter=metadata_filter
)
assert len(kernels) > 0
kernel = kernels[0]
compiled_artifact = kernel.compile(discovery_args)
runtime_args = cutlass_api.arguments.GemmArguments(
A=ScaledTensor(
A,
SFA,
runtime_scale_mode,
ScaleSwizzleMode.Swizzle32x4x4,
),
B=ScaledTensor(
B,
SFB,
runtime_scale_mode,
ScaleSwizzleMode.Swizzle32x4x4,
),
out=D,
accumulator_type=torch.float32,
)
assert kernel.supports(runtime_args)
kernel.run(
runtime_args, compiled_artifact=compiled_artifact, assume_supported_args=True
)
reference = reference_scaled_mm(A, B, SFA, SFB, torch.float32)
torch.testing.assert_close(D, reference)
@pytest.mark.skipif(
not is_device_cc_supported({100, 103})
or (os.getenv("CUTE_DSL_ARCH", "") not in ["", "sm_100a", "sm_103a"]),
reason="Requires compute capability 100 and to be compiled with sm_100a or sm_103a",
)
def test_mxfp8_incompatible_mode_or_swizzle(fixture_enable_tvm_ffi):
"""
Tests for correct flagging of incompatible scale modes or swizzles:
- Using no swizzle mode for MXFP8
- Using an incompatible scale mode for MXFP8
"""
M, N, K, L = 256, 512, 1024, 1
scale_size = 32
# Create torch fp8 tensors for A and B
A = torch.randint(-1, 2, (L, M, K), device="cuda").to(torch.float8_e4m3fn)
B = torch.randint(-1, 2, (L, N, K), device="cuda").to(torch.float8_e4m3fn).transpose(1, 2)
D = torch.empty((L, M, N), device="cuda", dtype=torch.float32)
SFA = torch.rand((L, M, prep_k(K, scale_size),), device="cuda").to(torch.float8_e8m0fnu)
SFB = torch.rand((L, prep_k(K, scale_size), N), device="cuda").to(torch.float8_e8m0fnu)
args_bad = cutlass_api.arguments.GemmArguments(
# Use incompatible swizzle mode for SFA
A=ScaledTensor(
A,
SFA,
ScaleMode.Blockwise1x32,
ScaleSwizzleMode.SwizzleNone,
),
B=ScaledTensor(
B,
SFB,
ScaleMode.Blockwise1x32,
ScaleSwizzleMode.Swizzle32x4x4,
),
out=D,
accumulator_type=torch.float32,
)
# Only test kernels for which the swizzle mode is incompatible
def metadata_filter(metadata: KernelMetadata):
return metadata.kernel_class == cutlass_api.providers.cutedsl.gemm.sm100_dense_blockscaled_static_persistent.PersistentDenseBlockScaledGemmKernel
kernels = cutlass_api.get_kernels(args_bad, cc=100, metadata_filter=metadata_filter)
assert len(kernels) == 0
# Find kernel using compatible swizzle mode, but pass in arguments at runtime with
# incompatible swizzle mode
args_good = cutlass_api.arguments.GemmArguments(
A=ScaledTensor(
A,
SFA,
ScaleMode.Blockwise1x32,
ScaleSwizzleMode.Swizzle32x4x4,
),
B=ScaledTensor(
B,
SFB,
ScaleMode.Blockwise1x32,
ScaleSwizzleMode.Swizzle32x4x4,
),
out=D,
accumulator_type=torch.float32,
)
kernels = cutlass_api.get_kernels(args_good, cc=100, metadata_filter=metadata_filter)
assert len(kernels) > 0
kernel = kernels[0]
try:
kernel.run(args_bad)
assert False, "Expected ValueError for failed supports() check"
except ValueError as e:
pass
# Pass in arguments with an incompatible scale mode
args_bad_scale = cutlass_api.arguments.GemmArguments(
A=ScaledTensor(
A,
SFA,
ScaleMode.Blockwise1x32,
ScaleSwizzleMode.Swizzle32x4x4,
),
# Bad scale mode: SFB had scale mode Blockwise1x16
B=ScaledTensor(
B,
SFB,
ScaleMode.Blockwise1x16,
ScaleSwizzleMode.Swizzle32x4x4,
),
out=D,
accumulator_type=torch.float32,
)
try:
kernel.run(args_bad_scale)
assert False, "Expected ValueError for failed supports() check"
except ValueError as e:
pass
@pytest.mark.skipif(
not is_device_cc_supported({100, 103})
or (os.getenv("CUTE_DSL_ARCH", "") not in ["", "sm_100a", "sm_103a"]),
reason="Requires compute capability 100 and to be compiled with sm_100a or sm_103a",
)
def test_mxfp8_missing_scale_factors(fixture_enable_tvm_ffi):
"""
Tests for correct flagging of a missing scale factor (both must be supplied)
"""
M, N, K, L = 256, 512, 1024, 1
scale_size = 32
# Create torch fp8 tensors for A and B
A = torch.randint(-1, 2, (L, M, K), device="cuda").to(torch.float8_e4m3fn)
B = torch.randint(-1, 2, (L, N, K), device="cuda").to(torch.float8_e4m3fn).transpose(1, 2)
D = torch.empty((L, M, N), device="cuda", dtype=torch.float32)
SFA = torch.rand((L, M, prep_k(K, scale_size),), device="cuda").to(torch.float8_e8m0fnu)
# Construct arguments with only one scale factor
args = cutlass_api.arguments.GemmArguments(
A=ScaledTensor(
A,
SFA,
ScaleMode.Blockwise1x32,
ScaleSwizzleMode.Swizzle32x4x4,
),
B=B,
out=D,
accumulator_type=torch.float32,
)
# Only test kernels for which we know there will be an error if the scale factors are missing
def metadata_filter(metadata: KernelMetadata):
return metadata.kernel_class == cutlass_api.providers.cutedsl.gemm.sm100_dense_blockscaled_static_persistent.PersistentDenseBlockScaledGemmKernel
kernels = cutlass_api.get_kernels(args, cc=100, metadata_filter=metadata_filter)
assert len(kernels) == 0
@pytest.mark.skipif(
not is_device_cc_supported({100, 103})
or (os.getenv("CUTE_DSL_ARCH", "") not in ["", "sm_100a", "sm_103a"]),
reason="Requires compute capability 100 and to be compiled with sm_100a or sm_103a",
)
def test_mxfp8_invalid_scale_factors(fixture_enable_tvm_ffi):
"""
Tests for correct flagging of invalid scale factors:
- Too large for the A tensor
- Too large for the B tensor
"""
M, N, K, L = 256, 512, 1024, 1
scale_size = 32
# Create torch fp8 tensors for A and B
A = torch.randint(-1, 2, (L, M, K), device="cuda").to(torch.float8_e4m3fn)
B = torch.randint(-1, 2, (L, N, K), device="cuda").to(torch.float8_e4m3fn).transpose(1, 2)
D = torch.empty((L, M, N), device="cuda", dtype=torch.float32)
# Add 32 elements to the K mode of SFA. This makes it too large for the A tensor
SFA = torch.rand((L, M, prep_k(K, scale_size) + 32,), device="cuda").to(torch.float8_e8m0fnu)
SFB = torch.rand((L, prep_k(K, scale_size), N), device="cuda").to(torch.float8_e8m0fnu)
args = cutlass_api.arguments.GemmArguments(
A=ScaledTensor(
A,
SFA,
ScaleMode.Blockwise1x32,
ScaleSwizzleMode.Swizzle32x4x4,
),
B=ScaledTensor(
B,
SFB,
ScaleMode.Blockwise1x32,
ScaleSwizzleMode.Swizzle32x4x4,
),
out=D,
accumulator_type=torch.float32,
)
# Only test kernels for which we know there will be an error if the scale factors are invalid
def metadata_filter(metadata: KernelMetadata):
return metadata.kernel_class == cutlass_api.providers.cutedsl.gemm.sm100_dense_blockscaled_static_persistent.PersistentDenseBlockScaledGemmKernel
kernels = cutlass_api.get_kernels(args, cc=100, metadata_filter=metadata_filter)
assert len(kernels) == 0
# Add 32 elements to the K mode of SFB. This makes it too large for the B tensor
SFA = torch.rand((L, M, prep_k(K, scale_size),), device="cuda").to(torch.float8_e8m0fnu)
SFB = torch.rand((L, prep_k(K, scale_size) + 32, N), device="cuda").to(torch.float8_e8m0fnu)
args = cutlass_api.arguments.GemmArguments(
A=ScaledTensor(
A,
SFA,
ScaleMode.Blockwise1x32,
ScaleSwizzleMode.Swizzle32x4x4,
),
B=ScaledTensor(
B,
SFB,
ScaleMode.Blockwise1x32,
ScaleSwizzleMode.Swizzle32x4x4,
),
out=D,
accumulator_type=torch.float32,
)
kernels = cutlass_api.get_kernels(args, cc=100, metadata_filter=metadata_filter)
assert len(kernels) == 0
@pytest.mark.skipif(
not is_device_cc_supported({100, 103})
or (os.getenv("CUTE_DSL_ARCH", "") not in ["", "sm_100a", "sm_103a"]),
reason="Requires compute capability 100 and to be compiled with sm_100a or sm_103a",
)
def test_mxfp8_gemm_sm100_invalid_epilogue(fixture_enable_tvm_ffi):
"""
Tests for correct flagging of invalid epilogues: currently, no MXFP8 kernels support
epilogue fusions.
"""
M, N, K, L = 256, 512, 1024, 1
# Create torch fp8 tensors for A and B
A = torch.randint(-1, 2, (L, M, K), device="cuda").to(torch.float8_e4m3fn)
B = torch.randint(-1, 2, (L, N, K), device="cuda").to(torch.float8_e4m3fn).transpose(1, 2)
D = torch.empty((L, M, N), device="cuda", dtype=torch.float32)
scale_size = 32
SFA = torch.rand((L, M, prep_k(K, scale_size),), device="cuda").to(torch.float8_e8m0fnu)
SFB = torch.rand((L, prep_k(K, scale_size), N), device="cuda").to(torch.float8_e8m0fnu)
C = torch.randint(-1, 2, (L, M, N), device="cuda").to(torch.float32)
def epilogue(accum, alpha, beta, C):
D = alpha * accum + beta * C
return D
epi_args = EpilogueArguments(epilogue, alpha=1.0, beta=1.0, C=C, D=D)
args = cutlass_api.arguments.GemmArguments(
A=ScaledTensor(
A,
SFA,
ScaleMode.Blockwise1x32,
ScaleSwizzleMode.Swizzle32x4x4,
),
B=ScaledTensor(
B,
SFB,
ScaleMode.Blockwise1x32,
ScaleSwizzleMode.Swizzle32x4x4,
),
out=D,
accumulator_type=torch.float32,
epilogue=epi_args,
)
kernels = cutlass_api.get_kernels(args, cc=100)
assert len(kernels) == 0
@pytest.mark.skipif(
not is_device_cc_supported({100, 103})
or (os.getenv("CUTE_DSL_ARCH", "") not in ["", "sm_100a", "sm_103a"]),
reason="Requires compute capability 100 and to be compiled with sm_100a or sm_103a",
)
def test_mxfp8_gemm_sm100_design_metadata(fixture_enable_tvm_ffi):
"""
Tests for correct filtering of kernels based on design metadata specifications
(e.g., tile shape, cluster shape, etc.).
"""
M, N, K, L = 256, 512, 1024, 1
# Create torch fp8 tensors for A and B
A = torch.randint(-1, 2, (L, M, K), device="cuda").to(torch.float8_e4m3fn)
B = torch.randint(-1, 2, (L, N, K), device="cuda").to(torch.float8_e4m3fn).transpose(1, 2)
D = torch.empty((L, M, N), device="cuda", dtype=torch.float32)
scale_size = 32
SFA = torch.rand((L, M, prep_k(K, scale_size),), device="cuda").to(torch.float8_e8m0fnu)
SFB = torch.rand((L, prep_k(K, scale_size), N), device="cuda").to(torch.float8_e8m0fnu)
args = cutlass_api.arguments.GemmArguments(
A=ScaledTensor(
A,
SFA,
ScaleMode.Blockwise1x32,
ScaleSwizzleMode.Swizzle32x4x4,
),
B=ScaledTensor(
B,
SFB,
ScaleMode.Blockwise1x32,
ScaleSwizzleMode.Swizzle32x4x4,
),
out=D,
accumulator_type=torch.float32,
)
def design_filter(metadata: KernelMetadata):
if not isinstance(metadata.design, Sm100DesignMetadata):
return False
return metadata.design.tile_shape[:2] == (256, 128)
kernels = cutlass_api.get_kernels(args, cc=100, metadata_filter=design_filter)
assert len(kernels) > 0
for kernel in kernels:
assert design_filter(kernel.metadata)
kernel = kernels[0]
assert kernel.supports(args)
compiled_artifact = kernel.compile(args)
kernel.run(args, compiled_artifact=compiled_artifact)
reference = reference_scaled_mm(A, B, SFA, SFB, torch.float32)
torch.testing.assert_close(D, reference)

View File

@@ -0,0 +1,69 @@
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
import functools
import os
import pytest
import random
import torch
import cutlass_api
torch.manual_seed(2025)
random.seed(2025)
def test_incorrect_offset_length():
"""
Offset tensors are required to have `problem_count` elements.
Test that no kernels are found when this is violated.
"""
problem_count, m, n, k = 12, 8192, 128, 512
A = torch.empty((1, m, k), device="cuda", dtype=torch.float16)
B = torch.empty((problem_count, n, k), device="cuda", dtype=torch.float16).permute(
0, 2, 1
)
out = torch.empty((1, m, n), device="cuda", dtype=torch.float32)
# Incorrect: should have `problem_count` elements
offsets = torch.empty((problem_count + 1,), device="cuda", dtype=torch.int32)
args = cutlass_api.arguments.GroupedGemmArguments(
A=A,
B=B,
out=out,
accumulator_type=torch.float32,
offsets=offsets,
)
kernels = cutlass_api.get_kernels(args, cc=100)
assert len(kernels) == 0

View File

@@ -37,11 +37,12 @@ import cutlass_api
@pytest.mark.parametrize(
"notebook_name, supported_ccs",
[
("000_gemm.ipynb", [80, 89, 90,100, 103]),
("000_gemm.ipynb", [80, 89, 90, 100, 103]),
("001_gemm_with_fused_epilogue.ipynb", [100, 103]),
("002_bring_your_own_kernel.ipynb", [80, 89, 90, 100, 103, 120, 121]),
("003_host_latency_best_practices.ipynb", [80, 89, 90, 100, 103]),
("004_fake_tensors.ipynb", [80, 89, 90, 100, 103]),
("005_grouped_gemm_contiguous_offset.ipynb", [100]),
],
)
def test_notebooks(notebook_name, supported_ccs):

View File

@@ -37,7 +37,7 @@ from cutlass_api.arguments import ElementwiseArguments
from cutlass_api.config import GlobalOptions
from cutlass_api.metadata import (
ElementwiseOperandsMetadata,
TensorAttributes,
DenseTensorAttributes,
)
@@ -49,7 +49,9 @@ class NoopKernelForTesting(cutlass_api.providers.cutedsl.kernel.CuteDslKernel):
def compile(self, args: ElementwiseArguments):
stream = cute.runtime.make_fake_stream()
return self.cute_compile(self.impl, args.A, args.B, args.out, stream)
return self.cute_compile(
self.impl, args.A.tensor, args.B.tensor, args.out.tensor, stream
)
def _run(
self,
@@ -58,10 +60,12 @@ class NoopKernelForTesting(cutlass_api.providers.cutedsl.kernel.CuteDslKernel):
stream,
workspace=None,
):
self.cute_run(compiled_artifact, args.A, args.B, args.out, stream)
self.cute_run(
compiled_artifact, args.A.tensor, args.B.tensor, args.out.tensor, stream
)
def generate_kernels(_ignored_filter, _ignored_epilogue_args, _ignored_cc):
attrs = TensorAttributes(
attrs = DenseTensorAttributes(
stride=(0, 1),
dtype=cutlass.Float16,
divisibility=8,