2026-01-06 updates

This commit is contained in:
jkosaian
2026-01-06 04:25:33 -08:00
parent dfcb55de16
commit 7c09485e25
77 changed files with 2563 additions and 444 deletions

View File

@@ -69,7 +69,7 @@ See CUTLASS's [Compatibility section](https://github.com/NVIDIA/cutlass?tab=read
### Current support
* Dense GEMM: `out = A @ B`
- Compute capabilities: 100, 103
- Compute capabilities: 100, 103 (WIP to expand to more)
- Input precisions (A and B must be of same type): F16, BF16, TF32, INT8
- Output precisions: F32, F16, BF16, INT32
- Epilogue operations:
@@ -78,6 +78,7 @@ See CUTLASS's [Compatibility section](https://github.com/NVIDIA/cutlass?tab=read
- Auxiliary load of scalar
- Tensor-tensor elementwise or tensor-scalar addition, multiplication, subtraction, division
- Elementwise tensor exponent, relu, sigmoid, tanh
- Note: Partial support exists on CC 80/89/90 (limited dtypes/tilings coverage)
* Planned additions
* Block-scaled GEMMs

View File

@@ -60,9 +60,16 @@ class PerformanceControls:
class EpilogueArguments:
def __init__(self, epilogue_fn: Callable | str | None = None, **kwargs):
def __init__(
self,
epilogue_fn: Callable | str,
**kwargs,
):
"""
Encapsulation of the epilogue function and its arguments needed to
Describes a user-defined epilogue that is performed on top of the operation
described by the primary `RuntimeArguments`.
It encapsulates an epilogue function and its arguments needed to
determine the functional operation of an epilogue pattern.
To support flexible definition of epilogues, `EpilogueArguments` is
@@ -120,7 +127,10 @@ class EpilogueArguments:
```
A user would need to construct epilogue arguments as follows:
```python
epi_args = EpilogueArguments(my_epi, alpha=..., C=..., beta=..., D=..., F=...)
epi_args = EpilogueArguments(
my_epi,
alpha=..., C=..., beta=..., D=..., F=...
)
```
:param epilogue_fn: The epilogue function to be traced.
@@ -263,13 +273,13 @@ class GemmArguments(RuntimeArguments):
N: Number of columns in B and out
:param A: Input tensor A of shape (L, M, K) or (M, K)
:type A: TensorWrapper
:type A: TensorLike
:param B: Input tensor B of shape (L, K, N) or (K, N)
:type B: TensorWrapper
:type B: TensorLike
:param out: Output tensor C of shape (L, M, N) or (M, N)
:type out: TensorWrapper
:type out: TensorLike
:param accumulator_type: Data type of the accumulator
:type accumulator_type: cutlass.Numeric
:type accumulator_type: NumericLike
"""
A: TensorLike

View File

@@ -81,3 +81,17 @@ class GlobalOptions:
"TVM FFI is not installed, please install it via `pip install apache-tvm-ffi`."
)
self._options["use_tvm_ffi"] = value
def save(self, out: dict) -> None:
"""
Save the current options to a dictionary.
"""
for key, value in self._options.items():
out[key] = value
def restore(self, inp: dict) -> None:
"""
Restore the options from a dictionary.
"""
for key, value in inp.items():
self._options[key] = value

View File

@@ -1,4 +1,3 @@
#################################################################################################
#
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
@@ -28,7 +27,7 @@
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
import ast
import textwrap

View File

@@ -1,4 +1,4 @@
#################################################################################################
#
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
@@ -28,7 +28,7 @@
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
"""
Helpers for common activation functions

View File

@@ -1,4 +1,3 @@
#################################################################################################
#
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
@@ -28,7 +27,6 @@
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
from cutlass_api.fusion.backend.sm80_emitter import Sm80Emitter
import cutlass_api.fusion.backend.sm80_nodes as sm80_nodes

View File

@@ -1,4 +1,4 @@
#################################################################################################
#
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
@@ -28,7 +28,7 @@
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
"""
Base class for Epilogue Visitor Emitter

View File

@@ -1,4 +1,4 @@
#################################################################################################
#
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
@@ -28,7 +28,7 @@
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
"""
Emitter for Sm100 Epilogue Visitor

View File

@@ -1,4 +1,4 @@
#################################################################################################
#
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
@@ -28,7 +28,7 @@
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
from cutlass_api.fusion.ir import AuxLoadImpl, AuxStoreImpl

View File

@@ -1,4 +1,4 @@
#################################################################################################
#
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
@@ -28,7 +28,7 @@
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
"""
Emitter for Sm80 Epilogue Visitor

View File

@@ -1,4 +1,4 @@
#################################################################################################
#
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
@@ -28,7 +28,7 @@
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
from cutlass_api.fusion.library import (
DataTypeTag,

View File

@@ -1,4 +1,4 @@
#################################################################################################
#
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
@@ -28,7 +28,7 @@
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
"""
Emitter for Sm90 Epilogue Visitor

View File

@@ -1,4 +1,4 @@
#################################################################################################
#
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
@@ -28,7 +28,7 @@
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
from cutlass_api.fusion.library import (

View File

@@ -1,4 +1,4 @@
#################################################################################################
#
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
@@ -28,7 +28,7 @@
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
"""
Epilogue Visitor interface for compiling, and running visitor-based epilogue.

View File

@@ -1,4 +1,4 @@
#################################################################################################
#
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
@@ -28,7 +28,7 @@
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
"""
Collection of builtin functions used for host reference in EVT

View File

@@ -1,4 +1,4 @@
#################################################################################################
#
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
@@ -28,7 +28,7 @@
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
from cutlass_api.fusion.frontend.python_ast import (
PythonASTFrontend,

View File

@@ -1,4 +1,4 @@
#################################################################################################
#
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
@@ -28,7 +28,7 @@
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
"""
Base class for Python EVT Frontend
@@ -122,6 +122,9 @@ class EVTFrontendBase:
# Parse the input
self.parse(*args, **kwargs)
if not self.dag_ir.has_node("D"):
raise RuntimeError("Output node D is not found in the epilogue function.")
# Verify the DAG IR to ensure that "D" is the output node with out_degree = 0
if self.cc >= 90:
if self.dag_ir.out_degree("D") != 0:

View File

@@ -1,4 +1,4 @@
#################################################################################################
#
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
@@ -28,7 +28,7 @@
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
"""
Python AST frontend that parses input into DAG IR

View File

@@ -1,4 +1,4 @@
#################################################################################################
#
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
@@ -28,7 +28,7 @@
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
from cutlass_api.fusion.ir.compute_nodes import ComputeNode, ComputeImpl
from cutlass_api.fusion.ir.dag_ir import DAGIR

View File

@@ -1,4 +1,4 @@
#################################################################################################
#
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
@@ -28,7 +28,7 @@
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
import ctypes

View File

@@ -1,4 +1,4 @@
#################################################################################################
#
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
@@ -28,7 +28,7 @@
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
"""
Python registration for compute nodes in EVT

View File

@@ -1,4 +1,4 @@
#################################################################################################
#
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
@@ -28,7 +28,7 @@
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
"""
DAG IR used by Python EVT

View File

@@ -1,4 +1,4 @@
#################################################################################################
#
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
@@ -28,7 +28,7 @@
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
"""
Layout algebras

View File

@@ -1,4 +1,4 @@
#################################################################################################
#
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
@@ -28,7 +28,7 @@
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
"""
Layout manipulation nodes and implementations

View File

@@ -1,4 +1,4 @@
#################################################################################################
#
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
@@ -28,7 +28,7 @@
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
"""
Load nodes and implementations

View File

@@ -1,4 +1,4 @@
#################################################################################################
#
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
@@ -28,7 +28,7 @@
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
"""
Base & visitor classes of DAGIR Nodes

View File

@@ -1,4 +1,4 @@
#################################################################################################
#
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
@@ -28,7 +28,7 @@
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
"""
Store node and implementations

View File

@@ -1,4 +1,4 @@
#################################################################################################
#
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
@@ -28,7 +28,7 @@
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
"""
High-level class for tensor

View File

@@ -1,4 +1,4 @@
#################################################################################################
#
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
@@ -28,7 +28,7 @@
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
"""
Copies of enum types from cutlass_library that are used in the fusion frontend.

View File

@@ -1,4 +1,4 @@
#################################################################################################
#
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
@@ -28,7 +28,7 @@
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
from cutlass_api.fusion.passes.graph_drawer import EVTGraphDrawer
from cutlass_api.fusion.passes.pass_argument_type import PassGetArgumentType

View File

@@ -1,4 +1,4 @@
#################################################################################################
#
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
@@ -28,7 +28,7 @@
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
from __future__ import annotations

View File

@@ -1,4 +1,4 @@
#################################################################################################
#
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
@@ -28,7 +28,7 @@
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
"""
Construct the epilogue visitor argument type

View File

@@ -1,4 +1,4 @@
#################################################################################################
#
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
@@ -28,7 +28,7 @@
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
"""
Merge non-tree sub-graphs of the DAG IR into a single DAG. The fused DAG will be implemented

View File

@@ -1,4 +1,4 @@
#################################################################################################
#
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
@@ -28,7 +28,7 @@
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
"""
Fix the element_output of producer of D.

View File

@@ -1,4 +1,4 @@
#################################################################################################
#
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
@@ -28,7 +28,7 @@
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
"""
Infer the underlying implement of each node.

View File

@@ -1,4 +1,4 @@
#################################################################################################
#
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
@@ -28,7 +28,7 @@
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
"""
Eliminate layout manipulation nodes

View File

@@ -1,4 +1,4 @@
#################################################################################################
#
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
@@ -28,7 +28,7 @@
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
"""
Pass manager for DAG IR.

View File

@@ -1,4 +1,4 @@
#################################################################################################
#
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
@@ -28,7 +28,7 @@
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
"""
No op elimination node

View File

@@ -1,4 +1,4 @@
#################################################################################################
#
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
@@ -28,7 +28,7 @@
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
"""
Preprocess the reduction nodes.

View File

@@ -1,4 +1,4 @@
#################################################################################################
#
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
@@ -28,7 +28,7 @@
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
"""
Shape and type propagation pass

View File

@@ -1,4 +1,4 @@
#################################################################################################
#
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
@@ -28,7 +28,7 @@
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
"""
Compute the shared memory size in bytes

View File

@@ -1,4 +1,4 @@
#################################################################################################
#
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
@@ -28,7 +28,7 @@
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
"""
Utilities for passes

View File

@@ -1,4 +1,4 @@
#################################################################################################
#
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
@@ -28,7 +28,7 @@
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
from .int_tuple import *
from .layout import *

View File

@@ -1,4 +1,4 @@
#################################################################################################
#
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
@@ -28,7 +28,7 @@
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
"""
Functions for manipulating IntTuples

View File

@@ -1,4 +1,4 @@
#################################################################################################
#
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
@@ -28,7 +28,7 @@
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
"""
Definition of CuTe Layouts and functions to manipulate them

View File

@@ -1,4 +1,4 @@
#################################################################################################
#
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
@@ -28,7 +28,7 @@
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
"""
Methods for layout swizzling

View File

@@ -1,4 +1,4 @@
#################################################################################################
#
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
@@ -28,7 +28,7 @@
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
from abc import ABC

View File

@@ -36,6 +36,7 @@ from cutlass_api.arguments import EpilogueArguments, RuntimeArguments
from cutlass_api.artifact import CompiledArtifact
from cutlass_api.metadata import KernelMetadata
from cutlass_api.status import Status
from cutlass_api.utils import device_cc
class Kernel(ABC):
@@ -43,6 +44,16 @@ class Kernel(ABC):
Base class for all kernels to be implemented in providers
"""
def __init__(self, metadata: KernelMetadata):
self._metadata = metadata
@property
def metadata(self) -> KernelMetadata:
"""
The read-only metadata for the kernel.
"""
return self._metadata
@final
def supports(self, args: RuntimeArguments) -> Status:
"""

View File

@@ -90,7 +90,6 @@ class Manifest:
return False
return True
epilogue_args = None if args is None else args.epilogue
kernels = [
k

View File

@@ -68,50 +68,74 @@ def _convert_stride(shape: tuple[int, ...], stride: tuple[int, ...]) -> tuple[in
return new_stride
def _get_max_pow2_alignment(
shape: tuple[int, ...], stride: tuple[int, ...], dtype: cutlass.Numeric
) -> int:
def _is_tuple_aligned(tup: tuple[int], divisibility: int, contiguous_dim: int) -> bool:
"""
Get the maximum power of 2 alignment for a given data type
Check if the all elements of the shape/stride tuple are divisible by a given
divisibility, except along the contiguous dimension.
"""
return all(
t % divisibility == 0 for dim, t in enumerate(tup) if dim != contiguous_dim
)
def _get_max_pow2_divisibility(shape: tuple[int, ...], stride: tuple[int, ...]) -> int:
"""
Get the maximum power of 2 divisibility met by the given shape and stride.
This is the largest power of 2 that divides the number of elements in the major mode (with stride 1).
:param shape: the shape of the tensor
:type shape: tuple[int, ...]
:param stride: the stride of the tensor
:type stride: tuple[int, ...]
:param dtype: the data type
:type dtype: cutlass.Numeric
:return: the maximum power of 2 alignment
:return: the maximum power of 2 divisibility by which the shape is divisible
:rtype: int
"""
if 1 not in stride:
return 1
major_mode_idx = stride.index(1)
num_major_elements = shape[major_mode_idx]
for alignment in [128, 64, 32, 16, 8, 4, 2]:
num_contiguous_elements = alignment * 8 // dtype.width
if num_major_elements % num_contiguous_elements == 0:
return alignment
return 1
best_divisibility = 1
while num_major_elements % (best_divisibility * 2) == 0:
best_divisibility *= 2
return best_divisibility
@dataclass
class TensorAttributes:
"""
Description of a single tensor. This includes the data type, stride, and alignment.
Description of a single tensor. This includes the data type, stride, and divisibility.
:param dtype: The data type of the tensor.
:type dtype: cutlass.Numeric
:param stride: The stride of the tensor.
:type stride: tuple[int, ...]
:param alignment: The alignment of the tensor
:type alignment: int
:param divisibility: The divisibility requirement on a tensor's stride & shape elements
:type divisibility: int
:param ptr_alignment_bytes: The alignment of the tensor's data pointer, in bytes.
By default, it matches the number of bytes in stride/shape alignment.
:type ptr_alignment_bytes: int
"""
dtype: cutlass.Numeric # F32, F16, etc.
stride: tuple[int, ...]
alignment: int
divisibility: int
ptr_alignment_bytes: int
def __init__(
self,
dtype: cutlass.Numeric,
stride: tuple[int, ...],
divisibility: int,
ptr_alignment_bytes: int | None = None,
):
self.dtype = dtype
self.stride = stride
self.divisibility = divisibility
self.ptr_alignment_bytes = ptr_alignment_bytes or (
(divisibility * dtype.width) // 8
)
def supports(self, operand: TensorWrapper | Self) -> Status:
"""
@@ -159,22 +183,37 @@ class TensorAttributes:
# When setting stride from args, any modes of stride 1 and shape 1
# are changed to have stride 0. Thus, there will only be one mode
# with stride 1.
if not all_zeros and normalized_operand_stride[expected_stride.index(1)] != 1:
contiguous_dim = expected_stride.index(1)
if not all_zeros and normalized_operand_stride[contiguous_dim] != 1:
return Status.fail(
f"Expected stride[{expected_stride.index(1)}] to be 1, got {normalized_operand_stride[expected_stride.index(1)]} (strides: {normalized_operand_stride})"
f"Expected stride[{contiguous_dim}] to be 1, got "
f"{normalized_operand_stride[contiguous_dim]} "
f"(strides: {normalized_operand_stride})"
)
# Alignment of operand should be divisible by this metadata's alignment
# Check that divisibility constraints are met
if isinstance(operand, TensorWrapper):
operand_alignment = _get_max_pow2_alignment(
operand.shape, normalized_operand_stride, operand.element_type
)
if not _is_tuple_aligned(
normalized_operand_stride, self.divisibility, contiguous_dim
):
return Status.fail(
f"Expected operand stride to be divisible by {self.divisibility} for"
f"all non-contiguous modes, got {normalized_operand_stride}"
)
else:
operand_alignment = operand.alignment
# When comparing another TensorAttribute, ensure its divisibility is a subset
if operand.divisibility % self.divisibility != 0:
return Status.fail(
f"Expected operand divisibility {operand.divisibility} to be divisible by {self.divisibility}"
)
if operand_alignment % self.alignment != 0:
# Check data ptr alignment, if available
if (
isinstance(operand, TensorWrapper)
and operand.data_ptr % self.ptr_alignment_bytes != 0
):
return Status.fail(
f"Expected operand alignment {operand_alignment} (strides: {normalized_operand_stride}) to be a multiple of {self.alignment}"
f"Expected data pointer to be {self.ptr_alignment_bytes}B-aligned."
)
return Status.success()
@@ -191,11 +230,9 @@ class TensorAttributes:
:rtype: TensorAttributes
"""
stride = _convert_stride(tensor.shape, tensor.stride)
max_alignment = _get_max_pow2_alignment(
tensor.shape, stride, tensor.element_type
)
max_divisibility = _get_max_pow2_divisibility(tensor.shape, stride)
return TensorAttributes(
dtype=tensor.element_type, stride=stride, alignment=max_alignment
dtype=tensor.element_type, stride=stride, divisibility=max_divisibility
)

View File

@@ -26,9 +26,8 @@
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
from collections.abc import Callable
import logging
from typing import Type
from collections.abc import Callable
from cutlass_api.arguments import EpilogueArguments
from cutlass_api.kernel import Kernel
@@ -64,6 +63,7 @@ if available:
epilogue_args: EpilogueArguments = None,
cc: int = None,
) -> list[Kernel]:
kernels_for_provider = []
for kernel_cls in cls._kernel_classes:
kernels_for_provider.extend(

View File

@@ -32,6 +32,9 @@ from typing import ClassVar
import cuda.bindings.driver as cuda
import cutlass
import cutlass.cute as cute
from cutlass_api.arguments import ElementwiseArguments, EpilogueArguments
from cutlass_api.artifact import CompiledArtifact
from cutlass_api.metadata import (
@@ -43,9 +46,6 @@ from cutlass_api.providers.cutedsl import CuTeDSLProvider
from cutlass_api.providers.cutedsl.kernel import CuteDslKernel
from cutlass_api.utils import to_cuda_stream
import cutlass
import cutlass.cute as cute
"""
An Elementwise Addition kernel using CuTe DSL:
out = A + B
@@ -60,7 +60,7 @@ class ElementwiseAddKernel(CuteDslKernel):
]
def __init__(self, metadata: KernelMetadata):
self.metadata = metadata
super().__init__(metadata)
self.impl = ElementwiseAddKernelImpl()
def compile(self, args: ElementwiseArguments, cc: int = None) -> CompiledArtifact:
@@ -99,7 +99,7 @@ class ElementwiseAddKernel(CuteDslKernel):
kernel_list = []
for dtype in ElementwiseAddKernel._supported_dtypes:
alignment = 128 // dtype.width
divisibility = 128 // dtype.width
for stride_A, stride_B, stride_out in product(
stride_names.keys(), repeat=3
):
@@ -112,17 +112,20 @@ class ElementwiseAddKernel(CuteDslKernel):
operands = ElementwiseOperandsMetadata(
A=TensorAttributes(
dtype=dtype, stride=stride_A, alignment=alignment
dtype=dtype, stride=stride_A, divisibility=divisibility
),
B=TensorAttributes(
dtype=dtype, stride=stride_B, alignment=alignment
dtype=dtype, stride=stride_B, divisibility=divisibility
),
out=TensorAttributes(
dtype=dtype, stride=stride_out, alignment=alignment
dtype=dtype, stride=stride_out, divisibility=divisibility
),
)
metadata = KernelMetadata(
operands=operands, kernel_name=kernel_name, kernel_class=ElementwiseAddKernel, min_cc=min_cc
operands=operands,
kernel_name=kernel_name,
kernel_class=ElementwiseAddKernel,
min_cc=min_cc,
)
if metadata_filter(metadata):
kernel_list.append(ElementwiseAddKernel(metadata))

View File

@@ -1,4 +1,4 @@
#################################################################################################
#
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
@@ -28,7 +28,7 @@
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
from collections.abc import Callable

View File

@@ -31,3 +31,4 @@
# ruff: noqa: F401
import cutlass_api.providers.cutedsl.gemm.sm100_static_persistent
import cutlass_api.providers.cutedsl.gemm.sm100_static_persistent_efc
import cutlass_api.providers.cutedsl.gemm.sm80_tensorop_gemm

View File

@@ -27,18 +27,15 @@
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
from collections.abc import Callable
from typing import Union
import cuda.bindings.driver as cuda
import cutlass
import cutlass.cute as cute
import cutlass.cute.testing as testing
from cutlass.cute.runtime import from_dlpack
import cutlass.utils as utils
import cutlass.pipeline as pipeline
from cutlass.pipeline import pipeline_init_arrive, pipeline_init_wait
import cutlass.utils as utils
from cutlass.cute.nvgpu import cpasync, tcgen05
from cutlass.pipeline import pipeline_init_arrive, pipeline_init_wait
"""
A high-performance persistent batched dense GEMM example for the NVIDIA Blackwell SM100 architecture
@@ -1222,7 +1219,7 @@ class PersistentDenseGemmKernelImpl:
tAcc: cute.Tensor,
gC_mnl: cute.Tensor,
epi_tile: cute.Tile,
use_2cta_instrs: Union[cutlass.Boolean, bool],
use_2cta_instrs: cutlass.Boolean | bool,
) -> tuple[cute.TiledCopy, cute.Tensor, cute.Tensor]:
"""
Make tiledCopy for tensor memory load, then use it to partition tensor memory (source) and register array (destination).

View File

@@ -0,0 +1,814 @@
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
import argparse
import math
from typing import Tuple, Type
import torch
import cutlass
import cutlass.cute as cute
import cutlass.cute.testing as testing
import cutlass.torch as cutlass_torch
import cutlass.utils as utils
from cutlass.cute.runtime import from_dlpack
"""
A high-performance batched dense GEMM example for the utilzing SM80 style TensorCore instructions
- Matrix A is LxMxK, L is batch dimension, A can be row-major("t") or column-major("n")
- Matrix B is LxKxN, L is batch dimension, B can be row-major("t") or column-major("n")
- Matrix C is LxMxN, L is batch dimension, C can be row-major("t") or column-major("n")
- Internally, this is treated as an MxKxL, NxKxL, MxNxL CuTe Tensors
This GEMM kernel supports the following features:
- Utilizes SM80's tensor cores for matrix multiply-accumulate (MMA) operations
- Threadblock rasterization to improve data re-use
- Supports multi-stage pipeline to overlap computation and memory access
Constraints:
* Supported input and output data types: fp16, bf16
* Support accumulator data types: f32
* Atom layout's MNK shape is set so that tile shape can be divided by MMA instruction shape
* The contiguous dimension of A/B/C tensors must be at least 16 bytes aligned,
"""
class Sm80TensorOpGemmImpl:
def __init__(
self,
a_dtype: Type[cutlass.Numeric],
b_dtype: Type[cutlass.Numeric],
c_dtype: Type[cutlass.Numeric],
acc_dtype: Type[cutlass.Numeric],
):
self.ab_dtype = a_dtype
self.c_dtype = c_dtype
self.acc_dtype = acc_dtype
self.cta_tiler = (128, 128, 32)
self.num_stages = 3
self.atom_layout_mnk = (2, 2, 1)
atom_lay_M, atom_lay_N, atom_lay_K = self.atom_layout_mnk
self.num_threads = atom_lay_M * atom_lay_N * atom_lay_K * 32
self.bM, self.bN, self.bK = self.cta_tiler
self.mma_inst_shape = (16, 8, 16)
mmaM, mmaN, mmaK = self.mma_inst_shape
assert a_dtype == b_dtype, "This example does not support different A, B dtypes"
assert self.bM % (atom_lay_M * mmaM) == 0, (
"bM must be divisible by MMA instruction"
)
assert self.bN % (atom_lay_N * mmaN) == 0, (
"bN must be divisible by MMA instruction"
)
assert atom_lay_K == 1, "this example does not support atom layout K > 1"
assert self.bK % mmaK == 0, "bK must be divisible by MMA instruction"
assert self.num_stages >= 3, "num_stages must be greater than or equal to 3"
@cute.jit
def __call__(
self,
a: cute.Tensor,
b: cute.Tensor,
c: cute.Tensor,
stream,
epilogue_op: cutlass.Constexpr = lambda x: x,
):
def add_batch_mode(tensor: cute.Tensor) -> cute.Tensor:
return cute.make_tensor(
tensor.iterator,
cute.prepend(tensor.layout, cute.make_layout(1), up_to_rank=3),
)
a = add_batch_mode(a)
b = add_batch_mode(b)
c = add_batch_mode(c)
# Permute tensor modes from torch to cute convention
# A: (L, M, K) -> (M, K, L)
a = cute.make_tensor(a.iterator, cute.select(a.layout, [1, 2, 0]))
# B: (L, K, N) -> (N, K, L)
b = cute.make_tensor(b.iterator, cute.select(b.layout, [2, 1, 0]))
# C: (L, M, N) -> (M, N, L)
c = cute.make_tensor(c.iterator, cute.select(c.layout, [1, 2, 0]))
# The grid divides the problems's M, N, and L dimensions by the
# respective modes of the tile shape (bM, bN, 1). The K dimension is
# handled within a block via a multistage process.
self.a_major_mode = utils.LayoutEnum.from_tensor(a)
self.b_major_mode = utils.LayoutEnum.from_tensor(b)
self.c_major_mode = utils.LayoutEnum.from_tensor(c)
# ///////////////////////////////////////////////////////////////////////////////
# Shared memory layout:
# ///////////////////////////////////////////////////////////////////////////////
# Creates a layout with the size required for the provided tile
# size and num stages (stages are used for K dimension) that is also
# sectioned into 64x8 or 8x32 layout atoms. The swizzle is set so that
# the atom for the shared memory -> register copy does not encounter
# bank conflicts
# assume the input is 16B align
ab_copy_bits = 128
sA_layout = self._make_smem_layout_AB(
a.element_type,
self.a_major_mode,
ab_copy_bits,
(self.cta_tiler[0], self.cta_tiler[2], self.num_stages),
)
sB_layout = self._make_smem_layout_AB(
b.element_type,
self.b_major_mode,
ab_copy_bits,
(self.cta_tiler[1], self.cta_tiler[2], self.num_stages),
)
# Creates a similar layout but without num_stages or layout atoms
sC_layout = self._make_smem_layout_C(
c.element_type,
self.c_major_mode,
ab_copy_bits,
(self.cta_tiler[0], self.cta_tiler[1]),
)
# Shared memory allocated for operations with A, B will be
# overwritten for operations on C. This is to improve performance
# by reducing the size of shared memory requested by each block
smem_size = max(
cute.size_in_bytes(c.element_type, sC_layout),
cute.size_in_bytes(a.element_type, sA_layout)
+ cute.size_in_bytes(b.element_type, sB_layout),
)
# ///////////////////////////////////////////////////////////////////////////////
# Tiled copy:
# The majorness of tA/tB/tC follows the majorness of gA/gB/gC,
# enabling merged accesses to global memory for faster data
# transfer between global and shared memory.
# ///////////////////////////////////////////////////////////////////////////////
# Create a copy atom for a global to shared memory asynchronous copy
atom_async_copy = cute.make_copy_atom(
cute.nvgpu.cpasync.CopyG2SOp(
cache_mode=cute.nvgpu.cpasync.LoadCacheMode.GLOBAL
),
a.element_type,
num_bits_per_copy=ab_copy_bits,
)
# Create thread layouts for tiled copy from the copy atom where the
# thread layout simply follows the leading dimension of the tensor
tiled_copy_A = self._make_gmem_tiled_copy_AB(
atom_async_copy, a.element_type, self.a_major_mode, ab_copy_bits
)
tiled_copy_B = self._make_gmem_tiled_copy_AB(
atom_async_copy, b.element_type, self.b_major_mode, ab_copy_bits
)
# Creates a synchronous copy atom and thread layouts for the epilogue
c_copy_bits = 128
atom_sync_copy = cute.make_copy_atom(
cute.nvgpu.CopyUniversalOp(),
c.element_type,
num_bits_per_copy=c_copy_bits,
)
tiled_copy_C = self._make_gmem_tiled_copy_C(
atom_sync_copy, c.element_type, self.c_major_mode, c_copy_bits
)
# ///////////////////////////////////////////////////////////////////////////////
# Tiled MMA
# ///////////////////////////////////////////////////////////////////////////////
# Creates a mma atom with 16x8x16 shape for MNK
op = cute.nvgpu.warp.MmaF16BF16Op(
self.ab_dtype, self.acc_dtype, self.mma_inst_shape
)
permutation_mnk = (
self.atom_layout_mnk[0] * self.mma_inst_shape[0],
# if atom layout's N-mode is 1, to leverage the largest coalesced
# shared memory -> register copy, set the tiled mma's N mode to 16
self.atom_layout_mnk[1] * self.mma_inst_shape[1] * 2,
self.atom_layout_mnk[2] * self.mma_inst_shape[2],
)
# Created a tiled mma that tiles the atom according to specified layout.
# For a 2x2x1 atom layout, the mma atom is duplicated 4 times, twice
# across M and twice across N
tC = cute.make_layout(self.atom_layout_mnk)
tiled_mma = cute.make_tiled_mma(
op,
tC,
permutation_mnk=permutation_mnk,
)
# grid_dim: ((m + BLK_M - 1) // BLK_M, (n + BLK_N - 1) // BLK_N, l)
grid_dim = cute.ceil_div(c.shape, (self.bM, self.bN, 1))
# Add threadblock rasterization to improve re-use of data
raster_factor = 1
grid_dim_n = cute.size(grid_dim[1])
# Thresholds picked so that it doesn't cause too many no-op CTAs
if grid_dim_n > 5:
raster_factor = 8
elif grid_dim_n > 2:
raster_factor = 4
elif grid_dim_n > 1:
raster_factor = 2
rasterization_remap_grid_dim = (
cute.size(grid_dim[0]) * raster_factor,
(cute.size(grid_dim[1]) + raster_factor - 1) // raster_factor,
cute.size(grid_dim[2]),
)
self.kernel(
a,
b,
c,
sA_layout,
sB_layout,
sC_layout,
tiled_copy_A,
tiled_copy_B,
tiled_copy_C,
tiled_mma,
raster_factor,
epilogue_op,
).launch(
grid=rasterization_remap_grid_dim,
block=[self.num_threads, 1, 1],
smem=smem_size,
)
@cute.kernel
def kernel(
self,
mA: cute.Tensor,
mB: cute.Tensor,
mC: cute.Tensor,
sA_layout: cute.ComposedLayout,
sB_layout: cute.ComposedLayout,
sC_layout: cute.ComposedLayout,
tiled_copy_A: cute.TiledCopy,
tiled_copy_B: cute.TiledCopy,
tiled_copy_C: cute.TiledCopy,
tiled_mma: cute.TiledMma,
rasterization_factor: cutlass.Int32,
epilogue_op: cutlass.Constexpr = lambda x: x,
):
# Thread index, block index
tidx, _, _ = cute.arch.thread_idx()
bidx, bidy, bidz = cute.arch.block_idx()
grid_dim = cute.ceil_div(mC.shape, (self.bM, self.bN, 1))
offset_tile_x, offset_tile_y = self.raster_tile(
bidx, bidy, rasterization_factor
)
# Early exit if CTA is out of range
if grid_dim[0] <= offset_tile_x or grid_dim[1] <= offset_tile_y:
pass
else:
tiler_coord = (offset_tile_x, offset_tile_y, None)
# ///////////////////////////////////////////////////////////////////////////////
# Get the appropriate tiles for this thread block.
# gA: (BLK_M, BLK_N, k), gB: (BLK_N, BLK_K, k), gC: (BLK_M, BLK_N)
# ///////////////////////////////////////////////////////////////////////////////
gA = cute.local_tile(
mA[None, None, bidz],
tiler=self.cta_tiler,
coord=tiler_coord,
proj=(1, None, 1),
)
gB = cute.local_tile(
mB[None, None, bidz],
tiler=self.cta_tiler,
coord=tiler_coord,
proj=(None, 1, 1),
)
gC = cute.local_tile(
mC[None, None, bidz],
tiler=self.cta_tiler,
coord=tiler_coord,
proj=(1, 1, None),
)
# By default, if the tensor k mode does not divide into the tile k
# size, then last tiles in the k dimension are irregular.
# Instead, make the first tiles irregular when k is irregular.
# This allows us to handle the irregular tile first to avoid
# checking for this condition within the mainloop.
# residual_k is a negative number indicating the amount needed to
# shift the pointer by in dimension k
residual_k = cute.size(mA, mode=[1]) - cutlass.Int32(self.bK) * cute.size(
gA, mode=[2]
)
# move the pointer of gA/gB in the `-k` direction
gA = cute.domain_offset((0, residual_k, 0), gA)
gB = cute.domain_offset((0, residual_k, 0), gB)
# input is 16B aligned
gA = cute.make_tensor(gA.iterator.align(16), gA.layout)
gB = cute.make_tensor(gB.iterator.align(16), gB.layout)
# Construct identity layout for sA and sB (mirrors global tensors,
# used for predication only)
mcA = cute.make_identity_tensor(mA.layout.shape)
mcB = cute.make_identity_tensor(mB.layout.shape)
cA = cute.local_tile(
mcA[None, None, bidz],
tiler=self.cta_tiler,
coord=tiler_coord,
proj=(1, None, 1),
)
cB = cute.local_tile(
mcB[None, None, bidz],
tiler=self.cta_tiler,
coord=tiler_coord,
proj=(None, 1, 1),
)
cA = cute.domain_offset((0, residual_k, 0), cA)
cB = cute.domain_offset((0, residual_k, 0), cB)
# ///////////////////////////////////////////////////////////////////////////////
# Create shared memory buffers and get the appropriate fragments for this thread.
# sA: (BLK_M, BLK_K, PIPE) , sB: (BLK_N, BLK_K, PIPE)
# tAgA: (CPY, CPY_M, CPY_K, k) , tBgB: (CPY, CPY_N, CPY_K, k)
# tAsA: (CPY, CPY_M, CPY_K, PIPE) , tBsB: (CPY, CPY_N, CPY_K, PIPE)
# ///////////////////////////////////////////////////////////////////////////////
# Shared memory buffer
smem = cutlass.utils.SmemAllocator()
sA = smem.allocate_tensor(mA.element_type, sA_layout, 16)
sB = smem.allocate_tensor(mB.element_type, sB_layout, 16)
sC = cute.make_tensor(
cute.recast_ptr(sA.iterator, dtype=self.c_dtype), sC_layout
)
thr_copy_A = tiled_copy_A.get_slice(tidx)
thr_copy_B = tiled_copy_B.get_slice(tidx)
thr_copy_C = tiled_copy_C.get_slice(tidx)
tAgA = thr_copy_A.partition_S(gA)
tAsA = thr_copy_A.partition_D(sA)
tBgB = thr_copy_B.partition_S(gB)
tBsB = thr_copy_B.partition_D(sB)
tCsC_epilogue = thr_copy_C.partition_S(sC)
tCgC_epilogue = thr_copy_C.partition_D(gC)
# Repeat the partitioning with identity layouts
tAcA = thr_copy_A.partition_S(cA)
tBcB = thr_copy_B.partition_S(cB)
# ///////////////////////////////////////////////////////////////////////////////
# Predicate: Mark indices that need to copy when problem_shape isn't a multiple
# of tile_shape
# ///////////////////////////////////////////////////////////////////////////////
# For predication over the tensors A (M/K), B (N/K), and (in the
# epilogue) C (M/N), we will compute it in a fashion similar to an
# outer product. The predication along one of the dimensions is
# evaluated and stored in a predication tensor. Then, the
# predication for the remaining dimension is handled later via an
# if/else branch at the copy.
# For A and B, predication booleans along M/N are stored in a
# predication tensor and along K is handled via a if/else branch.
# Allocate predicate tensors for M and N. Predication is checked
# at the granularity of a copy atom, so the predicate tensor does not
# need separate booleans for individual elements within a copy
# atom (for example, the elements of tAgA.shape[0][0].)
tApA = cute.make_rmem_tensor(
cute.make_layout(
(
tAgA.shape[0][1],
cute.size(tAgA, mode=[1]),
cute.size(tAgA, mode=[2]),
),
stride=(cute.size(tAgA, mode=[1]), 1, 0),
),
cutlass.Boolean,
)
tBpB = cute.make_rmem_tensor(
cute.make_layout(
(
tBsB.shape[0][1],
cute.size(tBsB, mode=[1]),
cute.size(tBsB, mode=[2]),
),
stride=(cute.size(tBsB, mode=[1]), 1, 0),
),
cutlass.Boolean,
)
# Set predicates for M/N bounds
for rest_v in range(tApA.shape[0]):
for m in range(tApA.shape[1]):
tApA[rest_v, m, 0] = cute.elem_less(
tAcA[(0, rest_v), m, 0, 0][0], mA.shape[0]
)
for rest_v in range(tBpB.shape[0]):
for n in range(tBpB.shape[1]):
tBpB[rest_v, n, 0] = cute.elem_less(
tBcB[(0, rest_v), n, 0, 0][0], mB.shape[0]
)
# ///////////////////////////////////////////////////////////////////////////////
# Prefetch Prologue
# ///////////////////////////////////////////////////////////////////////////////
# Clear the smem tiles to account for predicated off loads
tAsA.fill(0)
tBsB.fill(0)
cute.arch.sync_threads()
# Start async loads for the first k-tile. Here we take care of the k residue
# via if/else check along the k dimension. Because we shifted the identity tensor
# by the residue_k and because the identity tensor is a coord tensor, the
# values of any identity tensor element that is poison is less than -1
num_smem_stages = cute.size(tAsA, mode=[3])
k_tile_count = cute.size(tAgA, mode=[3])
k_tile_index = cutlass.Int32(0)
for k in range(tApA.shape[2]):
if cute.elem_less(cutlass.Int32(-1), tAcA[0, 0, k, 0][1]):
cute.copy(
tiled_copy_A,
tAgA[None, None, k, k_tile_index],
tAsA[None, None, k, 0],
pred=tApA[None, None, k],
)
for k in range(tBpB.shape[2]):
if cute.elem_less(cutlass.Int32(-1), tBcB[0, 0, k, 0][1]):
cute.copy(
tiled_copy_B,
tBgB[None, None, k, k_tile_index],
tBsB[None, None, k, 0],
pred=tBpB[None, None, k],
)
k_tile_index = k_tile_index + 1
cute.arch.cp_async_commit_group()
# Start async loads for rest of the k-tiles
for k_tile in range(1, num_smem_stages - 1):
if k_tile == k_tile_count:
tApA.fill(0)
tBpB.fill(0)
cute.copy(
tiled_copy_A,
tAgA[None, None, None, k_tile_index],
tAsA[None, None, None, k_tile],
pred=tApA,
)
cute.copy(
tiled_copy_B,
tBgB[None, None, None, k_tile_index],
tBsB[None, None, None, k_tile],
pred=tBpB,
)
k_tile_index = k_tile_index + 1
cute.arch.cp_async_commit_group()
# ///////////////////////////////////////////////////////////////////////////////
# Tile MMA compute thread partitions and allocate accumulators
# ///////////////////////////////////////////////////////////////////////////////
thr_mma = tiled_mma.get_slice(tidx)
tCsA = thr_mma.partition_A(sA)
tCsB = thr_mma.partition_B(sB)
tCsC = thr_mma.partition_C(sC)
tCgC = thr_mma.partition_C(gC)
tCrA = tiled_mma.make_fragment_A(tCsA[None, None, None, 0])
tCrB = tiled_mma.make_fragment_B(tCsB[None, None, None, 0])
tCrC = tiled_mma.make_fragment_C(tCgC)
# Clear the accumulator
tCrC.fill(0.0)
# ///////////////////////////////////////////////////////////////////////////////
# Copy Atom A/B retiling
# ///////////////////////////////////////////////////////////////////////////////
# Create the copy atoms for the copy from shared memory to register
atom_copy_s2r_A = cute.make_copy_atom(
cute.nvgpu.warp.LdMatrix8x8x16bOp(
self.a_major_mode != utils.LayoutEnum.ROW_MAJOR, 4
),
mA.element_type,
)
atom_copy_s2r_B = cute.make_copy_atom(
cute.nvgpu.warp.LdMatrix8x8x16bOp(
self.b_major_mode != utils.LayoutEnum.ROW_MAJOR, 4
),
mB.element_type,
)
# Creates the tiled copy so that it matches the thread-value layout
# expected by the tiled mma
tiled_copy_s2r_A = cute.make_tiled_copy_A(atom_copy_s2r_A, tiled_mma)
tiled_copy_s2r_B = cute.make_tiled_copy_B(atom_copy_s2r_B, tiled_mma)
thr_copy_ldmatrix_A = tiled_copy_s2r_A.get_slice(tidx)
thr_copy_ldmatrix_B = tiled_copy_s2r_B.get_slice(tidx)
tCsA_copy_view = thr_copy_ldmatrix_A.partition_S(sA)
tCrA_copy_view = thr_copy_ldmatrix_A.retile(tCrA)
tCsB_copy_view = thr_copy_ldmatrix_B.partition_S(sB)
tCrB_copy_view = thr_copy_ldmatrix_B.retile(tCrB)
# Current pipe index in smem to read from / write to
smem_pipe_read = 0
smem_pipe_write = num_smem_stages - 1
tCsA_p = tCsA_copy_view[None, None, None, smem_pipe_read]
tCsB_p = tCsB_copy_view[None, None, None, smem_pipe_read]
# ///////////////////////////////////////////////////////////////////////////////
# PREFETCH register pipeline
# ///////////////////////////////////////////////////////////////////////////////
num_k_block = cute.size(tCrA, mode=[2])
if num_k_block > 1:
# Wait until our first prefetched tile is loaded in
cute.arch.cp_async_wait_group(num_smem_stages - 2)
cute.arch.sync_threads()
# Prefetch the first k-block rmem from the first k-tile
cute.copy(
tiled_copy_s2r_A,
tCsA_p[None, None, 0],
tCrA_copy_view[None, None, 0],
)
cute.copy(
tiled_copy_s2r_B,
tCsB_p[None, None, 0],
tCrB_copy_view[None, None, 0],
)
# ///////////////////////////////////////////////////////////////////////////////
# Mainloop
# 1. Shared memory pipeline (gmem -> smem):
# The default smem pipeline depth is 3, meaning that for shared
# memory buffers, we allocate three times the size described by the
# CTA tiler. We prefetch 2 of these buffers before entering the main
# loop. Considering only the transfer from global memory to shared
# memory, the general structure of the mainloop is:
# (1) copy k-tile from gmem to smem;
# (2) perform gemm computation on k-tile;
# (3) wait for the next copy to finish.
# The `cute.arch.cp_async_wait_group(num_smem_stages - 2)` command
# waits for the number of unfinished 'copy' to be <= 1. The advantage
# of this approach is that it allows for simultaneous production
# (i.e., step (1)) and consumption (i.e., step (2)) of smem.
# A common misconception is to prefetch N buffers and rewrite
# the pipeline logic to wait on N-1 pending copies. The disadvantage
# of this approach is that it requires fully consuming a buffer in
# order to open an empty buffer for the next copy.
# 2. Register pipeline (smem -> register):
# Similarly, the register pipeline produces i+1, consumes i, and
# produces i+2... Notably, i and i+1 do not use the same register,
# eliminating dependencies on the same register for better parallelism.
# 3. Combining the smem and register pipelines results in the mainloop.
# ///////////////////////////////////////////////////////////////////////////////
for k_tile in range(k_tile_count):
for k_block in cutlass.range(num_k_block, unroll_full=True):
if k_block == num_k_block - 1:
tCsA_p = tCsA_copy_view[None, None, None, smem_pipe_read]
tCsB_p = tCsB_copy_view[None, None, None, smem_pipe_read]
cute.arch.cp_async_wait_group(num_smem_stages - 2)
cute.arch.sync_threads()
# Load A, B from shared memory to registers for k_block + 1
k_block_next = (k_block + 1) % num_k_block # static
cute.copy(
tiled_copy_s2r_A,
tCsA_p[None, None, k_block_next],
tCrA_copy_view[None, None, k_block_next],
)
cute.copy(
tiled_copy_s2r_B,
tCsB_p[None, None, k_block_next],
tCrB_copy_view[None, None, k_block_next],
)
# Fetch next A: To better interleave global memory access and compute
# instructions, we intentionally use the sequence: copy A, perform GEMM,
# then copy B.
if k_block == 0:
if k_tile + num_smem_stages - 1 < k_tile_count:
cute.copy(
tiled_copy_A,
tAgA[None, None, None, k_tile_index],
tAsA[None, None, None, smem_pipe_write],
pred=tApA,
)
# Thread-level register gemm for k_block
cute.gemm(
tiled_mma,
tCrC,
tCrA[None, None, k_block],
tCrB[None, None, k_block],
tCrC,
)
# Fetch next B and update smem pipeline read/write
if k_block == 0:
if k_tile + num_smem_stages - 1 < k_tile_count:
cute.copy(
tiled_copy_B,
tBgB[None, None, None, k_tile_index],
tBsB[None, None, None, smem_pipe_write],
pred=tBpB,
)
k_tile_index = k_tile_index + 1
cute.arch.cp_async_commit_group()
smem_pipe_write = smem_pipe_read
smem_pipe_read = smem_pipe_read + 1
if smem_pipe_read == num_smem_stages:
smem_pipe_read = 0
# Sync before epilogue
cute.arch.cp_async_wait_group(0)
cute.arch.sync_threads()
# ///////////////////////////////////////////////////////////////////////////////
# Epilogue with fusion
# ///////////////////////////////////////////////////////////////////////////////
tCrD = cute.make_fragment_like(tCrC, self.c_dtype)
tCrD[None] = epilogue_op(tCrC.load()).to(self.c_dtype)
# Copy results of D back to shared memory
cute.autovec_copy(tCrD, tCsC)
# Create coord tensor for C
ceilM, ceilN, _ = cute.ceil_div(mC.shape, (self.bM, self.bN, 1))
mcC = cute.make_identity_tensor(
(
cute.size(ceilM) * self.cta_tiler[0],
cute.size(ceilN) * self.cta_tiler[1],
1,
)
)
cC = cute.local_tile(
mcC[None, None, bidz],
tiler=self.cta_tiler,
coord=tiler_coord,
proj=(1, 1, None),
)
tCcC = thr_copy_C.partition_S(cC)
tCrC_epilogue = cute.make_fragment_like(tCsC_epilogue)
# Wait for all writes to shared memory to finish before starting copies
# using the new layouts
cute.arch.sync_threads()
cute.autovec_copy(tCsC_epilogue, tCrC_epilogue)
# Create predication tensor for m
tCpC = cute.make_rmem_tensor(
cute.make_layout(
(
tCgC_epilogue.shape[0][1],
cute.size(tCgC_epilogue, mode=[1]),
cute.size(tCgC_epilogue, mode=[2]),
),
stride=(cute.size(tCgC_epilogue, mode=[1]), 1, 0),
),
cutlass.Boolean,
)
for rest_v in range(tCpC.shape[0]):
for m in range(tCpC.shape[1]):
tCpC[rest_v, m, 0] = cute.elem_less(
tCcC[(0, rest_v), m, 0][0], mC.shape[0]
)
# Copy to global memory using better vectorization
for rest_v in range(tCpC.shape[0]):
for n in range(tCpC.shape[2]):
if cute.elem_less(tCcC[(0, rest_v), 0, n][1], mC.shape[1]):
cute.copy(
tiled_copy_C,
tCrC_epilogue[None, None, n],
tCgC_epilogue[None, None, n],
pred=tCpC[None, None, n],
)
return
def _make_smem_layout_AB(self, dtype, major_mode, copy_bits, smem_tiler):
major_mode_size = (
smem_tiler[1] if major_mode == utils.LayoutEnum.ROW_MAJOR else smem_tiler[0]
)
major_mode_size = 64 if major_mode_size >= 64 else major_mode_size
swizzle_bits = int(math.log2(major_mode_size * dtype.width // copy_bits))
swizzle_bits = min(swizzle_bits, 3)
layout_atom_outer = (
cute.make_layout((8, major_mode_size), stride=(major_mode_size, 1))
if major_mode == utils.LayoutEnum.ROW_MAJOR
else cute.make_layout((major_mode_size, 8), stride=(1, major_mode_size))
)
layout_atom = cute.make_composed_layout(
cute.make_swizzle(swizzle_bits, 3, 3),
0,
layout_atom_outer,
)
layout = cute.tile_to_shape(layout_atom, smem_tiler, (0, 1, 2))
return layout
def _make_smem_layout_C(self, dtype, major_mode, copy_bits, smem_tiler):
major_mode_size = (
smem_tiler[1] if major_mode == utils.LayoutEnum.ROW_MAJOR else smem_tiler[0]
)
swizzle_bits = int(math.log2(major_mode_size * dtype.width // copy_bits))
swizzle_bits = min(swizzle_bits, 3)
layout_atom_outer = (
cute.make_layout((8, major_mode_size), stride=(major_mode_size, 1))
if major_mode == utils.LayoutEnum.ROW_MAJOR
else cute.make_layout((major_mode_size, 8), stride=(1, major_mode_size))
)
layout_atom = cute.make_composed_layout(
cute.make_swizzle(swizzle_bits, 3, 4),
0,
layout_atom_outer,
)
# Due to the thread layout of the mma, remove swizzle in C to
# prevent shared memory fragments owned by an single thread from
# holding swizzles
if major_mode == utils.LayoutEnum.COL_MAJOR:
layout_atom = cute.make_composed_layout(
cute.make_swizzle(0, 3, 4), 0, layout_atom_outer
)
layout = cute.tile_to_shape(
layout_atom,
smem_tiler,
(0, 1),
)
return layout
def _make_gmem_tiled_copy_AB(self, atom_copy, dtype, major_mode, copy_bits):
copy_elems = copy_bits // dtype.width
shape_dim_1 = cute.size(self.bK) // copy_elems
# thread layout for copy
thread_layout = cute.make_layout(
(self.num_threads // shape_dim_1, shape_dim_1), stride=(shape_dim_1, 1)
)
if major_mode != utils.LayoutEnum.ROW_MAJOR:
shape_dim_0 = cute.size(self.bM) // copy_elems
thread_layout = cute.make_layout(
(shape_dim_0, self.num_threads // shape_dim_0), stride=(1, shape_dim_0)
)
# Value layout for copy
value_layout = (
cute.make_layout((1, copy_elems))
if major_mode == utils.LayoutEnum.ROW_MAJOR
else cute.make_layout((copy_elems, 1))
)
return cute.make_tiled_copy_tv(atom_copy, thread_layout, value_layout)
def _make_gmem_tiled_copy_C(self, atom_copy, dtype, major_mode, copy_bits):
copy_elems = copy_bits // dtype.width
shape_dim_1 = cute.size(self.bN) // copy_elems
# thread layout for copy
thread_layout = cute.make_layout(
(self.num_threads // shape_dim_1, shape_dim_1), stride=(shape_dim_1, 1)
)
if major_mode != utils.LayoutEnum.ROW_MAJOR:
shape_dim_0 = cute.size(self.bM) // copy_elems
thread_layout = cute.make_layout(
(shape_dim_0, self.num_threads // shape_dim_0), stride=(1, shape_dim_0)
)
value_layout = (
cute.make_layout((1, copy_elems))
if major_mode == utils.LayoutEnum.ROW_MAJOR
else cute.make_layout((copy_elems, 1))
)
return cute.make_tiled_copy_tv(atom_copy, thread_layout, value_layout)
def raster_tile(self, i, j, f):
new_i = i // f
new_j = (i % f) + (j * f)
return (new_i, new_j)

View File

@@ -60,7 +60,6 @@ class PersistentDenseGemmKernel(CuteDslKernel):
- i.e., Float8E4M3FN for A and Float8E5M2 for B is not supported
:note: Supported A/B data types:
- TFloat32
- Float16/BFloat16
- Int8/Uint8
- Float8E4M3FN/Float8E5M2
@@ -85,7 +84,7 @@ class PersistentDenseGemmKernel(CuteDslKernel):
"""
def __init__(self, metadata: KernelMetadata):
self.metadata = metadata
super().__init__(metadata)
def epilogue_op(x):
return x
@@ -106,6 +105,7 @@ class PersistentDenseGemmKernel(CuteDslKernel):
)
def _supports(self, args: GemmArguments) -> Status:
if args.epilogue is not None:
return Status.fail("This kernel does not support any epilogue fusion.")
return Status.success()
@@ -149,12 +149,10 @@ class PersistentDenseGemmKernel(CuteDslKernel):
abtype = operands.A.dtype
# Supported A/B data types:
# - TFloat32
# - Float16/BFloat16
# - Int8/Uint8
# - Float8E4M3FN/Float8E5M2
if abtype not in [
cutlass.Float32,
cutlass.Float16,
cutlass.BFloat16,
cutlass.Int8,
@@ -197,7 +195,11 @@ class PersistentDenseGemmKernel(CuteDslKernel):
operands.out.dtype == cutlass.Float16
or operands.out.dtype == cutlass.BFloat16
):
if operands.accumulator_type not in [cutlass.Float16, cutlass.BFloat16]:
if operands.accumulator_type not in [
cutlass.Float32,
cutlass.Float16,
cutlass.BFloat16,
]:
return False
elif operands.out.dtype == cutlass.Int8 or operands.out.dtype == cutlass.Uint8:
if operands.accumulator_type not in [cutlass.Int32]:
@@ -221,7 +223,6 @@ class PersistentDenseGemmKernel(CuteDslKernel):
"""
# Supported A/B data types (must be the same)
ab_dtypes = [
cutlass.Float32,
cutlass.Float16,
cutlass.BFloat16,
cutlass.Int8,
@@ -232,15 +233,13 @@ class PersistentDenseGemmKernel(CuteDslKernel):
row_major_stride = (0, 0, 1)
col_major_stride = (0, 1, 0)
alignment = 16
alignment_bytes = 16
for ab_dtype in ab_dtypes:
# Determine valid accumulator types for this A/B dtype
valid_acc_dtypes = []
if (
ab_dtype.is_float
): # Float32, Float16, BFloat16, Float8E4M3FN, Float8E5M2
if ab_dtype.is_float:
valid_acc_dtypes.append(cutlass.Float32)
if ab_dtype in [
cutlass.Float16,
@@ -277,15 +276,23 @@ class PersistentDenseGemmKernel(CuteDslKernel):
for stride_A, stride_B, stride_out in itertools.product(
[row_major_stride, col_major_stride], repeat=3
):
ab_divisibility = alignment_bytes * 8 // ab_dtype.width
out_divisibility = alignment_bytes * 8 // out_dtype.width
# Create TensorAttributes for A, B, and out tensors
a_attrs = TensorAttributes(
dtype=ab_dtype, stride=stride_A, alignment=alignment
dtype=ab_dtype,
stride=stride_A,
divisibility=ab_divisibility,
)
b_attrs = TensorAttributes(
dtype=ab_dtype, stride=stride_B, alignment=alignment
dtype=ab_dtype,
stride=stride_B,
divisibility=ab_divisibility,
)
out_attrs = TensorAttributes(
dtype=out_dtype, stride=stride_out, alignment=alignment
dtype=out_dtype,
stride=stride_out,
divisibility=out_divisibility,
)
# Create and yield the GemmOperandsMetadata
@@ -316,26 +323,26 @@ class PersistentDenseGemmKernel(CuteDslKernel):
if cluster_size_m * cluster_size_n > 16:
return False
tile = design.tile_shape
# Constraints based on whether 2CTA instructions are used
if design.use_2cta_mma is not None:
if design.use_2cta_mma:
if cluster_size_m % 2 != 0:
return False
if design.tile_shape is not None and design.tile_shape[0] not in [
if tile is not None and tile[0] not in [
128,
256,
]:
return False
else:
if design.tile_shape is not None and design.tile_shape[0] not in [
if tile is not None and tile[0] not in [
64,
128,
]:
return False
if design.tile_shape is not None and design.tile_shape[1] not in range(
32, 256, 32
):
if tile is not None and tile[1] not in range(32, 257, 32):
return False
if metadata.epilogue is not None:
@@ -407,8 +414,6 @@ class PersistentDenseGemmKernel(CuteDslKernel):
if PersistentDenseGemmKernel._valid_metadata(
metadata
) and metadata_filter(metadata):
kernel_list.append(
PersistentDenseGemmKernel(metadata)
)
kernel_list.append(PersistentDenseGemmKernel(metadata))
return kernel_list

View File

@@ -128,7 +128,7 @@ class PersistentDenseGemmEFCKernel(CuteDslKernel):
}
def __init__(self, metadata: KernelMetadata):
self.metadata = metadata
super().__init__(metadata)
if metadata.epilogue is not None:
epilogue_op = EFCConverter.convert(
@@ -249,7 +249,7 @@ class PersistentDenseGemmEFCKernel(CuteDslKernel):
"""
row_major_stride = (0, 0, 1)
col_major_stride = (0, 1, 0)
alignment = 16
alignment_bytes = 16
for (
ab_dtype,
@@ -281,15 +281,23 @@ class PersistentDenseGemmEFCKernel(CuteDslKernel):
for stride_A, stride_B, stride_out in itertools.product(
[row_major_stride, col_major_stride], repeat=3
):
ab_divisibility = alignment_bytes * 8 // ab_dtype.width
out_divisibility = alignment_bytes * 8 // out_dtype.width
# Create TensorAttributes for A, B, and out tensors
a_attrs = TensorAttributes(
dtype=ab_dtype, stride=stride_A, alignment=alignment
dtype=ab_dtype,
stride=stride_A,
divisibility=ab_divisibility,
)
b_attrs = TensorAttributes(
dtype=ab_dtype, stride=stride_B, alignment=alignment
dtype=ab_dtype,
stride=stride_B,
divisibility=ab_divisibility,
)
out_attrs = TensorAttributes(
dtype=out_dtype, stride=stride_out, alignment=alignment
dtype=out_dtype,
stride=stride_out,
divisibility=out_divisibility,
)
# Create and yield the GemmOperandsMetadata
@@ -303,10 +311,13 @@ class PersistentDenseGemmEFCKernel(CuteDslKernel):
yield operands
def _supports(self, args: GemmArguments) -> Status:
if args.epilogue is not None:
fusion_metadata = EpilogueMetadata.from_args(args.epilogue)
if not self._valid_fusion(fusion_metadata):
return Status.fail("Provided epilogue fusion is not supported by this kernel")
return Status.fail(
"Provided epilogue fusion is not supported by this kernel"
)
return Status.success()
@@ -436,8 +447,6 @@ class PersistentDenseGemmEFCKernel(CuteDslKernel):
if PersistentDenseGemmEFCKernel._valid_metadata(
metadata
) and metadata_filter(metadata):
kernel_list.append(
PersistentDenseGemmEFCKernel(metadata)
)
kernel_list.append(PersistentDenseGemmEFCKernel(metadata))
return kernel_list

View File

@@ -0,0 +1,273 @@
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
import itertools
from collections.abc import Callable, Generator
import cutlass
from cutlass_api.arguments import (
EpilogueArguments,
GemmArguments,
)
from cutlass_api.artifact import CompiledArtifact
from cutlass_api.metadata import (
GemmOperandsMetadata,
KernelMetadata,
TensorAttributes,
)
from cutlass_api.providers.cutedsl import CuTeDSLProvider
from cutlass_api.providers.cutedsl.kernel import CuteDslKernel
from cutlass_api.status import Status
from cutlass_api.utils import strides_to_layout_string, to_cuda_stream, tuple_to_string
from cutlass_api.metadata import BLASDesignMetadata
from .implementations.sm80_tensorop_gemm_impl import Sm80TensorOpGemmImpl
@CuTeDSLProvider.register
class Sm80TensorOpGemmKernel(CuteDslKernel):
"""This class implements batched matrix multiplication (C = A @ B)
:note: In current version, A and B tensor must have the same data type
:note: Supported A/B data types:
- Float16/BFloat16
:note: Supported accumulator data types:
- Float32 (for all floating point A/B data types)
:note: Supported C data types:
- Float16/BFloat16 (same as A, B)
:note: Constraints:
- MMA tiler M must be 64/128
- MMA tiler N must be 64/128
"""
def __init__(self, metadata: KernelMetadata):
super().__init__(metadata)
self.impl = Sm80TensorOpGemmImpl(
metadata.operands.A.dtype,
metadata.operands.B.dtype,
metadata.operands.out.dtype,
metadata.operands.accumulator_type
)
def _supports(self, args: GemmArguments) -> Status:
if args.epilogue is not None:
return Status.fail("This kernel does not support any epilogue fusion.")
return Status.success()
def compile(self, args: GemmArguments, cc: int = None) -> CompiledArtifact:
stream = cutlass.cute.runtime.make_fake_stream()
compiled_kernel = self.cute_compile(
self.impl,
args.A,
args.B,
args.out,
stream,
)
return CompiledArtifact(compiled_kernel, self)
def _run(
self,
args: GemmArguments,
compiled_artifact: CompiledArtifact,
stream,
workspace=None,
) -> None:
stream = to_cuda_stream(stream)
compiled_gemm = compiled_artifact.compiled_obj
self.cute_run(compiled_gemm, args.A, args.B, args.out, stream)
@staticmethod
def _valid_operands(operands: GemmOperandsMetadata) -> bool:
if not isinstance(operands, GemmOperandsMetadata):
return False
# In current version, A and B tensor must have the same data type
if operands.A.dtype != operands.B.dtype:
return False
# Supported A/B data types:
if operands.A.dtype not in [cutlass.Float16, cutlass.BFloat16]:
return False
# Supported accumulator data types:
if operands.accumulator_type not in [cutlass.Float32]:
return False
# Supported out data types:
if operands.out.dtype not in [cutlass.Float32, cutlass.Float16, cutlass.BFloat16]:
return False
return True
@staticmethod
def _metadata_operand_combinations() -> Generator[GemmOperandsMetadata, None, None]:
"""
Generator that yields all valid GemmOperandsMetadata combinations
based on the validation rules in _valid_operands.
"""
# Supported A/B data types (must be the same)
ab_dtypes = [cutlass.Float16, cutlass.BFloat16]
valid_acc_dtypes = [cutlass.Float32]
valid_out_dtypes = ab_dtypes
# Torch conventions (L, M, K) and (L, K, N)
row_major_stride = (0, 0, 1)
col_major_stride = (0, 1, 0)
alignment_bytes = 16
for ab_dtype in ab_dtypes:
for out_dtype in valid_out_dtypes:
for acc_dtype in valid_acc_dtypes:
for stride_A, stride_B, stride_out in itertools.product(
[row_major_stride, col_major_stride], repeat=3
):
ab_divisibility = alignment_bytes * 8 // ab_dtype.width
out_divisibility = alignment_bytes * 8 // out_dtype.width
# Create TensorAttributes for A, B, and out tensors
a_attrs = TensorAttributes(
dtype=ab_dtype,
stride=stride_A,
divisibility=ab_divisibility,
)
b_attrs = TensorAttributes(
dtype=ab_dtype,
stride=stride_B,
divisibility=ab_divisibility,
)
out_attrs = TensorAttributes(
dtype=out_dtype,
stride=stride_out,
divisibility=out_divisibility,
)
# Create and yield the GemmOperandsMetadata
operands = GemmOperandsMetadata(
A=a_attrs,
B=b_attrs,
out=out_attrs,
accumulator_type=acc_dtype,
)
yield operands
@staticmethod
def _valid_metadata(metadata: KernelMetadata) -> bool:
if not Sm80TensorOpGemmKernel._valid_operands(metadata.operands):
return False
design = metadata.design
if not isinstance(design, BLASDesignMetadata):
return False
if design.tile_shape is None:
return False
# MMA tiler N must be 32/64/128/256
if design.tile_shape[1] not in [32, 64, 128]:
return False
# MMA tiler M must be 64/128/256
if design.tile_shape[0] not in [64, 128]:
return False
if metadata.epilogue is not None:
return False
return True
@staticmethod
def generate_kernels(
metadata_filter: Callable[[KernelMetadata], bool],
epilogue_args: EpilogueArguments = None,
cc: int = None,
) -> list["Sm80TensorOpGemmKernel"]:
"""
Returns a list of all possible configurations of Sm80TensorOpGemmKernel that
adhere to constraints passed in under kwargs.
"""
min_cc = 80
if cc is not None and cc < min_cc:
return []
design_params = {
"tile_shape": [
(M, N, 32) for M in [64, 128] for N in [64, 128]
],
# SM80 kernels do not currently use cluster_shape for tuning; fix it to a
# single valid value to satisfy the BLASDesignMetadata interface.
"cluster_shape": [(1, 1, 1)],
}
if epilogue_args is not None:
return []
from itertools import product
# Get the list of tunable parameter names and their possible values
param_names = list(design_params.keys())
param_values = [design_params[name] for name in param_names]
kernel_list = []
for operands in Sm80TensorOpGemmKernel._metadata_operand_combinations():
for values in product(*param_values):
design = BLASDesignMetadata(**dict(zip(param_names, values)))
kernel_name = "cutedsl.Sm80TensorOpGemmKernel_{layout}_A{A}_B{B}_out{out}_acc{acc}_tile{tile}".format(
layout=strides_to_layout_string(
operands.A.stride, operands.B.stride, operands.out.stride
),
A=operands.A.dtype,
B=operands.B.dtype,
out=operands.out.dtype,
acc=operands.accumulator_type,
tile=tuple_to_string(design.tile_shape),
)
metadata = KernelMetadata(
operands=operands,
design=design,
kernel_name=kernel_name,
kernel_class=Sm80TensorOpGemmKernel,
min_cc=min_cc,
epilogue=None,
)
if Sm80TensorOpGemmKernel._valid_metadata(
metadata
) and metadata_filter(metadata):
kernel_list.append(
Sm80TensorOpGemmKernel(metadata)
)
return kernel_list

View File

@@ -26,12 +26,12 @@
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
import cutlass.cute as cute
from cutlass_api.config import GlobalOptions
from cutlass_api.kernel import Kernel
from cutlass_api.utils import TensorWrapper
import cutlass.cute as cute
class CuteDslKernel(Kernel):
"""

View File

@@ -58,4 +58,3 @@ class Status:
"""Raise the stored exception if this status represents a failure."""
if self.error is not None:
raise self.error

View File

@@ -308,21 +308,54 @@ def leading_dim(tensor) -> int:
def get_stride_order(stride: tuple[int, ...]) -> tuple[int, ...]:
"""
Returns the order of the stride of a tensor. For a stride of rank N,
the dimension with the smallest stride will have stride order 0 and the
dimension with the largest stride will have stride order N-1.
Returns the order of the modes of a stride. The returned tuple contains
indices of modes in `stride`, sorted from outermost to innermost.
For example, for a stride (1024, 1, 512), the stride order is (0, 2, 1).
:param stride: The stride of the tensor.
:type stride: tuple[int, ...]
:return: The order of the stride of the tensor.
:return: The order of the modes of the stride.
:rtype: tuple[int, ...]
"""
# The code below performs an == reverse argsort on the stride:
# indices = range(len(stride))
# Sort indices using comparison between stride[i] and stride[j] when
# sorting indices i and j. Sort in descending order.
# Example: For a stride (1024, 1, 512), the reverse argsort is (0, 2, 1).
return tuple(sorted(range(len(stride)), key=stride.__getitem__, reverse=True))
def get_stride_rank(stride: tuple[int, ...]) -> tuple[int, ...]:
"""
Returns the rank of the each mode in the stride of a tensor. For a stride of rank N,
the mode with the smallest stride will have stride rank 0 and the
mode with the largest stride will have stride rank N-1.
For example, for a stride (1024, 1, 512), the stride rank is (2, 0, 1).
:param stride: The stride of the tensor.
:type stride: tuple[int, ...]
:return: The rank of the each mode in the stride of the tensor.
:rtype: tuple[int, ...]
"""
# The code below performs an argsort on the stride:
# indices = range(len(stride))
# Sort indices using comparison between stride[i] and stride[j] when
# sorting indices i and j.
return tuple(sorted(range(len(stride)), key=stride.__getitem__))
# Example: For a stride (1024, 1, 512), the argsorted is (1, 2, 0).
argsorted = tuple(sorted(range(len(stride)), key=stride.__getitem__))
# Set the stride order of each mode to the index of the mode in the
# argsorted tuple.
# Example: For a stride (1024, 1, 512), the argsorted is (1, 2, 0),
# so the stride rank is (2, 0, 1).
res = [-1] * len(stride)
for i, idx in enumerate(argsorted):
res[idx] = i
return tuple(res)
class TensorWrapper:
@@ -353,6 +386,7 @@ class TensorWrapper:
self.compile_time_tensor = tensor
self._shape = tensor.shape
self._stride = tensor.stride
self._data_ptr = tensor.iterator._pointer
elif GlobalOptions().use_tvm_ffi:
# If TVM-FFI is enabled, runtime tensor is set simply as the tensor passed in, but
# we must make a fake tensor for compilation.
@@ -362,28 +396,48 @@ class TensorWrapper:
rank = self.runtime_tensor.dim()
self._stride = self.runtime_tensor.stride()
stride_order = get_stride_order(self._stride)
stride_order = get_stride_rank(self._stride)
leading_dim_idx = stride_order.index(0)
shape = [cute.SymInt() for _ in range(rank)]
shape[stride_order.index(0)] = cute.SymInt(divisibility=16)
shape[leading_dim_idx] = cute.SymInt(divisibility=16 * 8 // dtype.width)
self._shape = tuple(self.runtime_tensor.shape)
self._data_ptr = self.runtime_tensor.data_ptr()
else:
raise ValueError(
f"Unsupported tensor type: {type(self.runtime_tensor)}"
)
self.compile_time_tensor = cute.runtime.make_fake_compact_tensor(
dtype, shape, stride_order=stride_order, assumed_align=16
dtype,
shape,
stride_order=stride_order,
assumed_align=16, # bytes
)
else:
# TVM-FFI is disabled and the tensor passed in is not a cute.Tensor,
# We must convert it to a cute.Tensor
self.runtime_tensor = from_dlpack(
tensor, assumed_align=16
).mark_layout_dynamic(leading_dim(tensor))
if is_torch_tensor(tensor):
dtype = to_cutlass_type(tensor.dtype)
stride = tensor.stride()
else:
raise ValueError(f"Unsupported tensor type: {type(tensor)}")
stride_order = get_stride_order(stride)
self.runtime_tensor = (
from_dlpack(
tensor,
assumed_align=16, # bytes
)
.mark_layout_dynamic(leading_dim(tensor))
.mark_compact_shape_dynamic(
mode=leading_dim(tensor),
divisibility=16 * 8 // dtype.width,
stride_order=stride_order,
)
)
self._shape = self.runtime_tensor.shape
self._stride = self.runtime_tensor.stride
self._data_ptr = self.runtime_tensor.iterator._pointer
# Since the runtime tensor is now a cute.Tensor, we can use it at
# compile time as well
@@ -401,6 +455,10 @@ class TensorWrapper:
def stride(self) -> tuple[int, ...]:
return self._stride
@property
def data_ptr(self) -> int:
return self._data_ptr
def strides_to_layout_string(*strides: list[tuple[int, ...]]) -> str:
"""

View File

@@ -31,9 +31,9 @@
"\n",
"import cutlass_api\n",
"\n",
"if not (status := cutlass_api.utils.is_device_cc_supported({100, 103})):\n",
"if not (status := cutlass_api.utils.is_device_cc_supported({80, 90, 100, 103})):\n",
" print(\n",
" f\"This notebook requires a GPU with compute capability 100 or 103.\\n{status.error}\"\n",
" f\"This notebook requires a GPU with compute capability >= 80.\\n{status.error}\"\n",
" )\n",
" import sys\n",
"\n",
@@ -67,7 +67,7 @@
"source": [
"M, N, K, L = 128, 256, 64, 2\n",
"ab_type = torch.float16\n",
"out_type = torch.float32\n",
"out_type = torch.float16\n",
"acc_type = torch.float32\n",
"\n",
"A = torch.randint(-1, 2, (L, M, K), device=\"cuda\", dtype=ab_type)\n",
@@ -118,9 +118,9 @@
"metadata": {},
"outputs": [],
"source": [
"kernels = cutlass_api.get_kernels(args)\n",
"cc = cutlass_api.utils.device_cc()\n",
"kernels = cutlass_api.get_kernels(args, cc=cc)\n",
"assert kernels, \"No kernels found for the given arguments!\"\n",
"\n",
"kernel = kernels[0]"
]
},
@@ -148,28 +148,6 @@
"torch.testing.assert_close(out, reference)"
]
},
{
"cell_type": "markdown",
"id": "4d7ad85b",
"metadata": {},
"source": [
"One can also explicitly compile the kernel and pass this in to `kernel.run` to avoid\n",
"JIT compilation on future invocations. Additional details related to this will be\n",
"described below."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "06f9f844",
"metadata": {},
"outputs": [],
"source": [
"artifact = kernel.compile(args)\n",
"kernel.run(args, compiled_artifact=artifact)\n",
"torch.testing.assert_close(out, reference)"
]
},
{
"cell_type": "markdown",
"id": "630e9e4b",
@@ -179,7 +157,7 @@
"\n",
"---\n",
"\n",
"### Understanding the core interfaces"
"## Understanding the core interfaces"
]
},
{
@@ -187,7 +165,7 @@
"id": "2d8b8e94",
"metadata": {},
"source": [
"#### 1. `RuntimeArguments` / `GemmArguments`\n",
"### 1. `RuntimeArguments` / `GemmArguments`\n",
"\n",
"`RuntimeArguments` describe the operation a user wants to perform, and all the runtime operands or other runtime parameters needed for it. \n",
"This includes primary runtime operands to the operation, as well as any custom epilogue fusions and runtime performance knobs.\n",
@@ -220,7 +198,7 @@
"id": "e7eda0dd",
"metadata": {},
"source": [
"#### 2. Kernel Discovery\n",
"### 2. Kernel Discovery\n",
"\n",
"There are several kernels available in CUTLASS DSLs that are registered with, and discoverable via, the CUTLASS API.\n",
"\n",
@@ -254,6 +232,11 @@
"kernels = cutlass_api.get_kernels(args)\n",
"print(f\"Of these, {len(kernels)} support the given arguments.\")\n",
"\n",
"# we can limit the search to kernels supporting given args + current device compute capability\n",
"cc = cutlass_api.utils.device_cc()\n",
"kernels = cutlass_api.get_kernels(args, cc=cc)\n",
"print(f\"Of these, {len(kernels)} support the given arguments.\")\n",
"\n",
"kernel = kernels[0]\n",
"print(f\"Picked kernel with name: {kernel.metadata.kernel_name}\")"
]
@@ -263,7 +246,7 @@
"id": "252a4d38",
"metadata": {},
"source": [
"#### 3. `Kernel` execution"
"### 3. `Kernel` execution"
]
},
{
@@ -286,7 +269,8 @@
"id": "e8945aa6",
"metadata": {},
"source": [
"* `kernel.supports(args)` checks if the kernel supports the given `args`\n",
"#### Verify that the kernel supports the given `args`\n",
"`kernel.supports(args)` checks if the kernel supports the given `args`\n",
" * this is relevant if the kernel was not picked just for these `args`"
]
},
@@ -338,7 +322,9 @@
"id": "c2db8f20",
"metadata": {},
"source": [
"* `kernel.compile(args)` compiles the kernel, and returns a `CompiledArtifact`\n",
"#### JIT compiling the kernel\n",
"\n",
"`kernel.compile(args)` compiles the kernel, and returns a `CompiledArtifact`\n",
"\n",
"This compiled artifact is a lightweight wrapper over the result of compiling a kernel (e.g., via `cute.compile()`).\n",
"\n",
@@ -361,7 +347,8 @@
"id": "4dfb8d51",
"metadata": {},
"source": [
"* `kernel.run(args)` launches the compiled kernel function. This example uses:\n",
"#### Launching the compiled kernel function\n",
"`kernel.run(args)` launches the compiled kernel function. The next example uses:\n",
" * the precompiled artifact\n",
" * a custom stream to launch to\n",
" * bypasses the supports check already performed above (`assume_supported_args=True`)."
@@ -386,6 +373,24 @@
"torch.testing.assert_close(out, reference)"
]
},
{
"cell_type": "markdown",
"id": "7956813d",
"metadata": {},
"source": [
"Passing in a precompiled kernel is critical to achieving good performance because it avoids\n",
"JIT compiling the kernel on each invocation. JIT compilation always occurs when a precompiled\n",
"kernel is not provided in the call to `kernel.run()`."
]
},
{
"cell_type": "markdown",
"id": "c228495a",
"metadata": {},
"source": [
"#### Workspace Buffers"
]
},
{
"cell_type": "markdown",
"id": "f67eeb8f",
@@ -416,7 +421,7 @@
"id": "baffaf12",
"metadata": {},
"source": [
"### Advanced: Filtering on Metadata"
"## Advanced: Filtering on Metadata"
]
},
{

View File

@@ -12,7 +12,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 1,
"id": "bb450878",
"metadata": {},
"outputs": [],
@@ -20,7 +20,9 @@
"import cutlass_api\n",
"\n",
"if not (status := cutlass_api.utils.is_device_cc_supported({100, 103})):\n",
" print(f\"This notebook requires a GPU with compute capability 100 or 103.\\n{status.error}\")\n",
" print(\n",
" f\"This notebook requires a GPU with compute capability 100 or 103.\\n{status.error}\"\n",
" )\n",
" import sys\n",
"\n",
" sys.exit(0)"
@@ -42,7 +44,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 2,
"id": "e6d77d53",
"metadata": {},
"outputs": [],
@@ -56,13 +58,15 @@
"B = torch.randn(L, K, N, device=\"cuda\", dtype=torch.float16)\n",
"C = torch.randn(L, M, N, device=\"cuda\", dtype=torch.float16)\n",
"\n",
"\n",
"def my_epilogue(accum, C, alpha, beta, extra_scalar):\n",
" Aux = (alpha * accum) + (beta * C)\n",
" D = extra_scalar * Aux\n",
" return D, Aux\n",
"\n",
"\n",
"alpha, beta, extra_scalar = 1.0, 2.0, 0.5\n",
"D, Aux = my_epilogue(A @ B, C, alpha, beta, extra_scalar)\n"
"D, Aux = my_epilogue(A @ B, C, alpha, beta, extra_scalar)"
]
},
{
@@ -77,18 +81,28 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 3,
"id": "f079d9d6",
"metadata": {},
"outputs": [],
"source": [
"import cutlass_api\n",
"from cutlass_api.arguments import GemmArguments, EpilogueArguments\n",
"from cutlass_api.arguments import EpilogueArguments, GemmArguments\n",
"\n",
"# Allocate buffers for D and Aux\n",
"D_, Aux_ = [torch.empty((L, M, N), device=\"cuda\", dtype=torch.float16) for _ in range(2)]\n",
"D_, Aux_ = [\n",
" torch.empty((L, M, N), device=\"cuda\", dtype=torch.float16) for _ in range(2)\n",
"]\n",
"\n",
"epi_args = EpilogueArguments(my_epilogue, C=C, alpha=alpha, beta=beta, extra_scalar=extra_scalar, D=D_, Aux=Aux_)\n"
"epi_args = EpilogueArguments(\n",
" my_epilogue,\n",
" C=C,\n",
" alpha=alpha,\n",
" beta=beta,\n",
" extra_scalar=extra_scalar,\n",
" D=D_,\n",
" Aux=Aux_,\n",
")"
]
},
{
@@ -101,14 +115,15 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 4,
"id": "60215c4e",
"metadata": {},
"outputs": [],
"source": [
"args = GemmArguments(A=A, B=B, out=D, accumulator_type=torch.float32, epilogue=epi_args)\n",
"kernels = cutlass_api.get_kernels(args)\n",
"assert len(kernels) > 0\n"
"cc = cutlass_api.utils.device_cc()\n",
"kernels = cutlass_api.get_kernels(args, cc=cc)\n",
"assert len(kernels) > 0"
]
},
{
@@ -122,7 +137,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 5,
"id": "150f3296",
"metadata": {},
"outputs": [],
@@ -130,7 +145,7 @@
"kernels[0].run(args)\n",
"\n",
"torch.testing.assert_close(D, D_)\n",
"torch.testing.assert_close(Aux, Aux_)\n"
"torch.testing.assert_close(Aux, Aux_)"
]
},
{
@@ -289,17 +304,19 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 6,
"id": "171ac178",
"metadata": {},
"outputs": [],
"source": [
"from cutlass_api.fusion.activation import relu\n",
"\n",
"\n",
"def relu_aux_store(accum, alpha, C):\n",
" F = (accum * alpha) + (C * 2.0) # Constant beta\n",
" D = relu(F)\n",
" return D, F\n",
" F = (accum * alpha) + (C * 2.0) # Constant beta\n",
" D = relu(F)\n",
" return D, F\n",
"\n",
"\n",
"C = torch.randn((L, M, N), device=\"cuda\", dtype=torch.float16)\n",
"alpha = 3.0\n",
@@ -308,14 +325,14 @@
"\n",
"epi_args = EpilogueArguments(relu_aux_store, alpha=alpha, C=C, D=D, F=F)\n",
"args = GemmArguments(A=A, B=B, out=D, accumulator_type=torch.float32, epilogue=epi_args)\n",
"kernels = cutlass_api.get_kernels(args, cc=100)\n",
"kernels = cutlass_api.get_kernels(args, cc=cc)\n",
"assert len(kernels) > 0\n",
"kernels[0].run(args)\n",
"\n",
"D_ref, F_ref = relu_aux_store(A @ B, alpha, C)\n",
"\n",
"torch.testing.assert_close(D, D_ref)\n",
"torch.testing.assert_close(F, F_ref)\n"
"torch.testing.assert_close(F, F_ref)"
]
},
{
@@ -328,15 +345,16 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 7,
"id": "62c2b49b",
"metadata": {},
"outputs": [],
"source": [
"def relu_scale_return_acc(accum, alpha, beta, C, scale):\n",
" F = relu((accum * alpha) + (C * beta))\n",
" D = F * scale\n",
" return D, F, accum\n",
" F = relu((accum * alpha) + (C * beta))\n",
" D = F * scale\n",
" return D, F, accum\n",
"\n",
"\n",
"C = torch.randn((L, M, N), device=\"cuda\", dtype=torch.float16)\n",
"alpha = 1.0\n",
@@ -346,9 +364,18 @@
"F = torch.empty((L, M, N), device=\"cuda\", dtype=torch.float16)\n",
"accum = torch.empty((L, M, N), device=\"cuda\", dtype=torch.float32)\n",
"\n",
"epi_args = EpilogueArguments(relu_scale_return_acc, alpha=alpha, beta=beta, C=C, scale=scale, D=D, F=F, accum=accum)\n",
"epi_args = EpilogueArguments(\n",
" relu_scale_return_acc,\n",
" alpha=alpha,\n",
" beta=beta,\n",
" C=C,\n",
" scale=scale,\n",
" D=D,\n",
" F=F,\n",
" accum=accum,\n",
")\n",
"args = GemmArguments(A=A, B=B, out=D, accumulator_type=torch.float32, epilogue=epi_args)\n",
"kernels = cutlass_api.get_kernels(args, cc=100)\n",
"kernels = cutlass_api.get_kernels(args, cc=cc)\n",
"assert len(kernels) > 0\n",
"kernels[0].run(args)\n",
"\n",
@@ -356,7 +383,7 @@
"\n",
"torch.testing.assert_close(D, D_ref)\n",
"torch.testing.assert_close(F, F_ref)\n",
"torch.testing.assert_close(accum, accum_ref.to(accum.dtype))\n"
"torch.testing.assert_close(accum, accum_ref.to(accum.dtype))"
]
},
{
@@ -370,7 +397,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 8,
"id": "5987bf44",
"metadata": {},
"outputs": [],
@@ -385,7 +412,7 @@
"\n",
"epi_args = EpilogueArguments(epi_str, alpha=alpha, beta=beta, C=C, D=D, F=F)\n",
"args = GemmArguments(A=A, B=B, out=D, accumulator_type=torch.float32, epilogue=epi_args)\n",
"kernels = cutlass_api.get_kernels(args, cc=100)\n",
"kernels = cutlass_api.get_kernels(args, cc=cc)\n",
"assert len(kernels) > 0\n",
"kernels[0].run(args)\n",
"\n",
@@ -393,7 +420,7 @@
"D_ref = torch.relu(F_ref)\n",
"\n",
"torch.testing.assert_close(D, D_ref)\n",
"torch.testing.assert_close(F, F_ref)\n"
"torch.testing.assert_close(F, F_ref)"
]
},
{
@@ -407,54 +434,90 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 9,
"id": "1e3d0c89",
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"accum must be an input to the epilogue function\n"
]
}
],
"source": [
"####################################################\n",
"# Epilogues must take in an accumulator\n",
"####################################################\n",
"def fail_missing_accum(alpha, beta, C):\n",
" D = (C * beta)\n",
" return D\n",
" D = C * beta\n",
" return D\n",
"\n",
"\n",
"try:\n",
" epi_args = EpilogueArguments(fail_missing_accum, alpha=alpha, beta=beta, C=C, D=D)\n",
" args = GemmArguments(A=A, B=B, out=D, accumulator_type=torch.float32, epilogue=epi_args)\n",
" epi_args = EpilogueArguments(\n",
" fail_missing_accum,\n",
" alpha=alpha,\n",
" beta=beta,\n",
" C=C,\n",
" D=D,\n",
" )\n",
" args = GemmArguments(\n",
" A=A, B=B, out=D, accumulator_type=torch.float32, epilogue=epi_args\n",
" )\n",
"except Exception as e:\n",
" # \"accum must be an input to the epilogue function\"\n",
" print(e)\n"
" # \"accum must be an input to the epilogue function\"\n",
" print(e)"
]
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 10,
"id": "48a359f7",
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Output node D is not found in the epilogue function\n"
]
}
],
"source": [
"####################################################\n",
"# Epilogues must return an output named D\n",
"####################################################\n",
"def fail_missing_D(accum, alpha, beta, C):\n",
" F = (accum * alpha) + (C * beta)\n",
" return F\n",
" F = (accum * alpha) + (C * beta)\n",
" return F\n",
"\n",
"\n",
"try:\n",
" epi_args = EpilogueArguments(fail_missing_D, alpha=alpha, beta=beta, C=C, F=F)\n",
" args = GemmArguments(A=A, B=B, out=D, accumulator_type=torch.float32, epilogue=epi_args)\n",
" epi_args = EpilogueArguments(fail_missing_D, alpha=alpha, beta=beta, C=C, F=F)\n",
" args = GemmArguments(\n",
" A=A, B=B, out=D, accumulator_type=torch.float32, epilogue=epi_args\n",
" )\n",
"except Exception as e:\n",
" # \"On SM90 or higher, D is expected to be a output node with 0 users to enable smem reuse between C and D, but got []\"\n",
" print(e)\n"
" # \"Output node D is not found in the epilogue function\n",
" print(e)"
]
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 11,
"id": "49d9ee94",
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Variable 'tmp' cannot be defined twice.\n"
]
}
],
"source": [
"####################################################\n",
"# Epilogues must use single-static assignment (SSA)\n",
@@ -466,51 +529,81 @@
" D = tmp / 4.0\n",
" return D, tmp\n",
"\n",
"\n",
"try:\n",
" epi_args = EpilogueArguments(fail_ssa, D=D, tmp=F)\n",
" args = GemmArguments(A=A, B=B, out=D, accumulator_type=torch.float32, epilogue=epi_args)\n",
" epi_args = EpilogueArguments(fail_ssa, D=D, tmp=F)\n",
" args = GemmArguments(\n",
" A=A, B=B, out=D, accumulator_type=torch.float32, epilogue=epi_args\n",
" )\n",
"except Exception as e:\n",
" # \"Variable 'tmp' cannot be defined twice.\"\n",
" print(e)\n"
" # \"Variable 'tmp' cannot be defined twice.\"\n",
" print(e)"
]
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 12,
"id": "871bb727",
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Argument D is not provided in the kwargs of the EpilogueArguments constructor\n",
"Argument alpha is not provided in the kwargs of the EpilogueArguments constructor\n"
]
}
],
"source": [
"####################################################\n",
"# Must provide all operands and outputs to\n",
"# EpilogueArguments\n",
"####################################################\n",
"def my_epi(accum, alpha, beta, C):\n",
" F = (accum * alpha) + (C * beta)\n",
" D = relu(F)\n",
" return D\n",
" F = (accum * alpha) + (C * beta)\n",
" D = relu(F)\n",
" return D\n",
"\n",
"try:\n",
" # Missing D\n",
" epi_args = EpilogueArguments(my_epi, alpha=alpha, beta=beta, C=C)\n",
" args = GemmArguments(A=A, B=B, out=D, accumulator_type=torch.float32, epilogue=epi_args)\n",
" # Missing D\n",
" epi_args = EpilogueArguments(my_epi, alpha=alpha, beta=beta, C=C)\n",
" args = GemmArguments(\n",
" A=A, B=B, out=D, accumulator_type=torch.float32, epilogue=epi_args\n",
" )\n",
"except Exception as e:\n",
" # \"Argument D is not provided in the kwargs of the EpilogueArguments constructor\"\n",
" print(e)\n",
" # \"Argument D is not provided in the kwargs of the EpilogueArguments constructor\"\n",
" print(e)\n",
"\n",
"try:\n",
" # Missing alpha\n",
" epi_args = EpilogueArguments(my_epi, beta=beta, C=C, D=D)\n",
" args = GemmArguments(A=A, B=B, out=D, accumulator_type=torch.float32, epilogue=epi_args)\n",
" # Missing alpha\n",
" epi_args = EpilogueArguments(my_epi, beta=beta, C=C, D=D)\n",
" args = GemmArguments(\n",
" A=A, B=B, out=D, accumulator_type=torch.float32, epilogue=epi_args\n",
" )\n",
"except Exception as e:\n",
" # \"Argument alpha is not provided in the kwargs of the EpilogueArguments constructor\"\n",
" print(e)\n"
" # \"Argument alpha is not provided in the kwargs of the EpilogueArguments constructor\"\n",
" print(e)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"name": "python"
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.5"
}
},
"nbformat": 4,

View File

@@ -24,12 +24,12 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 1,
"id": "5a64b0be",
"metadata": {},
"outputs": [],
"source": [
"from typing import Callable\n",
"from collections.abc import Callable\n",
"\n",
"import cuda.bindings.driver as cuda\n",
"\n",
@@ -110,7 +110,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 2,
"id": "1a2da869",
"metadata": {},
"outputs": [],
@@ -148,18 +148,34 @@
"class F64GemmKernel(cutlass_api.providers.cutedsl.kernel.CuteDslKernel):\n",
" # Empty versions of interface methods. These will be implemented later, interspersed\n",
" # with notebook markdown. Normally, one would define them inline with the class definition.\n",
" def __init__(self, metadata: KernelMetadata): pass\n",
" def __init__(self, metadata: KernelMetadata):\n",
" pass\n",
"\n",
" def _run(self, args: GemmArguments, artifact: cutlass_api.artifact.CompiledArtifact, stream, workspace=None): pass\n",
" def _run(\n",
" self,\n",
" args: GemmArguments,\n",
" artifact: cutlass_api.artifact.CompiledArtifact,\n",
" stream,\n",
" workspace=None,\n",
" ):\n",
" pass\n",
"\n",
" def compile(self, args: GemmArguments, cc: int = None) -> cutlass_api.artifact.CompiledArtifact: pass\n",
" def compile(\n",
" self, args: GemmArguments, cc: int = None\n",
" ) -> cutlass_api.artifact.CompiledArtifact:\n",
" pass\n",
"\n",
" @staticmethod\n",
" def generate_kernels(metadata_filter, epilogue_args=None, cc=None) -> list[\"F64GemmKernel\"]: pass\n",
" def generate_kernels(\n",
" metadata_filter, epilogue_args=None, cc=None\n",
" ) -> list[\"F64GemmKernel\"]:\n",
" pass\n",
"\n",
" def _supports(self, args: GemmArguments) -> Status: pass\n",
" def _supports(self, args: GemmArguments) -> Status:\n",
" pass\n",
"\n",
" def get_workspace_size(self, args: GemmArguments) -> int: pass"
" def get_workspace_size(self, args: GemmArguments) -> int:\n",
" pass"
]
},
{
@@ -175,13 +191,14 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 4,
"id": "785d1882",
"metadata": {},
"outputs": [],
"source": [
"def __init__(self, metadata: KernelMetadata):\n",
" self.metadata = metadata\n",
" # Using Python-2-style super() because we're defining this method outside of the class definition.\n",
" super(F64GemmKernel, self).__init__(metadata)\n",
" cta_tile_shape_mn = metadata.design.tile_shape[:2]\n",
" self.impl = F64GemmKernelImplementation(cta_tile_shape_mn)"
]
@@ -203,12 +220,14 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 5,
"id": "63b4a129",
"metadata": {},
"outputs": [],
"source": [
"def compile(self, args: GemmArguments, cc: int = None) -> cutlass_api.artifact.CompiledArtifact:\n",
"def compile(\n",
" self, args: GemmArguments, cc: int = None\n",
") -> cutlass_api.artifact.CompiledArtifact:\n",
" stream = cutlass.cute.runtime.make_fake_stream()\n",
" compiled_gemm = self.cute_compile(self.impl, args.A, args.B, args.out, stream)\n",
" return cutlass_api.artifact.CompiledArtifact(compiled_gemm, self)"
@@ -227,12 +246,18 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 6,
"id": "2ae7c009",
"metadata": {},
"outputs": [],
"source": [
"def _run(self, args: GemmArguments, artifact: cutlass_api.artifact.CompiledArtifact, stream, workspace=None):\n",
"def _run(\n",
" self,\n",
" args: GemmArguments,\n",
" artifact: cutlass_api.artifact.CompiledArtifact,\n",
" stream,\n",
" workspace=None,\n",
"):\n",
" stream = cutlass_api.utils.to_cuda_stream(stream)\n",
" compiled_gemm = artifact.compiled_obj\n",
" self.cute_run(compiled_gemm, args.A, args.B, args.out, stream)"
@@ -249,7 +274,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 7,
"id": "968906ea",
"metadata": {},
"outputs": [],
@@ -286,7 +311,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 8,
"id": "47dc2f20",
"metadata": {},
"outputs": [],
@@ -297,7 +322,6 @@
" epilogue_args: cutlass_api.arguments.EpilogueArguments = None,\n",
" cc: int = None,\n",
") -> list[\"F64GemmKernel\"]:\n",
"\n",
" # The tile shapes this kernel supports/exposes\n",
" supported_tile_shapes = [(32, 32, 1), (16, 16, 1)]\n",
"\n",
@@ -306,10 +330,12 @@
"\n",
" row_major_stride = (0, 0, 1)\n",
" col_major_stride = (0, 1, 0)\n",
" stride_combos = list(itertools.product([row_major_stride, col_major_stride], repeat=3))\n",
" alignment = 1\n",
" stride_combos = list(\n",
" itertools.product([row_major_stride, col_major_stride], repeat=3)\n",
" )\n",
" divisibility = 1\n",
"\n",
" def stride_name(stride): \n",
" def stride_name(stride):\n",
" return \"T\" if stride == row_major_stride else \"N\"\n",
"\n",
" kernels = []\n",
@@ -317,10 +343,18 @@
" design_metadata = cutlass_api.metadata.BLASDesignMetadata(tile_shape, (1, 1, 1))\n",
" for stride_A, stride_B, stride_out in stride_combos:\n",
" # Create TensorAttributes for A, B, and out tensors\n",
" a_attrs = cutlass_api.metadata.TensorAttributes(cutlass.Float64, stride_A, alignment)\n",
" b_attrs = cutlass_api.metadata.TensorAttributes(cutlass.Float64, stride_B, alignment)\n",
" out_attrs = cutlass_api.metadata.TensorAttributes(cutlass.Float64, stride_out, alignment)\n",
" layout_str = cutlass_api.utils.strides_to_layout_string(stride_A, stride_B, stride_out)\n",
" a_attrs = cutlass_api.metadata.TensorAttributes(\n",
" cutlass.Float64, stride_A, divisibility\n",
" )\n",
" b_attrs = cutlass_api.metadata.TensorAttributes(\n",
" cutlass.Float64, stride_B, divisibility\n",
" )\n",
" out_attrs = cutlass_api.metadata.TensorAttributes(\n",
" cutlass.Float64, stride_out, divisibility\n",
" )\n",
" layout_str = cutlass_api.utils.strides_to_layout_string(\n",
" stride_A, stride_B, stride_out\n",
" )\n",
"\n",
" name = f\"F64GemmKernel_tile{tile_shape[0]}x{tile_shape[1]}_{layout_str}\"\n",
"\n",
@@ -355,24 +389,24 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 9,
"id": "54067d47",
"metadata": {},
"outputs": [],
"source": [
"def _supports(self, args: GemmArguments) -> Status:\n",
" if not (\n",
" len(args.A.shape) == 3 and # A should be (L, M, K)\n",
" len(args.B.shape) == 3 and # B should be (L, K, N)\n",
" len(args.out.shape) == 3 # out should be (L, M, N)\n",
" len(args.A.shape) == 3 # A should be (L, M, K)\n",
" and len(args.B.shape) == 3 # B should be (L, K, N)\n",
" and len(args.out.shape) == 3 # out should be (L, M, N)\n",
" ):\n",
" return Status.fail(\"All operands must be rank 3.\")\n",
" return Status.success()\n"
" return Status.success()"
]
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 10,
"id": "edaf2cba",
"metadata": {},
"outputs": [],
@@ -404,7 +438,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 11,
"id": "cec5431d",
"metadata": {},
"outputs": [],
@@ -420,9 +454,11 @@
"\n",
"args = GemmArguments(A, B, out, accumulator_type=torch.float64)\n",
"\n",
"\n",
"def is_f64gemm_kernel(metadata):\n",
" return metadata.kernel_class == F64GemmKernel\n",
"\n",
"\n",
"kernels = cutlass_api.get_kernels(args, metadata_filter=is_f64gemm_kernel)"
]
},
@@ -437,10 +473,19 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 12,
"id": "cdb92b5e",
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"F64GemmKernel_tile32x32_ttt\n",
"F64GemmKernel_tile16x16_ttt\n"
]
}
],
"source": [
"print(kernels[0].metadata.kernel_name)\n",
"print(kernels[1].metadata.kernel_name)"
@@ -456,7 +501,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 13,
"id": "f5486244",
"metadata": {},
"outputs": [],
@@ -478,17 +523,19 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 14,
"id": "917c74e3",
"metadata": {},
"outputs": [],
"source": [
"def my_filter(metadata):\n",
" return (\n",
" is_f64gemm_kernel(metadata) and\n",
" isinstance(metadata.design, cutlass_api.metadata.BLASDesignMetadata) and\n",
" metadata.design.tile_shape[0] == 256\n",
" is_f64gemm_kernel(metadata)\n",
" and isinstance(metadata.design, cutlass_api.metadata.BLASDesignMetadata)\n",
" and metadata.design.tile_shape[0] == 256\n",
" )\n",
"\n",
"\n",
"kernels_ctam256 = cutlass_api.get_kernels(args, metadata_filter=my_filter)\n",
"\n",
"# No kernels should be found\n",
@@ -539,8 +586,22 @@
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"name": "python"
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.5"
}
},
"nbformat": 4,

View File

@@ -17,12 +17,15 @@
"This notebook focuses on the latter: techniques to minimize any overheads incurred from the CUTLASS API and underlying\n",
"DSL runtimes.\n",
"\n",
"This notebook does not discuss techniques for improving device-side performance. A future notebook may cover this topic."
"This notebook does not discuss techniques for improving device-side performance. A future notebook may cover this topic.\n",
"\n",
"**Note**: Latency measurements can vary from system to system. You may see different results on your system than shown\n",
"in the pre-populated fields of this notebook."
]
},
{
"cell_type": "code",
"execution_count": 1,
"execution_count": null,
"id": "e3ca9e40",
"metadata": {},
"outputs": [],
@@ -34,13 +37,15 @@
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": null,
"id": "efaac09c",
"metadata": {},
"outputs": [],
"source": [
"if not (status := cutlass_api.utils.is_device_cc_supported({100, 103})):\n",
" print(f\"This notebook requires a GPU with compute capability 100 or 103.\\n{status.error}\")\n",
"if not (status := cutlass_api.utils.is_device_cc_supported({80, 89, 90, 100, 103})):\n",
" print(\n",
" f\"This notebook requires a GPU with compute capability >= 80.\\n{status.error}\"\n",
" )\n",
" import sys\n",
"\n",
" sys.exit(0)"
@@ -61,7 +66,7 @@
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": null,
"id": "b8c44947",
"metadata": {},
"outputs": [],
@@ -76,19 +81,34 @@
"# We use different operands in each iteration. Though not particularly relevant for\n",
"# host latency, this is a best practice when benchmarking GPU kernels to avoid\n",
"# unrealistic caching effects.\n",
"As = [torch.randint(-1, 2, (M, K), device=\"cuda\", dtype=torch.float16) for _ in range(total_iterations)]\n",
"Bs = [torch.randint(-1, 2, (K, N), device=\"cuda\", dtype=torch.float16) for _ in range(total_iterations)]\n",
"outs = [torch.empty((M, N), device=\"cuda\", dtype=torch.float16) for _ in range(total_iterations)]\n",
"As = [\n",
" torch.randint(-1, 2, (M, K), device=\"cuda\", dtype=torch.float16)\n",
" for _ in range(total_iterations)\n",
"]\n",
"Bs = [\n",
" torch.randint(-1, 2, (K, N), device=\"cuda\", dtype=torch.float16)\n",
" for _ in range(total_iterations)\n",
"]\n",
"outs = [\n",
" torch.empty((M, N), device=\"cuda\", dtype=torch.float16)\n",
" for _ in range(total_iterations)\n",
"]\n",
"\n",
"# Construct arguments outside of the benchmarking loop. We will later also consider\n",
"# cases in which they are constructed inside the benchmarking loop.\n",
"args = [cutlass_api.arguments.GemmArguments(A=As[i], B=Bs[i], out=outs[i], accumulator_type=torch.float16) for i in range(total_iterations)]\n",
"args = [\n",
" cutlass_api.arguments.GemmArguments(\n",
" A=As[i], B=Bs[i], out=outs[i], accumulator_type=torch.float32\n",
" )\n",
" for i in range(total_iterations)\n",
"]\n",
"\n",
"references = [(As[i] @ Bs[i]).to(outs[i].dtype) for i in range(total_iterations)]\n",
"\n",
"kernels = cutlass_api.get_kernels(args[0], cc=100)\n",
"\n",
"cc = cutlass_api.utils.device_cc()\n",
"kernels = cutlass_api.get_kernels(args[0], cc=cc)\n",
"assert len(kernels) > 0\n",
"\n",
"kernel = kernels[0]"
]
},
@@ -102,14 +122,18 @@
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": null,
"id": "2472eafa",
"metadata": {},
"outputs": [],
"source": [
"def benchmark(label, code, warmup_it=warmup_iterations, profiling_it=profiling_iterations):\n",
"def benchmark(\n",
" label, code, warmup_it=warmup_iterations, profiling_it=profiling_iterations\n",
"):\n",
" total_it = warmup_it + profiling_it\n",
" assert total_it <= total_iterations, f\"Benchmark-local iteration count must be less than or equal to total iterations: {total_it} > {total_iterations}\"\n",
" assert total_it <= total_iterations, (\n",
" f\"Benchmark-local iteration count must be less than or equal to total iterations: {total_it} > {total_iterations}\"\n",
" )\n",
" # warmup\n",
" rets = [None] * total_it\n",
" for i in range(warmup_it):\n",
@@ -151,7 +175,7 @@
"### Compile once, run many times\n",
"The `kernel.run` method takes in an optional `compiled_artifact` argument of type\n",
"`cutlass_api.artifact.CompiledArtifact`. When this argument is set, the kernel\n",
"will directly use the precompiled function within `compiled_argument`. When\n",
"will directly use the precompiled function within `compiled_artifact`. When\n",
"it is not set, the call to `kernel.run` will JIT compile the kernel on each\n",
"invocation.\n",
"\n",
@@ -160,25 +184,29 @@
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": null,
"id": "6de11f56",
"metadata": {},
"outputs": [],
"source": [
"stream = torch.cuda.current_stream()\n",
"\n",
"\n",
"def no_compiled_artifact(i: int):\n",
" return kernel.run(args[i], stream=stream)\n",
"\n",
"\n",
"# Compile the kernel once, reuse for each iterations\n",
"compiled_artifact = kernel.compile(args[0])\n",
"\n",
"\n",
"def with_compiled_artifact(i: int):\n",
" return kernel.run(args[i], stream=stream, compiled_artifact=compiled_artifact)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": null,
"id": "350c9bd6",
"metadata": {},
"outputs": [
@@ -192,8 +220,12 @@
}
],
"source": [
"time_no_artifact, _ = benchmark(f\"Without compiled artifact\", no_compiled_artifact, warmup_it=2, profiling_it=5)\n",
"time_w_artifact, _ = benchmark(f\"With compiled artifact\", with_compiled_artifact, warmup_it=2, profiling_it=5)"
"time_no_artifact, _ = benchmark(\n",
" f\"Without compiled artifact\", no_compiled_artifact, warmup_it=2, profiling_it=5\n",
")\n",
"time_w_artifact, _ = benchmark(\n",
" f\"With compiled artifact\", with_compiled_artifact, warmup_it=2, profiling_it=5\n",
")"
]
},
{
@@ -215,21 +247,32 @@
},
{
"cell_type": "code",
"execution_count": 7,
"execution_count": null,
"id": "5b93dfae",
"metadata": {},
"outputs": [],
"source": [
"def with_supports_check(i: int):\n",
" return kernel.run(args[i], compiled_artifact=compiled_artifact, stream=stream, assume_supported_args=False)\n",
" return kernel.run(\n",
" args[i],\n",
" compiled_artifact=compiled_artifact,\n",
" stream=stream,\n",
" assume_supported_args=False,\n",
" )\n",
"\n",
"\n",
"def without_supports_check(i: int):\n",
" return kernel.run(args[i], compiled_artifact=compiled_artifact, stream=stream, assume_supported_args=True)"
" return kernel.run(\n",
" args[i],\n",
" compiled_artifact=compiled_artifact,\n",
" stream=stream,\n",
" assume_supported_args=True,\n",
" )"
]
},
{
"cell_type": "code",
"execution_count": 8,
"execution_count": null,
"id": "b282f437",
"metadata": {},
"outputs": [
@@ -244,7 +287,7 @@
}
],
"source": [
"time_w_supports, _ = benchmark(\"With supports check\", with_supports_check)\n",
"time_w_supports, _ = benchmark(\"With supports check\", with_supports_check)\n",
"time_wo_supports, _ = benchmark(\"Bypass supports check\", without_supports_check)\n",
"print(f\"Speedup with skip supports: {time_w_supports / time_wo_supports:.2f}x\")"
]
@@ -262,14 +305,16 @@
"id": "656d5e2c",
"metadata": {},
"source": [
"CUTLASS API supports [CUDA Graphs](https://developer.nvidia.com/blog/cuda-graphs/) usage with PyTorch as usual.\n",
"[CUDA Graphs](https://developer.nvidia.com/blog/cuda-graphs/) allow a sequence of GPU operations to be defined as a dependency graph and then launched as a single unit, significantly reducing CPU launch overhead and enabling whole-graph optimizations.\n",
"\n",
"CUTLASS API supports CUDA Graphs usage with PyTorch as usual.\n",
"\n",
"The kernel compilation must happen outside the CUDA graph. Then, we create a graph using usual PyTorch idioms to launch a kernel several times on the graph's stream."
]
},
{
"cell_type": "code",
"execution_count": 9,
"execution_count": null,
"id": "e614509f",
"metadata": {},
"outputs": [],
@@ -279,6 +324,10 @@
"# Create a CUDA Graph to run our compiled kernel N times\n",
"g = torch.cuda.CUDAGraph()\n",
"with torch.cuda.graph(g):\n",
"\n",
" ### NOTE! Kernel compilation must happen outside the graph\n",
" ### kernel.compile(args)\n",
"\n",
" # Run N iterations of our compiled kernel on the current stream\n",
" for i in range(num_launches):\n",
" kernel.run(\n",
@@ -286,10 +335,7 @@
" compiled_artifact=compiled_artifact,\n",
" stream=torch.cuda.current_stream(),\n",
" assume_supported_args=True,\n",
" )\n",
"\n",
"# Zero the output so we don't refcheck stale results\n",
"_ = outs[0].zero_()"
" )"
]
},
{
@@ -297,8 +343,12 @@
"id": "8fc69c6e",
"metadata": {},
"source": [
"Once captured, we can replay the graph. This will only replay the kernel launches placed on the CUDA stream.\n",
"Any other prepratory work on the host and arguments passed in from python are cached during the capture."
"This records/captures all the kernel launches to the CUDA Stream associated with the graph `g`, without actually launching them.\n",
"Once captured, we can replay the graph.\n",
"\n",
"Note that graph replay will only replay the kernel launches placed on the graph's stream\n",
"* During graph capture, we must be careful to capture to the correct stream (`torch.cuda.current_stream()` under the graph context)\n",
"* Any other preparatory work on the host and arguments passed in from Python are cached during the capture. Changing them would require re-capturing the graph"
]
},
{
@@ -324,7 +374,7 @@
},
{
"cell_type": "code",
"execution_count": 11,
"execution_count": null,
"id": "45d4e739",
"metadata": {},
"outputs": [
@@ -348,12 +398,23 @@
" assume_supported_args=True,\n",
" )\n",
"\n",
"\n",
"def with_cuda_graph(x: int):\n",
" g.replay()\n",
"\n",
"\n",
"time_wo_cuda_graph, _ = benchmark(f\"{num_launches} launches without CUDA Graph\", without_cuda_graph, warmup_it=0, profiling_it=1)\n",
"time_w_cuda_graph, _ = benchmark(f\"{num_launches} launches with CUDA Graph\", with_cuda_graph, warmup_it=0, profiling_it=1)\n",
"time_wo_cuda_graph, _ = benchmark(\n",
" f\"{num_launches} launches without CUDA Graph\",\n",
" without_cuda_graph,\n",
" warmup_it=0,\n",
" profiling_it=1,\n",
")\n",
"time_w_cuda_graph, _ = benchmark(\n",
" f\"{num_launches} launches with CUDA Graph\",\n",
" with_cuda_graph,\n",
" warmup_it=0,\n",
" profiling_it=1,\n",
")\n",
"\n",
"print(f\"Speedup with CUDA Graph: {time_wo_cuda_graph / time_w_cuda_graph:.2f}x\")"
]
@@ -371,8 +432,8 @@
"id": "ee7f9fd2",
"metadata": {},
"source": [
"When applicable, CUTLASS API uses [Apache TVM FFI](https://tvm.apache.org/ffi/) under the hood for invoking compiled DSL kernels from Python.\n",
"Apache TVM FFI is an open ABI and FFI for machine learning systems.\n",
"[Apache TVM FFI](https://tvm.apache.org/ffi/) is an open ABI and FFI for machine learning systems.\n",
"When available, CUTLASS API uses Apache TVM-FFI under the hood as its interface for invoking compiled DSL kernels from Python.\n",
"\n",
"TVM FFI is enabled by default in CUTLASS API, and is recommended for best performance."
]
@@ -413,7 +474,7 @@
},
{
"cell_type": "code",
"execution_count": 13,
"execution_count": null,
"id": "e8f56be3",
"metadata": {},
"outputs": [
@@ -428,10 +489,15 @@
}
],
"source": [
"original_use_tvm_ffi = cutlass_api.config.GlobalOptions().use_tvm_ffi\n",
"\n",
"cutlass_api.config.GlobalOptions().use_tvm_ffi = True\n",
"\n",
"\n",
"def run_iteration(i):\n",
" args = cutlass_api.arguments.GemmArguments(A=As[i], B=Bs[i], out=outs[i], accumulator_type=torch.float16)\n",
" args = cutlass_api.arguments.GemmArguments(\n",
" A=As[i], B=Bs[i], out=outs[i], accumulator_type=torch.float16\n",
" )\n",
" return kernel.run(\n",
" args,\n",
" compiled_artifact=compiled_artifact,\n",
@@ -439,18 +505,35 @@
" assume_supported_args=True,\n",
" )\n",
"\n",
"\n",
"def create_arguments(i: int):\n",
" return cutlass_api.arguments.GemmArguments(A=As[i], B=Bs[i], out=outs[i], accumulator_type=torch.float16)\n",
" return cutlass_api.arguments.GemmArguments(\n",
" A=As[i], B=Bs[i], out=outs[i], accumulator_type=torch.float16\n",
" )\n",
"\n",
"\n",
"args_creation_on, args = benchmark(\"[TVM-FFI ON ] Create args\", create_arguments)\n",
"compilation_on, compiled = benchmark(\"[TVM-FFI ON ] Compile kernel\", lambda i: kernel.compile(args[i]), warmup_it=2, profiling_it=5)\n",
"compilation_on, compiled = benchmark(\n",
" \"[TVM-FFI ON ] Compile kernel\",\n",
" lambda i: kernel.compile(args[i]),\n",
" warmup_it=2,\n",
" profiling_it=5,\n",
")\n",
"compiled_artifact = compiled[0]\n",
"run_on, _ = benchmark(\"[TVM-FFI ON ] Run kernel\", lambda i: kernel.run(args[i], compiled_artifact=compiled_artifact, assume_supported_args=True, stream=stream))"
"run_on, _ = benchmark(\n",
" \"[TVM-FFI ON ] Run kernel\",\n",
" lambda i: kernel.run(\n",
" args[i],\n",
" compiled_artifact=compiled_artifact,\n",
" assume_supported_args=True,\n",
" stream=stream,\n",
" ),\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 14,
"execution_count": null,
"id": "5a4c2db4",
"metadata": {},
"outputs": [
@@ -467,9 +550,25 @@
"source": [
"cutlass_api.config.GlobalOptions().use_tvm_ffi = False\n",
"args_creation_off, args = benchmark(\"[TVM-FFI OFF ] Create args\", create_arguments)\n",
"compilation_off, compiled = benchmark(\"[TVM-FFI OFF ] Compile kernel\", lambda i: kernel.compile(args[i]), warmup_it=2, profiling_it=5)\n",
"compilation_off, compiled = benchmark(\n",
" \"[TVM-FFI OFF ] Compile kernel\",\n",
" lambda i: kernel.compile(args[i]),\n",
" warmup_it=2,\n",
" profiling_it=5,\n",
")\n",
"compiled_artifact = compiled[0]\n",
"run_off, _ = benchmark(\"[TVM-FFI OFF ] Run kernel\", lambda i: kernel.run(args[i], compiled_artifact=compiled_artifact, assume_supported_args=True, stream=stream))"
"run_off, _ = benchmark(\n",
" \"[TVM-FFI OFF ] Run kernel\",\n",
" lambda i: kernel.run(\n",
" args[i],\n",
" compiled_artifact=compiled_artifact,\n",
" assume_supported_args=True,\n",
" stream=stream,\n",
" ),\n",
")\n",
"\n",
"# Restore original setting\n",
"cutlass_api.config.GlobalOptions().use_tvm_ffi = original_use_tvm_ffi"
]
},
{

View File

@@ -0,0 +1,166 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "4620d513",
"metadata": {},
"source": [
"# Using fake tensors with the CUTLASS API\n",
"Fake tensors (e.g., [torch's FakeTensor](https://docs.pytorch.org/docs/2.8/torch.compiler_fake_tensor.html))\n",
"are useful for describing the properties of a tensor without actually allocating backing data.\n",
"\n",
"This example shows how fake tensors can be used within the CUTLASS API\n",
"for discovering and compiling a GEMM kernel."
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "d231b32e",
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"\n",
"import cutlass_api\n",
"\n",
"torch.manual_seed(2025)\n",
"\n",
"if not (status := cutlass_api.utils.is_device_cc_supported({80, 89, 90, 100, 103})):\n",
" print(f\"This notebook requires a GPU with compute capability >= 80.\\n{status.error}\")\n",
" import sys\n",
" sys.exit(0)"
]
},
{
"cell_type": "markdown",
"id": "f7af2d90",
"metadata": {},
"source": [
"We first set up operands `A`, `B`, and `out` in torch's `FakeTensorMode`.\n",
"These will have all the properties needed for CUTLASS API to construct\n",
"the internal representations of tensors used for discovering and compiling\n",
"kernels."
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "9426b66f",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"FakeTensor(..., device='cuda:0', size=(128, 512), dtype=torch.float16)\n",
"FakeTensor(..., device='cuda:0', size=(512, 256), dtype=torch.float16)\n",
"FakeTensor(..., device='cuda:0', size=(128, 256), dtype=torch.float16)\n"
]
}
],
"source": [
"M, N, K = 128, 256, 512\n",
"\n",
"with torch._subclasses.fake_tensor.FakeTensorMode():\n",
" A_fake = torch.randn(M, K, device=\"cuda\", dtype=torch.float16)\n",
" B_fake = torch.randn(K, N, device=\"cuda\", dtype=torch.float16)\n",
" out_fake = torch.empty(M, N, device=\"cuda\", dtype=torch.float16)\n",
"\n",
"print(A_fake)\n",
"print(B_fake)\n",
"print(out_fake)"
]
},
{
"cell_type": "markdown",
"id": "4f540c78",
"metadata": {},
"source": [
"We can now use these fake tensors to create `GemmArguments`, and use\n",
"these to discover and compile a compatible kernel. Note that the same APIs are\n",
"used in creating `GemmArguments` as would be used if using\n",
"\"real\" tensors."
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "e32b700d",
"metadata": {},
"outputs": [],
"source": [
"args_fake = cutlass_api.arguments.GemmArguments(\n",
" A_fake, B_fake, out_fake, accumulator_type=torch.float32)\n",
"\n",
"cc = cutlass_api.utils.device_cc()\n",
"kernels = cutlass_api.get_kernels(args_fake, cc=cc)\n",
"assert len(kernels) > 0\n",
"\n",
"kernel = kernels[0]\n",
"compiled_artifact = kernel.compile(args_fake)"
]
},
{
"cell_type": "markdown",
"id": "07fff511",
"metadata": {},
"source": [
"The `kernel` and `compiled_artifact` discovered using fake tensors\n",
"above can now used for running the kernel using real tensors."
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "b3034bf1",
"metadata": {},
"outputs": [],
"source": [
"# Create real tensors\n",
"A_real = torch.randn(M, K, device=\"cuda\", dtype=torch.float16)\n",
"B_real = torch.randn(K, N, device=\"cuda\", dtype=torch.float16)\n",
"out_real = torch.empty(M, N, device=\"cuda\", dtype=torch.float16)\n",
"\n",
"args_real = cutlass_api.arguments.GemmArguments(\n",
" A_real, B_real, out_real, accumulator_type=torch.float32)\n",
"\n",
"# Run the kernel using the compiled_artifact from resulting\n",
"# from compiling with fake tensors.\n",
"kernel.run(args_real, compiled_artifact)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "09871eca",
"metadata": {},
"outputs": [],
"source": [
"ref = A_real @ B_real\n",
"torch.testing.assert_close(out_real, ref)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.5"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

View File

@@ -22,6 +22,8 @@ torch = [
]
test = [
"jupyter",
"nbconvert",
"nbformat",
"pytest",
"cutlass_api[torch]",
]

View File

@@ -0,0 +1,72 @@
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
import pytest
import cutlass_api
from cutlass_api.config import GlobalOptions
#
# Before each test, save the GlobalOptions dict
# After each test, restore the GlobalOptions dict
#
global_options = {}
def setup_function():
global global_options
GlobalOptions().save(global_options)
def teardown_function():
global global_options
GlobalOptions().restore(global_options)
#
# Fixtures for toggling the TVM FFI global option and forcing its enablement or disablement
#
@pytest.fixture(
params=[True, False], ids=["use_tvm_ffi=True", f"use_tvm_ffi=False"], autouse=False
)
def fixture_toggle_tvm_ffi(request):
GlobalOptions().use_tvm_ffi = request.param
@pytest.fixture(autouse=False)
def fixture_enable_tvm_ffi(request):
GlobalOptions().use_tvm_ffi = True
@pytest.fixture(autouse=False)
def fixture_disable_tvm_ffi(request):
GlobalOptions().use_tvm_ffi = False

View File

@@ -68,6 +68,7 @@ def test_gemm_sm100(
c_dtype: torch.dtype,
accumulator_type: torch.dtype,
n_iterations: int,
fixture_toggle_tvm_ffi,
):
A = torch.randint(-1, 2, (M, K), device="cuda").to(ab_dtype)
B = torch.randint(-1, 2, (K, N), device="cuda").to(ab_dtype)

View File

@@ -32,6 +32,7 @@ import torch
import cutlass_api
from cutlass_api.utils import device_cc
from cutlass_api.config import GlobalOptions
@pytest.mark.parametrize(
@@ -48,11 +49,7 @@ from cutlass_api.utils import device_cc
torch.float16,
],
)
def test_elementwise_add(
M: int,
N: int,
dtype: torch.dtype,
):
def test_elementwise_add(M: int, N: int, dtype: torch.dtype, fixture_toggle_tvm_ffi):
A = torch.randint(-1, 2, (M, N), device="cuda", dtype=dtype)
B = torch.randint(-1, 2, (M, N), device="cuda", dtype=dtype)
D = torch.empty((M, N), device="cuda", dtype=dtype)

View File

@@ -38,7 +38,6 @@ import torch
import cutlass
import cutlass_api
from cutlass_api.config import GlobalOptions
from cutlass_api.utils import is_device_cc_supported
torch.manual_seed(2025)
@@ -62,8 +61,17 @@ logger = logging.getLogger(__name__)
],
)
@pytest.mark.parametrize(
"use_tvm_ffi",
[True, False],
"a_major, b_major, c_major",
[
("k", "k", "n"),
("k", "k", "m"),
("k", "n", "m"),
("k", "n", "n"),
("m", "k", "n"),
("m", "k", "m"),
("m", "n", "m"),
("m", "n", "n"),
],
)
@pytest.mark.skipif(
not is_device_cc_supported({100})
@@ -71,20 +79,29 @@ logger = logging.getLogger(__name__)
reason="Requires compute capability 100 and to be compiled with sm_100a or sm_100f",
)
def test_gemm_sm100(
M: int,
N: int,
K: int,
L: int,
M: int, N: int, K: int, L: int,
ab_dtype: torch.dtype,
c_dtype: torch.dtype,
accumulator_type: torch.dtype,
use_tvm_ffi: bool,
a_major: str,
b_major: str,
c_major: str,
fixture_toggle_tvm_ffi,
):
A = torch.randint(-1, 2, (L, M, K), device="cuda", dtype=ab_dtype)
B = torch.randint(-1, 2, (L, K, N), device="cuda", dtype=ab_dtype)
D = torch.empty((L, M, N), device="cuda", dtype=c_dtype)
if a_major == "k":
A = torch.randint(-1, 2, (L, M, K), device="cuda", dtype=ab_dtype)
else:
A = torch.randint(-1, 2, (L, K, M), device="cuda", dtype=ab_dtype).permute(0, 2, 1)
GlobalOptions().use_tvm_ffi = use_tvm_ffi
if b_major == "n":
B = torch.randint(-1, 2, (L, K, N), device="cuda", dtype=ab_dtype)
else:
B = torch.randint(-1, 2, (L, N, K), device="cuda", dtype=ab_dtype).permute(0, 2, 1)
if c_major == "n":
D = torch.empty((L, M, N), device="cuda", dtype=c_dtype)
else:
D = torch.empty((L, N, M), device="cuda", dtype=c_dtype).permute(0, 2, 1)
args = cutlass_api.arguments.GemmArguments(A, B, D, accumulator_type)
@@ -100,16 +117,12 @@ def test_gemm_sm100(
torch.testing.assert_close(D, reference.to(D.dtype))
@pytest.mark.parametrize(
"use_tvm_ffi",
[True, False],
)
@pytest.mark.skipif(
not is_device_cc_supported({100})
or (os.getenv("CUTE_DSL_ARCH", "") not in ["", "sm_100a", "sm_100f"]),
reason="Requires compute capability 100 and to be compiled with sm_100a or sm_100f",
)
def test_gemm_sm100_2d(use_tvm_ffi: bool):
def test_gemm_sm100_2d(fixture_toggle_tvm_ffi):
ab_dtype = torch.float16
c_dtype = torch.float16
accumulator_type = torch.float32
@@ -120,8 +133,6 @@ def test_gemm_sm100_2d(use_tvm_ffi: bool):
B = torch.randint(-1, 2, (K, N), device="cuda", dtype=ab_dtype)
D = torch.empty((M, N), device="cuda", dtype=c_dtype)
GlobalOptions().use_tvm_ffi = use_tvm_ffi
args = cutlass_api.arguments.GemmArguments(A, B, D, accumulator_type)
kernels = cutlass_api.get_kernels(args, cc=100)
@@ -141,7 +152,7 @@ def test_gemm_sm100_2d(use_tvm_ffi: bool):
or (os.getenv("CUTE_DSL_ARCH", "") not in ["", "sm_100a"]),
reason="Requires compute capability 100 and to be compiled with sm_100a",
)
def test_gemm_sm100_int8():
def test_gemm_sm100_int8(fixture_toggle_tvm_ffi):
M, N, K = 256, 512, 128
A = torch.randint(-1, 2, (M, K), device="cuda", dtype=torch.int8)
B = torch.randint(-1, 2, (K, N), device="cuda", dtype=torch.int8)
@@ -159,27 +170,48 @@ def test_gemm_sm100_int8():
reference = torch._int_mm(A, B)
torch.testing.assert_close(D, reference.to(D.dtype))
@pytest.mark.skipif(
find_spec("tvm_ffi") is None,
reason="FP8 currently requires TVM FFI to be installed",
@pytest.mark.parametrize(
"problem_size",
[
(256, 512, 128),
(256, 512, 128, 1),
(256, 512, 128, 2),
]
)
@pytest.mark.skipif(
not is_device_cc_supported({100})
or (os.getenv("CUTE_DSL_ARCH", "") not in ["", "sm_100a"]),
reason="Requires compute capability 100 and to be compiled with sm_100a",
)
def test_gemm_sm100_fp8():
# Currently, FP8 is only supported via TVM FFI.
GlobalOptions().use_tvm_ffi = True
def test_gemm_sm100_fp8(
problem_size: tuple[int, ...],
fixture_enable_tvm_ffi, # FP8 currently requires TVM FFI to be installed
):
M, N, K = problem_size[:3]
M, N, K = 256, 512, 128
# Create torch fp8 tensors for A and B
A = torch.randint(-1, 2, (M, K), device="cuda").to(torch.float8_e4m3fn)
D = torch.empty((M, N), device="cuda", dtype=torch.float32)
identity_scale = torch.ones(1, device="cuda", dtype=torch.float32)
# Transpose B because torch._scaled_mm expects B to be column-major
B = torch.randint(-1, 2, (N, K), device="cuda").to(torch.float8_e4m3fn).T
if len(problem_size) == 3:
A = torch.randint(-1, 2, (M, K), device="cuda").to(torch.float8_e4m3fn)
D = torch.empty((M, N), device="cuda", dtype=torch.float32)
# Transpose B because torch._scaled_mm expects B to be column-major
B = torch.randint(-1, 2, (N, K), device="cuda").to(torch.float8_e4m3fn).T
reference = torch._scaled_mm(A, B, identity_scale, identity_scale, out_dtype=torch.float32)
else:
L = problem_size[3]
A = torch.randint(-1, 2, (L, M, K), device="cuda").to(torch.float8_e4m3fn)
D = torch.empty((L, M, N), device="cuda", dtype=torch.float32)
# Transpose B because torch._scaled_mm expects B to be column-major
B = torch.randint(-1, 2, (L, N, K), device="cuda").to(torch.float8_e4m3fn).permute(0, 2, 1)
reference = torch.empty((L, M, N), device="cuda", dtype=torch.float32)
for l in range(L):
reference[l, :, :] = torch._scaled_mm(
A[l, :, :], B[l, :, :], identity_scale, identity_scale, out_dtype=torch.float32
)
args = cutlass_api.arguments.GemmArguments(A, B, D, accumulator_type=torch.float32)
kernels = cutlass_api.get_kernels(args, cc=100)
@@ -190,10 +222,6 @@ def test_gemm_sm100_fp8():
compiled_artifact = kernel.compile(args)
kernel.run(args, compiled_artifact=compiled_artifact, assume_supported_args=True)
identity_scale = torch.ones(1, device="cuda", dtype=torch.float32)
reference = torch._scaled_mm(
A, B, identity_scale, identity_scale, out_dtype=torch.float32
)
torch.testing.assert_close(D, reference)
@@ -207,7 +235,7 @@ def test_no_gemms_available():
args = cutlass_api.arguments.GemmArguments(A, B, D, accumulator_type=torch.float32)
kernels = cutlass_api.get_kernels(args, cc=70)
# There are currenlty no kernels available for compute capability 70.
# There are currently no kernels available for compute capability 70.
assert len(kernels) == 0

View File

@@ -33,8 +33,8 @@ import pytest
import torch
import cutlass_api
from cutlass_api.utils import device_cc
from cutlass_api.config import GlobalOptions
from cutlass_api.utils import is_device_cc_supported
torch.manual_seed(2025)
@@ -62,7 +62,7 @@ def base_data_types():
def supports_sm100af():
return device_cc() == 100 and (
return is_device_cc_supported({100}) and (
os.getenv("CUTE_DSL_ARCH", "") in ["", "sm_100a", "sm_100f"]
)
@@ -485,9 +485,7 @@ def test_gemm_fusion_return_acc():
epi_str = "def epi(accum, C): D = relu(accum) + C; return D, accum"
epi_args = cutlass_api.arguments.EpilogueArguments(
epi_str, C=C, D=D, accum=accum
)
epi_args = cutlass_api.arguments.EpilogueArguments(epi_str, C=C, D=D, accum=accum)
args = cutlass_api.arguments.GemmArguments(
A=A, B=B, out=D, accumulator_type=accumulator_type, epilogue=epi_args
@@ -628,19 +626,13 @@ def test_gemm_fusion_matmul_input_as_aux():
@pytest.mark.parametrize(
"ab_dtype, c_dtype, d_dtype, accumulator_type", base_data_types()
)
@pytest.mark.parametrize(
"use_tvm_ffi",
[True, False],
)
@pytest.mark.skipif(
not supports_sm100af(),
reason="Requires compute capability 100 and to be compiled with sm_100a or sm_100f",
)
def test_gemm_alpha_beta(
M, N, K, L, ab_dtype, c_dtype, d_dtype, accumulator_type, use_tvm_ffi
M, N, K, L, ab_dtype, c_dtype, d_dtype, accumulator_type, fixture_toggle_tvm_ffi
):
GlobalOptions().use_tvm_ffi = use_tvm_ffi
A = torch.randint(-1, 2, (L, M, K), device="cuda", dtype=ab_dtype)
B = torch.randint(-1, 2, (L, K, N), device="cuda", dtype=ab_dtype)
C = torch.randint(-1, 2, (L, M, N), device="cuda", dtype=c_dtype)
@@ -680,7 +672,7 @@ def test_gemm_alpha_beta(
not supports_sm100af(),
reason="Requires compute capability 100 and to be compiled with sm_100a or sm_100f",
)
def test_gemm_big_epi():
def test_gemm_big_epi(fixture_toggle_tvm_ffi):
M, N, K, L = 256, 512, 128, 2
ab_dtype = torch.float16
c_dtype = torch.float32
@@ -809,7 +801,7 @@ def test_gemm_big_epi():
not supports_sm100af(),
reason="Requires compute capability 100 and to be compiled with sm_100a or sm_100f",
)
def test_gemm_fusion_not_available():
def test_gemm_fusion_not_available(fixture_toggle_tvm_ffi):
M = 256
N = 512
K = 1024

View File

@@ -0,0 +1,211 @@
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
import logging
import os
from importlib.util import find_spec
from pprint import pformat
import pytest
import torch
import cutlass
import cutlass_api
from cutlass_api.utils import is_device_cc_supported
logger = logging.getLogger(__name__)
# Set to None to test all kernels, otherwise set to the number of kernels to test
_MAX_KERNELS_TO_TEST = 1
# ab_dtype, c_dtype, accumulator_type
# NOTE: The current Sm80TensorOpGemmKernel implementation only generates kernels
# for Float32 accumulators (see _metadata_operand_combinations in sm80_tensorop_gemm.py),
# so we restrict tests to those combinations for now.
_DTYPES = [
(torch.float16, torch.float16, torch.float32),
(torch.bfloat16, torch.bfloat16, torch.float32),
]
# M, N, K, L and L is optional
def get_sizes_mnkl(level: int = 0) -> list[tuple[int, int, int, int]]:
problem_sizes = []
if level == 0:
problem_sizes.append((128, 128, 64, 1))
problem_sizes.append((256, 256, 128, 1))
problem_sizes.append((512, 256, 256, 1))
problem_sizes.append((256, 512, 384, 2))
else:
raise ValueError(f"Invalid level: {level}")
return problem_sizes
# Layout combinations to test: (A_layout, B_layout, C_layout)
# 't' = transposed/row-major (last dim contiguous, stride=1)
# 'n' = normal/column-major (middle dim contiguous, stride=1)
def get_layouts() -> list[tuple[str, str, str]]:
return [
("t", "t", "t"), # All row-major
("t", "n", "t"), # A row-major, B column-major, C row-major
("n", "t", "t"), # A column-major, B row-major, C row-major
("n", "n", "t"), # A,B column-major, C row-major
("t", "t", "n"),
("t", "n", "n"),
("n", "t", "n"),
("n", "n", "n"),
]
def create_layout_tensor(shape: tuple, layout: str, dtype: torch.dtype, device: str = "cuda") -> torch.Tensor:
"""Create a tensor with specified layout ensuring proper alignment.
For layout 't' (row-major): last dimension is contiguous
For layout 'n' (column-major): middle dimension is contiguous
"""
L, dim1, dim2 = shape
if layout == "t":
# Row-major: last dimension contiguous
return torch.randint(-1, 2, (L, dim1, dim2), device=device, dtype=dtype)
else:
# Column-major: middle dimension contiguous
tensor_temp = torch.randint(-1, 2, (L, dim2, dim1), device=device, dtype=dtype)
return tensor_temp.transpose(1, 2)
@pytest.mark.parametrize(
"M, N, K, L, ab_dtype, c_dtype, accumulator_type",
[(*sizes, *dtypes) for sizes in get_sizes_mnkl() for dtypes in _DTYPES],
)
def test_gemm_sm80(
M: int,
N: int,
K: int,
L: int,
ab_dtype: torch.dtype,
c_dtype: torch.dtype,
accumulator_type: torch.dtype,
fixture_toggle_tvm_ffi,
):
A = torch.randint(-1, 2, (L, M, K), device="cuda", dtype=ab_dtype)
B = torch.randint(-1, 2, (L, K, N), device="cuda", dtype=ab_dtype)
D = torch.empty((L, M, N), device="cuda", dtype=c_dtype)
args = cutlass_api.arguments.GemmArguments(A, B, D, accumulator_type)
kernels = cutlass_api.get_kernels(args, cc=80)
assert len(kernels) > 0, "No kernels returned for the given configuration"
max_kernels = len(kernels) if _MAX_KERNELS_TO_TEST is None else _MAX_KERNELS_TO_TEST
kernels_to_test = kernels[:max_kernels]
# Compute reference using PyTorch's GPU matmul
reference = (A.to(torch.float32) @ B.to(torch.float32)).to(c_dtype)
# Test the selected kernels
for idx, kernel in enumerate(kernels_to_test):
logger.debug(f"Testing kernel {idx+1}/{len(kernels_to_test)}: {kernel.metadata.kernel_name}")
logger.debug(f"Kernel metadata:\n{pformat(kernel.metadata)}")
# Create fresh output tensor for each kernel to avoid interference
D_test = torch.empty((L, M, N), device="cuda", dtype=c_dtype)
args_test = cutlass_api.arguments.GemmArguments(A, B, D_test, accumulator_type)
try:
kernel.run(args_test)
# Verify correctness against PyTorch reference
torch.testing.assert_close(D_test, reference, atol=1e-2, rtol=1e-2)
logger.debug(f"Kernel {idx+1}/{len(kernels_to_test)} passed")
except Exception as e:
pytest.fail(
f"Kernel {idx+1}/{len(kernels_to_test)} ({kernel.metadata.kernel_name}) failed: {e}"
)
@pytest.mark.parametrize(
"M, N, K, L, ab_dtype, c_dtype, accumulator_type, layouts",
[
(*sizes, *dtypes, layouts)
for sizes in [(248, 264, 248, 1)] # Test with one size to keep test time reasonable
for dtypes in _DTYPES
for layouts in get_layouts()
],
)
def test_gemm_sm80_layouts(
M: int,
N: int,
K: int,
L: int,
ab_dtype: torch.dtype,
c_dtype: torch.dtype,
accumulator_type: torch.dtype,
layouts: tuple[str, str, str],
fixture_disable_tvm_ffi,
):
"""Test different tensor layouts (row-major vs column-major).
Note: Currently only tests with use_tvm_ffi=False because TVM-FFI has strict
stride alignment requirements that are violated by transposed column-major layouts.
"""
layout_A, layout_B, layout_C = layouts
# Create tensors with specified layouts
A = create_layout_tensor((L, M, K), layout_A, ab_dtype)
B = create_layout_tensor((L, K, N), layout_B, ab_dtype)
D = create_layout_tensor((L, M, N), layout_C, c_dtype)
args = cutlass_api.arguments.GemmArguments(A, B, D, accumulator_type)
kernels = cutlass_api.get_kernels(args, cc=80)
assert len(kernels) > 0, f"No kernels returned for layout {layout_A}{layout_B}{layout_C}"
max_kernels = len(kernels) if _MAX_KERNELS_TO_TEST is None else _MAX_KERNELS_TO_TEST
kernels_to_test = kernels[:max_kernels]
# Compute reference using PyTorch's GPU matmul
reference = (A.to(torch.float32) @ B.to(torch.float32)).to(c_dtype)
# Test the selected kernels
for idx, kernel in enumerate(kernels_to_test):
logger.debug(f"Testing kernel {idx+1}/{len(kernels_to_test)}: {kernel.metadata.kernel_name}")
# Create fresh output tensor for each kernel with same layout
D_test = create_layout_tensor((L, M, N), layout_C, c_dtype)
args_test = cutlass_api.arguments.GemmArguments(A, B, D_test, accumulator_type)
try:
kernel.run(args_test)
# Verify correctness against PyTorch reference
torch.testing.assert_close(D_test, reference, atol=1e-2, rtol=1e-2)
logger.debug(f"Kernel {idx+1}/{len(kernels_to_test)} passed for layout {layout_A}{layout_B}{layout_C}")
except Exception as e:
pytest.fail(
f"Kernel {idx+1}/{len(kernels_to_test)} ({kernel.metadata.kernel_name}) failed for layout {layout_A}{layout_B}{layout_C}: {e}"
)

View File

@@ -27,6 +27,7 @@
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
import os
from pathlib import Path
import pytest
@@ -36,10 +37,11 @@ import cutlass_api
@pytest.mark.parametrize(
"notebook_name, supported_ccs",
[
("000_gemm.ipynb", [100, 103]),
("000_gemm.ipynb", [80, 89, 90,100, 103]),
("001_gemm_with_fused_epilogue.ipynb", [100, 103]),
("002_bring_your_own_kernel.ipynb", [80, 89, 90, 100, 103, 120, 121]),
("003_host_latency_best_practices.ipynb", [100, 103]),
("003_host_latency_best_practices.ipynb", [80, 89, 90, 100, 103]),
("004_fake_tensors.ipynb", [80, 89, 90, 100, 103]),
],
)
def test_notebooks(notebook_name, supported_ccs):
@@ -74,12 +76,21 @@ def test_notebooks(notebook_name, supported_ccs):
notebook_dir = os.path.join(os.path.dirname(__file__), "..", "..", "examples")
full_notebook_path = os.path.join(notebook_dir, notebook_name)
import nbconvert
import subprocess
import sys
from nbconvert.preprocessors import ExecutePreprocessor
import nbformat
with open(full_notebook_path, "r") as file:
# Register the current Python interpreter as the python3 kernel
subprocess.run(
[sys.executable, "-m", "ipykernel", "install", "--user", "--name", "python3"],
check=True,
capture_output=True,
)
with Path(full_notebook_path).open() as file:
notebook = nbformat.read(file, as_version=4)
ep = nbconvert.preprocessors.ExecutePreprocessor(timeout=600, kernel_name="python3")
ep = ExecutePreprocessor(timeout=600, kernel_name="python3")
# Execute the notebook. This call will error out on any assertions or errors in the
# notebook itself. Allow these to propagate up so the test will fail on notebook failure.

View File

@@ -32,7 +32,6 @@ import torch
import cutlass_api
pytestmark = pytest.mark.arch("80")
@@ -69,7 +68,7 @@ def epi(accum, C, alpha, beta):
except ValueError as e:
assert "F" in str(e)
else:
assert False, "Failed to catch missing keyword"
raise AssertionError("Failed to catch missing keyword")
def test_extra_keywords():
@@ -92,4 +91,4 @@ def epi(accum, C, alpha, beta):
except ValueError as e:
assert "gamma" in str(e)
else:
assert False, "Failed to catch extra keyword"
raise AssertionError("Failed to catch extra keyword")

View File

@@ -0,0 +1,151 @@
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
import pytest
import torch
import cutlass
from cutlass import cute
import cutlass_api
from cutlass_api.arguments import ElementwiseArguments
from cutlass_api.config import GlobalOptions
from cutlass_api.metadata import (
ElementwiseOperandsMetadata,
TensorAttributes,
)
class NoopKernelForTesting(cutlass_api.providers.cutedsl.kernel.CuteDslKernel):
@cute.jit
def impl(self, A, B, out, stream):
cute.printf("Called kernel from host successfully!")
return
def compile(self, args: ElementwiseArguments):
stream = cute.runtime.make_fake_stream()
return self.cute_compile(self.impl, args.A, args.B, args.out, stream)
def _run(
self,
args: ElementwiseArguments,
compiled_artifact,
stream,
workspace=None,
):
self.cute_run(compiled_artifact, args.A, args.B, args.out, stream)
def generate_kernels(_ignored_filter, _ignored_epilogue_args, _ignored_cc):
attrs = TensorAttributes(
stride=(0, 1),
dtype=cutlass.Float16,
divisibility=8,
)
metadata = cutlass_api.KernelMetadata(
kernel_name="NoopKernelForTesting",
kernel_class=NoopKernelForTesting,
min_cc=80,
operands=ElementwiseOperandsMetadata(
A=attrs,
B=attrs,
out=attrs,
),
)
return [NoopKernelForTesting(metadata)]
kernel = NoopKernelForTesting.generate_kernels(None, None, None)[0]
def test_perfectly_aligned():
divisibility = kernel.metadata.operands.A.divisibility
A, B, out = [
torch.randn(divisibility, divisibility * 2, dtype=torch.float16, device="cuda")
for _ in range(3)
]
args = ElementwiseArguments(A=A, B=B, out=out)
kernel.run(args)
def test_overaligned():
A, B, out = [
torch.randn(1024, 1024, dtype=torch.float16, device="cuda") for _ in range(3)
]
args = ElementwiseArguments(A=A, B=B, out=out)
kernel.run(args)
def _check_misaligned_args(use_tvm_ffi: bool, error_match_string: str, **tensors):
"""Helper to test various misalignment errors are properly caught.
With TVM-FFI:
args creation may succeed, but kernel.supports must fail.
TVM-FFI should still catch errors if user bypasses supports.
Without TVM-FFI:
error must be caught early during argument creation itself.
"""
GlobalOptions().use_tvm_ffi = use_tvm_ffi
if use_tvm_ffi:
args = ElementwiseArguments(**tensors)
assert not kernel.supports(args), "Unsupported args should be rejected"
with pytest.raises(Exception, match=error_match_string):
kernel.run(args, assume_supported_args=True)
else:
with pytest.raises(Exception, match=error_match_string):
ElementwiseArguments(**tensors)
@pytest.mark.parametrize("use_tvm_ffi", [True, False])
def test_underaligned(use_tvm_ffi: bool):
divisibility = kernel.metadata.operands.A.divisibility
A, B, out = [
torch.randn(
divisibility + divisibility // 2,
divisibility + divisibility // 4,
dtype=torch.float16,
device="cuda",
)
for _ in range(3)
]
_check_misaligned_args(use_tvm_ffi, "divisible", A=A, B=B, out=out)
@pytest.mark.parametrize("use_tvm_ffi", [True, False])
def test_ptr_misaligned(use_tvm_ffi: bool):
rows = kernel.metadata.operands.A.divisibility * 4
cols = rows
offset = 117
A = torch.randn(rows * cols + offset, dtype=torch.float16, device="cuda")
B = torch.randn(rows, cols, dtype=torch.float16, device="cuda")
out = torch.randn(rows, cols, dtype=torch.float16, device="cuda")
A_offset = torch.as_strided(A[offset:], (rows, cols), (cols, 1))
_check_misaligned_args(use_tvm_ffi, "align", A=A_offset, B=B, out=out)