Initial commit

This commit is contained in:
jkosaian
2025-12-16 10:00:46 -08:00
parent d4e16f5d4e
commit ead2fbfe13
81 changed files with 19407 additions and 0 deletions

View File

@@ -0,0 +1,89 @@
# CUTLASS API
**NOTE: This is an experimental/early-access version of the CUTLASS API. All APIs, names, and paths are subject to change.**
The CUTLASS API provides high-level, universal interfaces to discover, compile,
and execute GEMMs (including grouped GEMMs and scaled GEMMs) targeting NVIDIA GPUs.
It allows GEMM kernels written in different DSLs to be integrated and discovered under
a uniform API.
```python
import cutlass_api
import torch
A, B, out = (torch.randn(128, 128, device="cuda", dtype=torch.float16) for _ in range(3))
# Create arguments for the GEMM operation
args = cutlass_api.arguments.GemmArguments(A, B, out, accumulator_type=torch.float32)
# Query for kernels that support our provided arguments on SM100
kernels = cutlass_api.get_kernels(args, cc=100)
# JIT compile and execute one of the returned kernels
kernels[0].run(args)
```
## Deep dive
To learn more about the API, follow the in-depth tutorials in the [examples directory](./examples).
## Directory structure
* [cutlass_api](./cutlass_api/): source for the CUTLASS API
* [examples](./examples/): examples of using the CUTLASS API
* [test](./test/): tests of the CUTLASS API
## Installation
There is currently no wheel for the CUTLASS API.
Install in editable mode from the `python/cutlass_api` directory of the CUTLASS repository:
```bash
# From the root of this README file
# Install required dependencies
pip install -e .
# Install required dependencies for use with torch
pip install -e ".[torch]"
# To install all dependencies to develop, run tests, etc.
pip install -e ".[dev]"
```
## Running examples and tests
[Tests](./test/) use [pytest](https://docs.pytest.org/en/stable/). An example of running a test is:
```bash
pytest test/integration/test_gemm.py
```
[Examples](./examples/) are [Jupyter notebooks](https://jupyter.org/). They can be launched via:
```bash
cd examples
jupyter notebook
```
## Requirements, compatibility, and current support
Please see [pyproject.toml](pyproject.toml) for dependencies.
**Compatibility:** CUTLASS API has the same compatibility requirements as the CUTLASS project.
See CUTLASS's [Compatibility section](https://github.com/NVIDIA/cutlass?tab=readme-ov-file#compatibility).
### Current support
* Dense GEMM: `out = A @ B`
- Compute capabilities: 100, 103
- Input precisions (A and B must be of same type): F16, BF16, TF32, INT8
- Output precisions: F32, F16, BF16, INT32
- Epilogue operations:
- Auxiliary load of tensors equal in rank and shape of `out`
- Auxiliary store of tensors equal in rank and shape of `out`
- Auxiliary load of scalar
- Tensor-tensor elementwise or tensor-scalar addition, multiplication, subtraction, division
- Elementwise tensor exponent, relu, sigmoid, tanh
* Planned additions
* Block-scaled GEMMs
* Grouped GEMMs
* Additional epilogue operations
* Reductions
* Row/column broadcasts
* SM90 kernels
* Convolutions

View File

@@ -0,0 +1,71 @@
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
__version__ = "0.1.0a1"
from collections.abc import Callable
from cutlass_api import fusion
from cutlass_api.arguments import RuntimeArguments
from cutlass_api.kernel import Kernel
from cutlass_api.manifest import Manifest
from cutlass_api.metadata import KernelMetadata
from cutlass_api.status import Status
def get_kernels(
args: RuntimeArguments = None,
metadata_filter: Callable[[KernelMetadata], bool] | None = None,
cc: int = None,
providers: list[str] = None,
) -> list[Kernel]:
"""
Get the kernels that match the given arguments, metadata filter, and compute capability.
:param args: the arguments of the kernel
:type args: RuntimeArguments
:param metadata_filter: a boolean function that takes in KernelMetadata and returns whether
a Kernel from this metadata should be included
:type metadata_filter: Callable[[KernelMetadata], bool]
:param cc: the compute capability
:type cc: int
:param providers: the providers to use
:type providers: list[str]
:return: the kernels that match the given arguments, metadata filter, and compute capability
:rtype: list[Kernel]
"""
return Manifest.get_kernels(args, metadata_filter, cc, providers)
__all__ = [
"Manifest",
"Status",
"fusion",
"get_kernels",
]

View File

@@ -0,0 +1,363 @@
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
from __future__ import annotations
from collections import OrderedDict
from dataclasses import dataclass, field, fields
from typing import TYPE_CHECKING, get_type_hints
if TYPE_CHECKING:
from collections.abc import Callable
from cutlass_api.metadata import (
ElementwiseOperandsMetadata,
GemmOperandsMetadata,
)
import cutlass
import cutlass.cute as cute
from cutlass_api.fusion import EmptyTensor, trace, trace_in_out
from cutlass_api.fusion.library import LayoutType
from cutlass_api.typing import NumericLike, TensorLike
from cutlass_api.utils import (
TensorWrapper,
add_batch_mode,
is_torch_tensor,
to_cutlass_type,
)
@dataclass
class PerformanceControls:
pass
class EpilogueArguments:
def __init__(self, epilogue_fn: Callable | str | None = None, **kwargs):
"""
Encapsulation of the epilogue function and its arguments needed to
determine the functional operation of an epilogue pattern.
To support flexible definition of epilogues, `EpilogueArguments` is
defined generically as taking in an `epilogue_fn` and additional `kwargs`.
Under the hood, the AST for `epilogue_fn` is parsed to determine the
operands and outputs of the epilogue. `kwargs` must contain Tensors or scalars
for all operands and outputs in the provided epilogue.
Structure of `epilogue_fn`
--------------------------
Epilogue fusion patterns are defined via Python functions that perform Tensor-level
operations -- using a `torch.Tensor` (for example) resulting from matrix multiplication,
the function must be able to compute the desired results of the epilogue.
The structure of these functions is as follows:
```python
def custom_epi_name(accum, *args) -> Union[TensorType, tuple[TensorType]]:
'''
:param accum: result of matrix multiplication, convolution, etc. before the epilogue
:type accum: TensorType
:param args: additional arguments to be used in the epilogue (e.g., aux tensors)
:type args: list[Union[TensorType, ScalarType]]
:returns: at least one tensor resulting from the operation of the epilogue
:rtype: Union[TensorType, tuple[TensorType]]
'''
# Do some compute
return D # and potentially other values
```
`epilogue_fn` must be a Python function or strign representation of a Python function
that **must** satisfy the following constraints:
- Take in a first positional argument named `accum` that represents the result
of operation just before the epilogue is to be performed. For example, in a
GEMM, `accum = A @ B`.
- Return at least one tensor that results from computing the epilogue.
Currently, the return list must contain at least one output named `D`,
though this constraint may be loosened in the future.
Each additional argument following `accum` in the function definition must be
a Tensor or scalar to be loaded. Each variable in the return statement represents
a Tensor or scalar to be stored. The underlying implementation of the epilogue in
the kernel will determine how operands are loaded and stored.
Structure of `kwargs`
----------------------
`kwargs` must contain Tensors or scalars for all operands and outputs in the provided epilogue.
For example, with an epilogue of:
```python
def my_epi(accum, alpha, C, beta):
F = (accum * alpha) + (C * beta)
D = relu(F)
return D, F
```
A user would need to construct epilogue arguments as follows:
```python
epi_args = EpilogueArguments(my_epi, alpha=..., C=..., beta=..., D=..., F=...)
```
:param epilogue_fn: The epilogue function to be traced.
:type epilogue_fn: Optional[Union[Callable, str]]
:param kwargs: Additional keyword arguments consisting of the metadata
for operands and outputs of the epilogue function.
:type kwargs: dict
"""
epilogue_inputs: list[str] = []
epilogue_outputs: list[str] = []
if epilogue_fn is not None:
# Parse the epilogue_fn AST to get the required input and output arguments
epilogue_inputs, epilogue_outputs = trace_in_out(epilogue_fn)
# Get required input and output arguments from kwargs
self.tensors = OrderedDict()
for kw in epilogue_inputs + epilogue_outputs:
if kw not in kwargs:
raise ValueError(
f"Argument {kw} is not provided in the kwargs of the EpilogueArguments constructor"
)
self.tensors[kw] = kwargs[kw]
del kwargs[kw]
if len(kwargs) > 0:
raise ValueError(
f"Unexpected keyword arguments for epilogue: {kwargs.keys()}"
)
self.epilogue_fn = epilogue_fn
@property
def parameters(self) -> list[cute.Tensor | cutlass.Numeric]:
"""Returns the list of input and output parameters of the epilogue"""
return list(self.tensors.values())
@property
def parameter_names(self) -> list[str]:
"""Returns the list of names of the input and output parameters of the epilogue"""
return list(self.tensors.keys())
def add_batch_modes(self):
for name in self.tensors:
self.tensors[name] = add_batch_mode(self.tensors[name])
def to_tensor_wrappers(self, permute: list[int] | None = None):
"""Converts the input and output parameters of the epilogue to TensorWrappers"""
for k, v in self.tensors.items():
if is_torch_tensor(v):
if permute is not None:
v = v.permute(permute)
self.tensors[k] = TensorWrapper(v)
def trace(
self, accumulator_shape: tuple[int, ...], accumulator_type: cutlass.Numeric
):
"""
Traces the epilogue function and generates an internal representation of the epilogue.
:param accumulator_shape: The shape of the accumulator tensor. For example, for a GEMM, this would be the shape of the output tensor.
:type accumulator_shape: tuple[int, ...]
:param accumulator_type: The datatype of the accumulator tensor.
:type accumulator_type: cutlass.Numeric
"""
accumulator = EmptyTensor(
element=accumulator_type,
shape=accumulator_shape,
layout_tag=LayoutType.RowMajor,
)
tensors_for_tracing = {**self.tensors, "accum": accumulator}
# Parse the AST of the epilogue_fn again, this time with the set of required
# tensors. This pass converts the epilogue into an internal representation and
# performs a limited set of correctness checks (e.g., shape matches)
#
# Since all current providers are not based on C++ EVT, we do not need to convert
# the DAG to a tree. If a provider that tightly matches the C++ EVT template structure,
# this will need to be revisited.
self.traced_epilogue = trace(
self.epilogue_fn, tensors_for_tracing, requires_conversion_to_tree=False
)
@dataclass
class RuntimeArguments:
"""
Base class for container of all runtime operands and other runtime parameters needed
by a kernel. This includes primary runtime operands to the operation, as well as
any custom epilogue fusions and runtime performance knobs.
Subclasses map to an operation type (e.g., GEMM, elementwise). These subclasses have
additional fields that are specific to the operation type.
:param epilogue: Optional custom epilogue fusion to be performed after the operation.
:type epilogue: Optional[EpilogueArguments]
:param performance: Optional performance controls to be used by the kernel.
:type performance: Optional[PerformanceControls]
"""
epilogue: EpilogueArguments | None = field(default=None, kw_only=True)
performance: PerformanceControls | None = field(default=None, kw_only=True)
def _validate(self):
"""
Checks that the arguments are valid.
This is run before all fields have been converted to TensorWrapper and cutlass.Numeric.
"""
def __post_init__(self):
self._validate()
self._convert_to_internal_types()
def _convert_to_internal_types(self):
"""
Converts all fields to their internal types.
"""
type_hints = get_type_hints(type(self))
for f in fields(self):
hint = type_hints.get(f.name)
value = getattr(self, f.name)
# Find all fields that are annotated as TensorLike,
# and wrap them in TensorWrapper
if hint is TensorLike:
setattr(self, f.name, TensorWrapper(value))
elif hint is NumericLike:
setattr(self, f.name, to_cutlass_type(value))
@dataclass
class GemmArguments(RuntimeArguments):
"""
Arguments for a Generalized Matrix Multiplication (GEMM) operation: out = A @ B
The tensors must be all rank-3 or all rank-2.
L: Number of batches
M: Number of rows in A and out
K: Number of columns in A and rows in B
N: Number of columns in B and out
:param A: Input tensor A of shape (L, M, K) or (M, K)
:type A: TensorWrapper
:param B: Input tensor B of shape (L, K, N) or (K, N)
:type B: TensorWrapper
:param out: Output tensor C of shape (L, M, N) or (M, N)
:type out: TensorWrapper
:param accumulator_type: Data type of the accumulator
:type accumulator_type: cutlass.Numeric
"""
A: TensorLike
B: TensorLike
out: TensorLike
accumulator_type: NumericLike
def to_operands_metadata(self) -> GemmOperandsMetadata:
from cutlass_api.metadata import GemmOperandsMetadata
return GemmOperandsMetadata.from_args(self)
def _validate(self):
"""
Checks that the arguments are valid.
"""
if len(self.A.shape) < 2 or len(self.A.shape) > 3:
raise ValueError(
f"A must be a tensor of rank 2 or 3 (L=1, M, K), got {self.A.shape}"
)
if len(self.B.shape) < 2 or len(self.B.shape) > 3:
raise ValueError(
f"B must be a tensor of rank 2 or 3 (L=1, K, N), got {len(self.B.shape)}"
)
if len(self.out.shape) < 2 or len(self.out.shape) > 3:
raise ValueError(
f"out must be a tensor of rank 2 or 3 (L=1, M, N), got {len(self.out.shape)}"
)
if self.A.shape[-1] != self.B.shape[-2]:
raise ValueError(
f"A's K dimension ({self.A.shape[-1]}) must be equal to B's K dimension ({self.B.shape[-2]}). A shape (L, M, K): {self.A.shape}, B shape (L, K, N): {self.B.shape}"
)
if self.out.shape[-2] != self.A.shape[-2]:
raise ValueError(
f"out's M dimension ({self.out.shape[-2]}) must be equal to A's M dimension ({self.A.shape[-2]}). A shape (L, M, K): {self.A.shape}, out shape (L, M, N): {self.out.shape}"
)
if self.out.shape[-1] != self.B.shape[-1]:
raise ValueError(
f"out's N dimension ({self.out.shape[-1]}) must be equal to B's N dimension ({self.B.shape[-1]}). B shape (L, K, N): {self.B.shape}, out shape (L, M, N): {self.out.shape}"
)
if self.A.shape[:-2] != self.B.shape[:-2]:
raise ValueError(
f"A & B must have the same rank and batch dimension (if any). A shape (L, M, K): {self.A.shape}, B shape (L, K, N): {self.B.shape}"
)
if self.out.shape[:-2] != self.A.shape[:-2]:
raise ValueError(
f"out & A must have the same rank and batch dimension (if any). out shape (L, M, N): {self.out.shape}, A shape (L, M, K): {self.A.shape}"
)
if isinstance(self.epilogue, EpilogueArguments):
L = self.A.shape[0] if len(self.A.shape) == 3 else 1
M, N = self.A.shape[-2], self.B.shape[-1]
accum_shape = (L, M, N)
self.epilogue.trace(accum_shape, self.accumulator_type)
self.epilogue.to_tensor_wrappers()
@dataclass
class ElementwiseArguments(RuntimeArguments):
"""
Arguments needed for an elementwise operation.
:param A: The input tensor A.
:type A: TensorLike
:param B: The input tensor B.
:type B: TensorLike
:param out: The output tensor.
:type out: TensorLike
"""
A: TensorLike
B: TensorLike
out: TensorLike
def to_operands_metadata(self) -> ElementwiseOperandsMetadata:
from cutlass_api.metadata import ElementwiseOperandsMetadata
return ElementwiseOperandsMetadata.from_args(self)
def _validate(self):
"""
Checks that the arguments are valid.
"""
if self.A.shape != self.B.shape:
raise ValueError(
f"A.shape ({self.A.shape}) must be equal to B.shape ({self.B.shape})"
)
if self.out.shape != self.A.shape:
raise ValueError(
f"out.shape ({self.out.shape}) must be equal to A.shape ({self.A.shape})"
)

View File

@@ -0,0 +1,47 @@
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
from dataclasses import dataclass
from typing import Any
@dataclass
class CompiledArtifact:
"""
Lightweight wrapper over the result of compiling a kernel (e.g., via
`cute.compile` when using CuTe DSL).
:param compiled_obj: The result of compiling the kernel.
:type compiled_obj: Any
:param kernel_obj: The Kernel object on which `.compile()` was called to generate this artifact.
:type kernel_obj: cutlass_api.kernel.Kernel
"""
compiled_obj: Any
kernel_obj: "cutlass_api.kernel.Kernel"

View File

@@ -0,0 +1,83 @@
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
from importlib.util import find_spec
class GlobalOptions:
"""
Singleton class that configures global options for CUTLASS API.
This can be used to enable or disable certain features of the API. For example,
the user can enable TVM-FFI support to allow for the use of framework-level tensors
at run time.
Example:
```python
GlobalOptions().use_tvm_ffi = True
```
"""
_instance = None
def __new__(cls):
"""
Create a new singleton instance of the GlobalOptions class only once.
"""
if cls._instance is None:
cls._instance = super().__new__(cls)
has_tvm_ffi = find_spec("tvm_ffi") is not None
cls._instance._options = {
"use_tvm_ffi": has_tvm_ffi,
}
return cls._instance
@property
def use_tvm_ffi(self) -> bool:
"""
Check if TVM FFI support is enabled.
Default: True if `tvm_ffi` is installed.
When enabled, conversions from DLPack compatible tensors to cute.Tensor is via TVM FFI.
Additionally, invoking the compiled kernel happens via TVM FFI.
Both can offer significant (3x-10x) speedups.
Dependencies:
- Required: `tvm_ffi` (pip install apache-tvm-ffi)
- Optional: `torch_c_dlpack_ext` (pip install torch-c-dlpack-ext)
"""
return self._options["use_tvm_ffi"]
@use_tvm_ffi.setter
def use_tvm_ffi(self, value: bool) -> None:
if value and not find_spec("tvm_ffi"):
raise ImportError(
"TVM FFI is not installed, please install it via `pip install apache-tvm-ffi`."
)
self._options["use_tvm_ffi"] = value

View File

@@ -0,0 +1,127 @@
#################################################################################################
#
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
import ast
import textwrap
from typing import Callable, Union
from cutlass_api.fusion.ir.tensor import Tensor as EmptyTensor
from cutlass_api.fusion.frontend import PythonASTFrontend, PythonASTInOutProcessor
def trace_in_out(fn: Union[Callable, str]) -> tuple[list[str], list[str]]:
"""
Trace the input function and returns the names of inputs and outputs
:param fn: The function to trace.
:type fn: Union[Callable, str]
:return: A tuple of lists of input and output names.
:rtype: tuple[list[str], list[str]]
"""
processor = PythonASTInOutProcessor()
return processor.trace(fn)
def trace(fn, example_tensors, **kwargs):
"""
Trace `fn(**example_tensors)` and generates epilogue visitor
:param fn or str: Python callable or string of the epilogue function
:param example_tensors: example inputs for fn
:type example_tensors: dict
.. hightlight:: python
.. code-block:: python
# Define epilogue function as Python callable
def example_fn(accum, C, alpha, beta, gamma):
D = ((accum + C) * alpha - gamma) / beta
return D
# Define the example tensors
example_inputs = {
"accum": torch.empty(size=(6, 512, 512), dtype=torch.float16, device="cuda"),
"C": torch.empty(size=(6, 512, 512), dtype=torch.float16, device="cuda"),
"alpha": 1.5,
"beta": 0.5,
"gamma": 2.5,
"D": torch.empty(size=(6, 512, 512), dtype=torch.float16, device="cuda")
}
# Generate the epilogue functor
epilogue_visitor = cutlass_api.fusion.trace(example_fn, example_inputs)
"""
if callable(fn):
class EpilogueFunctor(PythonASTFrontend):
def __init__(self, cc=None, **kwargs):
# Since we are currently only using the trace() method for generating an
# intermiediate representation (which is not CC specific), we can hardcode the cc to 100
if not cc:
cc = 100
super().__init__(cc, **kwargs)
pass
setattr(EpilogueFunctor, "__call__", staticmethod(fn))
epilogue_functor = EpilogueFunctor(**kwargs)
epilogue_functor.trace(example_tensors)
return epilogue_functor
elif isinstance(fn, str):
class EpilogueFunctor(PythonASTFrontend):
def __init__(self, cc=None, **kwargs):
self.source = textwrap.dedent(fn)
# Since we are currently only using the trace() method for generating an
# intermiediate representation (which is not CC specific), we can hardcode the cc to 100
if not cc:
cc = 100
super().__init__(cc, **kwargs)
def parse(self, example_inputs) -> None:
self.example_inputs = example_inputs
self.ast = ast.parse(self.source)
self.visit(self.ast)
epilogue_functor = EpilogueFunctor(**kwargs)
epilogue_functor.trace(example_tensors)
return epilogue_functor
else:
raise NotImplementedError("Expect a callable Python function")
__all__ = [
"EmptyTensor",
"trace_in_out",
"trace",
]

View File

@@ -0,0 +1,237 @@
#################################################################################################
#
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
"""
Helpers for common activation functions
"""
import ctypes
from cutlass_api.fusion.ir.c_types import dtype2ctype, to_ctype_value
from cutlass_api.fusion.library import ActivationOp
from cutlass_api.utils import (
is_torch_available,
is_numpy_available,
is_numpy_tensor,
is_torch_tensor,
)
if is_torch_available():
import torch
import torch.nn.functional as F
if is_numpy_available():
import numpy as np
class ActivationFunctor:
"""
Base class for frequently used activation functions
"""
@staticmethod
def numpy(x):
raise NotImplementedError()
@staticmethod
def epilogue_output_op(element_epilogue):
c_element_epilogue = dtype2ctype[element_epilogue]
class _EpilogueOutputOpParams(ctypes.Structure):
_fields_ = [
("alpha", c_element_epilogue),
("beta", c_element_epilogue),
("alpha_ptr", ctypes.c_void_p),
("beta_ptr", ctypes.c_void_p),
]
def __init__(self, alpha, beta, *args) -> None:
self.alpha = to_ctype_value(alpha, element_epilogue)
self.beta = to_ctype_value(beta, element_epilogue)
return _EpilogueOutputOpParams
class ActivationMeta(type):
@classmethod
def __call__(cls, x, *args):
if is_numpy_tensor(x):
return cls.numpy(x, *args)
elif is_torch_tensor(x):
return cls.torch(x, *args)
else:
raise NotImplementedError("Unsupported tensor type")
@classmethod
def numpy(cls, *args):
raise NotImplementedError(
f"Numpy reference for {cls.__name__[:-4]} is not implemented."
)
@classmethod
def torch(cls, *args):
raise NotImplementedError(
f"PyTorch reference for {cls.__name__[:-4]} is not implemented."
)
##############################################################################
# identity operator
class identityMeta(ActivationMeta):
@classmethod
def numpy(cls, x):
return x
@classmethod
def torch(cls, x):
return x
class identity(ActivationFunctor, metaclass=identityMeta):
binding_type = ActivationOp.Identity
##############################################################################
# ReLu operator
class reluMeta(ActivationMeta):
@classmethod
def numpy(cls, x):
return np.where(x > 0, x, 0)
@classmethod
def torch(cls, x):
return F.relu(x)
class relu(ActivationFunctor, metaclass=reluMeta):
binding_type = ActivationOp.ReLU
##############################################################################
# Leaky ReLu operator
class leakyReLUMeta(ActivationMeta):
@classmethod
def numpy(cls, x, leaky_alpha):
return np.maximum(x, 0) + np.minimum(x, 0) * leaky_alpha
@classmethod
def torch(cls, x, leaky_alpha):
return F.leaky_relu(x, leaky_alpha)
class leaky_relu(ActivationFunctor, metaclass=leakyReLUMeta):
binding_type = ActivationOp.LeakyReLU
##############################################################################
# Tanh operator
class tanhMeta(ActivationMeta):
@classmethod
def numpy(cls, x):
return np.tanh(x)
@classmethod
def torch(cls, x):
return torch.tanh(x)
class tanh(ActivationFunctor, metaclass=tanhMeta):
binding_type = ActivationOp.Tanh
##############################################################################
# Sigmoid operator
class sigmoidMeta(ActivationMeta):
@classmethod
def numpy(cls, x):
return 1.0 / (1.0 + np.exp(-x))
@classmethod
def torch(cls, x):
return F.sigmoid(x)
class sigmoid(ActivationFunctor, metaclass=sigmoidMeta):
binding_type = ActivationOp.Sigmoid
##############################################################################
# SiLu operator
class siluMeta(ActivationMeta):
@classmethod
def numpy(cls, x):
return x * sigmoidMeta.numpy(x)
@classmethod
def torch(cls, x):
return F.silu(x)
class silu(ActivationFunctor, metaclass=siluMeta):
binding_type = ActivationOp.SiLU
##############################################################################
# Hardswish operator
class hardswishMeta(ActivationMeta):
@classmethod
def numpy(cls, x):
relu6 = np.minimum(np.maximum(x + 3.0, 0), 6.0)
return x * relu6 / 6.0
@classmethod
def torch(cls, x):
return F.hardswish(x)
class hardswish(ActivationFunctor, metaclass=hardswishMeta):
binding_type = ActivationOp.HardSwish
##############################################################################
# GELU operator
class geluMeta(ActivationMeta):
@classmethod
def numpy(cls, x):
from scipy.special import erf
return 0.5 * x * (1 + erf(x / np.sqrt(2.0)))
@classmethod
def torch(cls, x):
return F.gelu(x)
class gelu(ActivationFunctor, metaclass=geluMeta):
binding_type = ActivationOp.Gelu

View File

@@ -0,0 +1,47 @@
#################################################################################################
#
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
from cutlass_api.fusion.backend.sm80_emitter import Sm80Emitter
import cutlass_api.fusion.backend.sm80_nodes as sm80_nodes
from cutlass_api.fusion.backend.sm90_emitter import Sm90Emitter
import cutlass_api.fusion.backend.sm90_nodes as sm90_nodes
from cutlass_api.fusion.backend.sm100_emitter import Sm100Emitter
import cutlass_api.fusion.backend.sm100_nodes as sm100_nodes
__all__ = [
"Sm80Emitter",
"sm80_nodes",
"Sm90Emitter",
"sm90_nodes",
"Sm100Emitter",
"sm100_nodes",
]

View File

@@ -0,0 +1,165 @@
#################################################################################################
#
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
"""
Base class for Epilogue Visitor Emitter
"""
from cutlass_api.fusion.library import DataTypeTag
from cutlass_api.fusion.ir import TopoVisitorNode, DAGIR
class FusionCallbacks:
def __init__(self, dag_ir: DAGIR, cc: int, emit_CD=True) -> None:
"""
Emit the EVT fusion callbacks
:param dag_ir: the DAG IR holding the epilogue visitor
:param cc: compute capability
:param emit_CD: whether to emit nodes C & D as a part of the fusion callbacks
For Sm90, set emit_CD=False, as Tensor C & D are hardcoded in the collective API
so that their shared memory can be explicitly reused
For Sm89, set emit_CD=True as they are treated as normal AuxLoad & AuxStore nodes.
"""
self.dag_ir = dag_ir
self.emit_CD = emit_CD
self.cc = cc
self.evt_cc = 90 if cc >= 90 else cc
if self.cc < 90:
self.namespace = "threadblock"
else:
self.namespace = "fusion"
#
# Helper functions
#
def get_visitor_name(self, node: str):
"""
Get the visitor name
"""
meta = self.dag_ir.get_node_meta(node)
if not isinstance(meta, TopoVisitorNode) and self.dag_ir.in_degree(node) > 0:
return f"EVT{meta.name_camel}"
else:
return meta.name_camel
def emit(self):
node_metas = self.dag_ir.node_metas_topological_order()
epilogue_str = ""
# Step 1: emit individual node type decl
# emit the EVT & DAG connector
for meta in node_metas:
if not meta.disabled:
epilogue_str += self.emit_node(meta)
if not self.emit_CD and meta.name == "D":
continue
if isinstance(meta, TopoVisitorNode):
epilogue_str += self.emit_dag(meta)
else:
epilogue_str += self.emit_evt(meta)
# Step 2: post-processing & get callback name
if not self.emit_CD:
if not self.dag_ir.has_node("C"):
epilogue_str += "using ElementC = void;\nusing StrideC = StrideD;\n"
output_node = self.dag_ir.get_all_inputs("D")[0]
# The callback is the src of node D
callback_name = self.get_visitor_name(output_node)
else:
# The callback is the last node in the topological order
callback_name = self.get_visitor_name(node_metas[-1].name)
return epilogue_str, callback_name
def emit_evt(self, node):
if self.dag_ir.in_degree(node.name) == 0:
return ""
evt_tmp = f"""
using EVT{node.name_camel} = cutlass::epilogue::{self.namespace}::Sm{self.evt_cc}EVT<
{node.name_camel},
"""
sorted_children = self.dag_ir.get_all_inputs(node.name)
evt_node_strs = [
f" {self.get_visitor_name(child_name)}" for child_name in sorted_children
]
evt_tmp += ",\n".join(evt_node_strs) + ">;\n"
return evt_tmp
def emit_dag(self, node):
subgraph = node.subgraph
subgraph_nodes = subgraph.nodes_topological_order()
# Emit the Edge Tuple
edge_tuples = "cute::tuple<\n"
for n in subgraph_nodes[:-1]:
in_edges = subgraph.in_edges(n)
edge_weights = [
subgraph.get_edge_weight(edge[0], edge[1]) for edge in in_edges
]
sorted_children = [
edge[0] for _, edge in sorted(zip(edge_weights, in_edges))
]
edge_tuple = " cute::seq<"
edge_str = [str(subgraph_nodes.index(child)) for child in sorted_children]
edge_tuple += ", ".join(edge_str) + ">,\n"
edge_tuples += edge_tuple
edge_tuples += " >"
# Emit the node list
dag_nodes = ""
dag_node_strs = []
for n in subgraph_nodes[:-1]:
n_meta = subgraph.get_node_meta(n)
if n_meta.disabled:
dag_node_strs.append(f" {self.get_visitor_name(n)}")
else:
dag_node_strs.append(f" {n_meta.name_camel}")
dag_nodes = ",\n".join(dag_node_strs)
return f"""
using {node.name_camel} = cutlass::epilogue::{self.namespace}::Sm{self.evt_cc}TopologicalVisitor<
{DataTypeTag[node.subgraph.element_compute]},
{edge_tuples},
{dag_nodes}
>;
"""
def emit_node(self, node):
if isinstance(node, TopoVisitorNode):
emission = ""
for meta in node.subgraph.node_metas_topological_order():
if not meta.disabled:
emission += self.emit_node(meta)
return emission
else:
return node.underlying_impl.type_decl

View File

@@ -0,0 +1,151 @@
#################################################################################################
#
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
"""
Emitter for Sm100 Epilogue Visitor
"""
from cutlass_api.fusion.library import (
DataType,
DataTypeTag,
EpilogueScheduleTag,
KernelScheduleSuffixes,
OpcodeClassTag,
)
from cutlass_api.fusion.backend.emitter_base import FusionCallbacks
def to_blackwell_threadblock_shape(tile_description, cluster_shape, kernel_schedule):
blackwell_threadblock_shape = tile_description.threadblock_shape
is_2sm = (
False
if kernel_schedule is None
else ("2sm" in KernelScheduleSuffixes[kernel_schedule])
)
if cluster_shape[0] > 0:
blackwell_threadblock_shape = [
tile_description.threadblock_shape[0] // cluster_shape[0],
tile_description.threadblock_shape[1] // cluster_shape[1],
tile_description.threadblock_shape[2] // cluster_shape[2],
]
if is_2sm:
blackwell_threadblock_shape[0] *= 2
else:
blackwell_threadblock_shape = (
tile_description.math_instruction.instruction_shape
)
return blackwell_threadblock_shape, is_2sm
class Sm100CollectiveEpilogue:
def __init__(
self,
tile_description,
kernel_schedule,
epilogue_schedule,
element_accumulator,
element_d,
fusion_callbacks,
) -> None:
self.cta_tile_mnk, _ = to_blackwell_threadblock_shape(
tile_description, tile_description.cluster_shape, kernel_schedule
)
self.element_accumulator = element_accumulator
if fusion_callbacks.dag_ir.has_node("C"):
self.element_c = fusion_callbacks.dag_ir.get_node_meta("C").element
else:
self.element_c = DataType.void
self.element_d = element_d
self.schedule = epilogue_schedule
self.fusion_callbacks = fusion_callbacks
self.opclass = tile_description.math_instruction.opcode_class
@property
def CtaTileMNK(self) -> str:
"""
The threadblock shape
"""
return f"cute::Shape<_{self.cta_tile_mnk[0]}, _{self.cta_tile_mnk[1]}, _{self.cta_tile_mnk[2]}>"
@property
def EpilogueTileType(self) -> str:
"""
The epilogue tile type
"""
return "cutlass::epilogue::collective::EpilogueTileAuto"
@property
def Schedule(self) -> str:
return EpilogueScheduleTag[self.schedule]
def emit(self):
stride_D_str = self.fusion_callbacks.dag_ir.get_node_meta(
"D"
).underlying_impl.stride_mnl
stride_C_str = stride_D_str
if self.fusion_callbacks.dag_ir.has_node("C"):
stride_C_str = self.fusion_callbacks.dag_ir.get_node_meta(
"C"
).underlying_impl.stride_mnl
callback_decl, callback_name = self.fusion_callbacks.emit()
return (
callback_name,
f"""
using EpilogueDescriptor = cutlass::epilogue::collective::detail::Sm100EpilogueDescriptor<
{OpcodeClassTag[self.opclass]},
{self.CtaTileMNK}, {self.EpilogueTileType},
{DataTypeTag[self.element_accumulator]}, {DataTypeTag[self.element_c]}, {DataTypeTag[self.element_d]},
{self.Schedule}, {stride_C_str}, {stride_D_str},
false /* IsPerColScaleSupported */,
false /* IsBlockScaleSupported */
>;
{callback_decl}
""",
)
class Sm100Emitter:
def __init__(self, operation, graph) -> None:
fusion_callbacks = FusionCallbacks(graph, cc=100, emit_CD=False)
self.collective_epilogue = Sm100CollectiveEpilogue(
tile_description=operation.tile_description,
kernel_schedule=operation.tile_description.kernel_schedule,
epilogue_schedule=operation.tile_description.epilogue_schedule,
element_accumulator=operation.tile_description.math_instruction.element_accumulator,
element_d=fusion_callbacks.dag_ir.get_node_meta("D").element,
fusion_callbacks=fusion_callbacks,
)
def emit(self):
return self.collective_epilogue.emit()

View File

@@ -0,0 +1,140 @@
#################################################################################################
#
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
from cutlass_api.fusion.ir import AuxLoadImpl, AuxStoreImpl
from cutlass_api.fusion.library import DataTypeSize, DataTypeTag, FloatRoundStyleTag
import cutlass_api.fusion.backend.sm90_nodes as sm90_nodes
from cutlass_api.fusion.pycute import product
Sm100AccumulatorImpl = sm90_nodes.Sm90AccumulatorImpl
Sm100LoadSrcImpl = sm90_nodes.Sm90LoadSrcImpl
Sm100ScalarBroadcastImpl = sm90_nodes.Sm90ScalarBroadcastImpl
Sm100RowBroadcastImpl = sm90_nodes.Sm90RowBroadcastImpl
Sm100ColumnBroadcastImpl = sm90_nodes.Sm90ColumnBroadcastImpl
Sm100ComputeImpl = sm90_nodes.Sm90ComputeImpl
Sm100StoreDImpl = sm90_nodes.Sm90StoreDImpl
Sm100ColumnReductionImpl = sm90_nodes.Sm90ColumnReductionImpl
Sm100RowReductionImpl = sm90_nodes.Sm90RowReductionImpl
Sm100ScalarReductionImpl = sm90_nodes.Sm90ScalarReductionImpl
class Sm100AuxLoadImpl(AuxLoadImpl):
@property
def descriptor(self) -> str:
"""
Descriptor for Aux Load
"""
return f"{self.name_camel}Descriptor"
def decl_descriptor(self) -> str:
"""
Declare the descriptor type
"""
return f"\nusing {self.descriptor} = cutlass::epilogue::collective::detail::Sm100AuxLoadDescriptor<EpilogueDescriptor, {self.stride_mnl}, {DataTypeTag[self.element]}>;\n"
@property
def type_decl(self):
"""
Return the string defining the type
"""
if self._type_decl is not None:
return self._type_decl
self._type_decl = self.decl_descriptor()
self._type_decl += f"""
using {self.name_camel} = cutlass::epilogue::fusion::Sm90AuxLoad<
{self.descriptor}::Stages, typename {self.descriptor}::EpilogueTile, {DataTypeTag[self.element]},
{self.stride_mnl}, typename {self.descriptor}::SmemLayoutAtom, typename {self.descriptor}::CopyOpS2R
>;
"""
return self._type_decl
def get_smem_size(
self, cta_tile_mnk, epilogue_tile_mn, stages_c, stages_d, epi_tiles
):
"""
Get the shared memory size based on epilogue_tile_mn, stages_c, and stages_d
"""
return (
DataTypeSize[self.element] * stages_c * product(epilogue_tile_mn) // 8,
128,
)
class Sm100AuxStoreImpl(AuxStoreImpl):
@property
def descriptor(self) -> str:
"""
Descriptor for Aux Load
"""
return f"{self.name_camel}Descriptor"
def decl_descriptor(self) -> str:
"""
Declare the descriptor type
"""
return f"""
using {self.descriptor} = cutlass::epilogue::collective::detail::Sm100AuxStoreDescriptor<
EpilogueDescriptor, {self.stride_mnl}, {DataTypeTag[self.element]}
>;
"""
@property
def type_decl(self):
"""
Return the string defining the type
"""
if self._type_decl is not None:
return self._type_decl
self._type_decl = self.decl_descriptor()
self._type_decl += f"""
using {self.name_camel} = cutlass::epilogue::fusion::Sm90AuxStore<
{self.descriptor}::Stages, typename {self.descriptor}::EpilogueTile, {DataTypeTag[self.element]},
{FloatRoundStyleTag[self.round_style]}, {self.stride_mnl}, typename {self.descriptor}::SmemLayoutAtom,
typename {self.descriptor}::CopyOpR2S
>;
"""
return self._type_decl
def get_smem_size(
self, cta_tile_mnk, epilogue_tile_mn, stages_c, stages_d, epi_tiles
):
"""
Get the shared memory size based on epilogue_tile_mn, stages_c, and stages_d
"""
return (
DataTypeSize[self.element] * stages_d * product(epilogue_tile_mn) // 8,
128,
)

View File

@@ -0,0 +1,46 @@
#################################################################################################
#
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
"""
Emitter for Sm80 Epilogue Visitor
"""
from cutlass_api.fusion.backend.emitter_base import FusionCallbacks
class Sm80Emitter:
def __init__(self, operation, graph) -> None:
self.fusion_callbacks = FusionCallbacks(graph, cc=80)
def emit(self):
callback_decl, callback_name = self.fusion_callbacks.emit()
return callback_name, callback_decl

View File

@@ -0,0 +1,247 @@
#################################################################################################
#
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
from cutlass_api.fusion.library import (
DataTypeTag,
FloatRoundStyleTag,
FunctionalOp,
op_tag,
)
from cutlass_api.fusion.ir import (
# Load Node
AccumulatorImpl,
AuxLoadImpl,
ColumnBroadcastImpl,
LoadNode,
RowBroadcastImpl,
ScalarBroadcastImpl,
# Compute Node
ComputeImpl,
# Store Node
AuxStoreImpl,
ColumnReductionImpl,
RowReductionImpl,
ScalarReductionImpl,
)
class Sm80AccumulatorImpl(AccumulatorImpl):
@property
def type_decl(self):
"""
Return the string defining the type
"""
if self._type_decl is not None:
return self._type_decl
self._type_decl = f"""\nusing {self.name_camel} = cutlass::epilogue::threadblock::VisitorAccFetch;\n"""
return self._type_decl
class Sm80AuxLoadImpl(AuxLoadImpl):
@property
def type_decl(self):
"""
Return the string defining the type
"""
if self._type_decl is not None:
return self._type_decl
self._type_decl = f"""
using {self.name_camel} = cutlass::epilogue::threadblock::VisitorAuxLoad<
OutputTileThreadMap, {DataTypeTag[self.element]}, {self.stride_mnl}
>;
"""
return self._type_decl
class Sm80LoadSrcImpl(Sm80AuxLoadImpl):
pass
class Sm80ScalarBroadcastImpl(ScalarBroadcastImpl):
def __init__(self, node: LoadNode) -> None:
super().__init__(node)
self.broadcast_count = 1
self.reduction_fn = FunctionalOp.Multiplies
@property
def type_decl(self):
"""
Return the string defining the type
"""
if self._type_decl is not None:
return self._type_decl
self._type_decl = f"""
using {self.name_camel} = cutlass::epilogue::threadblock::VisitorScalarBroadcast<
{DataTypeTag[self.element]}, {self.stride_mnl}, {self.broadcast_count}, {op_tag(self.reduction_fn)}
>;
"""
return self._type_decl
class Sm80RowBroadcastImpl(RowBroadcastImpl):
@property
def type_decl(self):
"""
Return the string defining the type
"""
if self._type_decl is not None:
return self._type_decl
self._type_decl = f"""
using {self.name_camel} = cutlass::epilogue::threadblock::VisitorRowBroadcast<
OutputTileThreadMap, {DataTypeTag[self.element]},
{self.stride_mnl}
>;
"""
return self._type_decl
class Sm80ColumnBroadcastImpl(ColumnBroadcastImpl):
@property
def type_decl(self):
"""
Return the string defining the type
"""
if self._type_decl is not None:
return self._type_decl
self._type_decl = f"""
using {self.name_camel} = cutlass::epilogue::threadblock::VisitorColBroadcast<
OutputTileThreadMap, {DataTypeTag[self.element]},
{self.stride_mnl}
>;
"""
return self._type_decl
class Sm80ComputeImpl(ComputeImpl):
@property
def type_decl(self):
"""
Return the string defining the type
"""
if self._type_decl is not None:
return self._type_decl
self._type_decl = f"""
using {self.name_camel} = cutlass::epilogue::threadblock::VisitorCompute<
{op_tag(self.fn)}, {DataTypeTag[self.element_output]}, {DataTypeTag[self.element_compute]},
{FloatRoundStyleTag[self.round_style]}
>;
"""
return self._type_decl
class Sm80AuxStoreImpl(AuxStoreImpl):
@property
def type_decl(self):
"""
Return the string defining the type
"""
if self._type_decl is not None:
return self._type_decl
self._type_decl = f"""
using {self.name_camel} = cutlass::epilogue::threadblock::VisitorAuxStore<
OutputTileThreadMap, {DataTypeTag[self.element]}, {FloatRoundStyleTag[self.round_style]},
{self.stride_mnl}
>;
"""
return self._type_decl
class Sm80StoreDImpl(Sm80AuxStoreImpl):
pass
class Sm80ColumnReductionImpl(ColumnReductionImpl):
@property
def type_decl(self):
"""
Return the string defining the type
"""
if self._type_decl is not None:
return self._type_decl
self._type_decl = f"""
using {self.name_camel} = cutlass::epilogue::threadblock::VisitorColReduction<
{op_tag(self.reg_reduce_fn)}, {op_tag(self.gmem_reduce_fn)},
OutputTileThreadMap, {DataTypeTag[self.element]},
{DataTypeTag[self.element_compute]}, {FloatRoundStyleTag[self.round_style]},
{self.stride_mnl}
>;
"""
return self._type_decl
class Sm80RowReductionImpl(RowReductionImpl):
@property
def type_decl(self):
"""
Return the string defining the type
"""
if self._type_decl is not None:
return self._type_decl
self._type_decl = f"""
using {self.name_camel} = cutlass::epilogue::threadblock::VisitorRowReduction<
{op_tag(self.reg_reduce_fn)}, {op_tag(self.gmem_reduce_fn)},
OutputTileThreadMap, {DataTypeTag[self.element]},
{DataTypeTag[self.element_compute]}, {FloatRoundStyleTag[self.round_style]},
{self.stride_mnl}
>;
"""
return self._type_decl
class Sm80ScalarReductionImpl(ScalarReductionImpl):
@property
def type_decl(self):
"""
Return the string defining the type
"""
if self._type_decl is not None:
return self._type_decl
self._type_decl = f"""
using {self.name_camel} = cutlass::epilogue::threadblock::VisitorScalarReduction<
{op_tag(self.reg_reduce_fn)}, {op_tag(self.gmem_reduce_fn)},
OutputTileThreadMap, {DataTypeTag[self.element]},
{DataTypeTag[self.element_compute]}, {FloatRoundStyleTag[self.round_style]},
{self.stride_mnl}
>;
"""
return self._type_decl

View File

@@ -0,0 +1,97 @@
#################################################################################################
#
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
"""
Emitter for Sm90 Epilogue Visitor
"""
from cutlass_api.fusion.library import DataTypeTag, EpilogueScheduleTag
from cutlass_api.fusion.backend.emitter_base import FusionCallbacks
class CollectiveEpilogue:
def __init__(
self, tile_description, schedule, element_c, element_d, fusion_callbacks
) -> None:
self.cta_tile_mnk = tile_description.threadblock_shape
self.element_c = element_c
self.element_d = element_d
self.schedule = schedule
self.fusion_callbacks = fusion_callbacks
@property
def CtaTileMNK(self) -> str:
"""
The threadblock shape
"""
return f"cute::Shape<_{self.cta_tile_mnk[0]}, _{self.cta_tile_mnk[1]}, _{self.cta_tile_mnk[2]}>"
@property
def EpilogueTileType(self) -> str:
"""
The epilogue tile type
"""
return "cutlass::epilogue::collective::EpilogueTileAuto"
@property
def Schedule(self) -> str:
return EpilogueScheduleTag[self.schedule]
def emit(self):
callback_decl, callback_name = self.fusion_callbacks.emit()
return (
callback_name,
f"""
using EpilogueDescriptor = cutlass::epilogue::collective::detail::EpilogueDescriptor<
{self.CtaTileMNK}, {self.EpilogueTileType},
{DataTypeTag[self.element_c]}, {DataTypeTag[self.element_d]},
{self.Schedule}
>;
{callback_decl}
""",
)
class Sm90Emitter:
def __init__(self, operation, graph) -> None:
fusion_callbacks = FusionCallbacks(graph, cc=90, emit_CD=False)
self.collective_epilogue = CollectiveEpilogue(
tile_description=operation.tile_description,
schedule=operation.tile_description.epilogue_schedule,
element_c=operation.C.element,
element_d=operation.D.element,
fusion_callbacks=fusion_callbacks,
)
def emit(self):
return self.collective_epilogue.emit()

View File

@@ -0,0 +1,327 @@
#################################################################################################
#
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
from cutlass_api.fusion.library import (
DataTypeSize,
DataTypeTag,
FloatRoundStyleTag,
FunctionalOp,
op_tag,
)
from cutlass_api.fusion.ir import (
# Load Node
AccumulatorImpl,
AuxLoadImpl,
ColumnBroadcastImpl,
LoadNode,
LoadSrcImpl,
RowBroadcastImpl,
ScalarBroadcastImpl,
# Compute Node
ComputeImpl,
# Store Node
AuxStoreImpl,
ColumnReductionImpl,
RowReductionImpl,
ScalarReductionImpl,
StoreDImpl,
)
from cutlass_api.fusion.pycute import product
class Sm90AccumulatorImpl(AccumulatorImpl):
@property
def type_decl(self):
"""
Return the string defining the type
"""
if self._type_decl is not None:
return self._type_decl
self._type_decl = f"""\nusing {self.name_camel} = cutlass::epilogue::fusion::Sm90AccFetch;\n"""
return self._type_decl
class Sm90LoadSrcImpl(LoadSrcImpl):
@property
def type_decl(self):
"""
Return the string defining the type
"""
if self._type_decl is not None:
return self._type_decl
self._type_decl = f"""
using ElementC = {DataTypeTag[self.element]};
using StrideC = {self.stride_mnl};
using {self.name_camel} = cutlass::epilogue::fusion::Sm90SrcFetch<{DataTypeTag[self.element]}>;
"""
return self._type_decl
class Sm90AuxLoadImpl(AuxLoadImpl):
@property
def descriptor(self) -> str:
"""
Descriptor for Aux Load
"""
return f"{self.name_camel}Descriptor"
def decl_descriptor(self) -> str:
"""
Declare the descriptor type
"""
return f"\nusing {self.descriptor} = cutlass::epilogue::collective::detail::AuxLoadDescriptor<EpilogueDescriptor, {self.stride_mnl}, {DataTypeTag[self.element]}>;\n"
@property
def type_decl(self):
"""
Return the string defining the type
"""
if self._type_decl is not None:
return self._type_decl
self._type_decl = self.decl_descriptor()
self._type_decl += f"""
using {self.name_camel} = cutlass::epilogue::fusion::Sm90AuxLoad<
{self.descriptor}::Stages, typename {self.descriptor}::EpilogueTile, {DataTypeTag[self.element]},
{self.stride_mnl}, typename {self.descriptor}::SmemLayoutAtom, typename {self.descriptor}::CopyOpS2R
>;
"""
return self._type_decl
def get_smem_size(
self, cta_tile_mnk, epilogue_tile_mn, stages_c, stages_d, epi_tiles
):
"""
Get the shared memory size based on epilogue_tile_mn, stages_c, and stages_d
"""
return (
DataTypeSize[self.element] * stages_c * product(epilogue_tile_mn) // 8,
128,
)
class Sm90ScalarBroadcastImpl(ScalarBroadcastImpl):
def __init__(self, node: LoadNode) -> None:
super().__init__(node)
self.broadcast_count = 1
self.reduction_fn = FunctionalOp.Multiplies
@property
def type_decl(self):
"""
Return the string defining the type
"""
if self._type_decl is not None:
return self._type_decl
self._type_decl = f"""
using {self.name_camel} = cutlass::epilogue::fusion::Sm90ScalarBroadcast<
{DataTypeTag[self.element]}, {self.stride_mnl}, {self.broadcast_count}, {op_tag(self.reduction_fn)}
>;
"""
return self._type_decl
class Sm90RowBroadcastImpl(RowBroadcastImpl):
@property
def type_decl(self):
"""
Return the string defining the type
"""
if self._type_decl is not None:
return self._type_decl
self._type_decl = f"""
using {self.name_camel} = cutlass::epilogue::fusion::Sm90RowBroadcast<
0 /*Stages*/, typename EpilogueDescriptor::TileShape, {DataTypeTag[self.element]}, {DataTypeTag[self.element_output]},
{self.stride_mnl}
>;
"""
return self._type_decl
class Sm90ColumnBroadcastImpl(ColumnBroadcastImpl):
@property
def type_decl(self):
"""
Return the string defining the type
"""
if self._type_decl is not None:
return self._type_decl
self._type_decl = f"""
using {self.name_camel} = cutlass::epilogue::fusion::Sm90ColBroadcast<
0 /*Stages*/, typename EpilogueDescriptor::TileShape, {DataTypeTag[self.element]}, {DataTypeTag[self.element_output]},
{self.stride_mnl}
>;
"""
return self._type_decl
class Sm90ComputeImpl(ComputeImpl):
@property
def type_decl(self):
"""
Return the string defining the type
"""
if self._type_decl is not None:
return self._type_decl
self._type_decl = f"""
using {self.name_camel} = cutlass::epilogue::fusion::Sm90Compute<
{op_tag(self.fn)}, {DataTypeTag[self.element_output]}, {DataTypeTag[self.element_compute]},
{FloatRoundStyleTag[self.round_style]}
>;
"""
return self._type_decl
class Sm90AuxStoreImpl(AuxStoreImpl):
@property
def descriptor(self) -> str:
"""
Descriptor for Aux Load
"""
return f"{self.name_camel}Descriptor"
def decl_descriptor(self) -> str:
"""
Declare the descriptor type
"""
return f"""
using {self.descriptor} = cutlass::epilogue::collective::detail::AuxStoreDescriptor<
EpilogueDescriptor, {self.stride_mnl}, {DataTypeTag[self.element]}
>;
"""
@property
def type_decl(self):
"""
Return the string defining the type
"""
if self._type_decl is not None:
return self._type_decl
self._type_decl = self.decl_descriptor()
self._type_decl += f"""
using {self.name_camel} = cutlass::epilogue::fusion::Sm90AuxStore<
{self.descriptor}::Stages, typename {self.descriptor}::EpilogueTile, {DataTypeTag[self.element]},
{FloatRoundStyleTag[self.round_style]}, {self.stride_mnl}, typename {self.descriptor}::SmemLayoutAtom,
typename {self.descriptor}::CopyOpR2S
>;
"""
return self._type_decl
def get_smem_size(
self, cta_tile_mnk, epilogue_tile_mn, stages_c, stages_d, epi_tiles
):
"""
Get the shared memory size based on epilogue_tile_mn, stages_c, and stages_d
"""
return (
DataTypeSize[self.element] * stages_d * product(epilogue_tile_mn) // 8,
128,
)
class Sm90StoreDImpl(StoreDImpl):
@property
def type_decl(self):
"""
Return the string defining the type
"""
return f"""
using ElementD = {DataTypeTag[self.element]};
using StrideD = {self.stride_mnl};
"""
class Sm90ColumnReductionImpl(ColumnReductionImpl):
@property
def type_decl(self):
"""
Return the string defining the type
"""
if self._type_decl is not None:
return self._type_decl
self._type_decl = f"""
using {self.name_camel} = cutlass::epilogue::fusion::Sm90ColReduction<
{op_tag(self.reg_reduce_fn)}, {op_tag(self.reg_reduce_fn)}, {op_tag(self.gmem_reduce_fn)}, 0,
typename EpilogueDescriptor::TileShape, {DataTypeTag[self.element]},
{DataTypeTag[self.element_compute]}, {FloatRoundStyleTag[self.round_style]},
{self.stride_mnl}
>;
"""
return self._type_decl
class Sm90RowReductionImpl(RowReductionImpl):
@property
def type_decl(self):
"""
Return the string defining the type
"""
if self._type_decl is not None:
return self._type_decl
self._type_decl = f"""
using {self.name_camel} = cutlass::epilogue::fusion::Sm90RowReduction<
{op_tag(self.reg_reduce_fn)}, {op_tag(self.reg_reduce_fn)}, {op_tag(self.gmem_reduce_fn)}, 0 /* Stages */,
typename EpilogueDescriptor::TileShape, {DataTypeTag[self.element]},
{DataTypeTag[self.element_compute]}, {FloatRoundStyleTag[self.round_style]},
{self.stride_mnl}
>;
"""
return self._type_decl
class Sm90ScalarReductionImpl(ScalarReductionImpl):
@property
def type_decl(self):
"""
Return the string defining the type
"""
if self._type_decl is not None:
return self._type_decl
self._type_decl = f"""
using {self.name_camel} = cutlass::epilogue::fusion::Sm90ScalarReduction<
{op_tag(self.reg_reduce_fn)}, {op_tag(self.gmem_reduce_fn)},
{DataTypeTag[self.element]}, {DataTypeTag[self.element_compute]},
{FloatRoundStyleTag[self.round_style]}, {self.stride_mnl}
>;
"""
return self._type_decl

View File

@@ -0,0 +1,88 @@
#################################################################################################
#
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
"""
Epilogue Visitor interface for compiling, and running visitor-based epilogue.
"""
import cutlass_api.fusion.backend
from cutlass_api.fusion.passes.util import cc_map
class EpilogueFunctorVisitor:
"""
Apply an epilogue functor described by the epilogue EVT
:param cc: compute capability
:param visitor_frontend: user-provide visitor frontend
"""
def __init__(self, cc: int, visitor, element_compute) -> None:
# Type of Emitter based on CC
self.emit_cls = getattr(
cutlass_api.fusion.backend, f"Sm{cc_map[cc]}Emitter"
)
# Visitor Types
self.visitor = visitor
self.graph = visitor.dag_ir
# Data types
self.element_epilogue = element_compute # element compute
self.element_output = self.graph.get_node_meta("D").underlying_impl.element
# Epilogue Thread Type
if cc_map[cc] in [90, 100]:
self.arg_c_type = self.visitor.arg_c_type
self.arg_d_type = self.visitor.arg_d_type
# Epilogue stages specialized for sm80 kernel
if cc == 80:
if hasattr(self.visitor, "epilogue_stages"):
self.epilogue_stages = self.visitor.epilogue_stages
assert self.epilogue_stages <= 2, (
"Only supports Stages <=2 in SM80 Epilogue"
)
def emit(self, operation):
"""
Emit the C++ code
"""
emitter = self.emit_cls(operation, self.graph)
return emitter.emit()
def get_smem_size(self, tile_description):
"""
Get the shared memory size in bytes
"""
return self.visitor.get_smem_size(tile_description)

View File

@@ -0,0 +1,120 @@
#################################################################################################
#
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
"""
Collection of builtin functions used for host reference in EVT
"""
from cutlass_api.utils import (
is_numpy_available,
is_numpy_tensor,
is_torch_available,
is_torch_tensor,
)
if is_numpy_available():
import numpy as np
if is_torch_available():
import torch
def multiply_add(x, y, z):
return x * y + z
def sum(x, dim):
if is_numpy_tensor(x):
return x.sum(axis=tuple(dim))
elif is_torch_tensor(x):
return torch.sum(x, dim)
else:
raise TypeError(f"Unsupported tensor type: {type(x)}")
def max(x, dim):
if is_numpy_tensor(x):
return x.max(axis=tuple(dim))
elif is_torch_tensor(x):
return torch.amax(x, dim)
else:
raise TypeError(f"Unsupported tensor type: {type(x)}")
def maximum(x, y):
if is_numpy_tensor(x):
return np.maximum(x, y)
elif is_torch_tensor(x):
return torch.maximum(x, torch.tensor(y))
else:
raise TypeError(f"Unsupported tensor type: {type(x)}")
def minimum(x, y):
if is_numpy_tensor(x):
return np.minimum(x, y)
elif is_torch_tensor(x):
return torch.minimum(x, torch.tensor(y))
else:
raise TypeError(f"Unsupported tensor type: {type(x)}")
def exp(x):
if is_numpy_tensor(x):
return np.exp(x)
elif is_torch_tensor(x):
return torch.exp(x)
else:
raise TypeError(f"Unsupported tensor type: {type(x)}")
##############################################################################
# Layout manipulation nodes
##############################################################################
def permute(x, indices: tuple):
if is_numpy_tensor(x):
return np.transpose(x, axes=indices)
elif is_torch_tensor(x):
return x.permute(*indices)
else:
raise TypeError(f"Unsupported tensor type: {type(x)}")
def reshape(x, new_shape: tuple):
if is_numpy_tensor(x):
return np.reshape(x, newshape=new_shape)
elif is_torch_tensor(x):
return x.view(new_shape)
else:
raise TypeError(f"Unsupported tensor type: {type(x)}")

View File

@@ -0,0 +1,41 @@
#################################################################################################
#
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
from cutlass_api.fusion.frontend.python_ast import (
PythonASTFrontend,
PythonASTInOutProcessor,
)
__all__ = [
"PythonASTFrontend",
"PythonASTInOutProcessor",
]

View File

@@ -0,0 +1,300 @@
#################################################################################################
#
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
"""
Base class for Python EVT Frontend
"""
from typing import Union
from cutlass_api.fusion.library import DataType
from cutlass_api.fusion.ir import (
ComputeNode,
DAGIR,
LayoutNode,
LoadNode,
StoreNode,
)
from cutlass_api.fusion.passes import (
EVTGraphDrawer,
EVTPassManager,
GetSmemSize,
PassDAG2Tree,
PassGetArgumentType,
PassGetImpl,
PassFixElementD,
PassLayoutManipulateElimination,
PassPreprocessRed,
PassShapeTypePropagation,
)
from cutlass_api.fusion.passes.util import cc_map
from cutlass_api.fusion.evt_ops import permute, reshape
class EVTFrontendBase:
layout_fns = {"permute": permute, "reshape": reshape}
def __init__(
self,
cc,
element_compute=DataType.f32,
additional_passes=None,
requires_conversion_to_tree=True,
**kwargs,
) -> None:
self.cc = cc
self.element_compute = element_compute
self.dag_ir = DAGIR(self.cc, self.element_compute)
self.compute_cnt = 0
self.layout_cnt = 0
self.imm_cnt = 0
if additional_passes is None:
additional_passes = []
passes = [
PassPreprocessRed,
PassGetArgumentType,
PassShapeTypePropagation,
PassLayoutManipulateElimination,
PassGetImpl,
PassDAG2Tree,
PassFixElementD,
] + additional_passes
soft_dependencies = []
if not requires_conversion_to_tree:
passes.remove(PassDAG2Tree)
soft_dependencies.append(PassDAG2Tree)
self.pass_manager = EVTPassManager(
self.dag_ir, passes, soft_dependencies=soft_dependencies
)
if self.cc == 80:
self._epilogue_stages = 1
else:
self._epilogue_stages = None
@property
def epilogue_stages(self):
return self._epilogue_stages
@epilogue_stages.setter
def epilogue_stages(self, stages):
self._epilogue_stages = stages
def parse(self, *args, **kwargs):
raise NotImplementedError(
"The 'parse' function must be overloaded in frontend class"
)
def trace(self, *args, **kwargs):
# Parse the input
self.parse(*args, **kwargs)
# Verify the DAG IR to ensure that "D" is the output node with out_degree = 0
if self.cc >= 90:
if self.dag_ir.out_degree("D") != 0:
raise RuntimeError(
f"On SM90 or higher, D is expected to be a output node with 0 users to "
f"enable smem reuse between C and D, but got {self.dag_ir.out_degree('D')}"
)
# Run the passes
self.pass_manager()
# Set the epilogue type
self.epilogue_thread_type = self.dag_ir.epilogue_thread_type
if cc_map[self.cc] in [90, 100]:
self.arg_c_type = self.dag_ir.arg_c_type
self.arg_d_type = self.dag_ir.arg_d_type
self.reduction_names = self.dag_ir.reduction_names
#
# Helper functions for DAG IR manipulation
#
def add_node(self, node):
self.dag_ir.add_node(node)
def add_edge(self, src, tgt, weight=0):
self.dag_ir.add_edge(src, tgt, weight=weight)
def set_tensor(self, node_name, example):
"""
Add an example tensor to node {node_name} in the DAG IR
"""
meta = self.dag_ir.get_node_meta(node_name)
meta.tensor = {"tensor": example}
def set_store_tensor(self, node_name, example):
"""
Add an example tensor to node {node_name} in the DAG IR
"""
meta = self.dag_ir.get_node_meta(node_name)
meta.store_tensor = {"tensor": example}
def mark_output(self, node_name):
"""
Mark a store node as output
"""
meta = self.dag_ir.get_node_meta(node_name)
# Allow accum to also be used as an output so that
# a user can include it in the return line
if not isinstance(meta, StoreNode) and not meta.name == "accum":
raise ValueError(
f"Only StoreNodes can be marked as output. "
f"Got {type(meta).__name__}: {node_name}"
)
meta.is_output = True
# Add node with specific type
def add_load_node(self, name, example):
"""
Add a Load node to DAG IR
:param name: name of the loaded variable
:type name: str
:param example: example input
:type example: np.ndarray|torch.Tensor|cupy.ndarray|float
"""
if name is None:
raise ValueError("Name is not provided.")
if example is None:
raise ValueError(f"Example input for {name} is not provided.")
load_node = LoadNode(name)
load_node.tensor = {"tensor": example}
# Special logics for accumulator
if name == "accum":
if load_node.tensor.rank == 2:
new_shape = tuple(
[
1,
]
+ list(load_node.tensor.shape)
)
load_node.tensor.broadcast(new_shape)
elif load_node.tensor.rank < 2 or load_node.tensor.rank > 3:
raise ValueError(
f"Expect example inputs for 'accum' be a rank-2 or rank-3 tensor. Got {load_node.tensor.shape}."
)
self.add_node(load_node)
def add_imm(self, value: Union[float, int]):
"""
Add an immediate scalar value to DAG IR
:param value: the value of the immediate scalar
:type value: float
"""
try:
value = float(value)
except Exception as e:
raise ValueError(
f"{type(value).__name__} cannot be converted to float."
) from e
name = f"imm_{value}_k{self.imm_cnt}".replace(".", "_")
self.imm_cnt += 1
load_node = LoadNode(name)
load_node.tensor = {"tensor": value, "is_constant": True}
self.add_node(load_node)
return name
def add_compute_node(self, op, name=None):
"""
Add a compute node.
:param op: the computation op
:param name: the node name (optional)
:type name: str
:return: the name of the compute node
"""
if name is None:
name = f"compute_{self.compute_cnt}"
self.compute_cnt += 1
compute_node = ComputeNode(
name=name,
fn=op,
element_output=self.element_compute,
element_compute=self.element_compute,
)
self.add_node(compute_node)
return compute_node.name
def add_layout_node(self, op, kwargs, name=None):
"""
Add a layout node.
:param op: the layout op
:type op: evt_ops
:param name: the node name (optional)
:type name: str
:return: the name of the layout node
"""
if name is None:
name = f"layout_{self.layout_cnt}"
self.layout_cnt += 1
layout_node = LayoutNode(name=name, fn=op, kwargs=kwargs)
self.add_node(layout_node)
return layout_node.name
def add_store_node(self, name):
store_node = StoreNode(name)
self.add_node(store_node)
#
# Visualization The DAG IR
#
def visualize(self, name="dag_ir"):
"""
Visualize the dag ir with svg file
:param name: the name of the graph
"""
drawer = EVTGraphDrawer(self.dag_ir, name)
try:
for name, graph in drawer.get_dot_graph():
graph.write_svg(f"./{name}.svg")
except:
raise RuntimeError(
"'dot' is not found in path. GraphDrawer is disabled. "
"Please install it with 'sudo apt-get install graphviz'."
)
#
# Get shared memory size
#
def get_smem_size(self, tile_description):
"""
Get the shared memory size of the epilogue
"""
smem_size = GetSmemSize(self.dag_ir)(tile_description)
return smem_size

View File

@@ -0,0 +1,262 @@
#################################################################################################
#
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
"""
Python AST frontend that parses input into DAG IR
"""
import ast
import inspect
import textwrap
from typing import Callable, Union
from cutlass_api.fusion.activation import (
identity,
relu,
tanh,
sigmoid,
silu,
hardswish,
gelu,
)
from cutlass_api.fusion.frontend.frontend_base import EVTFrontendBase
from cutlass_api.fusion.library import DataType, FunctionalOp
class PythonASTFrontend(EVTFrontendBase, ast.NodeVisitor):
def __init__(self, cc, element_compute=DataType.f32, **kwargs):
super().__init__(cc, element_compute, **kwargs)
# Flags
# If this state is True, visit_Constant returns values without creating imm node
self.no_imm = False
self.visiting_return = False
def parse(self, example_inputs):
self.example_inputs = example_inputs
self.source = textwrap.dedent(inspect.getsource(self.__call__))
self.ast = ast.parse(self.source)
self.visit(self.ast)
#
# Helper functions
#
@staticmethod
def ast_op_to_bindings(op):
mapping = {
ast.Add: FunctionalOp.Plus,
ast.Sub: FunctionalOp.Minus,
ast.Mult: FunctionalOp.Multiplies,
ast.Div: FunctionalOp.Divides,
"maximum": FunctionalOp.Maximum,
"minimum": FunctionalOp.Minimum,
"identity": identity.binding_type,
"relu": relu.binding_type,
"tanh": tanh.binding_type,
"sigmoid": sigmoid.binding_type,
"silu": silu.binding_type,
"hardswish": hardswish.binding_type,
"gelu": gelu.binding_type,
"multiply_add": FunctionalOp.MultiplyAdd,
"sum": (FunctionalOp.Plus, FunctionalOp.AtomicAdd),
"max": (FunctionalOp.Maximum, FunctionalOp.AtomicMaximum),
"exp": FunctionalOp.Exp,
}
return mapping[op]
#
# Visiting different node types
#
def visit_FunctionDef(self, node: ast.FunctionDef):
# Visit args and register load nodes
for arg in node.args.args:
self.visit(arg)
for expr in node.body:
self.visit(expr)
def visit_arg(self, node: ast.arg):
# Name of the argument
name = node.arg
try:
example_tensor = self.example_inputs[name]
except Exception as e:
raise RuntimeError(f"Example input for {name} is not provided.") from e
self.add_load_node(name, example_tensor)
def visit_Name(self, node: ast.Name):
return node.id
def visit_Constant(self, node: ast.Constant):
if self.no_imm:
return node.value
else:
name = self.add_imm(node.value)
return name
def visit_Tuple(self, node: ast.Tuple):
results = []
for elt in node.elts:
results.append(self.visit(elt))
return tuple(results)
def visit_keyword(self, node: ast.keyword):
return {node.arg: self.visit(node.value)}
def visit_BinOp(self, node: ast.BinOp):
if self.visiting_return:
raise SyntaxError("Return value cannot be an expression")
lhs = self.visit(node.left)
rhs = self.visit(node.right)
op = self.ast_op_to_bindings(type(node.op))
name = self.add_compute_node(op)
# Add edges
# The edge weights are used to sort the input args
self.add_edge(lhs, name, weight=0)
self.add_edge(rhs, name, weight=1)
return name
def visit_Assign(self, node: ast.Assign):
target = self.visit(node.targets[0])
value = self.visit(node.value)
# Create the assign node
self.add_store_node(target)
# Add edges
self.add_edge(value, target)
return target
def visit_Call(self, node: ast.Call):
if self.visiting_return:
raise SyntaxError("Return value cannot be an expression")
func = self.visit(node.func)
args = [self.visit(arg) for arg in node.args]
if func in self.layout_fns.keys():
# Parse kwargs
# By default, visiting imm automatically creates a load node
# However, in function call, keyword args are used to set
# specific function attributes such as indices for permute
# So no_imm is set to True temporarily
self.no_imm = True
kwargs = {}
for kw in node.keywords:
kwargs.update(self.visit(kw))
self.no_imm = False
op = self.layout_fns[func]
name = self.add_layout_node(op, kwargs)
else:
op = self.ast_op_to_bindings(func)
name = self.add_compute_node(op)
# Add edges
for idx, arg in enumerate(args):
self.add_edge(arg, name, weight=idx)
return name
def visit_Return(self, node: ast.Return):
self.visiting_return = True
results = self.visit(node.value)
self.visiting_return = False
self.return_names = results
if not isinstance(results, tuple):
results = (results,)
for rst in results:
try:
example_tensor = self.example_inputs[rst]
except Exception as e:
raise RuntimeError(f"Example input for {rst} is not provided.") from e
self.set_store_tensor(rst, example_tensor)
self.mark_output(rst)
class PythonASTInOutProcessor(ast.NodeVisitor):
"""
Simple processor that traces the input AST and returns the names of inputs
and outputs
:param source: The source code of the function to trace.
:type source: Union[Callable, str]
:return: A tuple of lists of input and output names.
:rtype: tuple[list[str], list[str]]
"""
def trace(self, source: Union[Callable, str]) -> tuple[list[str], list[str]]:
if callable(source):
self.source = textwrap.dedent(inspect.getsource(source))
if isinstance(source, str):
self.source = source
self.ast = ast.parse(self.source)
self.inputs = []
self.outputs = []
self.visiting_return = False
self.visit(self.ast)
if "accum" not in self.inputs:
raise ValueError("accum must be an input to the epilogue function")
# Users do not need to specify accum. Remove it.
self.inputs.remove("accum")
return self.inputs, self.outputs
def visit_FunctionDef(self, node: ast.FunctionDef):
for arg in node.args.args:
self.visit(arg)
for expr in node.body:
self.visit(expr)
def visit_arg(self, node: ast.arg):
self.inputs.append(node.arg)
def visit_Constant(self, node: ast.Constant):
if self.visiting_return:
raise SyntaxError("Return value cannot be a constant")
def visit_Name(self, node: ast.Name):
return node.id
def visit_Tuple(self, node: ast.Tuple):
results = []
for elt in node.elts:
results.append(self.visit(elt))
return tuple(results)
def visit_Return(self, node: ast.Return):
self.visiting_return = True
results = self.visit(node.value)
self.visiting_return = False
if not isinstance(results, tuple):
results = (results,)
self.outputs.extend(results)

View File

@@ -0,0 +1,75 @@
#################################################################################################
#
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
from cutlass_api.fusion.ir.compute_nodes import ComputeNode, ComputeImpl
from cutlass_api.fusion.ir.dag_ir import DAGIR
from cutlass_api.fusion.ir.layout_nodes import LayoutNode
from cutlass_api.fusion.ir.load_nodes import (
LoadNode,
AccumulatorImpl,
LoadSrcImpl,
AuxLoadImpl,
RowBroadcastImpl,
ColumnBroadcastImpl,
ScalarBroadcastImpl,
)
from cutlass_api.fusion.ir.node import TopoVisitorNode, NoOpImpl
from cutlass_api.fusion.ir.store_nodes import (
StoreNode,
StoreDImpl,
AuxStoreImpl,
ColumnReductionImpl,
RowReductionImpl,
ScalarReductionImpl,
)
__all__ = [
"ComputeNode",
"ComputeImpl",
"DAGIR",
"LayoutNode",
"LoadNode",
"AccumulatorImpl",
"LoadSrcImpl",
"AuxLoadImpl",
"RowBroadcastImpl",
"ColumnBroadcastImpl",
"ScalarBroadcastImpl",
"TopoVisitorNode",
"NoOpImpl",
"StoreNode",
"StoreDImpl",
"AuxStoreImpl",
"ColumnReductionImpl",
"RowReductionImpl",
"ScalarReductionImpl",
]

View File

@@ -0,0 +1,246 @@
#################################################################################################
#
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
import ctypes
from cutlass_api.fusion.library import DataType
from cutlass_api.utils import is_numpy_tensor, is_torch_tensor, is_numpy_available
dtype2ctype = {
DataType.f16: ctypes.c_uint16,
DataType.bf16: ctypes.c_uint16,
DataType.f32: ctypes.c_float,
DataType.f64: ctypes.c_double,
DataType.s8: ctypes.c_int8,
DataType.s32: ctypes.c_int32,
}
def get_scalar(value):
"""
Returns a scalar value from a container (e.g., np.ndarray)
"""
if is_numpy_tensor(value) or is_torch_tensor(value):
if value.size != 1:
raise Exception("Scalars used in epilogue must be of size 1")
return value.reshape(-1)[0]
else:
return value
def to_ctype_value(value, dtype):
"""
Converts ``value`` to the corresponding storage needed for the ctype that
will store ``value``.
"""
scalar = get_scalar(value)
if dtype == DataType.f16:
if is_numpy_available():
import numpy as np
# Convert f16 value into an integer
return int.from_bytes(np.float16(scalar).tobytes(), "little")
else:
raise NotImplementedError("Numpy is not available")
else:
return scalar
class Empty(ctypes.Structure):
_fields_ = []
def __init__(self, *arg) -> None:
pass
class EmptyByte(ctypes.Structure):
_fields_ = [("byte", ctypes.c_byte)]
def __init__(self, *arg) -> None:
pass
class EBO:
def __init__(self, index: int, type) -> None:
self.index = index
self.type = type
def __eq__(self, other) -> bool:
if isinstance(other, EBO):
return self.index == other.index and self.type == other.type
return False
def __hash__(self) -> int:
return hash((self.index, self.type))
def __ne__(self, other):
return not self.__eq__(other)
def __str__(self) -> str:
return f"<{self.index}, {self.type}>"
def tuple_factory_(input_tuple, dtype, constants: list[int] = None):
"""
The factory function generating cute::Tuple with input tuple
:param input_tuple: the input tuple
:type input_tuple: tuple
:param dtype: the data type for non-constant values
:type dtype: str, "int32_t", "int", "int64_t"
:param constant: the values that will be treated as constants
:type constant: list[int]
:return: ctype structure representing the cute::Tuple
:return: the empty base classes of the tuple
"""
if constants is None:
constants = [0, 1]
# The empty base classes of the current tuple
empty_bases = []
# The first non empty base class
first_non_empty_base = None
# The ctype fields of the current tuple
ctype_fields = []
for idx, entry in enumerate(input_tuple):
# For nested tuples
if isinstance(entry, tuple):
sub_tuple_ctype, sub_empty_bases = tuple_factory_(entry, dtype, constants)
if ctypes.sizeof(sub_tuple_ctype) == 0:
# The empty tuple base class is also an empty EBO
empty_bases.append(EBO(idx, entry))
else:
if first_non_empty_base is None:
first_non_empty_base = sub_empty_bases
ctype_fields.append((f"entry_{idx}", sub_tuple_ctype))
else:
if entry in constants:
empty_bases.append(EBO(idx, entry))
ctype_fields.append((f"entry_{idx}", Empty))
else:
ctype_fields.append((f"entry_{idx}", dtype))
if first_non_empty_base is None:
first_non_empty_base = []
# Create the ctype tuple
class TupleType(ctypes.Structure):
_fields_ = ctype_fields
def __init__(self, args) -> None:
fields = self._fields_
assert len(fields) == len(args)
for field, arg in zip(fields, args):
name = field[0]
field_type = field[1]
setattr(self, name, field_type(arg))
return TupleType, empty_bases
def tuple_factory(input_tuple, dtype: str, constants=[0, 1]):
"""
The factory function generating cute::Tuple with input tuple
:param input_tuple: the input tuple
:type input_tuple: tuple
:param dtype: the data type for non-constant values
:type dtype: str, "int32_t", "int", "int64_t"
:param constant: the values that will be treated as constants
:type constant: list[int]
:return: ctype structure representing the cute::Tuple
:return: the empty base classes of the tuple
"""
# Step 1: convert the dtype
if dtype == "int64_t":
dtype = ctypes.c_longlong
elif dtype in ["int", "int32_t"]:
dtype = ctypes.c_int32
else:
raise NotImplementedError(f"Type {dtype} is not supported")
tuple_type, _ = tuple_factory_(input_tuple, dtype, constants)
if ctypes.sizeof(tuple_type) == 0:
return EmptyByte
return tuple_type
def visitor_factory(node_types, node_names):
"""
Creates the argument type of epilogue visitor type
:param node_types: list of argument types under ctypes
:param node_names: list of argument names under str
:return: tuple type in ctypes.Structure
"""
ctypes_field = []
# Struct is used when number of nodes < 4
# Because the Sm90VisitorImplBase has specification up to 4 nodes
# in `include/cutlass/epilogue/fusion/sm90_visitor_tma_warpspecialized.hpp`
if len(node_types) <= 4:
for idx, node_type in enumerate(node_types):
if ctypes.sizeof(node_type) == 0:
# Special case for empty struct
# 1 byte placeholder is used for correct alignment
ctypes_field.append((node_names[idx], ctypes.c_byte))
else:
ctypes_field.append((node_names[idx], node_type))
class VisitorType(ctypes.Structure):
_fields_ = ctypes_field
def __init__(self, kwargs) -> None:
for field in self._fields_:
fname, ftype = field
if ftype != ctypes.c_byte:
setattr(self, fname, ftype(kwargs))
# For cases with more than 4 nodes, tuple is used
else:
for idx, node_type in enumerate(node_types):
ctypes_field.append((node_names[idx], node_type))
class VisitorType(ctypes.Structure):
_fields_ = ctypes_field
def __init__(self, kwargs) -> None:
for field in self._fields_:
fname, ftype = field
setattr(self, fname, ftype(kwargs))
return VisitorType

View File

@@ -0,0 +1,99 @@
#################################################################################################
#
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
"""
Python registration for compute nodes in EVT
"""
from cutlass_api.fusion.ir.node import NodeBase, ImplBase
from cutlass_api.fusion.library import FloatRoundStyle
class ComputeImplBase(ImplBase):
"""
Base class for compute implementation
"""
def __init__(self, node) -> None:
super().__init__(node)
class ComputeImpl(ComputeImplBase):
"""
Implementation for Compute Node
"""
def __init__(self, node) -> None:
super().__init__(node)
self.fn = node.fn
self.element_output = node.element_output
self.element_compute = node.element_compute
self.round_style = node.round_style
@staticmethod
def match(node, problem_size: tuple):
return True
class ComputeNode(NodeBase):
"""
Compute Node in DAG IR
"""
possible_impls = [
ComputeImpl,
]
def __init__(
self,
name: str,
fn,
element_output,
element_compute,
round_style=FloatRoundStyle.ToNearest,
) -> None:
super().__init__(name)
self.op = "compute"
self.fn = fn
self.element_compute = element_compute
self.round_style = round_style
def type_propagation(self, *args, **kwargs):
"""
Load node loads tensor under type `tensor.element` and returns an array of type `tensor.element`.
"""
self.element = self.element_compute
# In general, the compute nodes have element_output = element_compute
# In certain cases like producer of D it is overwritten by other passes
if not hasattr(self, "element_output"):
self.element_output = self.element

View File

@@ -0,0 +1,258 @@
#################################################################################################
#
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
"""
DAG IR used by Python EVT
"""
import networkx as nx
from cutlass_api.fusion.ir.compute_nodes import ComputeNode
from cutlass_api.fusion.library import DataType
from cutlass_api.fusion.ir.node import NodeBase
from cutlass_api.fusion.library import ActivationOp
class DAGIR:
"""
``DAGIR`` is the main data structure used in the EVT Intermediate Representation.
It consists of a series of ``Node`` s, each representing epilogue visitor nodes.
In the DAGIR, ``node`` is an string of its name. ``node_meta`` is the underlying class of the node
"""
def __init__(self, cc, element_compute=DataType.f32) -> None:
# The EVT DAGIR is managed through the nextworkX Digraph class
self._graph = nx.DiGraph()
self.element_compute = element_compute
self.reduction_names = []
self.cc = cc
self.identity_counter = 0
#
# IR manipulator
#
def add_node(self, meta: NodeBase):
"""
Add a node to dag ir
"""
if self.has_node(meta.name):
raise SyntaxError(f"Variable '{meta.name}' cannot be defined twice.")
self._graph.add_node(meta.name, meta=meta)
def add_edge(self, src: str, dst: str, weight: int = 0):
"""
Add an edge src -> dst to dag ir with weight
"""
if not self.has_node(src):
raise SyntaxError(f"Variable '{src}' is undefined.")
if not self.has_node(dst):
raise SyntaxError(f"Variable '{dst}' is undefined.")
if self._graph.has_edge(src, dst):
# The DiGraph doesn't support multiple edges between two nodes
# We insert an identity node in such case as a workaround
identity_name = f"autogen_identity_{self.identity_counter}"
self.identity_counter += 1
compute_node = ComputeNode(
name=identity_name,
fn=ActivationOp.Identity,
element_output=self.element_compute,
element_compute=self.element_compute,
)
self.add_node(compute_node)
self.add_edge(src, identity_name, 0)
self.add_edge(identity_name, dst, weight)
else:
self._graph.add_edge(src, dst, weight=weight)
def remove_node(self, node: str):
"""
Remove node from dag ir
"""
self._graph.remove_node(node)
def remove_edge(self, src: str, dst: str):
"""
Remove edge src -> dst
"""
self._graph.remove_edge(src, dst)
#
# Helper functions for getting attrs
#
def has_node(self, node: str) -> bool:
"""
Check if the node is in the graph
"""
return self._graph.has_node(node)
def in_degree(self, node: str):
"""
Get the input degree of node
"""
return self._graph.in_degree(node)
def in_edges(self, node: str):
"""
Get the input edges of node
"""
return [edge for edge in self._graph.in_edges(node)]
def out_degree(self, node: str):
"""
Get the output degree of node
"""
return self._graph.out_degree(node)
def out_edges(self, node: str):
"""
Get the output edges of node
"""
return [edge for edge in self._graph.out_edges(node)]
def get_node_meta(self, node: str):
"""
Get the meta data of the node
"""
return self._graph.nodes[node]["meta"]
def get_edge_weight(self, src, dst):
"""
Get the edge weight of edge src->dst
"""
return self._graph.get_edge_data(src, dst)["weight"]
#
# High-level helper functions
#
def all_reachable_nodes(self, node: str):
"""
Get all the nodes reachable from the current node (exclude)
"""
return list(nx.dfs_preorder_nodes(self._graph, source=node))
def get_users(self, node: str):
"""
Get all users of the current node
"""
return [edge[1] for edge in self.out_edges(node)]
def get_all_inputs(self, node: str):
"""
Get all the input nodes sorted by edge weight
"""
in_edges = self.in_edges(node)
edge_weights = [self.get_edge_weight(*edge) for edge in in_edges]
return [edge[0] for _, edge in sorted(zip(edge_weights, in_edges))]
def get_all_inputs_meta(self, node: str):
"""
Get all the input node metas sorted by edge weight
"""
return [
self.get_node_meta(input_node) for input_node in self.get_all_inputs(node)
]
def replace_all_uses_with(self, node1, node2):
"""
Replace all uses of node1 with node2
"""
for edge in self.out_edges(node1):
weight = self.get_edge_weight(*edge)
user = edge[1]
self.add_edge(node2, user, weight)
self.remove_edge(node1, user)
self.remove_node(node1)
#
# Node accessor
#
def nodes_topological_order(self):
"""
Get the nodes in the unique lexicographical topological order
It generates a unique ordering of nodes by first sorting topologically
and then additionally by sorting lexicographically.
Although topological_sort alone also works, this generates a unique key
for each epilogue visitor pattern and ensures the compilation cache can be reused.
:return: list[str]
"""
return list(nx.lexicographical_topological_sort(self._graph))
def node_metas_topological_order(self):
"""
Get the node metas in topological order
:return: list[NodeBase]
"""
return [self.get_node_meta(node) for node in self.nodes_topological_order()]
@property
def nodes(self):
"""
Get all nodes
:return: list[str]
"""
return list(self._graph.nodes)
@property
def nodes_meta(self):
"""
Get all node metas
:return: list[NodeBase]
"""
return [data[1]["meta"] for data in self._graph.nodes.data()]
@property
def edges(self):
"""
Get all edges
:return: list[(str, str)]
"""
return list(self._graph.edges)
#
# Path
#
def has_path(self, src: str, target: str) -> bool:
"""
Return True is a path exists from src to target
"""
return nx.has_path(self._graph, src, target)

View File

@@ -0,0 +1,362 @@
#################################################################################################
#
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
"""
Layout algebras
"""
from cutlass_api.fusion.pycute import (
Layout,
composition,
make_layout,
flatten,
product,
)
def _infer_split(old_shape, new_shape):
old_shape = _tuple_to_list(old_shape)
new_shape = _tuple_to_list(new_shape)
if len(old_shape) == 0 and len(new_shape) == 0:
return []
if len(old_shape) == 0:
if product(tuple(new_shape)) != 1:
raise ValueError("Invalid reshape size")
else:
return new_shape
if len(new_shape) == 0:
if product(tuple(old_shape)) != 1:
raise ValueError("Invalid reshape size")
else:
return old_shape
# This is done recursively by only process the last dimension at each time
old_dim = old_shape[-1]
new_dim = new_shape[-1]
# Exact match
if old_dim == new_dim:
return _infer_split(old_shape[:-1], new_shape[:-1]) + [
new_dim,
]
# Needs split
if old_dim > new_dim and old_dim % new_dim == 0:
residual = old_dim // new_dim
return _infer_split(
old_shape[:-1] + [residual],
new_shape[:-1],
) + [new_dim]
# Needs merge
if old_dim < new_dim and new_dim % old_dim == 0:
residual = new_dim // old_dim
return _infer_split(
old_shape[:-1],
new_shape[:-1] + [residual],
) + [old_dim]
raise NotImplementedError(f"Unsupported split: {old_shape} -> {new_shape}")
def _infer_merge(flatten_shape, shape):
flatten_shape = _tuple_to_list(flatten_shape)
shape = _tuple_to_list(shape)
idx_flat = 0
merged_shape = []
for dim in shape:
# Exact match
if dim == flatten_shape[idx_flat]:
merged_shape.append(dim)
idx_flat += 1
# Need group
elif dim > flatten_shape[idx_flat] and dim % flatten_shape[idx_flat] == 0:
residual = dim
group = []
while residual > 1:
group.append(flatten_shape[idx_flat])
residual = residual // flatten_shape[idx_flat]
idx_flat += 1
merged_shape.append(group)
else:
raise NotImplementedError(f"Unsupported merge: {flatten_shape} -> {shape}")
return merged_shape
def _list_to_tuple(nested_list):
if isinstance(nested_list, list) or isinstance(nested_list, tuple):
return tuple(_list_to_tuple(item) for item in nested_list)
return nested_list
def _tuple_to_list(nested_tuple):
if isinstance(nested_tuple, list) or isinstance(nested_tuple, tuple):
return list(_tuple_to_list(item) for item in nested_tuple)
return nested_tuple
def _reverse_tuple(nested_tuple: tuple):
if isinstance(nested_tuple, tuple):
return tuple([_reverse_tuple(item) for item in nested_tuple][::-1])
return nested_tuple
def _get_first_lhs_nonzero_stride(stride_list, idx):
for i in reversed(range(idx)):
if stride_list[i] != 0:
return i
else:
return None
def _get_first_rhs_nonzero_stride(stride_list, idx):
for i in range(idx + 1, len(stride_list)):
if stride_list[i] != 0:
return i
else:
return None
def reshape(layout, new_shape):
"""
General reshape of input layout.
It takes two steps:
1. split the dimensions of the old layout
2. merge the splitted dimensions according to the new shape
"""
#
# Step 1: Split the dimensions of the old layout
#
# 1.1 Flat old and new shape
old_flatten_shape = list(flatten(layout.shape))
new_flatten_shape = list(flatten(new_shape))
# 1.2 Infer the flatten splitted shape
splitted_flatten_shape = _infer_split(old_flatten_shape, new_flatten_shape)
# 1.3 Unflat the splitted shape based on the old shape
splited_shape = _infer_merge(splitted_flatten_shape, old_flatten_shape)
# 1.4 Infer the type of each split
# If the split type is in row-major (R), the dimension list is reversed because
# the cute::composition only support column-major split
split_type = [] # the type of each split (ColumnMajor or RowMajor)
permuted_splitted_shape = []
old_flatten_stride = list(flatten(layout.stride))
for idx, dim in enumerate(splited_shape):
if not isinstance(dim, list):
permuted_splitted_shape.append(dim)
split_type.append("C")
else:
lhs_stride = _get_first_lhs_nonzero_stride(old_flatten_stride, idx)
rhs_stride = _get_first_rhs_nonzero_stride(old_flatten_stride, idx)
# Special case for single tuple
# Use column-major by default
if lhs_stride is None and rhs_stride is None:
permuted_splitted_shape.append(dim)
split_type.append("C")
else:
if lhs_stride is not None and rhs_stride is not None:
# We consider shape[idx]:stride[idx]
# Case 1: stride[idx - 1] <= stride[idx] <= stride[idx + 1]: column major
if (
lhs_stride <= old_flatten_stride[idx]
and old_flatten_stride[idx] <= rhs_stride
):
permuted_splitted_shape.append(dim)
split_type.append("C")
# Case 2: stride[idx - 1] > stride[idx] > stride[idx + 1]: row major
elif (
lhs_stride > old_flatten_stride[idx]
and old_flatten_stride[idx] > rhs_stride
):
permuted_splitted_shape.append([d for d in reversed(dim)])
split_type.append("R")
# Case 3: stride[idx - 1] <= stride[idx] > stride[idx + 1]: concave
elif (
lhs_stride <= old_flatten_stride[idx]
and old_flatten_stride[idx] > rhs_stride
):
if lhs_stride >= rhs_stride:
permuted_splitted_shape.append(dim)
split_type.append("C")
else:
permuted_splitted_shape.append([d for d in reversed(dim)])
split_type.append("R")
# Case 4: stride[idx - 1] > stride[idx] <= stride[idx + 1]: concave
elif (
lhs_stride > old_flatten_stride[idx]
and old_flatten_stride[idx] <= rhs_stride
):
if lhs_stride >= rhs_stride:
permuted_splitted_shape.append(dim)
split_type.append("C")
else:
permuted_splitted_shape.append([d for d in reversed(dim)])
split_type.append("R")
else:
raise NotImplementedError()
elif lhs_stride is None:
# Case 1: dim's stride < dim+1's stride, expand in column major
if old_flatten_stride[idx] > rhs_stride:
permuted_splitted_shape.append([d for d in reversed(dim)])
split_type.append("R")
else:
permuted_splitted_shape.append(dim)
split_type.append("C")
else:
# Case 1: dim's stride > dim-1's stride
if old_flatten_stride[idx] < lhs_stride:
permuted_splitted_shape.append([d for d in reversed(dim)])
split_type.append("R")
else:
permuted_splitted_shape.append(dim)
split_type.append("C")
# 1.4 Generate the splitted layout
permuted_splitted_layout = composition(
layout, Layout(_list_to_tuple(permuted_splitted_shape))
)
# 1.5 Reverse the permutation in 1.4 before merge
splitted_shape = []
splitted_stride = []
for shape_dim, stride_dim, type in zip(
permuted_splitted_layout.shape, permuted_splitted_layout.stride, split_type
):
if type == "C":
splitted_shape.append(shape_dim)
splitted_stride.append(stride_dim)
else:
splitted_shape.append(tuple([d for d in reversed(shape_dim)]))
splitted_stride.append(tuple([d for d in reversed(stride_dim)]))
splitted_layout = Layout(tuple(splitted_shape), tuple(splitted_stride))
#
# Step 2: Merge the splitted dimensions according to the new shape
#
# 2.1 Merge layout
merged_layout = composition(splitted_layout, Layout(new_shape))
# 2.2 Cleaning up
output_layout = composition(merged_layout, Layout(new_shape))
return output_layout
def permutation(layout, permutation):
"""
Permute the layout
"""
new_shape = tuple([layout.shape[idx] for idx in permutation])
new_stride = tuple([layout.stride[idx] for idx in permutation])
return Layout(new_shape, new_stride)
def _broadcast(layout, new_shape):
if len(layout) == 1 and isinstance(new_shape, int):
old_dim = layout.shape
old_stride = layout.stride
new_dim = new_shape
if old_dim == new_dim:
return Layout(old_dim, old_stride)
elif old_dim == 1:
return Layout(new_dim, 0)
else:
raise NotImplementedError(f"Invalid Broadcast: {old_dim} -> {new_dim}")
# Align the dimensions
old_shape = layout.shape
if isinstance(old_shape, int):
old_shape = (old_shape,)
sub_layouts = [
layout,
]
else:
sub_layouts = [sub_layout for sub_layout in layout]
rhs_broadcast_layouts = [Layout(1, 0)] * (len(new_shape) - len(old_shape))
# Get the broadcasted layout
broadcast_layouts = []
try:
layout = make_layout(*sub_layouts, *rhs_broadcast_layouts)
broadcast_layouts = []
for idx, sub_layout in enumerate(layout):
broadcast_layouts.append(_broadcast(sub_layout, new_shape[idx]))
except NotImplementedError:
layout = make_layout(*rhs_broadcast_layouts, *sub_layouts)
for idx, sub_layout in enumerate(layout):
broadcast_layouts.append(_broadcast(sub_layout, new_shape[idx]))
return make_layout(*broadcast_layouts)
def broadcast(layout, new_shape):
"""
Broadcast the new layout based on the input shape
The broadcasted shape equals to the new shape
The stride of broadcasted dimensions are 0
"""
return _broadcast(layout, new_shape)
def debroadcast(layout, dims):
"""
Squeeze the 0-stride
"""
for dim in dims:
if layout.stride[dim] != 0:
raise ValueError(
f"Dim{dim} cannot be debroadcasted as it has stride {layout.stride[dim]}"
)
new_shape = tuple([s for idx, s in enumerate(layout.shape) if idx not in dims])
new_stride = tuple([s for idx, s in enumerate(layout.stride) if idx not in dims])
return Layout(new_shape, new_stride)
def canonicalization_(shapes, strides):
if isinstance(shapes, tuple):
c_shapes = []
c_strides = []
for shape, stride in zip(shapes, strides):
c_shape, c_stride = canonicalization_(shape, stride)
c_shapes.append(c_shape)
c_strides.append(c_stride)
return tuple(c_shapes), tuple(c_strides)
else:
if shapes == 1:
return 1, 0
else:
return shapes, strides
def canonicalization(layout):
"""
Canonicalize the input layout
1. set the stride of shape "1" to 0
"""
new_shape, new_stride = canonicalization_(layout.shape, layout.stride)
return Layout(new_shape, new_stride)

View File

@@ -0,0 +1,351 @@
#################################################################################################
#
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
"""
Layout manipulation nodes and implementations
The layout Nodes change the layout of intermediate nodes in epilogue visitor graph
"""
from copy import deepcopy
from cutlass_api.fusion.library import LayoutType
from cutlass_api.fusion.ir.layout_algorithm import _list_to_tuple, _tuple_to_list
from cutlass_api.fusion.ir.node import NodeBase
from cutlass_api.fusion.ir.tensor import Tensor
from cutlass_api.fusion.pycute import product, flatten
class PermutationImpl:
"""
Detailed implementation and helper functions for permutation
"""
def __init__(self, node) -> None:
assert "indices" in node.kwargs.keys()
self.indices = list(node.kwargs["indices"])
self.inverse_indices = self.get_inverse_indices(self.indices)
def get_inverse_impl(self):
inverse_impl = deepcopy(self)
inverse_impl.indices = self.inverse_indices
inverse_impl.inverse_indices = self.indices
return inverse_impl
def update(self, shape):
num_dim = len(shape)
indices = self.indices
num_old_dim = len(indices)
# Add offset
for i, idx in enumerate(indices):
indices[i] = idx + num_dim - num_old_dim
# Add broadcast dims
for i in range(num_dim - num_old_dim):
indices = [i] + indices
self.indices = indices
self.inverse_indices = self.get_inverse_indices(self.indices)
def get_inverse_indices(self, indices):
"""
Get the indices for inverse permutation
"""
num_dim = len(indices)
inverse_indices = [0] * num_dim
for i in range(num_dim):
inverse_indices[indices[i]] = i
return inverse_indices
def shape_propagation(self, input_node_meta):
input_shape = input_node_meta.tensor.shape
output_shape = tuple([input_shape[idx] for idx in self.indices])
return output_shape
def broadcast(self, shape, node_meta: NodeBase):
"""
Broadcast the inputs based on current shape
"""
self.update(shape)
inverse_shape = tuple([shape[idx] for idx in self.inverse_indices])
node_meta.tensor.broadcast(inverse_shape)
def apply_to_user(self, usr_meta: NodeBase):
"""
Propagate the permutation to the users of the current nodes
"""
usr_meta.tensor.permute(self.inverse_indices)
if hasattr(usr_meta, "store_tensor"):
if usr_meta.store_tensor is not None:
usr_meta.store_tensor.permute(self.inverse_indices)
def apply_to_input(self, input_meta: NodeBase):
"""
Propagate the permutation to inputs of the current nodes
"""
input_meta.tensor.permute(self.indices)
if hasattr(input_meta, "store_tensor"):
if input_meta.store_tensor is not None:
input_meta.store_tensor.permute(self.indices)
class ReshapeImpl:
"""
Detailed implementation and helper functions for reshape
"""
def __init__(self, node) -> None:
self.node = node
assert "new_shape" in node.kwargs.keys()
self.output_shape = _list_to_tuple(node.kwargs["new_shape"])
def get_inverse_impl(self):
inverse_impl = deepcopy(self)
inverse_impl.output_shape = self.input_shape
inverse_impl.input_shape = self.output_shape
return inverse_impl
def shape_propagation(self, input_node_meta):
self.input_shape = input_node_meta.tensor.shape
return _list_to_tuple(self.output_shape)
def broadcast(self, shape, node_meta: NodeBase):
"""
Broadcast the inputs based on current shape.
"""
# Step 1: infer split
flatten_split_shape = self.infer_split(
flatten(self.input_shape), flatten(self.output_shape)
)
split_input_shape = self.infer_merge(flatten_split_shape, self.input_shape)
split_output_shape = self.infer_merge(flatten_split_shape, self.output_shape)
# broadcast shape -> split_output_shape -> flatten_split_shape
if len(shape) - len(split_output_shape) > 0:
for _ in range(len(shape) - len(split_output_shape)):
split_output_shape = [1] + split_output_shape
flatten_split_shape = [1] + flatten_split_shape
split_input_shape = [1] + split_input_shape
broadcast_factor = []
for dim, old_dim in zip(shape, split_output_shape):
if not isinstance(dim, list):
dim = [dim]
if not isinstance(old_dim, list):
old_dim = [old_dim]
if product(tuple(dim)) == product(tuple(old_dim)):
broadcast_factor += [1] * len(old_dim)
elif product(tuple(old_dim)) == 1:
assert len(dim) == 1
broadcast_factor.append(dim[0])
else:
raise NotImplementedError(f"Invalid Broadcast: {old_dim} -> {dim}")
# flatten_split_shape -> split_input_shape
factor_idx = 0
broadcast_split_input_shape = []
for dim in split_input_shape:
if isinstance(dim, list):
new_dim = []
for d in dim:
new_dim.append(d * broadcast_factor[factor_idx])
factor_idx += 1
broadcast_split_input_shape.append(new_dim)
else:
broadcast_split_input_shape.append(dim * broadcast_factor[factor_idx])
factor_idx += 1
broadcast_split_input_shape = _list_to_tuple(broadcast_split_input_shape)
node_meta.tensor.reshape(_list_to_tuple(split_input_shape))
node_meta.tensor.broadcast(broadcast_split_input_shape)
# Last reshape op to clean up
broadcast_input_shape = tuple(
[product(dim) for dim in broadcast_split_input_shape]
)
node_meta.tensor.reshape(broadcast_input_shape)
# Update the input shape and output shape
self.input_shape = _list_to_tuple(node_meta.tensor.shape)
self.output_shape = _list_to_tuple(shape)
def apply_to_user(self, user_meta: NodeBase):
"""
Propagate the reshape to user nodes
"""
user_meta.tensor.reshape(tuple(self.input_shape))
if hasattr(user_meta, "store_tensor"):
if user_meta.store_tensor is not None:
user_meta.store_tensor.reshape(tuple(self.input_shape))
def apply_to_input(self, input_meta: NodeBase):
"""
Propagate the reshape to input nodes
"""
input_meta.tensor.reshape(tuple(self.output_shape))
if hasattr(input_meta, "store_tensor"):
if input_meta.store_tensor is not None:
input_meta.store_tensor.reshape(tuple(self.output_shape))
#
# Helper functions
#
def infer_split(self, input_shape, output_shape):
"""
Infer the flatten splitted shape that can be merged to both input_shape and output_shape
"""
input_shape = _tuple_to_list(input_shape)
output_shape = _tuple_to_list(output_shape)
if len(input_shape) == 0 and len(output_shape) == 0:
return []
if len(input_shape) == 0:
if product(tuple(output_shape)) != 1:
raise ValueError("Invalid reshape size")
else:
return output_shape
if len(output_shape) == 0:
if product(tuple(input_shape)) != 1:
raise ValueError("Invalid reshape size")
else:
return input_shape
# This is done recursively by only process the last dimension at each time
old_dim = input_shape[-1]
new_dim = output_shape[-1]
# Exact match
if old_dim == new_dim:
return self.infer_split(input_shape[:-1], output_shape[:-1]) + [new_dim]
# Needs split
if old_dim > new_dim and old_dim % new_dim == 0:
residual = old_dim // new_dim
return self.infer_split(
input_shape[:-1] + [residual],
output_shape[:-1],
) + [new_dim]
# Needs merge
if old_dim < new_dim and new_dim % old_dim == 0:
residual = new_dim // old_dim
return self.infer_split(
input_shape[:-1],
output_shape[:-1] + [residual],
) + [old_dim]
raise NotImplementedError(f"Unsupported split: {input_shape} -> {output_shape}")
def infer_merge(self, flatten_shape, shape):
flatten_shape = _tuple_to_list(flatten_shape)
shape = _tuple_to_list(shape)
idx_flat = len(flatten_shape) - 1
merged_shape = []
for dim in reversed(shape):
# Exact match
if dim == flatten_shape[idx_flat]:
merged_shape.append(dim)
idx_flat -= 1
# need group
elif dim > flatten_shape[idx_flat] and dim % flatten_shape[idx_flat] == 0:
residual = dim
group = []
while residual > 1:
group.append(flatten_shape[idx_flat])
residual = residual // flatten_shape[idx_flat]
idx_flat -= 1
merged_shape.append(group[::-1])
else:
raise NotImplementedError(
f"Unsupported merge: {flatten_shape} -> {shape}"
)
return merged_shape[::-1]
class LayoutNode(NodeBase):
"""
Layout manipulation nodes
"""
fn_to_impl = {
"permute": PermutationImpl,
"reshape": ReshapeImpl,
}
def __init__(self, name: str, fn, kwargs: dict) -> None:
super().__init__(name)
self.op = "layout"
self.fn = fn
self.kwargs = kwargs
self.underlying_impl = self.fn_to_impl[self.fn.__name__](self)
def get_inverse_node(self):
inverse_node = deepcopy(self)
inverse_node.underlying_impl = self.underlying_impl.get_inverse_impl()
return inverse_node
def shape_propagation(self, input_node_metas):
if self._tensor is not None:
return
assert len(input_node_metas) == 1, "Layout node can only have one input node"
output_shape = self.underlying_impl.shape_propagation(input_node_metas[0])
self._tensor = Tensor(
element=self.element_output,
shape=output_shape,
layout_tag=LayoutType.RowMajor,
)
return super().shape_propagation(input_node_metas)
def type_propagation(self, input_node_metas: "list[NodeBase]"):
"""
The store nodes has element_output = element_input
"""
assert len(input_node_metas) == 1, "Layout node can only have one input node"
self.element_output = input_node_metas[0].element_output
def broadcast_propagation(self, input_node_metas: "list[NodeBase]"):
"""
Propagate the broadcast in the reversed topological order
"""
if self.tensor is None:
raise RuntimeError(f"The tensor of node {self.name} is unknown.")
shape = self.tensor.shape
for child in input_node_metas:
self.underlying_impl.broadcast(shape, child)
def apply_to_user(self, usr_meta: NodeBase):
"""
Propagate the permutation to user nodes
"""
self.underlying_impl.apply_to_user(usr_meta)
def apply_to_input(self, input_meta: NodeBase):
"""
Propagate the permutation to input nodes
"""
self.underlying_impl.apply_to_input(input_meta)

View File

@@ -0,0 +1,312 @@
#################################################################################################
#
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
"""
Load nodes and implementations
"""
import ctypes
from cutlass_api.fusion.ir.c_types import dtype2ctype, to_ctype_value, tuple_factory
from cutlass_api.fusion.ir.node import NodeBase, ImplBase
class LoadImplBase(ImplBase):
"""
Base class for load node implementations
"""
reserved_names = ["accum", "C"]
def __init__(self, node) -> None:
super().__init__(node)
self.element = node.element
self.element_output = node.element_output
self.stride = node.tensor.stride
class AccumulatorImpl(LoadImplBase):
"""
Accumulator node implementation
"""
@staticmethod
def match(node, problem_size: tuple):
return node.name == "accum" and node.tensor.shape == problem_size
class LoadSrcImpl(LoadImplBase):
"""
Load C implementation
"""
@property
def name_camel(self) -> str:
return "TensorC"
@property
def argument_type_c(self):
stride_mnl = self.get_stride_mnl()
tuple_type = tuple_factory(stride_mnl, self.stride_dtype)
class _Argument(ctypes.Structure):
_fields_ = [("ptr_C", ctypes.c_void_p), ("stride_C", tuple_type)]
def __init__(self, ptr) -> None:
self.ptr_C = ptr
self.stride_C = tuple_type(stride_mnl)
return _Argument
@staticmethod
def match(node, problem_size: tuple):
return node.name == "C" and node.tensor.shape == problem_size
class AuxLoadImpl(LoadImplBase):
"""
Load arbitrary tensor
"""
@property
def argument_type(self):
stride_mnl = self.get_stride_mnl()
name = self.name
tuple_type = tuple_factory(stride_mnl, self.stride_dtype)
element_type = self.element
class _Argument(ctypes.Structure):
_fields_ = [
("ptr_aux", ctypes.c_void_p),
("null_default", dtype2ctype[element_type]),
("dAux", tuple_type),
]
def __init__(self, kwargs) -> None:
ptr = kwargs[name]
self.ptr_aux = ptr
self.null_default = to_ctype_value(0, element_type)
self.dAux = tuple_type(stride_mnl)
return _Argument
@staticmethod
def match(node, problem_size: tuple):
if node.name in LoadImplBase.reserved_names:
return False
strideMN = node.tensor.stride[-2:]
if 1 not in strideMN or 0 in strideMN:
return False
return True
class RowBroadcastImpl(LoadImplBase):
"""
Broadcast a row vector
"""
def __init__(self, node) -> None:
super().__init__(node)
self.stride_dtype = "int"
@property
def argument_type(self):
stride_mnl = self.get_stride_mnl()
name = self.name
tuple_type = tuple_factory(stride_mnl, self.stride_dtype)
element_type = self.element
class _Argument(ctypes.Structure):
_fields_ = [
("ptr_row", ctypes.c_void_p),
("null_default", dtype2ctype[element_type]),
("dRow", tuple_type),
]
def __init__(self, kwargs) -> None:
ptr = kwargs[name]
self.ptr_row = ptr
self.null_default = to_ctype_value(0, element_type)
self.dRow = tuple_type(stride_mnl)
return _Argument
@staticmethod
def match(node, problem_size: tuple):
if node.name in LoadImplBase.reserved_names:
return False
strideMN = node.tensor.stride[-2:]
if strideMN == (0, 1):
return True
else:
return False
class ColumnBroadcastImpl(LoadImplBase):
"""
Broadcast a column vector
"""
def __init__(self, node) -> None:
super().__init__(node)
self.stride_dtype = "int"
@property
def argument_type(self):
stride_mnl = self.get_stride_mnl()
name = self.name
tuple_type = tuple_factory(stride_mnl, self.stride_dtype)
element_type = self.element
class _Argument(ctypes.Structure):
_fields_ = [
("ptr_col", ctypes.c_void_p),
("null_default", dtype2ctype[element_type]),
("dCol", tuple_type),
]
def __init__(self, kwargs) -> None:
ptr = kwargs[name]
self.ptr_col = int(ptr)
self.null_default = to_ctype_value(0, element_type)
self.dCol = tuple_type(stride_mnl)
return _Argument
@staticmethod
def match(node, problem_size: tuple):
if node.name in LoadImplBase.reserved_names:
return False
strideMN = node.tensor.stride[-2:]
if strideMN == (1, 0):
return True
else:
return False
class ScalarBroadcastImpl(LoadImplBase):
"""
Broadcast a scalar
"""
def __init__(self, node) -> None:
super().__init__(node)
self.stride_dtype = "int"
@property
def argument_type(self):
stride_mnl = self.get_stride_mnl()
name = self.name
tuple_type = tuple_factory(stride_mnl, self.stride_dtype)
element_type = self.element
if self.tensor.is_constant:
value = self.tensor.value
class _Argument(ctypes.Structure):
_fields_ = [
("scalars", dtype2ctype[element_type]),
("scalar_ptrs", ctypes.c_void_p),
("dScalar", tuple_type),
]
def __init__(self, kwargs) -> None:
self.scalars = to_ctype_value(value, element_type)
self.scalar_ptrs = 0
self.dScalar = tuple_type(stride_mnl)
else:
class _Argument(ctypes.Structure):
_fields_ = [
("scalars", dtype2ctype[element_type]),
("scalar_ptrs", ctypes.c_void_p),
("dScalar", tuple_type),
]
def __init__(self, kwargs) -> None:
scalar_or_ptr = kwargs[name]
if isinstance(scalar_or_ptr, float):
self.scalars = to_ctype_value(scalar_or_ptr, element_type)
self.scalar_ptrs = 0
else:
self.scalar_ptrs = int(scalar_or_ptr)
self.dScalar = tuple_type(stride_mnl)
return _Argument
@staticmethod
def match(node, problem_size: tuple):
if node.name in LoadImplBase.reserved_names:
return False
strideMN = node.tensor.stride[-2:]
if strideMN == (0, 0):
return True
else:
return False
class LoadNode(NodeBase):
"""
Load Node
"""
cnt = 0
possible_impls = [
AccumulatorImpl,
LoadSrcImpl,
AuxLoadImpl,
RowBroadcastImpl,
ColumnBroadcastImpl,
ScalarBroadcastImpl,
]
def __init__(self, name: str) -> None:
if name is None:
name = f"load{LoadNode.cnt}"
LoadNode.cnt += 1
super().__init__(name)
self.op = "load"
def type_propagation(self, *args, **kwargs):
"""
Load node loads tensor under type `tensor.element` and returns an array of type `tensor.element`.
"""
if self.tensor is None:
raise RuntimeError(f"The tensor of node {self.name} is unknown.")
self.element = self.tensor.element
self.element_output = self.tensor.element

View File

@@ -0,0 +1,330 @@
#################################################################################################
#
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
"""
Base & visitor classes of DAGIR Nodes
"""
import ctypes
from re import sub
from cutlass_api.fusion.library import LayoutType
from cutlass_api.fusion.ir.layout_algorithm import _list_to_tuple, _reverse_tuple
from cutlass_api.fusion.ir.tensor import Tensor
class TupleEmitter:
"""
Emit the cute tuple to C++ code
"""
def __init__(self, stride_dtype):
self.stride_dtype = stride_dtype
def emit(self, py_tuple):
if isinstance(py_tuple, int):
if py_tuple in [0, 1]:
return f"cute::Int<{py_tuple}>"
else:
return f"{self.stride_dtype}"
elif isinstance(py_tuple, tuple):
decl = "cute::Stride<"
for item in py_tuple:
decl += self.emit(item) + ", "
return decl[:-2] + ">"
else:
raise ValueError(
f"TupleEmitter.emit only accepts tuple or int, got {type(py_tuple).__name__}"
)
class ImplBase:
"""
Base class for Node Implementation
"""
def __init__(self, node) -> None:
self.node = node
self.name = node.name
self.tensor = node.tensor
self._type_decl = None
self.tuple_emitter = TupleEmitter("int64_t")
@property
def stride_dtype(self):
return self.tuple_emitter.stride_dtype
@stride_dtype.setter
def stride_dtype(self, stride_dtype):
self.tuple_emitter.stride_dtype = stride_dtype
@staticmethod
def match(node, problem_size: tuple):
"""
Match function used in get_underlying_impl
"""
raise NotImplementedError("The `match` function is not defined.")
@property
def argument_type(self):
"""
Default class for Argument Type
"""
class _Argument(ctypes.Structure):
_fields_ = []
def __init__(self, *args, **kwargs) -> None:
pass
return _Argument
@property
def name_camel(self) -> str:
"""
Return the CamelCase name.
"""
return sub(r"(_|-)+", " ", self.name).title().replace(" ", "")
@property
def stride_mnl(self):
"""
Typename StrideMNL
"""
stride = _list_to_tuple(
[self.stride[-2], self.stride[-1]]
+ list(_reverse_tuple(tuple(self.stride[:-2])))
)
return self.tuple_emitter.emit(stride)
def get_non_constant_stride(self, py_tuple):
if isinstance(py_tuple, int):
if py_tuple not in [0, 1]:
return py_tuple
else:
return None
non_constant_stride = []
for item in py_tuple:
item_out = self.get_non_constant_stride(item)
if item_out:
non_constant_stride.append(item_out)
return tuple(non_constant_stride)
def get_stride_mnl(self):
"""
Get the non-zero stride mnl. This is used in argument construction
"""
stride = _list_to_tuple(
[self.stride[-2], self.stride[-1]]
+ list(_reverse_tuple(tuple(self.stride[:-2])))
)
return stride
def get_smem_size(self, *args, **kwargs):
"""
Get the shared memory size and alignment of current node
"""
return (0, 1)
class NoOpImpl(ImplBase):
"""
The NoOpImpl does nothing but forward its input to users
"""
def __init__(self, node) -> None:
super().__init__(node)
@staticmethod
def match(node, problem_size: tuple):
if node.op == "store":
# Store that is not output is a No OP
return not node.is_output
class NodeBase:
"""
Base class of DAG Node
"""
def __init__(self, name: str) -> None:
self.name = name
self.underlying_impl = None
self._tensor = None
# Whether the node is disabled for emit
self.disabled = False
@property
def name_camel(self) -> str:
"""
Return the CamelCase name.
"""
return self.underlying_impl.name_camel
@property
def tensor(self) -> Tensor:
"""
Return the output tensor (concept: cutlass_api.fusion.ir.tensor)
"""
return self._tensor
@tensor.setter
def tensor(self, kwargs):
"""
Setting the tensor
"""
self._tensor = Tensor(**kwargs)
#
# Helper functions for type/shape propagation
#
def shape_propagation(self, input_node_metas):
"""
Infer shape from input nodes
General Broadcasting Rules from NumPy
When operating on two arrays, we compare their shapes element-wise.
It starts with the trailing (i.e. rightmost) dimension and works its
way left. Two dimensions are compatible when
1. they are equal
2. one of them is 1
"""
if self._tensor is not None:
return
shape = None
for src in input_node_metas:
src_shape = src.tensor.shape
if shape is None:
shape = src_shape
else:
len_difference = len(shape) - len(src_shape)
if len_difference > 0:
for _ in range(len_difference):
src_shape = [1] + list(src_shape)
elif len_difference < 0:
for _ in range(-len_difference):
shape = [1] + list(shape)
broadcasted_shape = []
# Infer broadcast shape
for shape_dim, src_dim in zip(reversed(shape), reversed(src_shape)):
if shape_dim == 1:
broadcasted_shape = [src_dim] + list(broadcasted_shape)
elif src_dim == 1:
broadcasted_shape = [shape_dim] + list(broadcasted_shape)
elif shape_dim == src_dim:
broadcasted_shape = [shape_dim] + list(broadcasted_shape)
else:
error_msg = "Dimension mismatch between "
for src_ in input_node_metas:
error_msg += f"{src_.name}{src_.tensor.shape}, "
error_msg = error_msg[:-2] + "."
raise RuntimeError(error_msg)
shape = tuple(broadcasted_shape)
self._tensor = Tensor(
element=self.element_output, shape=shape, layout_tag=LayoutType.RowMajor
)
def type_propagation(self, *args, **kwargs):
"""
Each node is associated with two data types: `element` and `element_output`.
The `element_output` is the type of return array of the node. The `element`
has specific meaning for different node types.
* Load Node: data type of tensor in gmem
* Compute Node: element compute
* Store Node: data type of tensor in gmem
This function must be overloaded in the derived classes
"""
raise NotImplementedError(
f"Function `type_propagation` is not overloaded in {self.__class__.__name__}"
)
def broadcast_propagation(self, input_node_metas: "list[NodeBase]"):
"""
Propagate the broadcast in the reversed topological order.
For example:
C[L, M, N] = A[M, 1] + B[L, M, N]
After the broadcast propagation, it will be come
C[L, M, N] = A[L, M, N] + B[L, M, N]
and each tensor will have a proper stride accessing the underlying tensor
"""
if self.tensor is None:
raise RuntimeError(f"The tensor of node {self.name} is unknown.")
for child in input_node_metas:
child.tensor.broadcast(self.tensor.shape)
def get_underlying_impl(self, problem_size: tuple):
"""
Get the underlying implementation of the current node.
"""
if self.tensor is None:
raise RuntimeError(
f"The Layout of node {self.name} is unknown. Please call PassShapeTypePropagation first."
)
for impl in self.possible_impls:
if impl.match(self, problem_size):
self.underlying_impl = impl(self)
break
if self.underlying_impl is None:
raise NotImplementedError(
f"No matching op for node {self.name} with stride {self.tensor.stride}."
)
#
# Visitor Nodes & Impls
#
class TopoVisitorImpl(ImplBase):
"""
Impl for topological visitor
"""
def __init__(self, node) -> None:
super().__init__(node.output_node)
self.name = node.name
self.element_output = node.output_node.element_output
class TopoVisitorNode(NodeBase):
def __init__(self, name: str, subgraph, output_node) -> None:
super().__init__(name)
self.subgraph = subgraph
self.output_node = output_node
self.op = "dag"
self.underlying_impl = TopoVisitorImpl(self)

View File

@@ -0,0 +1,276 @@
#################################################################################################
#
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
"""
Store node and implementations
"""
import ctypes
from cutlass_api.fusion.library import DataType, FloatRoundStyle, FunctionalOp
from cutlass_api.fusion.ir.c_types import dtype2ctype, to_ctype_value, tuple_factory
from cutlass_api.fusion.ir.node import NodeBase, ImplBase, NoOpImpl
from cutlass_api.fusion.ir.tensor import Tensor
class StoreImplBase(ImplBase):
"""
Base class for store node implementation
"""
reserved_names = ["D"]
def __init__(self, node) -> None:
super().__init__(node)
self.element = node.element
self.element_output = node.element_output
self.stride = node.store_tensor.stride
class StoreDImpl(StoreImplBase):
"""
Store D implementation
"""
@property
def argument_type_d(self):
stride_mnl = self.get_stride_mnl()
tuple_type = tuple_factory(stride_mnl, self.stride_dtype)
class _Argument(ctypes.Structure):
_fields_ = [("ptr_D", ctypes.c_void_p), ("stride_D", tuple_type)]
def __init__(self, ptr: int) -> None:
self.ptr_D = ptr
self.stride_D = tuple_type(stride_mnl)
return _Argument
@staticmethod
def match(node, problem_size: tuple):
if node.name == "D" and node.store_tensor.shape == problem_size:
return True
return False
class AuxStoreImpl(StoreImplBase):
def __init__(self, node) -> None:
super().__init__(node)
self.round_style = FloatRoundStyle.ToNearest
@property
def argument_type(self):
stride_mnl = self.get_stride_mnl()
name = self.name
tuple_type = tuple_factory(stride_mnl, self.stride_dtype)
class _Argument(ctypes.Structure):
_fields_ = [("ptr_aux", ctypes.c_void_p), ("dAux", tuple_type)]
def __init__(self, kwargs) -> None:
ptr = kwargs[name]
self.ptr_aux = ptr
self.dAux = tuple_type(stride_mnl)
return _Argument
@staticmethod
def match(node, problem_size: tuple):
if not node.is_output:
return False
if node.name in StoreImplBase.reserved_names:
return False
strideMN = node.store_tensor.stride[-2:]
if 1 not in strideMN or 0 in strideMN:
return False
return True
class ReductionImplBase(StoreImplBase):
def __init__(self, node) -> None:
super().__init__(node)
self.element = node.store_tensor.element
self.element_compute = node.element_compute
self.reg_reduce_fn = self.node.reg_reduce_fn
self.gmem_reduce_fn = self.node.gmem_reduce_fn
self.round_style = node.round_style
self.stride_dtype = "int"
def get_reduce_identity(self):
"""
Return the reduction identity of the current reduce_fn
"""
maxes = {
DataType.f32: (2**31) - 1,
DataType.f16: (2**15),
DataType.s32: (2**31) - 1,
DataType.s8: (2**7) - 1,
}
mins = {
DataType.f32: -maxes[DataType.f32],
DataType.f16: -maxes[DataType.f16],
DataType.s32: -maxes[DataType.s32],
DataType.s8: -maxes[DataType.s8],
}
if self.reg_reduce_fn == FunctionalOp.Maximum:
if self.element_compute not in mins:
raise Exception(f"No min entry for data type {self.element_compute}")
return to_ctype_value(mins[self.element_compute], self.element_compute)
elif self.reg_reduce_fn == FunctionalOp.Multiplies:
return to_ctype_value(1.0, self.element_compute)
elif self.reg_reduce_fn == FunctionalOp.Minimum:
if self.element_compute not in maxes:
raise Exception(f"No max entry for data type {self.element_compute}")
return to_ctype_value(maxes[self.element_compute], self.element_compute)
else:
return to_ctype_value(0.0, self.element_compute)
@property
def argument_type(self):
self.get_reduce_identity()
stride_mnl = self.get_stride_mnl()
name = self.name
tuple_type = tuple_factory(stride_mnl, self.stride_dtype)
element_compute = self.element_compute
reduce_identity = self.get_reduce_identity()
class _Argument(ctypes.Structure):
_fields_ = [
("ptr", ctypes.c_void_p),
("reduce_identity", dtype2ctype[element_compute]),
("dMNL", tuple_type),
]
def __init__(self, kwargs) -> None:
ptr = kwargs[name]
self.ptr = ptr
self.reduce_identity = reduce_identity
self.dMNL = tuple_type(stride_mnl)
return _Argument
class ColumnReductionImpl(ReductionImplBase):
@staticmethod
def match(node, problem_size: tuple):
if not node.is_output:
return False
if node.name in StoreImplBase.reserved_names:
return False
strideMN = node.store_tensor.stride[-2:]
if strideMN == (1, 0):
return True
else:
return False
class RowReductionImpl(ReductionImplBase):
@staticmethod
def match(node, problem_size: tuple):
if not node.is_output:
return False
if node.name in StoreImplBase.reserved_names:
return False
strideMN = node.store_tensor.stride[-2:]
if strideMN == (0, 1):
return True
else:
return False
class ScalarReductionImpl(ReductionImplBase):
@staticmethod
def match(node, problem_size: tuple):
if not node.is_output:
return False
if node.name in StoreImplBase.reserved_names:
return False
strideMN = node.store_tensor.stride[-2:]
if strideMN == (0, 0):
return True
else:
return False
class StoreNode(NodeBase):
"""
Store node
"""
possible_impls = [
AuxStoreImpl,
RowReductionImpl,
ColumnReductionImpl,
ScalarReductionImpl,
NoOpImpl,
StoreDImpl,
]
def __init__(self, name: str) -> None:
super().__init__(name)
self.op = "store"
self.is_output = False
self._store_tensor = None
@property
def store_tensor(self) -> Tensor:
"""
Return the output tensor (concept: cutlass_api.fusion.ir.tensor)
"""
return self._store_tensor
@store_tensor.setter
def store_tensor(self, kwargs):
"""
Setting the tensor
"""
self._store_tensor = Tensor(**kwargs)
def type_propagation(self, input_node_metas: "list[NodeBase]"):
"""
The store nodes has element_output = element_input
"""
if self.is_output:
if self.store_tensor is None:
raise RuntimeError(f"The store tensor of node {self.name} is unknown.")
self.element = self.store_tensor.element
assert len(input_node_metas) == 1, "Store node can only have one input node"
self.element_output = input_node_metas[0].element_output
def broadcast_propagation(self, input_node_metas: "list[NodeBase]"):
super().broadcast_propagation(input_node_metas)
if self.is_output:
self._store_tensor.broadcast(self.tensor.shape)

View File

@@ -0,0 +1,155 @@
#################################################################################################
#
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
"""
High-level class for tensor
"""
from cutlass_api.fusion.library import (
LayoutType,
get_datatype_and_layout,
get_tensor_shape,
library_type,
)
from cutlass_api.fusion.ir.layout_algorithm import (
Layout,
broadcast,
canonicalization,
permutation,
reshape,
_reverse_tuple,
)
class Tensor:
"""
The tensor abstracts the data type
"""
def __init__(
self,
tensor=None,
element=None,
shape=None,
stride=None,
layout_tag=None,
is_constant=False,
) -> None:
if element is not None and tensor is not None:
raise Exception("Must not specify both element and tensor")
elif shape is not None and tensor is not None:
raise Exception("Must not specify both shape and tensor")
elif layout_tag is not None and tensor is not None:
raise Exception("Must not specify both layout_tag and tensor")
elif (
element is None or (layout_tag is None and stride is None) or shape is None
) and (tensor is None):
raise Exception(
"Must specify one of (element, shape, layout/stride) or (tensor)"
)
elif stride is not None and tensor is not None:
raise Exception("Must not specify both stride and tensor")
elif stride is not None and layout_tag is not None:
raise Exception("Must not specify layout_tag when stride is provided")
if isinstance(tensor, Tensor):
# Directly copy all the attributes
self.__dict__.update(vars(tensor))
else:
if tensor is None:
self.element = library_type(element)
else:
self.element, layout_tag = get_datatype_and_layout(tensor)
shape = get_tensor_shape(tensor)
if stride is not None:
self.layout = Layout(shape[::-1], stride[::-1])
else:
if layout_tag == LayoutType.RowMajor:
self.layout = Layout(shape[::-1])
elif layout_tag == LayoutType.ColumnMajor:
self.layout = permutation(
Layout(shape), [idx for idx in reversed(range(len(shape)))]
)
self.layout = canonicalization(self.layout)
self.is_constant = is_constant
# Save the tensor value if it is constant
if is_constant and tensor is not None:
self.value = tensor
@property
def shape(self):
"""
Returns the RowMajor layout shape
"""
return _reverse_tuple(self.layout.shape)
@property
def stride(self):
"""
Returns the RowMajor layout stride
"""
return _reverse_tuple(self.layout.stride)
@property
def rank(self):
"""
Returns the rank of the tensor
"""
return len(self.shape)
#
# Layout Algorithms
#
def broadcast(self, shape):
"""
Broadcast self.layout to shape
"""
assert isinstance(shape, tuple)
self.layout = broadcast(self.layout, _reverse_tuple(shape))
def reshape(self, shape):
"""
Reshape self.layout to shape
"""
assert isinstance(shape, tuple)
reverse_shape = _reverse_tuple(shape)
self.layout = reshape(self.layout, reverse_shape)
def permute(self, indices):
"""
Permute self.layout according to indices
"""
length = len(indices)
indices = [length - idx - 1 for idx in indices]
self.layout = permutation(self.layout, indices[::-1])

View File

@@ -0,0 +1,441 @@
#################################################################################################
#
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
"""
Copies of enum types from cutlass_library that are used in the fusion frontend.
"""
import enum
from enum import auto as enum_auto
import cutlass
from cutlass_api.utils import (
is_numpy_available,
is_numpy_tensor,
is_torch_available,
is_torch_tensor,
)
class DataType(enum.Enum):
void = enum_auto() # primarily used to disable C tensor for epilogues
b1 = enum_auto()
u2 = enum_auto()
u4 = enum_auto()
u8 = enum_auto()
u16 = enum_auto()
u32 = enum_auto()
u64 = enum_auto()
s2 = enum_auto()
s4 = enum_auto()
s8 = enum_auto()
s16 = enum_auto()
s32 = enum_auto()
s64 = enum_auto()
e4m3 = enum_auto()
e5m2 = enum_auto()
f8 = enum_auto()
f6 = enum_auto()
f4 = enum_auto()
e3m2 = enum_auto()
e2m3 = enum_auto()
e2m1 = enum_auto()
ue8m0 = enum_auto()
ue4m3 = enum_auto()
f16 = enum_auto()
bf16 = enum_auto()
f32 = enum_auto()
tf32 = enum_auto()
f64 = enum_auto()
cf16 = enum_auto()
cbf16 = enum_auto()
cf32 = enum_auto()
ctf32 = enum_auto()
cf64 = enum_auto()
cs2 = enum_auto()
cs4 = enum_auto()
cs8 = enum_auto()
cs16 = enum_auto()
cs32 = enum_auto()
cs64 = enum_auto()
cu2 = enum_auto()
cu4 = enum_auto()
cu8 = enum_auto()
cu16 = enum_auto()
cu32 = enum_auto()
cu64 = enum_auto()
invalid = enum_auto()
DataTypeSize = {
DataType.void: 0,
DataType.b1: 1,
DataType.u2: 2,
DataType.u4: 4,
DataType.u8: 8,
DataType.u16: 16,
DataType.u32: 32,
DataType.u64: 64,
DataType.s2: 2,
DataType.s4: 4,
DataType.s8: 8,
DataType.s16: 16,
DataType.s32: 32,
DataType.s64: 64,
DataType.e4m3: 8,
DataType.e5m2: 8,
DataType.f8: 8,
DataType.f6: 6,
DataType.f4: 4,
DataType.e2m3: 6,
DataType.e3m2: 6,
DataType.e2m1: 4,
DataType.ue8m0: 8,
DataType.ue4m3: 8,
DataType.f16: 16,
DataType.bf16: 16,
DataType.f32: 32,
DataType.tf32: 32,
DataType.f64: 64,
DataType.cf16: 32,
DataType.cbf16: 32,
DataType.cf32: 64,
DataType.ctf32: 32,
DataType.cf64: 128,
DataType.cu2: 4,
DataType.cu4: 8,
DataType.cu8: 16,
DataType.cu16: 32,
DataType.cu32: 64,
DataType.cu64: 128,
DataType.cs2: 4,
DataType.cs4: 8,
DataType.cs8: 16,
DataType.cs16: 32,
DataType.cs32: 64,
DataType.cs64: 128,
}
class LayoutType(enum.Enum):
ColumnMajor = enum_auto()
RowMajor = enum_auto()
ColumnMajorInterleaved2 = enum_auto()
RowMajorInterleaved2 = enum_auto()
ColumnMajorInterleaved32 = enum_auto()
RowMajorInterleaved32 = enum_auto()
ColumnMajorInterleaved64 = enum_auto()
RowMajorInterleaved64 = enum_auto()
TensorNWC = enum_auto()
TensorNHWC = enum_auto()
TensorNDHWC = enum_auto()
TensorNCHW = enum_auto()
TensorNGHWC = enum_auto()
TensorNC32HW32 = enum_auto()
TensorNC64HW64 = enum_auto()
TensorC32RSK32 = enum_auto()
TensorC64RSK64 = enum_auto()
TensorKCS = enum_auto()
TensorKCSR = enum_auto()
TensorKCSRT = enum_auto()
class EpilogueScheduleType(enum.Enum):
ScheduleAuto = enum_auto()
EpilogueTransposed = enum_auto()
NoSmemWarpSpecialized = enum_auto()
PtrArrayNoSmemWarpSpecialized = enum_auto()
NoSmemWarpSpecialized1Sm = enum_auto()
NoSmemWarpSpecialized2Sm = enum_auto()
FastF32NoSmemWarpSpecialized1Sm = enum_auto()
FastF32NoSmemWarpSpecialized2Sm = enum_auto()
BlockwiseNoSmemWarpSpecialized1Sm = enum_auto()
BlockwiseNoSmemWarpSpecialized2Sm = enum_auto()
PtrArrayNoSmemWarpSpecialized1Sm = enum_auto()
PtrArrayNoSmemWarpSpecialized2Sm = enum_auto()
PtrArrayFastF32NoSmemWarpSpecialized1Sm = enum_auto()
PtrArrayFastF32NoSmemWarpSpecialized2Sm = enum_auto()
PtrArrayBlockwiseNoSmemWarpSpecialized1Sm = enum_auto()
PtrArrayBlockwiseNoSmemWarpSpecialized2Sm = enum_auto()
TmaWarpSpecialized = enum_auto()
TmaWarpSpecializedCooperative = enum_auto()
TmaWarpSpecialized1Sm = enum_auto()
TmaWarpSpecialized2Sm = enum_auto()
PtrArrayTmaWarpSpecialized1Sm = enum_auto()
PtrArrayTmaWarpSpecialized2Sm = enum_auto()
PtrArrayTmaWarpSpecializedPingpong = enum_auto()
PtrArrayTmaWarpSpecializedCooperative = enum_auto()
TmaWarpSpecialized1SmNvf4 = enum_auto()
TmaWarpSpecialized2SmNvf4 = enum_auto()
TmaWarpSpecialized1SmMxf4 = enum_auto()
TmaWarpSpecialized2SmMxf4 = enum_auto()
TmaWarpSpecialized1SmMxf8f6f4 = enum_auto()
TmaWarpSpecialized2SmMxf8f6f4 = enum_auto()
SparseTmaWarpSpecializedCooperativeSm120 = enum_auto()
class ActivationOp(enum.Enum):
DGelu = enum_auto()
Gelu = enum_auto()
GeluTaylor = enum_auto()
HardSwish = enum_auto()
Identity = enum_auto()
LeakyReLU = enum_auto()
ReLU = enum_auto()
Sigmoid = enum_auto()
SiLU = enum_auto()
Tanh = enum_auto()
ActivationOpTag = {
ActivationOp.DGelu: "cutlass::epilogue::thread::dGELU",
ActivationOp.Gelu: "cutlass::epilogue::thread::GELU",
ActivationOp.GeluTaylor: "cutlass::epilogue::thread::GELU_taylor",
ActivationOp.HardSwish: "cutlass::epilogue::thread::HardSwish",
ActivationOp.Identity: "cutlass::epilogue::thread::Identity",
ActivationOp.LeakyReLU: "cutlass::epilogue::thread::LeakyReLU",
ActivationOp.ReLU: "cutlass::epilogue::thread::ReLu",
ActivationOp.Sigmoid: "cutlass::epilogue::thread::Sigmoid",
ActivationOp.SiLU: "cutlass::epilogue::thread::SiLu",
ActivationOp.Tanh: "cutlass::epilogue::thread::Tanh",
}
class FloatRoundStyle(enum.Enum):
ToNearest = enum_auto()
ToNearestSatfinite = enum_auto()
Indeterminate = enum_auto()
TowardZero = enum_auto()
TowardInfinity = enum_auto()
TowardNegInfinity = enum_auto()
HalfUlpTruncDntz = enum_auto()
HalfUlpTruncate = enum_auto()
class FunctionalOp(enum.Enum):
AtomicAdd = enum_auto()
AtomicMaximum = enum_auto()
Divides = enum_auto()
Maximum = enum_auto()
Minimum = enum_auto()
Minus = enum_auto()
Multiplies = enum_auto()
MultiplyAdd = enum_auto()
Plus = enum_auto()
Exp = enum_auto()
FunctionalOpTag = {
FunctionalOp.AtomicAdd: "cutlass::atomic_add",
FunctionalOp.AtomicMaximum: "cutlass::atomic_maximum",
FunctionalOp.Divides: "cutlass::divides",
FunctionalOp.Maximum: "cutlass::maximum",
FunctionalOp.Minimum: "cutlass::minimum",
FunctionalOp.Minus: "cutlass::minus",
FunctionalOp.Multiplies: "cutlass::multiplies",
FunctionalOp.MultiplyAdd: "cutlass::multiply_add",
FunctionalOp.Plus: "cutlass::plus",
FunctionalOp.Exp: "cutlass::fast_exp_op",
}
def op_tag(op) -> str:
"""
Dispatches `op` to the appropriate *Tag dictionary depending on whether
`op` is an ActivationOp or FunctionalOp. This is useful for cases in which
either type can be used.
:param op: operation to emit a tag for
:type op: ActivationOp | FunctionalOp
:return: tag corresponding to op
:rtype: str
"""
if isinstance(op, ActivationOp):
return ActivationOpTag[op]
elif isinstance(op, FunctionalOp):
return FunctionalOpTag[op]
else:
raise Exception(
f"Unexpected op type {op}. Must be one of ActivationOp or FunctionalOp."
)
# The current EVT frontend also contains code needed for emitting C++ EVTs.
# C++ emission is not currently supported by cutlass_api, but we still
# need many of the utilities surrounding it. These utilities make use of dictionaries
# from cutlass_library types to strings containing C++ code. We define placeholders
# for empty versions of these dictionaries that raise an error if they are used.
class _UnimplementedDict:
def __init__(self, name):
self.name = name
def __getitem__(self, key):
raise NotImplementedError(
f"Dictionary {self.name} is not implemented. This code path should not have been reafched."
)
DataTypeTag = _UnimplementedDict("DataTypeTag")
EpilogueScheduleTag = _UnimplementedDict("EpilogueScheduleTag")
FloatRoundStyleTag = _UnimplementedDict("FloatRoundStyleTag")
KernelScheduleSuffixes = _UnimplementedDict("KernelScheduleSuffixes")
OpcodeClassTag = _UnimplementedDict("OpcodeClassTag")
_torch_to_library_dict = None
if is_torch_available():
import torch
_torch_to_library_dict = {
torch.half: DataType.f16,
torch.float16: DataType.f16,
torch.bfloat16: DataType.bf16,
torch.float: DataType.f32,
torch.float32: DataType.f32,
}
def _tensor_from_numpy(np_tensor) -> tuple[DataType, LayoutType]:
dtype = library_type(np_tensor.dtype)
if np_tensor.flags.c_contiguous:
layout = LayoutType.RowMajor
elif np_tensor.flags.f_contiguous:
layout = LayoutType.ColumnMajor
return (dtype, layout)
def _tensor_from_torch(pt_tensor):
dtype = library_type(pt_tensor.dtype)
return (dtype, LayoutType.RowMajor)
def torch_library_type(inp) -> DataType:
if _torch_to_library_dict is None:
return None
return _torch_to_library_dict.get(inp, None)
def numpy_library_type(inp) -> DataType:
if is_numpy_available():
import numpy as np
if inp == np.float16:
return DataType.f16
elif inp == np.float32:
return DataType.f32
elif inp == np.float64:
return DataType.f64
elif inp == np.int8:
return DataType.s8
elif inp == np.int32:
return DataType.s32
return None
def cutlass_library_type(inp: cutlass.Numeric) -> DataType:
if inp == cutlass.Float32:
return DataType.f32
elif inp == cutlass.Float16:
return DataType.f16
elif inp == cutlass.BFloat16:
return DataType.bf16
elif inp == cutlass.Int32:
return DataType.s32
elif inp == cutlass.Int8:
return DataType.s8
elif inp == cutlass.Uint8:
return DataType.u8
elif inp == cutlass.Uint16:
return DataType.u16
elif inp == cutlass.Uint32:
return DataType.u32
elif inp == cutlass.Uint64:
return DataType.u64
elif inp == cutlass.Int16:
return DataType.s16
elif inp == cutlass.Int64:
return DataType.s64
elif inp == cutlass.Float8E5M2:
return DataType.e5m2
elif inp == cutlass.Float8E4M3FN:
return DataType.e4m3
else:
return None
def library_type(inp):
if inp in DataTypeSize:
return inp
for cvt_fn in [
numpy_library_type,
torch_library_type,
cutlass_library_type,
]:
out = cvt_fn(inp)
if out is not None:
return out
raise ValueError(f"No available conversion from type {inp} to a library type.")
def get_datatype_and_layout(tensor) -> tuple[DataType, LayoutType]:
if is_numpy_tensor(tensor):
return _tensor_from_numpy(tensor)
elif is_torch_tensor(tensor):
return _tensor_from_torch(tensor)
elif isinstance(tensor, float) or isinstance(tensor, int):
return (DataType.f32, LayoutType.RowMajor)
else:
raise Exception(
f"Unable to convert tensor of type {type(tensor)} to Python-bound CUTLASS datatype and layout."
)
def get_tensor_shape(tensor, op="GEMM") -> tuple:
if is_numpy_tensor(tensor):
return tensor.shape
elif is_torch_tensor(tensor):
size = tensor.size()
if op == "CONV":
# PyTorch Tensors have shape NCHW
return (size[0], size[2], size[3], size[1])
else:
return tuple(tensor.size())
elif isinstance(tensor, float) or isinstance(tensor, int):
return (1,)
else:
raise Exception(
f"Unable to convert tensor of type {type(tensor)} to Python-bound CUTLASS datatype and layout."
)

View File

@@ -0,0 +1,59 @@
#################################################################################################
#
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
from cutlass_api.fusion.passes.graph_drawer import EVTGraphDrawer
from cutlass_api.fusion.passes.pass_argument_type import PassGetArgumentType
from cutlass_api.fusion.passes.pass_dag_2_tree import PassDAG2Tree
from cutlass_api.fusion.passes.pass_get_impl import PassGetImpl
from cutlass_api.fusion.passes.pass_fix_element_d import PassFixElementD
from cutlass_api.fusion.passes.pass_layout_elimination import (
PassLayoutManipulateElimination,
)
from cutlass_api.fusion.passes.pass_manager import EVTPassManager
from cutlass_api.fusion.passes.pass_preprocess_red import PassPreprocessRed
from cutlass_api.fusion.passes.pass_shape_type_propagation import (
PassShapeTypePropagation,
)
from cutlass_api.fusion.passes.smem_size_calculator import GetSmemSize
__all__ = [
"EVTGraphDrawer",
"PassGetArgumentType",
"PassDAG2Tree",
"PassGetImpl",
"PassFixElementD",
"PassLayoutManipulateElimination",
"EVTPassManager",
"PassPreprocessRed",
"PassShapeTypePropagation",
"GetSmemSize",
]

View File

@@ -0,0 +1,133 @@
#################################################################################################
#
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
from __future__ import annotations
from cutlass_api.fusion.library import DataTypeTag
from cutlass_api.fusion.ir.dag_ir import DAGIR
_COLOR_MAP = {
"load": '"AliceBlue"',
"compute": "LemonChiffon1",
"accumulator": "LightGrey",
"store": "PowderBlue",
"layout": "lightseagreen",
"dag": "darkorange",
}
class EVTGraphDrawer:
"""
Visualize a EVT DAGIR with graphviz
"""
def __init__(self, graph: DAGIR, name: str):
self._name = name
self._dot_graphs = {}
self._dot_graphs[name] = self._to_dot(graph, name)
def _get_node_style(self, node):
template = {
"shape": "record",
"fillcolor": "#CAFFE3",
"style": '"filled,rounded"',
"fontcolor": "#000000",
}
if node.op in _COLOR_MAP:
template["fillcolor"] = _COLOR_MAP[node.op]
else:
raise NotImplementedError("unknown node op")
if node.disabled:
template["fontcolor"] = "grey"
template["fillcolor"] = "white"
return template
def _get_node_label(self, node):
label = "{" + f"name={node.name}|op={node.op}"
if node.op == "layout":
label += f"|fn={node.fn.__name__}"
for key in node.kwargs:
label += f"|{key}={node.kwargs[key]}"
if node.underlying_impl is not None:
label += f"|impl={type(node.underlying_impl).__name__}"
if node.op == "load":
label += f"|element_output={DataTypeTag[node.underlying_impl.element]}"
elif node.op == "compute":
label += f"|element_compute={DataTypeTag[node.underlying_impl.element_compute]}|element_output={DataTypeTag[node.underlying_impl.element_output]}"
elif node.op == "store":
label += f"|element_store={DataTypeTag[node.underlying_impl.element]}|element_output={DataTypeTag[node.underlying_impl.element_output]}"
label += f"|element_output={DataTypeTag[node.underlying_impl.element_output]}"
if node.tensor is not None:
shape = node.tensor.shape
stride = node.tensor.stride
label += f"|shape={shape}|stride={stride}"
if hasattr(node, "store_tensor") and node.store_tensor is not None:
store_shape = node.store_tensor.shape
store_stride = node.store_tensor.stride
label += f"|store_shape={store_shape}|store_stride={store_stride}"
label += "}"
return label
def _to_dot(self, graph: DAGIR, name: str):
import pydot
dot_graph = pydot.Dot(name, rankdir="TB")
for node in graph.nodes_meta:
style = self._get_node_style(node)
label = self._get_node_label(node)
dot_node = pydot.Node(node.name, label=label, **style)
dot_graph.add_node(dot_node)
if node.op == "dag":
dot_subgraph = self._to_dot(node.subgraph, name=node.name)
self._dot_graphs[node.name] = dot_subgraph
# Add edges
for src, dst in graph.edges:
weight = graph.get_edge_weight(src, dst)
dot_graph.add_edge(pydot.Edge(src, dst, label=weight))
return dot_graph
def get_dot_graph(self) -> "pydot.Dot":
return [
(key, self.get_dot_graph_by_name(key)) for key in self._dot_graphs.keys()
]
def get_dot_graph_by_name(self, name) -> "pydot.Dot":
return self._dot_graphs[name]
def get_main_dot_graph(self) -> "pydot.Dot":
return self._dot_graphs[self._name]

View File

@@ -0,0 +1,136 @@
#################################################################################################
#
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
"""
Construct the epilogue visitor argument type
"""
from cutlass_api.fusion.ir import TopoVisitorNode
from cutlass_api.fusion.ir.c_types import visitor_factory
from cutlass_api.fusion.passes.pass_dag_2_tree import PassDAG2Tree
from cutlass_api.fusion.passes.pass_get_impl import PassGetImpl
from cutlass_api.fusion.passes.pass_manager import EVTPassBase
from cutlass_api.fusion.passes.pass_shape_type_propagation import (
PassShapeTypePropagation,
)
from cutlass_api.fusion.passes.util import cc_map
class PassGetArgumentType(EVTPassBase):
"""
Construct the epilogue visitor argument type
"""
dependencies = [
PassShapeTypePropagation, # The Layout of all nodes must be set
PassDAG2Tree, # The type of each node must be set
PassGetImpl, # The DAG subgraphs must be set
]
def requires(self) -> None:
# Check "D" is in the node list
if cc_map[self.cc] in [90, 100] and (not self.dag_ir.has_node("D")):
raise SyntaxError(
"Sm90+ EVT requires the epilogue to have a returned tensor D, "
"but the variable 'D' is not found in the return values."
)
def call(self):
nodes = self.dag_ir.nodes_topological_order()
self.argument_types = {}
for node in nodes:
meta = self.dag_ir.get_node_meta(node)
if not meta.disabled:
self.argument_types[node] = meta.underlying_impl.argument_type
if node == "D" and cc_map[self.cc] in [90, 100]:
continue
if isinstance(meta, TopoVisitorNode):
self.get_dag_argument_type(node)
else:
self.get_evt_argument_type(node)
self.cc_specific_method(self.set_argument_type)()
def get_evt_argument_type(self, node):
# Sort the input nodes by edge weight
input_types = [
self.argument_types[child] for child in self.dag_ir.get_all_inputs(node)
]
if len(input_types) > 0:
self.argument_types[node] = visitor_factory(
input_types + [self.argument_types[node]],
self.dag_ir.get_all_inputs(node) + [node],
)
def get_dag_argument_type(self, node):
meta = self.dag_ir.get_node_meta(node)
subgraph = meta.subgraph
subgraph_nodes = subgraph.nodes_topological_order()
# Visit the unvisited nodes in subgraph
for n in subgraph_nodes:
M = subgraph.get_node_meta(n)
if M.disabled:
continue
else:
self.argument_types[n] = M.underlying_impl.argument_type
input_types = [self.argument_types[child] for child in subgraph_nodes[:-1]]
if len(input_types) > 0:
self.argument_types[node] = visitor_factory(
input_types, subgraph_nodes[:-1]
)
def set_argument_type(self):
pass
def sm90_set_argument_type(self):
self.dag_ir.epilogue_thread_type = self.argument_types[
self.dag_ir.get_all_inputs("D")[0]
]
# Get the tensorD argument type
self.dag_ir.arg_d_type = self.dag_ir.get_node_meta(
"D"
).underlying_impl.argument_type_d
# Get the tensorC argument type
if self.dag_ir.has_node("C"):
self.dag_ir.arg_c_type = self.dag_ir.get_node_meta(
"C"
).underlying_impl.argument_type_c
else:
self.dag_ir.arg_c_type = self.dag_ir.arg_d_type
def sm100_set_argument_type(self):
self.sm90_set_argument_type()
def sm80_set_argument_type(self):
nodes = self.dag_ir.nodes_topological_order()
self.dag_ir.epilogue_thread_type = self.argument_types[nodes[-1]]

View File

@@ -0,0 +1,176 @@
#################################################################################################
#
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
"""
Merge non-tree sub-graphs of the DAG IR into a single DAG. The fused DAG will be implemented
by the topological visitor, while the rest of the graph will be implemented with the tree visitor.
"""
from copy import deepcopy
from cutlass_api.fusion.ir import DAGIR, TopoVisitorNode
from cutlass_api.fusion.passes.pass_get_impl import PassGetImpl
from cutlass_api.fusion.passes.pass_manager import EVTPassBase
from cutlass_api.fusion.passes.pass_shape_type_propagation import (
PassShapeTypePropagation,
)
class PassDAG2Tree(EVTPassBase):
"""
Convert the DAG IR to Tree by fusing subgraphs
"""
dependencies = [PassShapeTypePropagation, PassGetImpl]
def call(self):
# Step 1: find the nodes that have multiple parents
multi_parent_nodes = []
for node in self.dag_ir.nodes_topological_order():
if self.dag_ir.out_degree(node) > 1:
multi_parent_nodes.append(node)
# Step 2: find the lowest common ancestor (LCA) of all its parents
for node in multi_parent_nodes:
# A multi-parent node could be already fused by the previous node
if not self.dag_ir.has_node(node):
continue
# A node uncovered by the previous fusions can have out degree change
# Case 1: it has <= 1 edges to the previously fused subgraph, no degree change
# Case 2: it has more than one edges to the previously fused subgraph, degree drops
if self.dag_ir.out_degree(node) <= 1:
continue
# Otherwise, the node still
reachable_nodes = []
# Complexity: O(Dout*N)
for parent in self.dag_ir.get_users(node):
reachable_nodes.append(set(self.dag_ir.all_reachable_nodes(parent)))
# get the common reachable objects
common_items = set.intersection(*reachable_nodes)
node_to_fuse = set.union(*reachable_nodes).difference(common_items)
lca = None
# If common ancestor exists, find the lowest one
if len(common_items) > 0:
topo_order = self.dag_ir.nodes_topological_order()
topo_idx = -1
for item in common_items:
if lca is None:
lca = item
topo_idx = topo_order.index(item)
else:
if topo_idx > topo_order.index(item):
lca = item
topo_idx = topo_order.index(item)
else:
# there is no common ancestor for all the parents, we pack all the reachable
# nodes into a single DAG node as a fallback. The lca should be the input node of
# one of the output nodes with out_degree = 0
potential_output_nodes = []
for node in node_to_fuse:
if self.dag_ir.out_degree(node) == 0:
potential_output_nodes.append(node)
if len(potential_output_nodes) == 0:
raise RuntimeError("No output node with out degree = 0 found.")
output_node = None
if self.dag_ir.cc >= 90:
# For SM90+, the lca should be the input node of D
if not self.dag_ir.has_node("D"):
raise RuntimeError("D is not a node in the DAG IR.")
output_node = "D"
else:
output_node = potential_output_nodes[0]
if output_node is None:
raise RuntimeError("No output node found.")
lca = self.dag_ir.get_all_inputs(output_node)[0]
node_to_fuse.remove(output_node)
# The lca is the output node of the DAG node
# Get the nodes to be fused
node_to_fuse.add(lca)
# Get all the input nodes
all_input_nodes = []
all_output_nodes = []
for node in node_to_fuse:
all_input_nodes.append(set(self.dag_ir.get_all_inputs(node)))
all_output_nodes.append(set(self.dag_ir.get_users(node)))
all_input_nodes = set.union(*all_input_nodes)
all_output_nodes = set.union(*all_output_nodes)
new_subgraph_nodes = set.union(
node_to_fuse, all_input_nodes, all_output_nodes
)
# Create the subgraph
subgraph_ = self.dag_ir._graph.subgraph(new_subgraph_nodes)
subgraph = DAGIR(self.dag_ir.cc)
for node in subgraph_.nodes:
meta = deepcopy(self.dag_ir.get_node_meta(node))
if node not in node_to_fuse:
meta.disabled = True
subgraph.add_node(meta)
for edge in subgraph_.edges:
subgraph.add_edge(
edge[0], edge[1], self.dag_ir.get_edge_weight(edge[0], edge[1])
)
# Create the fused node
dag_node = TopoVisitorNode(
name=f"dag_{lca}",
subgraph=subgraph,
output_node=self.dag_ir.get_node_meta(lca),
)
self.dag_ir.add_node(dag_node)
# Add input edges
for idx, node in enumerate(all_input_nodes):
self.dag_ir.add_edge(node, dag_node.name, weight=idx)
# Replace all uses with DAG node (only 1 output node)
self.dag_ir.replace_all_uses_with(lca, dag_node.name)
# Remove all fused nodes
node_to_fuse.remove(lca)
for node in node_to_fuse:
self.dag_ir.remove_node(node)
def ensures(self) -> None:
# Ensure that after the pass, the resulting DAG becomes a tree
for node in self.dag_ir.nodes:
out_degree = self.dag_ir.out_degree(node)
if out_degree > 1:
raise RuntimeError(
f"PassDAG2Tree failed. Node {node} still have outdegree = {out_degree}"
)

View File

@@ -0,0 +1,86 @@
#################################################################################################
#
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
"""
Fix the element_output of producer of D.
In Sm90 epilogue visitor, the node writing D to gmem does not have internal
element converter, so the compute node producing D must have element_output = type(D).
"""
from cutlass_api.fusion.passes.pass_layout_elimination import (
PassLayoutManipulateElimination,
)
from cutlass_api.fusion.passes.pass_manager import EVTPassBase
class PassFixElementD(EVTPassBase):
"""
In Sm90 epilogue visitor, the node writing D to gmem does not have internal
element converter, so the compute node producing D must have
element_output = type(D)
"""
dependencies = [PassLayoutManipulateElimination]
def get_producer(self, node: str, element_D, visited=None):
if visited is None:
visited = set()
if node in visited:
raise RuntimeError(
f"Cycle detected while traversing to producer of D: {node}"
)
visited.add(node)
node_meta = self.dag_ir.get_node_meta(node)
if node_meta.op == "compute":
node_meta.element_output = element_D
elif node_meta.op == "store":
inputs = self.dag_ir.get_all_inputs(node)
if len(inputs) != 1:
raise RuntimeError(
f"Store node {node} has {len(inputs)} inputs, expected 1"
)
self.get_producer(inputs[0], element_D, visited)
elif node_meta.op == "load":
node_meta.element_output = element_D
else:
raise NotImplementedError(
f"Unsupported node op: {node_meta.op} when getting producer of D"
)
def call(self):
if self.dag_ir.has_node("D"):
node_d_meta = self.dag_ir.get_node_meta("D")
element_D = node_d_meta.store_tensor.element
self.get_producer("D", element_D)

View File

@@ -0,0 +1,93 @@
#################################################################################################
#
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
"""
Infer the underlying implement of each node.
While the frontend only distinguish between Load/Store/Compute Node,
each of these nodes can have different underlying implementation based
on their layout. For instance, a LoadNode can be AuxLoad, Row/Col/Scalar broadcast, etc.
This pass infers the underlying impl of each node
"""
import cutlass_api.fusion.backend as evt_backend
from cutlass_api.fusion.ir import DAGIR, LoadNode
from cutlass_api.fusion.passes.pass_fix_element_d import PassFixElementD
from cutlass_api.fusion.passes.pass_manager import EVTPassBase
from cutlass_api.fusion.passes.pass_no_op_elimination import PassNoOpElimination
from cutlass_api.fusion.passes.pass_shape_type_propagation import (
PassShapeTypePropagation,
)
from cutlass_api.fusion.passes.util import cc_map
class PassGetImpl(EVTPassBase):
"""
While the frontend only distinguish between Load/Store/Compute Node,
each of these nodes can have different underlying implementation based
on their layout. For instance, a LoadNode can be AuxLoad, Row/Col/Scalar broadcast, etc.
This pass infers the underlying impl of each node
"""
dependencies = [
PassShapeTypePropagation, # The shape and type info are required for inference
PassFixElementD,
]
def __init__(self, dag_ir: DAGIR) -> None:
super().__init__(dag_ir)
self.no_op_elimination = PassNoOpElimination(dag_ir)
def requires(self) -> None:
# Verify "accum" is in the arg list
if not self.dag_ir.has_node("accum"):
raise SyntaxError("Cannot find 'accum' in the argument list.")
def call(self):
# The loop structure of the epilogue is determined by the
# accumulator shape
accumulator: LoadNode = self.dag_ir.get_node_meta("accum")
problem_size = accumulator.tensor.shape
for node_meta in self.dag_ir.node_metas_topological_order():
node_meta.get_underlying_impl(problem_size)
def ensures(self) -> None:
# Some nodes will be lowered to NoOp, eliminate them
self.no_op_elimination()
# Lower to cc-specific impl
for node_meta in self.dag_ir.nodes_meta:
node_impl_ccs = getattr(evt_backend, f"sm{cc_map[self.cc]}_nodes")
node_meta.underlying_impl = getattr(
node_impl_ccs,
f"Sm{cc_map[self.cc]}" + node_meta.underlying_impl.__class__.__name__,
)(node_meta)

View File

@@ -0,0 +1,230 @@
#################################################################################################
#
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
"""
Eliminate layout manipulation nodes
"""
from copy import deepcopy
from cutlass_api.fusion.ir import DAGIR, LayoutNode
from cutlass_api.fusion.passes.pass_manager import EVTPassBase
from cutlass_api.fusion.passes.pass_shape_type_propagation import (
PassShapeTypePropagation,
)
class PassLayoutManipulateElimination(EVTPassBase):
"""
Eliminate layout manipulation nodes
"""
dependencies = [PassShapeTypePropagation]
def __init__(self, dag_ir: DAGIR) -> None:
super().__init__(dag_ir)
self.copy_cnt = 0
def call(self):
self.layout_nodes_worklist = self.get_all_layout_nodes()
# Run while loop utill all layout nodes are eliminated
while len(self.layout_nodes_worklist) > 0:
node = self.layout_nodes_worklist.pop(0)
# for node in layout_nodes:
# Step 1: get the propagation direction
direction = self.get_propagation_direction(node)
self.visited = []
getattr(self, f"propagate_to_{direction}")(
self.dag_ir.get_node_meta(node), node
)
# Eliminate the current node
input_node = self.dag_ir.get_all_inputs(node)[0]
self.dag_ir.replace_all_uses_with(node, input_node)
# layout_nodes = self.get_all_layout_nodes()
def get_all_layout_nodes(self):
layout_nodes = []
for node_meta in reversed(self.dag_ir.node_metas_topological_order()):
if isinstance(node_meta, LayoutNode):
layout_nodes.append(node_meta.name)
return layout_nodes
def get_propagation_direction(self, node: str):
"""
The logic is propagating all layout nodes away from the accumulator node.
"""
self.visited = []
self.get_influenced_users(node)
nodes_influenced_dir_users = self.visited
self.visited = []
self.get_influenced_inputs(node)
nodes_influenced_dir_inputs = self.visited
if (
"accum" in nodes_influenced_dir_users
and "accum" not in nodes_influenced_dir_inputs
):
return "inputs"
elif (
"accum" not in nodes_influenced_dir_users
and "accum" in nodes_influenced_dir_inputs
):
return "users"
else:
raise RuntimeError("Unsolved propagation direction")
# Get all influenced nodes if we propagate along the user direction
def get_influenced_users(self, node: str):
if node in self.visited:
return
self.visited.append(node)
users = self.dag_ir.get_users(node)
for user in users:
self.get_influenced_users(user)
user_inputs = []
for user in users:
user_inputs.append(set(self.dag_ir.get_all_inputs(user)))
if len(user_inputs) > 0:
user_inputs = set.union(*user_inputs)
user_inputs.remove(node)
for input_node in user_inputs:
self.get_influenced_inputs(input_node)
# Get all influenced nodes if we propagate along the input direction
def get_influenced_inputs(self, node: str):
if node in self.visited:
return
self.visited.append(node)
inputs = self.dag_ir.get_all_inputs(node)
for input_node in inputs:
self.get_influenced_inputs(input_node)
input_users = []
for input_node in inputs:
input_users.append(set(self.dag_ir.get_users(input_node)))
if len(input_users) > 0:
input_users = set.union(*input_users)
input_users.remove(node)
for user in input_users:
self.get_influenced_users(user)
def add_copy_before(self, layout_node_meta: LayoutNode, target: str):
copied_node_meta = deepcopy(layout_node_meta)
copied_node = f"{copied_node_meta.name}_copy{self.copy_cnt}"
self.copy_cnt += 1
copied_node_meta.name = copied_node
self.dag_ir.add_node(copied_node_meta)
# Add edges
target_inputs = self.dag_ir.get_all_inputs(target)
for src in target_inputs:
self.dag_ir.remove_edge(src, target)
self.dag_ir.add_edge(src, copied_node)
self.dag_ir.add_edge(copied_node, target)
self.layout_nodes_worklist.append(copied_node)
def add_copy_after(self, layout_node_meta: LayoutNode, target: str):
copied_node_meta = deepcopy(layout_node_meta)
copied_node = f"{copied_node_meta.name}_copy{self.copy_cnt}"
self.copy_cnt += 1
copied_node_meta.name = copied_node
self.dag_ir.add_node(copied_node_meta)
# Add edges
users = self.dag_ir.get_users(target)
for user in users:
self.dag_ir.remove_edge(target, user)
self.dag_ir.add_edge(copied_node, user)
self.dag_ir.add_edge(target, copied_node)
self.layout_nodes_worklist.append(copied_node)
# Propagate the layout `node` along the user direction
def propagate_to_users(self, layout_node_meta: LayoutNode, node: str):
"""
Propagate layout node to users
"""
if node in self.visited:
# Avoid applying twice
return
self.visited.append(node)
node_meta = self.dag_ir.get_node_meta(node)
if layout_node_meta.name != node:
if isinstance(node_meta, LayoutNode):
# Layout node is not transparent with layout node
self.add_copy_before(layout_node_meta, node)
return
else:
layout_node_meta.apply_to_user(node_meta)
users = self.dag_ir.get_users(node)
user_inputs = []
for user in users:
user_inputs.append(set(self.dag_ir.get_all_inputs(user)))
for user in users:
self.propagate_to_users(layout_node_meta, user)
if len(user_inputs) > 0:
user_inputs = set.union(*user_inputs)
user_inputs.remove(node)
for input_node in user_inputs:
self.propagate_to_inputs(
layout_node_meta.get_inverse_node(), input_node
)
# Propagate the layout `node` along the input direction
def propagate_to_inputs(self, layout_node_meta: LayoutNode, node: str):
"""
Propagate layout node to inputs
"""
if node in self.visited:
# Avoid applying twice
return
self.visited.append(node)
node_meta = self.dag_ir.get_node_meta(node)
if layout_node_meta.name != node:
if isinstance(node_meta, LayoutNode):
# Layout node is not transparent with layout node
self.add_copy_after(layout_node_meta, node)
return
else:
layout_node_meta.apply_to_input(node_meta)
inputs = self.dag_ir.get_all_inputs(node)
input_users = []
for input_node in inputs:
input_users.append(set(self.dag_ir.get_users(input_node)))
for input_node in inputs:
self.propagate_to_inputs(layout_node_meta, input_node)
if len(input_users) > 0:
input_users = set.union(*input_users)
input_users.remove(node)
for user in input_users:
self.propagate_to_users(layout_node_meta.get_inverse_node(), user)

View File

@@ -0,0 +1,185 @@
#################################################################################################
#
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
"""
Pass manager for DAG IR.
"""
from typing import Any
import networkx as nx
from cutlass_api.fusion.ir import DAGIR
from cutlass_api.fusion.passes.util import cc_map
class EVTPassBase:
"""
Base class for EVT Passes
"""
dependencies = []
def __init__(self, dag_ir: DAGIR) -> None:
self.dag_ir = dag_ir
self.cc = self.dag_ir.cc
def requires(self) -> None:
"""
This function will be called before the pass is run.
"""
pass
def call(self) -> None:
"""
The pass that is run through the self.dag_ir
"""
raise NotImplementedError(
f"__call__ is not overwritten in Pass {self.__class__.__name__}"
)
def ensures(self) -> None:
"""
This function will be called after the pass is run.
"""
pass
def __call__(self) -> Any:
self.requires()
self.call()
self.ensures()
def cc_specific_method(self, func):
"""
This enables defining function that behaves differently under different cc
The simplest example of using this function is the following
.. highlight:: python
.. code-block:: python
class ExamplePass(EVTPassBase):
def call(sekf):
# This automatically select the smXX_func based on current cc
self.cc_specific_method(self.func)()
# Interface func, can be empty
def func(self):
pass
# Sm90 specific func
def sm90_func(self):
// sm90 specific method
return
# Sm80 specific func
def sm80_func(self):
// sm80 specific method
return
"""
func_name = f"sm{cc_map[self.cc]}_{func.__name__}"
if hasattr(self, func_name):
return getattr(self, func_name)
else:
raise NotImplementedError(
f"func {func.__name__} is not overwritten for Sm{self.cc}"
)
class EVTPassManager(nx.DiGraph):
"""
Topological-based Pass Manager.
Each registered pass has a list of dependencies. The pass manager organizes
the passes as a DAG and launch the compiler passes under topological order.
"""
def __new__(cls, *args, **kwargs):
# NetworkX 3.0+ changed DiGraph.__new__ to accept fewer arguments.
# Override to accept and ignore extra arguments from __init__.
return super().__new__(cls)
def __init__(self, dag_ir: DAGIR, pass_list, soft_dependencies=None):
super().__init__()
self.dag_ir = dag_ir
for pass_cls in pass_list:
self.add_pass(pass_cls)
self.passes = pass_list
self.soft_dependencies = soft_dependencies
if self.soft_dependencies is None:
self.soft_dependencies = []
self.sorted_passes = self.schedule()
def get_callable(self, pass_name):
"""
Return the callable of the pass
"""
return self.nodes[pass_name]["callable"]
def add_pass(self, pass_cls):
"""
Add a pass to the pass manager
:param pass_cls: the class of pass
:type pass_cls: derived class of EVTPassBase
"""
name = pass_cls.__name__
pass_callable = pass_cls(self.dag_ir)
self.add_node(name, callable=pass_callable)
def schedule(self):
"""
Schedule the added passes under topological order
"""
# Add edges
for pass_name in self.nodes:
callable = self.get_callable(pass_name)
for dependency_cls in callable.dependencies:
if dependency_cls not in self.passes:
if dependency_cls not in self.soft_dependencies:
raise ValueError(
f"Pass {pass_name} depends on {dependency_cls.__name__}, which is not in the pass list"
)
else:
continue
self.add_edge(dependency_cls.__name__, type(callable).__name__)
# Topological sort
return list(nx.topological_sort(self))
def __call__(self) -> Any:
"""
Launch the registered passes
"""
for pass_name in self.sorted_passes:
callable = self.get_callable(pass_name)
callable()

View File

@@ -0,0 +1,59 @@
#################################################################################################
#
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
"""
No op elimination node
"""
from typing import Any
from cutlass_api.fusion.ir import NoOpImpl
from cutlass_api.fusion.passes.pass_manager import EVTPassBase
class PassNoOpElimination(EVTPassBase):
"""
The dead node elimination pass removes nodes with NoOpImpl in DAG IR
"""
dependencies = []
def call(self) -> Any:
for node in self.dag_ir.nodes_topological_order():
node_meta = self.dag_ir.get_node_meta(node)
if isinstance(node_meta.underlying_impl, NoOpImpl):
inputs = self.dag_ir.get_all_inputs(node)
if len(inputs) != 1:
raise RuntimeError(
f"NoOp node {node} has {len(inputs)} inputs, expected 1"
)
self.dag_ir.replace_all_uses_with(node, inputs[0])

View File

@@ -0,0 +1,96 @@
#################################################################################################
#
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
"""
Preprocess the reduction nodes.
The parser treats reduction as Compute(op=(reg_reduce_fn, gmem_reduce_fn)) - Store()
This pass fuses these into a single store node, and then replaces all uses of the
current node with the new store node.
"""
from cutlass_api.fusion.ir import ComputeNode, StoreNode
from cutlass_api.fusion.passes.pass_manager import EVTPassBase
class PassPreprocessRed(EVTPassBase):
"""
Preprocess red nodes
"""
def call(self):
# Step 1: find the compute nodes with op=red
red_compute_nodes = []
for node_meta in self.dag_ir.nodes_meta:
if isinstance(node_meta, ComputeNode) and type(node_meta.fn) == tuple:
# To keep the frontend simple, the reduction nodes
# are parsed into compute nodes by default
# The simple heuristic to distinguish between compute
# and reduction node is that compute node is a single function,
# while the reduction node is a tuple of functions for
# in-register reduction and atomic global memory reduction
red_compute_nodes.append(node_meta.name)
# Step 2: for each compute, merge it with the succeeding store
for node in red_compute_nodes:
# Verify
users = self.dag_ir.get_users(node)
inputs = self.dag_ir.get_all_inputs(node)
# Has a single user
assert len(users) == 1
assert len(inputs) == 1
user = users[0]
input_node = inputs[0]
user_meta = self.dag_ir.get_node_meta(user)
# Must be a store node
assert isinstance(user_meta, StoreNode)
# With output degree == 0
assert self.dag_ir.out_degree(user) == 0
# Register the reduce op
node_meta = self.dag_ir.get_node_meta(node)
user_meta.reg_reduce_fn, user_meta.gmem_reduce_fn = node_meta.fn
user_meta.element_compute = node_meta.element_compute
user_meta.round_style = node_meta.round_style
# Replace all uses
self.dag_ir.remove_edge(input_node, node)
input_users = self.dag_ir.get_users(input_node)
for iu in input_users:
weight = self.dag_ir.get_edge_weight(input_node, iu)
self.dag_ir.add_edge(user, iu, weight)
self.dag_ir.remove_edge(input_node, iu)
self.dag_ir.add_edge(input_node, user)
self.dag_ir.remove_node(node)
# Register the reduction name
self.dag_ir.reduction_names.append(user)

View File

@@ -0,0 +1,60 @@
#################################################################################################
#
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
"""
Shape and type propagation pass
"""
from cutlass_api.fusion.ir.node import NodeBase
from cutlass_api.fusion.passes.pass_manager import EVTPassBase
from cutlass_api.fusion.passes.pass_preprocess_red import PassPreprocessRed
class PassShapeTypePropagation(EVTPassBase):
"""
Propagate the shape and type of all nodes
"""
dependencies = [PassPreprocessRed]
def call(self):
# Propagate the node shape and type
for node in self.dag_ir.nodes_topological_order():
node_meta: NodeBase = self.dag_ir.get_node_meta(node)
input_node_metas = self.dag_ir.get_all_inputs_meta(node)
node_meta.type_propagation(input_node_metas)
node_meta.shape_propagation(input_node_metas)
for node in reversed(self.dag_ir.nodes_topological_order()):
node_meta: NodeBase = self.dag_ir.get_node_meta(node)
input_node_metas = self.dag_ir.get_all_inputs_meta(node)
node_meta.broadcast_propagation(input_node_metas)

View File

@@ -0,0 +1,363 @@
#################################################################################################
#
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
"""
Compute the shared memory size in bytes
"""
from math import gcd
from cutlass_api.fusion.library import DataType, DataTypeSize, EpilogueScheduleType
from cutlass_api.fusion.ir import TopoVisitorNode, DAGIR
from cutlass_api.fusion.pycute import flatten, shape_div, product
class GetSmemSize:
"""
Get the size in byte of shared memory used by the kernel
"""
def __init__(self, dag_ir: DAGIR) -> None:
self.dag_ir = dag_ir
self.cc = self.dag_ir.cc
#
# Sm90 epilogue specific
#
def sm90_epilogue_tile(self, tile_description):
# Get the epilogue tile size
schedule = tile_description.epilogue_schedule
if schedule == EpilogueScheduleType.TmaWarpSpecialized:
element_d = self.dag_ir.get_node_meta("D").element
nperf = (
64
if (
DataTypeSize[element_d] == 8
and tile_description.threadblock_shape[1] % 64 == 0
)
else 32
)
epi_tile_m = min(64, tile_description.threadblock_shape[0])
epi_tile_n = gcd(
min(nperf, tile_description.threadblock_shape[1]),
tile_description.threadblock_shape[1],
)
epilogue_tile_mn = (epi_tile_m, epi_tile_n)
elif schedule == EpilogueScheduleType.TmaWarpSpecializedCooperative:
epi_tile_m = min(128, tile_description.threadblock_shape[0])
epi_tile_n = gcd(
min(32, tile_description.threadblock_shape[1]),
tile_description.threadblock_shape[1],
)
epilogue_tile_mn = (epi_tile_m, epi_tile_n)
else:
raise NotImplementedError(f"Unsupported schedule: {schedule}")
# Get the pipeline stages
stages_d = 2
epi_tiles = product(
shape_div(tuple(tile_description.threadblock_shape)[:2], epilogue_tile_mn)
)
if self.dag_ir.has_node("C"):
element_c = self.dag_ir.get_node_meta("C").element
else:
element_c = None
element_d = self.dag_ir.get_node_meta("D").element
if element_c == element_d:
reuse_smem_c = True
else:
reuse_smem_c = False
stages_c = max(epi_tiles, stages_d + 1) if reuse_smem_c else epi_tiles
# Record the epilogue tile
self.cta_tile_mnk = tuple(tile_description.threadblock_shape)
self.epilogue_tile_mn = epilogue_tile_mn
self.epi_tiles = epi_tiles
self.stages_c = stages_c
self.stages_d = stages_d
self.reuse_smem_c = reuse_smem_c
self.element_c = element_c
self.element_d = element_d
self.is_source_supported = element_c is not None
def sm90_or_sm100_epilogue_smem_size(self, tile_description):
# Get the Fusion Storage
nodes = self.dag_ir.nodes_topological_order()
self.smem_types = {}
for node in nodes:
meta = self.dag_ir.get_node_meta(node)
if not meta.disabled:
self.smem_types[node] = meta.underlying_impl.get_smem_size(
self.cta_tile_mnk,
self.epilogue_tile_mn,
self.stages_c,
self.stages_d,
self.epi_tiles,
)
if node == "D":
continue
if isinstance(meta, TopoVisitorNode):
self.get_dag_smem_type(node)
else:
self.get_evt_smem_type(node)
thread_smem_size = self.smem_types[self.dag_ir.get_all_inputs("D")[0]][0]
# Get the Tensor Storage
tensors = []
if self.is_source_supported:
smem_C = (
DataTypeSize[self.element_c]
* product(self.epilogue_tile_mn)
* self.stages_c
// 8
)
tensors.append((smem_C, 128))
else:
tensors.append((0, 1))
if self.reuse_smem_c:
tensors.append((0, 128))
else:
smem_D = (
DataTypeSize[self.element_d]
* product(self.epilogue_tile_mn)
* self.stages_d
// 8
)
tensors.append((smem_D, 128))
tensors.append((thread_smem_size, 128))
tensor_smem_size = self.get_struct_size(tensors)
# Get pipeline storage size
# sizeof(uint64_t * stages_c * 2), alignment of uint64_t
# 2 is for FullBarrier and EmptyBarrier
pipeline_smem_size = (8 * self.stages_c * 2, 8)
# get SharedStorage size
smem_size = self.get_struct_size([tensor_smem_size, pipeline_smem_size])
return smem_size[0]
def sm90_epilogue_smem_size(self, tile_description):
"""
Compute the shared memory size of sm90 collective epilogue
"""
self.sm90_epilogue_tile(tile_description)
return self.sm90_or_sm100_epilogue_smem_size(tile_description)
#
# Sm100 epilogue specific
#
def sm100_epilogue_tile(self, tile_description):
cta_tile = (
tile_description.blackwell_threadblock_shape[0],
tile_description.blackwell_threadblock_shape[1],
)
mma_tile = cta_tile
if tile_description.is_2sm:
cta_tile = (cta_tile[0] // 2, cta_tile[1])
if tile_description.is_2sm and mma_tile[0] == 128:
tmem_warps = (2, 2)
else:
tmem_warps = (4, 1)
if self.dag_ir.has_node("C"):
element_c = self.dag_ir.get_node_meta("C").element
element_c_size = DataTypeSize[element_c]
else:
element_c = None
element_c_size = 0
element_d = self.dag_ir.get_node_meta("D").element
DisableSource = (
element_c is None
or not self.dag_ir.has_node("C")
or self.dag_ir.get_node_meta("C").element == DataType.void
)
CtaM = cta_tile[0]
CtaN = cta_tile[1]
WarpM = tmem_warps[0]
WarpN = tmem_warps[1]
MaxBits = max(element_c_size, DataTypeSize[element_d])
DpFull = 32
M = min(CtaM, DpFull * WarpM)
if DisableSource:
# Epilogues w/o residual load are less sensitive to smem allocation
# Target a fixed amount of compute per epilogue iteration
if MaxBits == 4:
# Make epilogue tile larger to reduce the epilogue iterations.
# 64 is the experimental value. It will minimize epilogue iterations but keep the number of A/B buffers the same.
ComputeElts = 8192
Nperf = ComputeElts // M
else:
ComputeElts = 4096
Nperf = ComputeElts // M
else:
# Epilogues w/ residual load are more sensitive to smem allocation
# Target optimal smem distribution between epilogue+mainloop based on datatype+tilesize
if MaxBits == 32:
Nperf = 16 if CtaM > 64 and CtaN <= 128 else 32
elif MaxBits == 16:
Nperf = 32 if CtaN <= 128 else 64
else:
Nperf = 64
def is_m_major(layout):
return flatten(layout.stride[0]) == 1
if DisableSource or is_m_major(self.dag_ir.get_node_meta("C").tensor.layout):
N_min_C = 8 * WarpN
elif element_c_size == 6:
N_min_C = 128 * WarpN
else:
N_min_C = (128 // element_c_size) * WarpN
if is_m_major(self.dag_ir.get_node_meta("D").tensor.layout):
N_min_D = 8 * WarpN
elif DataTypeSize[element_d] == 6:
N_min_D = 128 * WarpN
else:
N_min_D = (128 // DataTypeSize[element_d]) * WarpN
N = min(CtaN, max(Nperf, N_min_C, N_min_D))
tile_m = M
tile_n_size = N // WarpN * WarpN
epilogue_tile_mn = (tile_m, tile_n_size)
epi_tiles = product(
shape_div(tuple(tile_description.threadblock_shape)[:2], epilogue_tile_mn)
)
stages_d = min(epi_tiles, 2)
reuse_smem_c = element_c_size > 8
if reuse_smem_c:
stages_c = max(min(epi_tiles, 4), stages_d + 1)
else:
stages_c = min(epi_tiles, 4)
# Record the epilogue tile
self.cta_tile_mnk = tuple(tile_description.threadblock_shape)
self.epilogue_tile_mn = epilogue_tile_mn
self.epi_tiles = epi_tiles
self.stages_c = stages_c
self.stages_d = stages_d
self.reuse_smem_c = reuse_smem_c
self.element_c = element_c
self.element_d = element_d
self.is_source_supported = not DisableSource
def sm100_epilogue_smem_size(self, tile_description):
"""
Compute the shared memory size of sm100 collective epilogue
"""
self.sm100_epilogue_tile(tile_description)
return self.sm90_or_sm100_epilogue_smem_size(tile_description)
def __call__(self, tile_description):
return getattr(self, f"sm{self.cc}_epilogue_smem_size")(tile_description)
#
# Helper functions
#
@staticmethod
def get_visitor_size(members: list, ebo: bool):
"""
Get the size of struct in bytes
"""
offset = 0
max_alignment = 1
if len(members) > 0:
# Get alignment
for _, alignment in members:
max_alignment = max(max_alignment, alignment)
for type_size, _ in members:
if type_size != 0:
offset = (
(offset + max_alignment - 1) // max_alignment
) * max_alignment
if type_size == 0 and not ebo:
offset += 1
else:
offset += type_size
offset = ((offset + max_alignment - 1) // max_alignment) * max_alignment
return (offset, max_alignment)
else:
# Struct size is at least 1
return (1, 1)
def get_struct_size(self, members: list):
"""
Get the size of struct in bytes
"""
return self.get_visitor_size(members, False)
def get_evt_smem_type(self, node):
# Sort the input nodes by edge weight
input_types = [
self.smem_types[child] for child in self.dag_ir.get_all_inputs(node)
]
input_types.append(self.smem_types[node])
if len(input_types) > 1:
ebo = len(input_types) > 4
self.smem_types[node] = self.get_visitor_size(input_types, ebo)
def get_dag_smem_type(self, node):
meta = self.dag_ir.get_node_meta(node)
subgraph = meta.subgraph
subgraph_nodes = subgraph.nodes_topological_order()
# Visit the unvisited nodes in subgraph
for n in subgraph_nodes:
M = subgraph.get_node_meta(n)
if M.disabled:
continue
else:
self.smem_types[n] = M.underlying_impl.get_smem_size(
self.cta_tile_mnk,
self.epilogue_tile_mn,
self.stages_c,
self.stages_d,
self.epi_tiles,
)
input_types = [self.smem_types[child] for child in subgraph_nodes[:-1]]
if len(input_types) > 0:
ebo = len(input_types) > 4
self.smem_types[node] = self.get_visitor_size(input_types, ebo)

View File

@@ -0,0 +1,46 @@
#################################################################################################
#
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
"""
Utilities for passes
"""
# Map from the CC of the kernel to the EVT implementation that the CC targets
cc_map = {
80: 80,
86: 80,
89: 80,
90: 90,
100: 100,
101: 100,
103: 100,
}

View File

@@ -0,0 +1,36 @@
#################################################################################################
#
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
from .int_tuple import *
from .layout import *
from .swizzle import *
from .typing import *

View File

@@ -0,0 +1,229 @@
#################################################################################################
#
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
"""
Functions for manipulating IntTuples
"""
from functools import reduce
from itertools import chain
from typing import Union
from .typing import Integer
def is_int(x):
return isinstance(x, Integer)
def is_tuple(x):
return isinstance(x, tuple)
def flatten(t):
if is_tuple(t):
if len(t) == 0:
return ()
else:
return tuple(i for a in t for i in flatten(a))
else:
return (t,)
def signum(a):
return bool(a > 0) - bool(a < 0)
def product(a):
if is_tuple(a):
return reduce(lambda val, elem: val * product(elem), a, 1)
else:
return a
def inner_product(a, b):
if is_tuple(a): # tuple tuple
assert len(a) == len(b)
return sum(inner_product(x, y) for x, y in zip(a, b))
else: # "int" "int"
assert not is_tuple(b)
return a * b
def tuple_max(a):
if is_tuple(a):
return max(tuple_max(x) for x in a)
else:
return a
def elem_scale(a, b):
if is_tuple(a):
if is_tuple(b): # tuple tuple
assert len(a) == len(b)
return tuple(elem_scale(x, y) for x, y in zip(a, b))
else: # tuple "int"
assert False # Error
else:
if is_tuple(b): # "int" tuple
return elem_scale(a, product(b))
else: # "int" "int"
return a * b
# Inclusive prefix ceil div with output congruent to input a
def shape_div(a, b):
if is_tuple(a):
if is_tuple(b): # tuple tuple
assert len(a) == len(b)
return tuple(shape_div(x, y) for x, y in zip(a, b))
else: # tuple "int"
# r = [shape_div(a[0],b)] + [shape_div(a[i],b := shape_div(b, product(a[i-1]))) for i in range(1,len(a))]
r = []
for v in a:
r.append(shape_div(v, b))
b = shape_div(b, product(v))
return tuple(r)
else:
if is_tuple(b): # "int" tuple
return shape_div(a, product(b))
else: # "int" "int"
assert a % b == 0 or b % a == 0
return (a + b - 1) // b
# Exclusive prefix product with output congruent to input a
def prefix_product(a, init=1):
if is_tuple(a):
if is_tuple(init): # tuple tuple
assert len(a) == len(init)
return tuple(prefix_product(x, i) for x, i in zip(a, init))
else: # tuple "int"
# r = [prefix_product(a[0],init)] + [prefix_product(a[i],init := init * product(a[i-1])) for i in range(1,len(a))]
r = []
for v in a:
r.append(prefix_product(v, init))
init = init * product(v)
return tuple(r)
else:
if is_tuple(init): # "int" tuple
assert False # Error
else: # "int" "int"
return init
def idx2crd(idx, shape, stride=None):
if stride is None:
stride = prefix_product(shape)
if is_tuple(idx):
if is_tuple(shape): # tuple tuple tuple
assert len(idx) == len(shape) and len(idx) == len(stride)
return tuple(idx2crd(i, s, d) for i, s, d in zip(idx, shape, stride))
else: # tuple "int" "int"
assert False # Error
else:
if is_tuple(shape): # "int" tuple tuple
assert len(shape) == len(stride)
return tuple(idx2crd(idx, s, d) for s, d in zip(shape, stride))
else: # "int" "int" "int"
return (idx // stride) % shape
def crd2idx(crd, shape, stride=None):
if stride is None:
stride = prefix_product(shape)
if is_tuple(crd):
if is_tuple(shape): # tuple tuple tuple
assert len(crd) == len(shape) and len(crd) == len(stride)
return sum(crd2idx(c, s, d) for c, s, d in zip(crd, shape, stride))
else: # tuple "int" "int"
assert False, f"crd={crd}, shape={shape}" # Error
else:
if crd is None:
crd = 0
if is_tuple(shape): # "int" tuple tuple
assert len(shape) == len(stride)
result = 0
for i in range(len(shape) - 1):
result += crd2idx(crd % product(shape[i]), shape[i], stride[i])
crd = crd // product(shape[i])
return result + crd2idx(crd, shape[-1], stride[-1])
else: # "int" "int" "int"
return crd * stride
# Transform crd into the dst_shape's iteration space
def crd2crd(crd, dst_shape, src_shape=None):
if is_tuple(crd):
if is_tuple(dst_shape): # tuple tuple
assert len(crd) == len(dst_shape)
return tuple(crd2crd(x, y) for x, y in zip(crd, dst_shape))
else: # tuple "int"
# Ambiguous unless we have src_shape
assert src_shape is not None
return crd2idx(crd, src_shape)
else:
if is_tuple(dst_shape): # "int" tuple
return idx2crd(crd, dst_shape)
else: # "int" "int"
assert crd < dst_shape
return crd
# Filter trg according to crd: keep only elements of trg that are paired with None
def slice_(crd: Union[None, tuple, int], trg: Union[tuple, int]):
if is_tuple(crd):
if is_tuple(trg): # tuple tuple
assert len(crd) == len(trg)
# match C++ behavior of `filter_tuple` using `tuple_cat(...)`
return tuple(
chain(
*filter(lambda x: x != (), [slice_(c, s) for c, s in zip(crd, trg)])
)
)
else:
assert False # tuple "int" : Error
elif crd is None:
# match C++ behavior `return cute::tuple<B>{b};`
return (trg,)
else:
return ()
# Determine if None appears at any of an int_tuples' terminals
def has_none(a: Union[None, tuple, int]):
if is_tuple(a):
return any(has_none(v) for v in a)
else:
return a is None

View File

@@ -0,0 +1,410 @@
#################################################################################################
#
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
"""
Definition of CuTe Layouts and functions to manipulate them
"""
from itertools import chain
from .int_tuple import *
class LayoutBase:
pass
def is_layout(x):
return isinstance(x, LayoutBase)
class Layout(LayoutBase):
def __init__(self, _shape, _stride=None):
self.shape = _shape
if _stride is None:
self.stride = prefix_product(self.shape)
else:
self.stride = _stride
# operator ==
def __eq__(self, other):
return self.shape == other.shape and self.stride == other.stride
# operator len(L) (len [rank] like tuples)
def __len__(self):
if is_tuple(self.shape):
return len(self.shape)
else:
return 1
# operator () (map coord to idx)
def __call__(self, *args):
"""
Map a logical coordinate to a linear index (Coord has no Underscore slice operators)
OR
Slice the layout and return the sublayout (Coord has an Underscore slice op)
Follow the same behavior of `Layout::operator(Coord const&)` in cute C++
"""
if has_none(args):
if len(args) == 1:
return Layout(slice_(args[0], self.shape), slice_(args[0], self.stride))
else:
return Layout(slice_(args, self.shape), slice_(args, self.stride))
else:
if len(args) == 1:
return crd2idx(args[0], self.shape, self.stride)
else:
return crd2idx(args, self.shape, self.stride)
# operator [] (get-i like tuples)
def __getitem__(self, i):
if is_tuple(self.shape):
return Layout(self.shape[i], self.stride[i])
else:
assert i == 0
return Layout(self.shape, self.stride)
# size(layout) Size of the domain
def size(self):
return product(self.shape)
# cosize(layout) Size of the codomain
def cosize(self):
return self(self.size() - 1) + 1
# print and str
def __str__(self):
return f"{self.shape}:{self.stride}"
# error msgs and representation
def __repr__(self):
return f"Layout({self.shape},{self.stride})"
# Make Layout from a list of layouts (each layout it's own mode in the result)
def make_layout(*layouts):
if len(layouts) == 1 and not is_layout(layouts[0]):
layouts = layouts[0]
shape, stride = zip(*((a.shape, a.stride) for a in layouts))
return Layout(shape, stride)
# Size of the domain
def size(layout):
if is_layout(layout):
return layout.size()
return product(layout)
# Size of the codomain
def cosize(layout):
return layout.cosize()
# Layout coalesce -- flatten and combine as many modes as possible while preserving the int-to-int function
def coalesce(layout, profile=None):
if is_tuple(profile):
assert len(layout) >= len(profile)
return make_layout(
chain(
(coalesce(layout[i], profile[i]) for i in range(0, len(profile))),
(layout[i] for i in range(len(profile), len(layout))),
)
)
result_shape = [1]
result_stride = [0]
for shape, stride in zip(flatten(layout.shape), flatten(layout.stride)):
# skip their shape-1s
if shape == 1:
continue
# replace our shape-1 with anything
elif result_shape[-1] == 1:
result_shape[-1] = shape
result_stride[-1] = stride
# merge modes if the shape*stride match
elif result_shape[-1] * result_stride[-1] == stride:
result_shape[-1] = result_shape[-1] * shape
# append a new mode
else:
result_shape.append(shape)
result_stride.append(stride)
if len(result_shape) == 1:
return Layout(result_shape[0], result_stride[0])
else:
return Layout(tuple(result_shape), tuple(result_stride))
# Layout filter -- replace all stride-0 modes with size-1 and then coalesce to remove them
def filter(layout, profile=None):
if is_tuple(profile):
assert len(layout) >= len(profile)
return make_layout(
chain(
(filter(layout[i], profile[i]) for i in range(0, len(profile))),
(layout[i] for i in range(len(profile), len(layout))),
)
)
result_shape = []
result_stride = []
for shape, stride in zip(flatten(layout.shape), flatten(layout.stride)):
# skip their shape-1s and stride-0s
if not (shape == 1 or stride == 0):
result_shape.append(shape)
result_stride.append(stride)
if len(result_shape) == 0:
return Layout(1, 0)
else:
return coalesce(Layout(tuple(result_shape), tuple(result_stride)))
# Layout composition
# Use tuples-of-layouts to perform this operation by-mode and None as no-op
def composition(layoutA, layoutB):
if layoutB is None:
return layoutA
elif is_int(layoutB):
return composition(layoutA, Layout(layoutB))
elif is_tuple(layoutB):
assert len(layoutA) >= len(layoutB)
return make_layout(
chain(
(composition(layoutA[i], layoutB[i]) for i in range(0, len(layoutB))),
(layoutA[i] for i in range(len(layoutB), len(layoutA))),
)
)
elif is_tuple(layoutB.shape):
return make_layout(composition(layoutA, layoutB_i) for layoutB_i in layoutB)
if layoutB.stride == 0:
return Layout(layoutB.shape, 0)
else:
result_shape = []
result_stride = []
rest_shape = layoutB.shape
rest_stride = layoutB.stride
flat_A = coalesce(layoutA)
for curr_shape, curr_stride in zip(
flatten(flat_A.shape)[:-1], flatten(flat_A.stride)[:-1]
):
assert curr_shape % rest_stride == 0 or rest_stride % curr_shape == 0
new_shape = min(max(1, curr_shape // rest_stride), rest_shape)
if new_shape != 1:
result_shape.append(new_shape)
result_stride.append(rest_stride * curr_stride)
rest_shape = rest_shape // new_shape
rest_stride = -(
-rest_stride // curr_shape
) # Python exclusive impl: "//" is always floor div so == ceil_div(abs(rest_stride), curr_shape) * signum(rest_stride)
if rest_shape != 1 or len(result_shape) == 0:
result_shape.append(rest_shape)
result_stride.append(rest_stride * flatten(flat_A.stride)[-1])
if len(result_shape) == 1:
return Layout(result_shape[0], result_stride[0])
else:
return Layout(tuple(result_shape), tuple(result_stride))
# Layout complement
def complement(layout, max_idx=1):
if is_int(layout):
return complement(Layout(layout))
result_shape = []
result_stride = []
current_idx = 1
sorted_DS = sorted(zip(flatten(layout.stride), flatten(layout.shape)))
for stride, shape in sorted_DS:
if stride == 0 or shape == 1:
continue
in_bound = current_idx <= shape * stride
# To support symbolic value which can't be evaluated now
assert (type(in_bound) is not bool) or in_bound
result_shape.append(stride // current_idx)
result_stride.append(current_idx)
current_idx = shape * stride
result_shape.append((max_idx + current_idx - 1) // current_idx) # ceil_div
result_stride.append(current_idx)
return coalesce(Layout(tuple(result_shape), tuple(result_stride)))
# Layout right inverse
def right_inverse(layout):
if layout is None:
return None
elif is_int(layout):
return Layout(layout)
result_shape = []
result_stride = []
current_idx = 1
flat_shape = flatten(layout.shape)
flat_stride = flatten(layout.stride)
sorted_DSA = sorted(zip(flat_stride, flat_shape, prefix_product(flat_shape)))
for stride, shape, rstride in sorted_DSA:
if shape == 1:
continue
if current_idx != stride:
break
result_shape.append(shape)
result_stride.append(rstride)
current_idx = shape * stride
return coalesce(Layout(tuple(result_shape), tuple(result_stride)))
# Layout left inverse
def left_inverse(layout):
if layout is None:
return None
elif is_int(layout):
return Layout(layout)
return right_inverse(make_layout(layout, complement(layout)))
# Split a layout by the composition of B and the "rest"
# Use tuples-of-layouts to perform this operation by-mode and None as no-op
def logical_divide(layoutA, layoutB):
if layoutB is None:
return layoutA
elif is_int(layoutB):
return logical_divide(layoutA, Layout(layoutB))
elif is_tuple(layoutB):
assert len(layoutA) >= len(layoutB)
return make_layout(
chain(
(
logical_divide(layoutA[i], layoutB[i])
for i in range(0, len(layoutB))
),
(layoutA[i] for i in range(len(layoutB), len(layoutA))),
)
)
return composition(
layoutA, make_layout(layoutB, complement(layoutB, size(layoutA)))
)
# Reproduce a layoutA over a layoutB
# Use tuples-of-layouts to perform this operation by-mode and None as no-op
def logical_product(layoutA, layoutB):
if layoutB is None:
return layoutA
elif is_int(layoutB):
return logical_divide(layoutA, Layout(layoutB))
elif is_tuple(layoutB):
assert len(layoutA) >= len(layoutB)
return make_layout(
chain(
(
logical_product(layoutA[i], layoutB[i])
for i in range(0, len(layoutB))
),
(layoutA[i] for i in range(len(layoutB), len(layoutA))),
)
)
return make_layout(
layoutA,
composition(complement(layoutA, size(layoutA) * cosize(layoutB)), layoutB),
)
# Gather the modes from a hierarchical logical_divide or logical_product
def hier_unzip(splitter, layoutA, layoutB):
if layoutB is None:
return make_layout(Layout(1, 0), layoutA)
elif is_tuple(layoutB):
assert len(layoutA) >= len(layoutB)
# A layout with shape ((A,a),(B,b),(C,c))
split = make_layout(
hier_unzip(splitter, layoutA[i], layoutB[i]) for i in range(0, len(layoutB))
)
# Gather to shape ((A,B,C,...),(a,b,c,...,y,z))
return make_layout(
make_layout(split[i][0] for i in range(0, len(layoutB))),
make_layout(
chain(
(split[i][1] for i in range(0, len(layoutB))),
(layoutA[i] for i in range(len(layoutB), len(layoutA))),
)
),
)
# splitter must return a rank-2 layout
return splitter(layoutA, layoutB)
# Apply logical divide hierarchically and gather the split modes into two modes
def zipped_divide(layoutA, layoutB):
return hier_unzip(logical_divide, layoutA, layoutB)
# Perform logical divide hierarchically and gather tiles (B-layouts) into a new mode
def tiled_divide(layoutA, layoutB):
result = zipped_divide(layoutA, layoutB)
return make_layout([result[0]] + [result[1][i] for i in range(len(result[1]))])
# Apply logical product hierarchically and gather the split modes into two modes
def zipped_product(layoutA, layoutB):
return hier_unzip(logical_product, layoutA, layoutB)
# Perform logical product hierarchically and gather tiles (B-layouts) into a new mode
def tiled_product(layoutA, layoutB):
result = zipped_product(layoutA, layoutB)
return make_layout([result[0]] + [result[1][i] for i in range(len(result[1]))])
def slice_and_offset(crd: tuple, layout: Layout):
return (
Layout(slice_(crd, layout.shape), slice_(crd, layout.stride)),
crd2idx(crd, layout.shape, layout.stride),
)

View File

@@ -0,0 +1,133 @@
#################################################################################################
#
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
"""
Methods for layout swizzling
"""
from .layout import *
def shiftr(a, s):
return a >> s if s > 0 else shiftl(a, -s)
def shiftl(a, s):
return a << s if s > 0 else shiftr(a, -s)
## A generic Swizzle functor
# 0bxxxxxxxxxxxxxxxYYYxxxxxxxZZZxxxx
# ^--^ Base is the number of least-sig bits to keep constant
# ^-^ ^-^ Bits is the number of bits in the mask
# ^---------^ Shift is the distance to shift the YYY mask
# (pos shifts YYY to the right, neg shifts YYY to the left)
#
# e.g. Given
# 0bxxxxxxxxxxxxxxxxYYxxxxxxxxxZZxxx
# the result is
# 0bxxxxxxxxxxxxxxxxYYxxxxxxxxxAAxxx where AA = ZZ xor YY
#
class Swizzle:
def __init__(self, bits, base, shift):
assert bits >= 0
assert base >= 0
assert abs(shift) >= bits
self.bits = bits
self.base = base
self.shift = shift
bit_msk = (1 << bits) - 1
self.yyy_msk = bit_msk << (base + max(0, shift))
self.zzz_msk = bit_msk << (base - min(0, shift))
# operator () (transform integer)
def __call__(self, offset):
return offset ^ shiftr(offset & self.yyy_msk, self.shift)
# Size of the domain
def size(self):
return 1 << (self.bits + self.base + abs(self.shift))
# Size of the codomain
def cosize(self):
return self.size()
# print and str
def __str__(self):
return f"SW_{self.bits}_{self.base}_{self.shift}"
# error msgs and representation
def __repr__(self):
return f"Swizzle({self.bits},{self.base},{self.shift})"
class ComposedLayout(LayoutBase):
def __init__(self, layoutB, offset, layoutA):
self.layoutB = layoutB
self.offset = offset
self.layoutA = layoutA
# operator ==
def __eq__(self, other):
return (
self.layoutB == other.layoutB
and self.offset == other.offset
and self.layoutA == other.layoutA
)
# operator len(L) (len [rank] like tuples)
def __len__(self):
return len(self.layoutA)
# operator () (map coord to idx)
def __call__(self, *args):
return self.layoutB(self.offset + self.layoutA(*args))
# operator [] (get-i like tuples)
def __getitem__(self, i):
return ComposedLayout(self.layoutB, self.offset, self.layoutA[i])
# size(layout) Size of the domain
def size(self):
return size(self.layoutA)
# cosize(layout) Size of the codomain
def cosize(self):
return cosize(self.layoutB)
# print and str
def __str__(self):
return f"{self.layoutB} o {self.offset} o {self.layoutA}"
# error msgs and representation
def __repr__(self):
return f"ComposedLayout({repr(self.layoutB)},{repr(self.offset)},{repr(self.layoutA)})"

View File

@@ -0,0 +1,42 @@
#################################################################################################
#
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
from abc import ABC
class Integer(ABC):
@classmethod
def __subclasshook__(cls, c):
if c in [bool, float]:
return False
return issubclass(c, int)

View File

@@ -0,0 +1,196 @@
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
from abc import ABC, abstractmethod
from collections.abc import Callable
from typing import final
import cuda.bindings.driver as cuda
from cutlass_api.arguments import EpilogueArguments, RuntimeArguments
from cutlass_api.artifact import CompiledArtifact
from cutlass_api.metadata import KernelMetadata
from cutlass_api.status import Status
class Kernel(ABC):
"""
Base class for all kernels to be implemented in providers
"""
@final
def supports(self, args: RuntimeArguments) -> Status:
"""
Returns whether the kernel can be compiled or run with the provided arguments
:param args: arguments with which the kernel is to be compiled or run
:type args: RuntimeArguments
:return: Status indicating whether the kernel can be compiled or run with the provided arguments
:rtype: Status
"""
# Metadata should capture most common checks
if not (status := self.metadata.supports(args)):
return status
# Additional checks can be implemented in the subclass
return self._supports(args)
@abstractmethod
def compile(
self, args: RuntimeArguments, cc: int | None = None
) -> CompiledArtifact:
"""
Compiles the kernel.
For just-in-time compilation, the CompiledArtifact can be used to execute the kernel using run().
For ahead-of-time compilation, the CompiledArtifact can be saved and restored later to use with run().
:param args: Arguments to compile the kernel with. These need not be the same arguments as those passed to the run() method.
:type args: RuntimeArguments
:param cc: Compute capability of device for which the kernel is to be compiled. For example, if running on H100, this should be set to 90.
:type cc: int
"""
raise NotImplementedError
@final
def run(
self,
args: RuntimeArguments,
compiled_artifact: CompiledArtifact | None = None,
stream=None,
workspace=None,
assume_supported_args: bool = False,
) -> None:
"""
Executes the kernel end-to-end with provided arguments -- check args are supported, compile if needed, create a default stream if needed, and launch the kernel.
:param args: Arguments to run the kernel with
:type args: RuntimeArguments
:param compiled_artifact: Compiled kernel object returned from the compile() method above. If None, the kernel will first be compiled.
:type compiled_artifact: CompiledArtifact | None
:param stream: Stream to execute the kernel on. If not provided, the default/null stream cuda.CUstream(0) is used.
:type stream: cuda.CUstream, torch.cuda.Stream, or other stream-like object, or None
:param workspace: Allocation of workspace at least as large as the workspace size returned from the get_workspace_size() method. If the kernel does not require workspace, this can be None.
If a workspace of inappropriate size is provided, the behavior is undefined and the kernel may crash.
:type workspace: any | None
:param assume_supported_args: By default, kernel.supports(args) is called to check if the arguments are supported. If True, this check is skipped.
:type assume_supported_args: bool
"""
if not assume_supported_args and not (supports := self.supports(args)):
raise ValueError(
f"Kernel does not support the provided arguments: {supports.error}"
)
compiled_artifact = (
self.compile(args) if not compiled_artifact else compiled_artifact
)
if not stream:
stream = cuda.CUstream(0)
return self._run(args, compiled_artifact, stream, workspace)
def get_workspace_size(self, args: RuntimeArguments) -> int:
"""
Returns the size of the workspace required by the kernel in bytes.
:param args: arguments of the kernel
:type args: RuntimeArguments
:return: size of the workspace required by the kernel in bytes
:rtype: int
"""
return 0
def initialize_workspace(self, args: RuntimeArguments, workspace) -> None:
"""
Initializes the workspace for the kernel.
:param args: Arguments to initialize the workspace for
:type args: RuntimeArguments
:param workspace: Workspace to initialize
:type workspace: any
"""
return
@staticmethod
@abstractmethod
def generate_kernels(
metadata_filter: Callable[[KernelMetadata], bool],
epilogue_args: EpilogueArguments = None,
cc: int = None,
) -> list["Kernel"]:
"""
Populates the `kernels` list with all supported kernel configurations
for the given compute capability and arguments.
:param metadata_filter: Optional function that takes KernelMetadata and returns True for kernels to include
:type metadata_filter: Callable[[KernelMetadata], bool]
:param epilogue_args: Optional arguments to pass to kernel epilogue
:type epilogue_args: EpilogueArguments | None
:param cc: Optional compute capability to target; e.g., 90 for H100
:type cc: int | None
:return: list of all supported kernel configurations
:rtype: list[Kernel]
"""
raise NotImplementedError
def _supports(self, args: RuntimeArguments) -> Status:
"""
Classes may override this method to perform any additional checks that are not captured in metadata.
Ideally, all/most checks should be captured in the metadata.
By default, no such checks are performed and the method trivially returns Status.success().
:param args: Arguments to check support for
:type args: RuntimeArguments
:return: Status indicating success, or reason why the kernel does not support the provided arguments
:rtype: Status
"""
return Status.success()
@abstractmethod
def _run(
self,
args: RuntimeArguments,
compiled_artifact: CompiledArtifact,
stream,
workspace=None,
) -> None:
"""
A miniminal version of the run() method that assumes args are supported, a valid compiled artifact and a valid stream are provided.
It is intended to be overridden by subclasses to implement the actual kernel execution, and be a minimal wrapper to launching the kernel.
:param args: Arguments to run the kernel with. It is assumed that kernel.supports(args) is True.
:type args: RuntimeArguments
:param compiled_artifact: Compiled kernel object returned from the compile() method
:type compiled_artifact: CompiledArtifact
:param stream: Stream to execute the kernel on.
:type stream: cuda.CUstream, torch.cuda.Stream, or other stream-like object
:param workspace: Allocation of workspace at least as large as the workspace size returned from the get_workspace_size() method.
If the kernel does not require workspace, this can be None.
If a workspace of inappropriate size is provided, the behavior is undefined and the kernel may crash.
:type workspace: any
"""
raise NotImplementedError

View File

@@ -0,0 +1,138 @@
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
import logging
from collections.abc import Callable
from cutlass_api.arguments import RuntimeArguments
from cutlass_api.kernel import Kernel
from cutlass_api.metadata import KernelMetadata
from cutlass_api.providers import available_providers
_logger = logging.getLogger(__name__)
class Manifest:
def __init__(self) -> None:
self._candidate_kernels = []
@staticmethod
def get_kernels(
args: RuntimeArguments = None,
metadata_filter: Callable[[KernelMetadata], bool] | None = None,
cc: int = None,
providers: list[str] = None,
) -> list[Kernel]:
"""
Get the kernels that match the given arguments, metadata filter, and compute capability.
:param args: the arguments of the kernel
:type args: RuntimeArguments
:param metadata_filter: a boolean function that takes in KernelMetadata and returns whether
a Kernel from this metadata should be included
:type metadata_filter: Callable[[KernelMetadata], bool]
:param cc: the compute capability
:type cc: int
:param providers: the providers to use
:type providers: list[str]
:return: the kernels that match the given arguments, metadata filter, and compute capability
:rtype: list[Kernel]
"""
# Setup providers to use
providers_to_use = []
if providers is None:
providers_to_use = list(available_providers.values())
else:
for provider in providers:
if provider not in available_providers:
raise ValueError(f"Provider {provider} is not available")
providers_to_use.append(available_providers[provider])
# Setup filter function
filter_fn = (lambda x: True) if metadata_filter is None else metadata_filter
if args is None:
full_filter_fn = filter_fn
else:
def full_filter_fn(metadata: KernelMetadata) -> bool:
if not filter_fn(metadata):
return False
if not (supports := metadata.supports(args)):
_logger.debug(
f"Rejecting kernel {metadata.kernel_name}. Reason: {supports.error}"
)
return False
return True
epilogue_args = None if args is None else args.epilogue
kernels = [
k
# Generate kernels from all providers
for provider in providers_to_use
# Filter kernels by metadata
for k in provider.generate_kernels(
full_filter_fn, cc=cc, epilogue_args=epilogue_args
)
# Do any additional checks on args that are not captured by metadata
if not args or k._supports(args)
]
return kernels
def add_kernels(
self,
args: RuntimeArguments = None,
metadata_filter: Callable[[KernelMetadata], bool] | None = None,
cc: int = None,
providers: list[str] = None,
) -> list[Kernel]:
"""
Get the kernels that match the given arguments, metadata filter, and compute capability,
and add them to the Manifest's set of discovered kernels.
:param args: the arguments of the kernel
:type args: RuntimeArguments
:param metadata_filter: a boolean function that takes in KernelMetadata and returns whether
a Kernel from this metadata should be included
:type metadata_filter: Callable[[KernelMetadata], bool]
:param cc: the compute capability
:type cc: int
:param providers: the providers to use
:type providers: list[str]
:return: the kernels that match the given arguments, metadata filter, and compute capability
:rtype: list[Kernel]
"""
matched_kernels = Manifest.get_kernels(args, metadata_filter, cc, providers)
self._candidate_kernels.extend(matched_kernels)
return matched_kernels
@property
def kernels(self) -> list[Kernel]:
return self._candidate_kernels

View File

@@ -0,0 +1,504 @@
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
from __future__ import annotations
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import TYPE_CHECKING, Self
if TYPE_CHECKING:
import cutlass
import cutlass.cute as cute
from cutlass_api.arguments import (
ElementwiseArguments,
EpilogueArguments,
GemmArguments,
RuntimeArguments,
)
from cutlass_api.status import Status
from cutlass_api.utils import TensorWrapper
def _convert_stride(shape: tuple[int, ...], stride: tuple[int, ...]) -> tuple[int, ...]:
"""
Zeros out modes of stride that have a shape of 1.
:param shape: the shape of the tensor
:type shape: tuple[int, ...]
:param stride: the stride of the tensor
:type stride: tuple[int, ...]
:return: the converted stride
:rtype: tuple[int, ...]
"""
new_stride = []
for i in range(len(shape)):
if shape[i] == 1:
new_stride.append(0)
else:
new_stride.append(stride[i])
return new_stride
def _get_max_pow2_alignment(
shape: tuple[int, ...], stride: tuple[int, ...], dtype: cutlass.Numeric
) -> int:
"""
Get the maximum power of 2 alignment for a given data type
:param shape: the shape of the tensor
:type shape: tuple[int, ...]
:param stride: the stride of the tensor
:type stride: tuple[int, ...]
:param dtype: the data type
:type dtype: cutlass.Numeric
:return: the maximum power of 2 alignment
:rtype: int
"""
if 1 not in stride:
return 1
major_mode_idx = stride.index(1)
num_major_elements = shape[major_mode_idx]
for alignment in [128, 64, 32, 16, 8, 4, 2]:
num_contiguous_elements = alignment * 8 // dtype.width
if num_major_elements % num_contiguous_elements == 0:
return alignment
return 1
@dataclass
class TensorAttributes:
"""
Description of a single tensor. This includes the data type, stride, and alignment.
:param dtype: The data type of the tensor.
:type dtype: cutlass.Numeric
:param stride: The stride of the tensor.
:type stride: tuple[int, ...]
:param alignment: The alignment of the tensor
:type alignment: int
"""
dtype: cutlass.Numeric # F32, F16, etc.
stride: tuple[int, ...]
alignment: int
def supports(self, operand: TensorWrapper | Self) -> Status:
"""
Checks whether the provided args satisfy the properties described by
these TensorAttributes.
:param operand: The operand to check support for.
:type operand: TensorWrapper | Self
:return: Whether the provided operand satisfies the properties described by
these TensorAttributes.
:rtype: Status
"""
if isinstance(operand, TensorWrapper):
if operand.element_type != self.dtype:
return Status.fail(
f"Expected element type {self.dtype}, got {operand.element_type}"
)
elif operand.dtype != self.dtype:
return Status.fail(
f"Expected element type {self.dtype}, got {operand.dtype}"
)
# Normalize the operand stride to have 0's where operand.shape is 1
normalized_operand_stride = _convert_stride(operand.shape, operand.stride)
# If the operand is batched, the expected stride is the same as the self.stride
if len(self.stride) == len(normalized_operand_stride):
expected_stride = self.stride
# We can try if strides match without the batch mode
elif len(self.stride) - 1 == len(normalized_operand_stride):
expected_stride = self.stride[1:]
else:
return Status.fail(
f"Expected tensor with strides of rank {len(self.stride)} (batched) or {len(self.stride) - 1} (unbatched), got {len(normalized_operand_stride)} ({normalized_operand_stride})"
)
# Strides are considered compatible if:
# 1. They have the same rank (checked above)
# 2. The leading mode is the same (or both are all 0)
all_zeros = all(x == 0 for x in expected_stride) and all(
x == 0 for x in normalized_operand_stride
)
# When setting stride from args, any modes of stride 1 and shape 1
# are changed to have stride 0. Thus, there will only be one mode
# with stride 1.
if not all_zeros and normalized_operand_stride[expected_stride.index(1)] != 1:
return Status.fail(
f"Expected stride[{expected_stride.index(1)}] to be 1, got {normalized_operand_stride[expected_stride.index(1)]} (strides: {normalized_operand_stride})"
)
# Alignment of operand should be divisible by this metadata's alignment
if isinstance(operand, TensorWrapper):
operand_alignment = _get_max_pow2_alignment(
operand.shape, normalized_operand_stride, operand.element_type
)
else:
operand_alignment = operand.alignment
if operand_alignment % self.alignment != 0:
return Status.fail(
f"Expected operand alignment {operand_alignment} (strides: {normalized_operand_stride}) to be a multiple of {self.alignment}"
)
return Status.success()
@staticmethod
def from_tensor(tensor) -> Self:
"""
Create a TensorAttributes from a tensor.
:param tensor: The tensor to create a TensorAttributes from.
:type tensor: cute.Tensor
:return: The TensorAttributes corresponding to the provided tensor.
:rtype: TensorAttributes
"""
stride = _convert_stride(tensor.shape, tensor.stride)
max_alignment = _get_max_pow2_alignment(
tensor.shape, stride, tensor.element_type
)
return TensorAttributes(
dtype=tensor.element_type, stride=stride, alignment=max_alignment
)
@dataclass
class OperandsMetadata(ABC):
"""
Base metadata class for descriptions of operands (e.g., GEMM A, B, out).
"""
@abstractmethod
def supports(self, args: RuntimeArguments) -> Status:
"""
Checks whether the provided args satisfy the properties described by
the operands in this metadata.
:param args: The arguments to check support for.
:type args: RuntimeArguments
:return: Whether the provided args satisfy the properties described by
the operands in this metadata.
:rtype: Status
"""
@dataclass
class DesignMetadata(ABC):
"""
Base metadata class for descriptions of design parameters for an operation
(e.g., tile shape, cluster shape, etc.).
"""
@abstractmethod
def supports(self, args: RuntimeArguments) -> Status:
"""
Checks whether the provided args satisfy the properties described by
the design in this metadata.
:param args: The arguments to check support for.
:type args: RuntimeArguments
:return: Whether the provided args satisfy the properties described by
the design in this metadata.
:rtype: Status
"""
@dataclass
class BLASDesignMetadata(DesignMetadata):
"""
Design metadata for a basic-linear algebra subprogram (BLAS) operation.
These include fields for tiling-related parameters (e.g., tile shape and cluster shape).
"""
tile_shape: tuple[int, ...]
cluster_shape: tuple[int, ...]
def supports(self, args: RuntimeArguments) -> Status:
"""
Checks whether the provided args satisfy the properties described by
the design in this metadata.
:param args: The arguments to check support for.
:type args: RuntimeArguments
:return: Whether the provided args satisfy the properties described by
the design in this metadata.
:rtype: Status
"""
if args.performance is not None:
return Status.fail(
"BLASDesignMetadata does not yet support performance controls"
)
return Status.success()
@dataclass
class Sm100DesignMetadata(BLASDesignMetadata):
"""
Design metadata for kernels in the SM100 architecture family.
"""
# Whether to use a 2CTA MMA instruction
use_2cta_mma: bool
# Whether to use TMA to store the results of the operation
use_tma_store: bool
@dataclass
class GemmOperandsMetadata(OperandsMetadata):
"""
Metadata for the operands of a GEMM operation.
:param A: Metadata for the input tensor A.
:type A: TensorAttributes
:param B: Metadata for the input tensor B.
:type B: TensorAttributes
:param out: Metadata for the output tensor.
:type out: TensorAttributes
:param accumulator_type: The data type of the accumulator tensor.
:type accumulator_type: cutlass.Numeric
"""
A: TensorAttributes
B: TensorAttributes
out: TensorAttributes
accumulator_type: cutlass.Numeric
def supports(self, other: GemmArguments | Self) -> Status:
"""
Checks whether the provided args satisfy the properties described by
the operands in this metadata.
:param other: The arguments to check support for.
:type other: GemmArguments | Self
:return: Whether the provided args satisfy the properties described by
the operands in this metadata.
:rtype: Status
"""
if isinstance(other, RuntimeArguments) and not isinstance(other, GemmArguments):
return Status.fail(f"Expected GemmArguments, got {type(other)}")
if not (status := self.A.supports(other.A)):
return Status.fail(f"Operand `A` is unsupported: {status.error}")
if not (status := self.B.supports(other.B)):
return Status.fail(f"Operand `B` is unsupported: {status.error}")
if not (status := self.out.supports(other.out)):
return Status.fail(f"Operand `out` is unsupported: {status.error}")
if self.accumulator_type != other.accumulator_type:
return Status.fail(
f"Expected accumulator type {self.accumulator_type}, got {other.accumulator_type}"
)
return Status.success()
@staticmethod
def from_args(args: GemmArguments) -> Self:
"""
Create a GemmOperandsMetadata from a GemmArguments.
:param args: The GemmArguments to create a GemmOperandsMetadata from.
:type args: GemmArguments
:return: The GemmOperandsMetadata corresponding to the provided GemmArguments.
:rtype: GemmOperandsMetadata
"""
return GemmOperandsMetadata(
A=TensorAttributes.from_tensor(args.A),
B=TensorAttributes.from_tensor(args.B),
out=TensorAttributes.from_tensor(args.out),
accumulator_type=args.accumulator_type,
)
@dataclass
class ElementwiseOperandsMetadata(OperandsMetadata):
"""
Metadata for the operands of an elementwise operation.
:param A: Metadata for the input tensor A.
:type A: TensorAttributes
:param B: Metadata for the input tensor B.
:type B: TensorAttributes
:param out: Metadata for the output tensor.
:type out: TensorAttributes
"""
A: TensorAttributes
B: TensorAttributes
out: TensorAttributes
def supports(self, other: ElementwiseArguments | Self) -> Status:
"""
Checks whether the provided args satisfy the properties described by
the operands in this metadata.
:param other: The arguments to check support for.
:type other: ElementwiseArguments | Self
:return: Whether the provided args satisfy the properties described by
the operands in this metadata.
:rtype: Status
"""
if isinstance(other, RuntimeArguments) and not isinstance(
other, ElementwiseArguments
):
return Status.fail(f"Expected ElementwiseArguments, got {type(other)}")
if not (status := self.A.supports(other.A)):
return Status.fail(f"Operand `A` is unsupported: {status.error}")
if not (status := self.B.supports(other.B)):
return Status.fail(f"Operand `B` is unsupported: {status.error}")
if not (status := self.out.supports(other.out)):
return Status.fail(f"Operand `out` is unsupported: {status.error}")
return Status.success()
@staticmethod
def from_args(args: ElementwiseArguments) -> Self:
"""
Create a ElementwiseOperandsMetadata from a ElementwiseArguments.
:param args: The ElementwiseArguments to create a ElementwiseOperandsMetadata from.
:type args: ElementwiseArguments
:return: The ElementwiseOperandsMetadata corresponding to the provided ElementwiseArguments.
:rtype: ElementwiseOperandsMetadata
"""
return ElementwiseOperandsMetadata(
A=TensorAttributes.from_tensor(args.A),
B=TensorAttributes.from_tensor(args.B),
out=TensorAttributes.from_tensor(args.out),
)
class EpilogueMetadata:
def __init__(self, epilogue_args: EpilogueArguments):
self.traced_epilogue = epilogue_args.traced_epilogue
self.tensors = epilogue_args.tensors
self.epilogue_fn = epilogue_args.epilogue_fn
@staticmethod
def from_args(args: EpilogueArguments) -> Self:
# For now, EpilogueArguments and EpilogueMetadata are the same
return EpilogueMetadata(args)
@property
def parameters(self) -> list[cute.Tensor | cutlass.Numeric]:
return list(self.tensors.values())
@property
def parameter_names(self) -> list[str]:
return list(self.tensors.keys())
def supports(self, args: RuntimeArguments) -> Status:
return Status.success()
@dataclass
class KernelMetadata:
"""
Metadata describing the operands and design of a kernel.
In addition to information about operands, this metadata also contains
the properties of the kernel that are independent of the specific arguments
that are passed to the kernel (e.g., tile shape, cluster shape, etc.).
:param kernel_name: The name of the kernel.
:type kernel_name: str
:param kernel_class: The class of the kernel.
:type kernel_class: type["Kernel"]
:param min_cc: The minimum compute capability of the kernel.
:type min_cc: int
:param operands: Metadata for the operands of the kernel.
:type operands: OperandsT
:param design: Metadata for the design of the kernel.
:type design: DesignT | None
:param epilogue: Metadata for the epilogue of the kernel.
:type epilogue: EpilogueT | None
"""
kernel_name: str
kernel_class: type["Kernel"]
min_cc: int
operands: OperandsMetadata
design: DesignMetadata | None = None
epilogue: EpilogueMetadata | None = None
def supports(self, args: RuntimeArguments) -> Status:
"""
Checks whether the provided args satisfy the properties described by
the operands, design, and epilogue metadata.
:param args: The arguments to check support for.
:type args: RuntimeArguments
:return: Whether the provided args satisfy the properties described by
the operands, design, and epilogue metadata.
:rtype: Status
"""
def supports_or_none(member, corresponding_arg, name: str) -> Status:
# If metadata is absent, accept only when the corresponding argument is also absent.
if member is None:
if corresponding_arg is None:
return Status.success()
return Status.fail(
f"{name} metadata is absent but argument is provided"
)
return member.supports(args)
if not (status := self.operands.supports(args)):
return status
if not (status := supports_or_none(self.design, args.performance, "design")):
return status
if not (status := supports_or_none(self.epilogue, args.epilogue, "epilogue")):
return status
return Status.success()

View File

@@ -0,0 +1,62 @@
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
from collections.abc import Callable
from cutlass_api.providers.provider import ProviderBase
available_providers: dict[str, type[ProviderBase]] = {}
def register_provider(name: str) -> Callable[[type[ProviderBase]], type[ProviderBase]]:
"""
Decorator used to register a provider class with the given name.
:param name: the name of the provider
:type name: str
:return: the wrapper function
:rtype: Callable[[Type[ProviderBase]], Type[ProviderBase]]
"""
def wrapper(provider_class: type[ProviderBase]) -> type[ProviderBase]:
"""
Wrapper function to register a provider class with the given name.
:param provider_class: the provider class to register
:type provider_class: Type[ProviderBase]
"""
global available_providers
available_providers[name] = provider_class
return provider_class
return wrapper
# Import for side effects (provider registration)
import cutlass_api.providers.cutedsl # noqa: F401, E402

View File

@@ -0,0 +1,88 @@
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
from collections.abc import Callable
import logging
from typing import Type
from cutlass_api.arguments import EpilogueArguments
from cutlass_api.kernel import Kernel
from cutlass_api.metadata import KernelMetadata
from cutlass_api.providers import register_provider
def try_import() -> bool:
import importlib.util
if importlib.util.find_spec("cutlass") is None:
logging.warning(
"'cutlass' could not be imported. The cutedsl provider will not be available."
)
return False
return True
available = try_import()
if available:
@register_provider("cutedsl")
class CuTeDSLProvider:
# Kernel classes currently registered with this provider
_kernel_classes = []
@classmethod
def generate_kernels(
cls,
metadata_filter: Callable[[KernelMetadata], bool] | None,
epilogue_args: EpilogueArguments = None,
cc: int = None,
) -> list[Kernel]:
kernels_for_provider = []
for kernel_cls in cls._kernel_classes:
kernels_for_provider.extend(
kernel_cls.generate_kernels(
metadata_filter, epilogue_args=epilogue_args, cc=cc
)
)
return kernels_for_provider
@classmethod
def register(cls, kernel_class: type[Kernel]) -> type[Kernel]:
cls._kernel_classes.append(kernel_class)
return kernel_class
# Imports for side effects (kernel registration)
import cutlass_api.providers.cutedsl.elementwise # noqa: F401
import cutlass_api.providers.cutedsl.gemm # noqa: F401
__all__ = ["CuTeDSLProvider"]
else:
__all__ = []

View File

@@ -0,0 +1,32 @@
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
# Importing is done for the side effect of registering the kernel classes with the provider.
# Suppress linter warnings about unused import.
# ruff: noqa: F401
import cutlass_api.providers.cutedsl.elementwise.elementwise_add

View File

@@ -0,0 +1,240 @@
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
from collections.abc import Callable
from itertools import product
from typing import ClassVar
import cuda.bindings.driver as cuda
from cutlass_api.arguments import ElementwiseArguments, EpilogueArguments
from cutlass_api.artifact import CompiledArtifact
from cutlass_api.metadata import (
ElementwiseOperandsMetadata,
KernelMetadata,
TensorAttributes,
)
from cutlass_api.providers.cutedsl import CuTeDSLProvider
from cutlass_api.providers.cutedsl.kernel import CuteDslKernel
from cutlass_api.utils import to_cuda_stream
import cutlass
import cutlass.cute as cute
"""
An Elementwise Addition kernel using CuTe DSL:
out = A + B
"""
@CuTeDSLProvider.register
class ElementwiseAddKernel(CuteDslKernel):
_supported_dtypes: ClassVar[list[type[cutlass.Numeric]]] = [
cutlass.Float32,
cutlass.Float16,
]
def __init__(self, metadata: KernelMetadata):
self.metadata = metadata
self.impl = ElementwiseAddKernelImpl()
def compile(self, args: ElementwiseArguments, cc: int = None) -> CompiledArtifact:
stream = cutlass.cute.runtime.make_fake_stream()
compiled_kernel = self.cute_compile(self.impl, args.A, args.B, args.out, stream)
return CompiledArtifact(compiled_kernel, self)
def _run(
self,
args: ElementwiseArguments,
compiled_artifact: CompiledArtifact,
stream,
workspace=None,
) -> None:
stream = to_cuda_stream(stream)
compiled_kernel = compiled_artifact.compiled_obj
self.cute_run(compiled_kernel, args.A, args.B, args.out, stream)
@staticmethod
def generate_kernels(
metadata_filter: Callable[[KernelMetadata], bool],
epilogue_args: EpilogueArguments = None,
cc: int = None,
) -> list["ElementwiseAddKernel"]:
if epilogue_args is not None:
return []
min_cc = 80
if cc is not None and cc < min_cc:
return []
stride_names = {
(0, 1): "t", # row major
(1, 0): "n", # column major
}
kernel_list = []
for dtype in ElementwiseAddKernel._supported_dtypes:
alignment = 128 // dtype.width
for stride_A, stride_B, stride_out in product(
stride_names.keys(), repeat=3
):
kernel_name = (
"cutedsl.ElementwiseAddKernel"
+ f"_A{dtype}_{stride_names[stride_A]}"
+ f"_B{dtype}_{stride_names[stride_B]}"
+ f"_out{dtype}_{stride_names[stride_out]}"
)
operands = ElementwiseOperandsMetadata(
A=TensorAttributes(
dtype=dtype, stride=stride_A, alignment=alignment
),
B=TensorAttributes(
dtype=dtype, stride=stride_B, alignment=alignment
),
out=TensorAttributes(
dtype=dtype, stride=stride_out, alignment=alignment
),
)
metadata = KernelMetadata(
operands=operands, kernel_name=kernel_name, kernel_class=ElementwiseAddKernel, min_cc=min_cc
)
if metadata_filter(metadata):
kernel_list.append(ElementwiseAddKernel(metadata))
return kernel_list
class ElementwiseAddKernelImpl:
@cute.jit
def __call__(
self, mA: cute.Tensor, mB: cute.Tensor, mC: cute.Tensor, stream: cuda.CUstream
):
copy_bits: cutlass.Constexpr = 128
dtype = mA.element_type
vector_size = copy_bits // dtype.width
thr_layout = cute.make_ordered_layout((4, 32), order=(1, 0))
val_layout = cute.make_ordered_layout((4, vector_size), order=(1, 0))
tiler_mn, tv_layout = cute.make_layout_tv(thr_layout, val_layout)
gA = cute.zipped_divide(mA, tiler_mn) # ((TileM,TileN),(RestM,RestN))
gB = cute.zipped_divide(mB, tiler_mn) # ((TileM,TileN),(RestM,RestN))
gC = cute.zipped_divide(mC, tiler_mn) # ((TileM,TileN),(RestM,RestN))
idC = cute.make_identity_tensor(mC.shape)
cC = cute.zipped_divide(idC, tiler=tiler_mn)
self.kernel(gA, gB, gC, cC, mC.shape, thr_layout, val_layout).launch(
grid=[cute.size(gC, mode=[1]), 1, 1],
block=[cute.size(tv_layout, mode=[0]), 1, 1],
stream=stream,
)
@cute.kernel
def kernel(
self,
gA: cute.Tensor,
gB: cute.Tensor,
gC: cute.Tensor,
cC: cute.Tensor, # coordinate tensor
shape: cute.Shape,
thr_layout: cute.Layout,
val_layout: cute.Layout,
):
tidx, _, _ = cute.arch.thread_idx()
bidx, _, _ = cute.arch.block_idx()
# slice for CTAs
# logical id -> address
blk_coord = ((None, None), bidx)
blkA = gA[blk_coord] # (TileM,TileN)
blkB = gB[blk_coord] # (TileM,TileN)
blkC = gC[blk_coord] # (TileM,TileN)
blkCrd = cC[blk_coord] # (TileM, TileN)
# declare the atoms which will be used later for memory copy
copy_atom_load = cute.make_copy_atom(
cute.nvgpu.CopyUniversalOp(), gA.element_type
)
copy_atom_store = cute.make_copy_atom(
cute.nvgpu.CopyUniversalOp(), gC.element_type
)
tiled_copy_A = cute.make_tiled_copy_tv(copy_atom_load, thr_layout, val_layout)
tiled_copy_B = cute.make_tiled_copy_tv(copy_atom_load, thr_layout, val_layout)
tiled_copy_C = cute.make_tiled_copy_tv(copy_atom_store, thr_layout, val_layout)
thr_copy_A = tiled_copy_A.get_slice(tidx)
thr_copy_B = tiled_copy_B.get_slice(tidx)
thr_copy_C = tiled_copy_C.get_slice(tidx)
thrA = thr_copy_A.partition_S(blkA)
thrB = thr_copy_B.partition_S(blkB)
thrC = thr_copy_C.partition_S(blkC)
# allocate fragments for gmem->rmem
frgA = cute.make_fragment_like(thrA)
frgB = cute.make_fragment_like(thrB)
frgC = cute.make_fragment_like(thrC)
thrCrd = thr_copy_C.partition_S(blkCrd)
frgPred = cute.make_rmem_tensor(thrCrd.shape, cutlass.Boolean)
for i in range(0, cute.size(frgPred), 1):
val = cute.elem_less(thrCrd[i], shape)
frgPred[i] = val
# Print per thread predicate mask
# if tidx == 0 and bidx == 0:
# cute.printf("block_dim = {}", cute.arch.grid_dim())
# cute.printf("shape = {}", shape)
# cute.print_tensor(thrA)
# cute.print_tensor(thrB)
# cute.print_tensor(frgPred)
##########################################################
# Move data to reg address space
##########################################################
cute.copy(copy_atom_load, thrA, frgA, pred=frgPred)
cute.copy(copy_atom_load, thrB, frgB, pred=frgPred)
# if tidx == 0 and bidx == 0:
# cute.print_tensor(frgA)
# cute.print_tensor(frgB)
# Load data before use. The compiler will optimize the copy and load
# operations to convert some memory ld/st into register uses.
result = frgA.load() + frgB.load()
# Save the results back to registers. Here we reuse b's registers.
frgC.store(result)
# Copy the results back to c
cute.copy(copy_atom_store, frgC, thrC, pred=frgPred)

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,333 @@
#################################################################################################
#
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
from collections.abc import Callable
from cutlass_api.fusion.ir import DAGIR, LoadNode, StoreNode, TopoVisitorNode
from cutlass_api.fusion.ir.load_nodes import (
ColumnBroadcastImpl,
RowBroadcastImpl,
ScalarBroadcastImpl,
)
from cutlass_api.fusion.ir.store_nodes import ReductionImplBase
from cutlass_api.fusion.library import ActivationOp, FunctionalOp
from cutlass_api.providers.cutedsl.evt import common_efc
from cutlass_api.status import Status
OpToCuteImplStr = {
FunctionalOp.Exp: lambda x: f"exp({x})",
ActivationOp.ReLU: lambda x: f"relu({x})",
ActivationOp.Sigmoid: lambda x: f"sigmoid({x})",
ActivationOp.Tanh: lambda x: f"tanh({x})",
FunctionalOp.Divides: lambda x, y: f"({x} / {y})",
FunctionalOp.Multiplies: lambda x, y: f"({x} * {y})",
FunctionalOp.Plus: lambda x, y: f"({x} + {y})",
FunctionalOp.Minus: lambda x, y: f"({x} - {y})",
}
# Functions to eventually be run as part of the EFC function.
# Every function takes in `efc_config` as the first argument (even if it
# is not used). This is necssary for running analysis
# passes on the EFC function absent an MLIR context (which would be
# needed is we used `cute.exp` directly).
OpToCuteImpl = {
FunctionalOp.Exp: lambda efc_config, x: efc_config.exp(x),
ActivationOp.ReLU: lambda efc_config, x: efc_config.where(
x > 0, x, efc_config.full_like(x, 0)
),
ActivationOp.Sigmoid: lambda efc_config, x: 1.0 / (1.0 + efc_config.exp(-x)),
ActivationOp.Tanh: lambda efc_config, x: efc_config.tanh(x),
FunctionalOp.Divides: lambda efc_config, x, y: x / y,
FunctionalOp.Multiplies: lambda efc_config, x, y: x * y,
FunctionalOp.Plus: lambda efc_config, x, y: x + y,
FunctionalOp.Minus: lambda efc_config, x, y: x - y,
}
def store(efc_config, x, y):
x.store(y)
def load(efc_config, x):
return x.load()
def get_val(x):
return lambda efc_config: x
class EFCConverter:
"""
Helper class to translate from DAGIR to the CuTe DSL epilogue visitor tree (EVT) structure
The CuTe DSL EVT structure is as follows for a (alpha * accum + beta * C) epilogue:
def epi(self, C, alpha, beta, D):
C_val = C.load()
alpha_val = alpha.load()
beta_val = beta.load()
compute_0_val = alpha_val * self.accum()
compute_1_val = beta_val * C_val
compute_2_val = compute_0_val + compute_1_val
D.store(compute_2_val)
"""
@staticmethod
def convert(dag_ir: DAGIR, parameter_names: list[str]) -> Callable:
"""
Converts the DAGIR to a callable epilogue function supported by CuTe DSL EFC.
The simplest way to do this would be to convert the DAGIR into the equivalent
string representation of the EFC function and call `exec` on it to get the callable.
However, this is generally not considered safe as it can potentially open avenues
for allowing arbitrary code execution.
Instead, we define a generic configurable epilogue that we specialize based
on the DAGIR itself. The generic epilogue takes in a list of parameters
and executes a predefined sequence of operations. Each operation is executed
in order and places its result on a stack.
Outside this generic epilogue, we define the sequence of operations that
are to be executed and determine where the sources for such operations live
(either their index in the parameter list or on the stack).
Example
=======
Suppose we have the following epilogue:
```
def epi(accum, alpha):
D = accum * alpha
return D
```
A simple string representation of the epilogue in EFC format would be:
```
def efc_epi(efc_config, alpha, D):
accum_val = efc_config.accum()
alpha_val = alpha.load()
temp = accum_val * alpha_val
D.store(temp)
```
We would like to generalize this to a function with signature:
```
def efc_epi(efc_config, *parameters):
```
that can perform arbitrary operations.
We would know from `parameter_names` that the list of parameters provided to the EFC epilogue
will be [alpha, D]. We then traverse the DAGIR and see the following operations in order:
LoadNode(accum)
LoadNode(alpha)
ComputeNode(mul, alpha, accum) -> temp
StoreNode(temp, D)
We can convert this into a series of corresponding operations that operate on variables
corresponding to indices. Indices represent either a position in the parameter list
or a position on the stack (index < len(parameters) implying a position in the parameter list).
In this example, [alpha, D] correspond to indices 0 and 1, respectively.
Traversing each DAGIR node above, we get:
LoadNode(accum) -> efc_config.accum(), result is in stack[0] (index 2)
LoadNode(alpha) -> alpha.load()
-> parameter[0].load(), result is in stack[1] (index 3)
ComputeNode(mul, alpha, accum) -> stack[1] * stack[0] (index 3 * index 2), result is in stack[2] (index 4)
StoreNode(temp, D) -> D.store(stack[2])
-> paramter[1].store(stack[2])
This is encoded as the following tuples:
```
ops = [
# Load of accum omitted because it is performed automatically in the generic epilogue
(load, 0),
(mul, 3, 2),
(store, 1, 2),
]
```
The generic epilogue then simply performs:
```
def efc_epi(efc_config, *parameters):
stack = [efc_config.accum()]
for op in ops:
fn = op[0]
inputs = [get(idx) for idx in op[1:]]
stack.append(fn(efc_config, *inputs))
```
Where the `get()` function is a helper that returns either a value from the parameter list or the stack,
depending on the index.
"""
# Provide a unique identifier for cases in which `accum` is also being written out
# We use an integer since parameter_names is a list of strings -- -1 is guaranteed
# not to be in the parameter_names list, but can still be used as a key in the dictionary.
accum_out_name = -1
# If 'accum' is in parameter_names, we know that it must be because
# the accumulator is also being written out
accum_out_loc = -1
if "accum" in parameter_names:
# Rename the output version temporarily so as not to confuse with the
# input accumulator
accum_out_loc = parameter_names.index("accum")
parameter_names[accum_out_loc] = accum_out_name
name_to_idx = {}
for idx, name in enumerate(parameter_names):
name_to_idx[name] = idx
cur_idx = len(name_to_idx)
def add_name(name: str):
nonlocal cur_idx
name_to_idx[name] = cur_idx
cur_idx += 1
def idx(name: str):
val = name_to_idx[name]
if isinstance(val, str):
return idx(val)
return val
add_name("accum")
# Each entry is a tuple containing a load/store/compute op and operands needed for it.
ops = []
debug_string_ops = []
for meta in dag_ir.node_metas_topological_order():
if isinstance(meta, LoadNode) and getattr(meta, "is_output", False):
assert meta.name == "accum"
# Handle the special case where the accumulator is also being written out
# This occurs in DAG IR as a load node with is_output = True
ops.append((store, idx(accum_out_name), idx("accum")))
debug_string_ops.append("accum_out_name.store(accum)")
cur_idx += 1
if isinstance(meta, LoadNode):
# Add new values to the stack for any operations that need to be .load'ed.
# This includes any non-scalar input parameters
is_scalar = isinstance(meta.underlying_impl, ScalarBroadcastImpl)
is_param_scalar = is_scalar and meta.name in parameter_names
if meta.name != "accum" and not is_param_scalar:
if is_scalar:
ops.append((get_val(meta.tensor.value),))
add_name(meta.name)
else:
ops.append((load, idx(meta.name)))
debug_string_ops.append(f"{meta.name} = {meta.name}.load()")
# Update the entry in name_to_idx to the index of the loaded value
add_name(meta.name)
elif isinstance(meta, StoreNode):
children = dag_ir.get_all_inputs(meta.name)
if len(children) != 1:
raise ValueError(
f"Store node {meta.name} has {len(children)} children, but only one is supported"
)
child = children[0]
ops.append((store, idx(meta.name), idx(child)))
debug_string_ops.append(f"{meta.name}.store({child})")
# We want to map the following sequence:
# def epi(accum, x, y):
# out0 = accum * x
# out1 = out0 * y
# return out0, out1
# To:
# x_val = x.load()
# y_val = y.load()
# c0 = self.Acc() * x_val
# out0.store(c0)
# c1 = c0 * y_val
# out1.store(c1)
# To support computation on top of out0, we need to remap it to c0 so that the next computation can use it.
name_to_idx[meta.name] = child
cur_idx += 1
elif isinstance(meta, TopoVisitorNode):
raise ValueError(f"TopoVisitorNode {meta.name} is not supported")
else:
if dag_ir.in_degree(meta.name) == 0:
continue
sorted_children = [
idx(child_name) for child_name in dag_ir.get_all_inputs(meta.name)
]
entry = (OpToCuteImpl[meta.fn], *sorted_children)
ops.append(entry)
debug_string_ops.append(
f"{meta.name} = {OpToCuteImplStr[meta.fn](*dag_ir.get_all_inputs(meta.name))}"
)
add_name(meta.name)
def epi(efc_config, *parameters):
stack = [efc_config.accum()]
def get(idx: int):
if idx < len(parameters):
return parameters[idx]
else:
return stack[idx - len(parameters)]
for i, op in enumerate(ops):
fn = op[0]
inputs = [get(idx) for idx in op[1:]]
stack.append(fn(efc_config, *inputs))
if accum_out_loc != -1:
# Restore the original name of the accumulator
parameter_names[accum_out_loc] = "accum"
named_epi = common_efc.create_named_epilogue(["self", *parameter_names], epi)
return named_epi
@staticmethod
def supports(dag_ir: DAGIR) -> Status:
"""
Checks if the DAGIR is supported by CuTe DSL EFC.
"""
# Currently do not support TopVisitorNode, row/col broadcasts, and reductions
for meta in dag_ir.node_metas_topological_order():
if isinstance(meta, TopoVisitorNode):
return Status.fail("TopoVisitorNode is not supported")
if isinstance(meta.underlying_impl, RowBroadcastImpl):
return Status.fail("RowBroadcastImpl is not supported")
if isinstance(meta.underlying_impl, ColumnBroadcastImpl):
return Status.fail("ColumnBroadcastImpl is not supported")
if isinstance(meta.underlying_impl, ReductionImplBase):
return Status.fail("ReductionImplBase is not supported")
return Status.success()
@staticmethod
def identity_efc(self, D):
D.store(self.accum())

View File

@@ -0,0 +1,33 @@
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
# Importing is done for the side effect of registering the kernel classes with the provider.
# Suppress linter warnings about unused import.
# ruff: noqa: F401
import cutlass_api.providers.cutedsl.gemm.sm100_static_persistent
import cutlass_api.providers.cutedsl.gemm.sm100_static_persistent_efc

View File

@@ -0,0 +1,414 @@
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
import itertools
from collections.abc import Callable, Generator
import cutlass
from cutlass_api.arguments import (
EpilogueArguments,
GemmArguments,
)
from cutlass_api.artifact import CompiledArtifact
from cutlass_api.metadata import (
GemmOperandsMetadata,
KernelMetadata,
Sm100DesignMetadata,
TensorAttributes,
)
from cutlass_api.providers.cutedsl import CuTeDSLProvider
from cutlass_api.providers.cutedsl.kernel import CuteDslKernel
from cutlass_api.providers.cutedsl.utils import get_max_active_clusters
from cutlass_api.status import Status
from cutlass_api.utils import strides_to_layout_string, to_cuda_stream, tuple_to_string
from .implementations.sm100_static_persistent_impl import PersistentDenseGemmKernelImpl
@CuTeDSLProvider.register
class PersistentDenseGemmKernel(CuteDslKernel):
"""This class implements batched matrix multiplication (C = A @ B) with support for various data types
and architectural features specific to Blackwell GPUs with persistent tile scheduling and warp specialization.
:note: In current version, A and B tensor must have the same data type
- i.e., Float8E4M3FN for A and Float8E5M2 for B is not supported
:note: Supported A/B data types:
- TFloat32
- Float16/BFloat16
- Int8/Uint8
- Float8E4M3FN/Float8E5M2
:note: Supported accumulator data types:
- Float32 (for all floating point A/B data types)
- Float16 (only for fp16 and fp8 A/B data types)
- Int32 (only for uint8/int8 A/B data types)
:note: Supported C data types:
- Float32 (for float32 and int32 accumulator data types)
- Int32 (for float32 and int32 accumulator data types)
- Float16/BFloat16 (for fp16 and fp8 accumulator data types)
- Int8/Uint8 (for uint8/int8 accumulator data types)
- Float8E4M3FN/Float8E5M2 (for float32 accumulator data types)
:note: Constraints:
- MMA tiler M must be 64/128 (use_2cta_instrs=False) or 128/256 (use_2cta_instrs=True)
- MMA tiler N must be 32-256, step 32
- Cluster shape M must be multiple of 2 if use_2cta_instrs=True
- Cluster shape M/N must be positive and power of 2, total cluster size <= 16
"""
def __init__(self, metadata: KernelMetadata):
self.metadata = metadata
def epilogue_op(x):
return x
mma_tiler_mn = (metadata.design.tile_shape[0], metadata.design.tile_shape[1])
cluster_shape_mn = (
metadata.design.cluster_shape[0],
metadata.design.cluster_shape[1],
)
self.impl = PersistentDenseGemmKernelImpl(
metadata.operands.accumulator_type,
metadata.design.use_2cta_mma,
mma_tiler_mn,
cluster_shape_mn,
metadata.design.use_tma_store,
epilogue_op,
)
def _supports(self, args: GemmArguments) -> Status:
if args.epilogue is not None:
return Status.fail("This kernel does not support any epilogue fusion.")
return Status.success()
def compile(self, args: GemmArguments, cc: int = None) -> CompiledArtifact:
stream = cutlass.cute.runtime.make_fake_stream()
max_active_clusters = get_max_active_clusters(self.impl.cluster_shape_mn)
compiled_kernel = self.cute_compile(
self.impl,
args.A,
args.B,
args.out,
max_active_clusters,
stream,
self.impl.epilogue_op,
)
return CompiledArtifact(compiled_kernel, self)
def _run(
self,
args: GemmArguments,
compiled_artifact: CompiledArtifact,
stream,
workspace=None,
) -> None:
stream = to_cuda_stream(stream)
compiled_gemm = compiled_artifact.compiled_obj
self.cute_run(compiled_gemm, args.A, args.B, args.out, stream)
@staticmethod
def _valid_operands(operands: GemmOperandsMetadata) -> bool:
if not isinstance(operands, GemmOperandsMetadata):
return False
# In current version, A and B tensor must have the same data type
# i.e., Float8E4M3FN for A and Float8E5M2 for B is not supported
if operands.A.dtype != operands.B.dtype:
return False
abtype = operands.A.dtype
# Supported A/B data types:
# - TFloat32
# - Float16/BFloat16
# - Int8/Uint8
# - Float8E4M3FN/Float8E5M2
if abtype not in [
cutlass.Float32,
cutlass.Float16,
cutlass.BFloat16,
cutlass.Int8,
cutlass.Uint8,
cutlass.Float8E4M3FN,
cutlass.Float8E5M2,
]:
return False
# Supported accumulator data types:
# - Float32 (for all floating point A/B data types)
# - Float16 (only for fp16 and fp8 A/B data types)
# - Int32 (only for uint8/int8 A/B data types)
if operands.accumulator_type == cutlass.Float32:
if not abtype.is_float:
return False
elif operands.accumulator_type == cutlass.Float16:
if abtype not in [
cutlass.Float16,
cutlass.Float8E4M3FN,
cutlass.Float8E5M2,
]:
return False
elif operands.accumulator_type == cutlass.Int32:
if abtype not in [cutlass.Uint8, cutlass.Int8]:
return False
else:
return False
# Supported out data types:
# - Float32 (for float32 and int32 accumulator data types)
# - Int32 (for float32 and int32 accumulator data types)
# - Float16/BFloat16 (for fp16 and fp8 accumulator data types)
# - Int8/Uint8 (for uint8/int8 accumulator data types)
# - Float8E4M3FN/Float8E5M2 (for float32 accumulator data types)
if operands.out.dtype == cutlass.Float32 or operands.out.dtype == cutlass.Int32:
if operands.accumulator_type not in [cutlass.Float32, cutlass.Int32]:
return False
elif (
operands.out.dtype == cutlass.Float16
or operands.out.dtype == cutlass.BFloat16
):
if operands.accumulator_type not in [cutlass.Float16, cutlass.BFloat16]:
return False
elif operands.out.dtype == cutlass.Int8 or operands.out.dtype == cutlass.Uint8:
if operands.accumulator_type not in [cutlass.Int32]:
return False
elif (
operands.out.dtype == cutlass.Float8E4M3FN
or operands.out.dtype == cutlass.Float8E5M2
):
if operands.accumulator_type not in [cutlass.Float32]:
return False
else:
return False
return True
@staticmethod
def _metadata_operand_combinations() -> Generator[GemmOperandsMetadata, None, None]:
"""
Generator that yields all valid GemmOperandsMetadata combinations
based on the validation rules in _valid_operands.
"""
# Supported A/B data types (must be the same)
ab_dtypes = [
cutlass.Float32,
cutlass.Float16,
cutlass.BFloat16,
cutlass.Int8,
cutlass.Uint8,
cutlass.Float8E4M3FN,
cutlass.Float8E5M2,
]
row_major_stride = (0, 0, 1)
col_major_stride = (0, 1, 0)
alignment = 16
for ab_dtype in ab_dtypes:
# Determine valid accumulator types for this A/B dtype
valid_acc_dtypes = []
if (
ab_dtype.is_float
): # Float32, Float16, BFloat16, Float8E4M3FN, Float8E5M2
valid_acc_dtypes.append(cutlass.Float32)
if ab_dtype in [
cutlass.Float16,
cutlass.Float8E4M3FN,
cutlass.Float8E5M2,
]:
valid_acc_dtypes.append(cutlass.Float16)
else: # Int8, Uint8
valid_acc_dtypes.append(cutlass.Int32)
for acc_dtype in valid_acc_dtypes:
# Determine valid output types for this accumulator type
valid_out_dtypes = []
if acc_dtype == cutlass.Float32:
valid_out_dtypes.extend(
[
cutlass.Float8E4M3FN,
cutlass.Float8E5M2,
cutlass.Float16,
cutlass.BFloat16,
cutlass.Float32,
cutlass.Int32,
]
)
elif acc_dtype == cutlass.Int32:
valid_out_dtypes.extend(
[cutlass.Int8, cutlass.Uint8, cutlass.Float32, cutlass.Int32]
)
elif acc_dtype == cutlass.Float16:
valid_out_dtypes.extend([cutlass.Float16])
for out_dtype in valid_out_dtypes:
for stride_A, stride_B, stride_out in itertools.product(
[row_major_stride, col_major_stride], repeat=3
):
# Create TensorAttributes for A, B, and out tensors
a_attrs = TensorAttributes(
dtype=ab_dtype, stride=stride_A, alignment=alignment
)
b_attrs = TensorAttributes(
dtype=ab_dtype, stride=stride_B, alignment=alignment
)
out_attrs = TensorAttributes(
dtype=out_dtype, stride=stride_out, alignment=alignment
)
# Create and yield the GemmOperandsMetadata
operands = GemmOperandsMetadata(
A=a_attrs,
B=b_attrs,
out=out_attrs,
accumulator_type=acc_dtype,
)
yield operands
@staticmethod
def _valid_metadata(metadata: KernelMetadata) -> bool:
if not PersistentDenseGemmKernel._valid_operands(metadata.operands):
return False
design = metadata.design
if not isinstance(design, Sm100DesignMetadata):
return False
cluster_size_m, cluster_size_n, _ = design.cluster_shape
if cluster_size_m % 2 != 0 and cluster_size_m != 1:
return False
if cluster_size_n % 2 != 0 and cluster_size_n != 1:
return False
if cluster_size_m * cluster_size_n > 16:
return False
# Constraints based on whether 2CTA instructions are used
if design.use_2cta_mma is not None:
if design.use_2cta_mma:
if cluster_size_m % 2 != 0:
return False
if design.tile_shape is not None and design.tile_shape[0] not in [
128,
256,
]:
return False
else:
if design.tile_shape is not None and design.tile_shape[0] not in [
64,
128,
]:
return False
if design.tile_shape is not None and design.tile_shape[1] not in range(
32, 256, 32
):
return False
if metadata.epilogue is not None:
return False
return True
@staticmethod
def generate_kernels(
metadata_filter: Callable[[KernelMetadata], bool],
epilogue_args: EpilogueArguments = None,
cc: int = None,
) -> list["PersistentDenseGemmKernel"]:
"""
Returns a list of all possible configurations of PersistentDenseGemmKernel that
adhere to constraints passed in under kwargs.
"""
if cc is not None and cc not in [100, 101, 103]:
return []
design_params = {
"use_2cta_mma": [True, False],
"tile_shape": [
(M, N, 256) for M in [64, 128, 256] for N in [32, 64, 128, 256]
],
"cluster_shape": [
(M, N, 1) for M in [1, 2, 4, 8, 16] for N in [1, 2, 4, 8, 16]
],
"use_tma_store": [True],
}
if epilogue_args is not None:
return []
from itertools import product
# Get the list of tunable parameter names and their possible values
param_names = list(design_params.keys())
param_values = [design_params[name] for name in param_names]
kernel_list = []
for operands in PersistentDenseGemmKernel._metadata_operand_combinations():
for values in product(*param_values):
design = Sm100DesignMetadata(**dict(zip(param_names, values)))
kernel_name = "cutedsl.PersistentDenseGemmKernel_sm100_{layout}_A{A}_B{B}_out{out}_acc{acc}_{num_cta}cta_cluster{cluster}_tile{tile}{_tma_store}".format(
layout=strides_to_layout_string(
operands.A.stride, operands.B.stride, operands.out.stride
),
A=operands.A.dtype,
B=operands.B.dtype,
out=operands.out.dtype,
acc=operands.accumulator_type,
num_cta=("2" if design.use_2cta_mma else "1"),
cluster=tuple_to_string(design.cluster_shape),
tile=tuple_to_string(design.tile_shape),
_tma_store="_tma_store" if design.use_tma_store else "",
)
metadata = KernelMetadata(
operands=operands,
design=design,
kernel_name=kernel_name,
kernel_class=PersistentDenseGemmKernel,
min_cc=100,
epilogue=None,
)
if PersistentDenseGemmKernel._valid_metadata(
metadata
) and metadata_filter(metadata):
kernel_list.append(
PersistentDenseGemmKernel(metadata)
)
return kernel_list

View File

@@ -0,0 +1,443 @@
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
import itertools
from collections.abc import Callable, Generator
import cutlass
from cutlass_api.arguments import (
EpilogueArguments,
GemmArguments,
)
from cutlass_api.artifact import CompiledArtifact
from cutlass_api.metadata import (
EpilogueMetadata,
GemmOperandsMetadata,
KernelMetadata,
Sm100DesignMetadata,
TensorAttributes,
)
from cutlass_api.providers.cutedsl import CuTeDSLProvider
from cutlass_api.providers.cutedsl.evt.converter import EFCConverter
from cutlass_api.providers.cutedsl.kernel import CuteDslKernel
from cutlass_api.providers.cutedsl.utils import get_max_active_clusters
from cutlass_api.status import Status
from cutlass_api.utils import (
TensorWrapper,
strides_to_layout_string,
to_cuda_stream,
tuple_to_string,
)
from .implementations.sm100_static_persistent_efc_impl import (
PersistentDenseGemmEFCKernelImpl,
)
@CuTeDSLProvider.register
class PersistentDenseGemmEFCKernel(CuteDslKernel):
"""Base class for batched GEMM with custom epilogue fusion using EFC.
This class provides the core infrastructure for persistent batched GEMM operations
with customizable epilogue fusion. Subclasses define specific epilogue behaviors
by providing an epilogue configuration function that describes operations on the
accumulator and supplemental tensors.
The class handles:
- GEMM mainloop (A * B computation)
- TMA-based memory operations
- Warp specialization
- Persistent tile scheduling
- EFC (Epilogue Fusion Configuration) integration
- CLI argument parsing (extensible via CLIParser.more_parsing())
- Tensor creation and validation
:param acc_dtype: Data type for accumulation during MMA computation
:type acc_dtype: type[cutlass.Numeric]
:param epi_dtype: Data type for epilogue operation
:type epi_dtype: type[cutlass.Numeric]
:param use_2cta_instrs: Whether to use CTA group 2 for 2CTA MMA instructions
:type use_2cta_instrs: bool
:param mma_tiler_mn: Shape of the Matrix Multiply-Accumulate (MMA) tile (M,N)
:type mma_tiler_mn: tuple[int, int]
:param cluster_shape_mn: Cluster dimensions (M,N) for parallel processing
:type cluster_shape_mn: tuple[int, int]
:param epilogue_configuration_function: Function defining the epilogue behavior via EFC
:type epilogue_configuration_function: Callable
:note: Supported A/B data types:
- TFloat32
- Float16/BFloat16
- Int8/Uint8
- Float8E4M3FN/Float8E5M2
(A and B must have the same data type)
:note: Supported accumulator data types:
- Float32 (for all floating point A/B data types)
- Float16 (only for fp16 and fp8 A/B data types)
- Int32 (only for uint8/int8 A/B data types)
:note: Supported supplemental tensor data types (epilogue-dependent):
- Float32 (for float32 and int32 accumulator data types)
- Int32 (for float32 and int32 accumulator data types)
- Float16/BFloat16 (for fp16 and fp8 accumulator data types)
- Int8/Uint8 (for uint8/int8 accumulator data types)
- Float8E4M3FN/Float8E5M2 (for float32 accumulator data types)
:note: Constraints:
- MMA tiler M must be 64/128 (use_2cta_instrs=False) or 128/256 (use_2cta_instrs=True)
- MMA tiler N must be 32-256, step 32
- Cluster shape M must be multiple of 2 if use_2cta_instrs=True
- Cluster shape M/N must be positive and power of 2, total cluster size <= 16
"""
_valid_ab_acc_combinations = {
cutlass.Float16: {cutlass.Float16, cutlass.Float32},
cutlass.BFloat16: {cutlass.Float32},
cutlass.TFloat32: {cutlass.Float32},
cutlass.Uint8: {cutlass.Int32},
cutlass.Int8: {cutlass.Int32},
cutlass.Float8E4M3FN: {cutlass.Float16, cutlass.Float32},
cutlass.Float8E5M2: {cutlass.Float16, cutlass.Float32},
}
def __init__(self, metadata: KernelMetadata):
self.metadata = metadata
if metadata.epilogue is not None:
epilogue_op = EFCConverter.convert(
metadata.epilogue.traced_epilogue.dag_ir,
metadata.epilogue.parameter_names,
)
else:
epilogue_op = EFCConverter.identity_efc
mma_tiler_mn = (metadata.design.tile_shape[0], metadata.design.tile_shape[1])
cluster_shape_mn = (
metadata.design.cluster_shape[0],
metadata.design.cluster_shape[1],
)
self.impl = PersistentDenseGemmEFCKernelImpl(
metadata.operands.accumulator_type,
metadata.operands.out.dtype,
metadata.design.use_2cta_mma,
mma_tiler_mn,
cluster_shape_mn,
epilogue_op,
)
@staticmethod
def _valid_fusion(fusion: EpilogueMetadata) -> Status:
if not isinstance(fusion, EpilogueMetadata):
return Status.fail("Unsupported epilogue argument type.")
if not (status := EFCConverter.supports(fusion.traced_epilogue.dag_ir)):
return status
return Status.success()
@staticmethod
def _valid_operands(operands: GemmOperandsMetadata) -> Status:
if not isinstance(operands, GemmOperandsMetadata):
return False
if operands.A.dtype != operands.B.dtype:
return False
ab_dtype = operands.A.dtype
acc_dtype = operands.accumulator_type
if ab_dtype not in PersistentDenseGemmEFCKernel._valid_ab_acc_combinations:
return False
# Check compatibility between accumulator type and AB type
if (
acc_dtype
not in PersistentDenseGemmEFCKernel._valid_ab_acc_combinations[ab_dtype]
):
return False
return True
@staticmethod
def _valid_design(design: Sm100DesignMetadata) -> bool:
"""
Check if the design metadata is valid.
:param design: The design metadata
:type design: Sm100DesignMetadata
:return: True if the design is valid, False otherwise
:rtype: bool
"""
if not isinstance(design, Sm100DesignMetadata):
return False
use_2cta_instrs = design.use_2cta_mma
mma_tiler_mn = design.tile_shape
cluster_shape_mn = design.cluster_shape
# Check invalid mma tile shape M dimension
if not (
(not use_2cta_instrs and mma_tiler_mn[0] in [64, 128])
or (use_2cta_instrs and mma_tiler_mn[0] in [128, 256])
):
return False
# Check invalid mma tile shape N dimension
if mma_tiler_mn[1] not in range(32, 257, 32):
return False
# Check illegal cluster shape M dimension
if cluster_shape_mn[0] % (2 if use_2cta_instrs else 1) != 0:
return False
def is_power_of_2(x):
return x > 0 and (x & (x - 1)) == 0
# Check invalid cluster shape constraints
if cluster_shape_mn[0] * cluster_shape_mn[1] > 16:
return False
if cluster_shape_mn[0] <= 0 or cluster_shape_mn[1] <= 0:
return False
if not is_power_of_2(cluster_shape_mn[0]) or not is_power_of_2(
cluster_shape_mn[1]
):
return False
return True
@staticmethod
def _valid_metadata(metadata: KernelMetadata) -> bool:
if not PersistentDenseGemmEFCKernel._valid_operands(metadata.operands):
return False
if not PersistentDenseGemmEFCKernel._valid_design(metadata.design):
return False
return True
@staticmethod
def _metadata_operand_combinations() -> Generator[GemmOperandsMetadata, None, None]:
"""
Generator that yields all valid GemmOperandsMetadata combinations
based on the validation rules in _valid_operands.
"""
row_major_stride = (0, 0, 1)
col_major_stride = (0, 1, 0)
alignment = 16
for (
ab_dtype,
valid_acc_dtypes,
) in PersistentDenseGemmEFCKernel._valid_ab_acc_combinations.items():
for acc_dtype in valid_acc_dtypes:
# Determine valid output types for this accumulator type
valid_out_dtypes = []
if acc_dtype == cutlass.Float32:
valid_out_dtypes.extend(
[
cutlass.Float32,
cutlass.Int32,
cutlass.Float16,
cutlass.BFloat16,
]
)
# Float8 output types only valid with Float32 accumulator
valid_out_dtypes.extend([cutlass.Float8E4M3FN, cutlass.Float8E5M2])
elif acc_dtype == cutlass.Int32:
valid_out_dtypes.extend([cutlass.Float32, cutlass.Int32])
# Integer output types only valid with Int32 accumulator
valid_out_dtypes.extend([cutlass.Int8, cutlass.Uint8])
elif acc_dtype == cutlass.Float16:
valid_out_dtypes.extend([cutlass.Float16, cutlass.BFloat16])
for out_dtype in valid_out_dtypes:
for stride_A, stride_B, stride_out in itertools.product(
[row_major_stride, col_major_stride], repeat=3
):
# Create TensorAttributes for A, B, and out tensors
a_attrs = TensorAttributes(
dtype=ab_dtype, stride=stride_A, alignment=alignment
)
b_attrs = TensorAttributes(
dtype=ab_dtype, stride=stride_B, alignment=alignment
)
out_attrs = TensorAttributes(
dtype=out_dtype, stride=stride_out, alignment=alignment
)
# Create and yield the GemmOperandsMetadata
operands = GemmOperandsMetadata(
A=a_attrs,
B=b_attrs,
out=out_attrs,
accumulator_type=acc_dtype,
)
yield operands
def _supports(self, args: GemmArguments) -> Status:
if args.epilogue is not None:
fusion_metadata = EpilogueMetadata.from_args(args.epilogue)
if not self._valid_fusion(fusion_metadata):
return Status.fail("Provided epilogue fusion is not supported by this kernel")
return Status.success()
def compile(self, args: GemmArguments, cc: int = None) -> CompiledArtifact:
stream = cutlass.cute.runtime.make_fake_stream()
max_active_clusters = get_max_active_clusters(self.impl.cluster_shape_mn)
if args.epilogue is not None:
epilogue_params = args.epilogue.parameters
else:
epilogue_params = [args.out]
epilogue_params = [
e.compile_time_tensor if isinstance(e, TensorWrapper) else e
for e in epilogue_params
]
# EFC needs special handling for supplemental arguments
self.impl.efc.compile(epilogue_params)
compiled_gemm = self.cute_compile(
self.impl,
args.A,
args.B,
max_active_clusters,
stream,
self.impl.efc.jit.pack_arguments(*epilogue_params),
)
# Wrap the compiled kernel to handle supplemental argument packing at launch time
def wrapped_launch(a_tensor, b_tensor, stream, *supplemental_args):
runtime_args = [
e.runtime_tensor if isinstance(e, TensorWrapper) else e
for e in supplemental_args
]
return compiled_gemm(
a_tensor,
b_tensor,
stream,
self.impl.efc.jit.pack_arguments(*runtime_args),
)
return CompiledArtifact(wrapped_launch, self)
def _run(
self,
args: GemmArguments,
compiled_artifact: CompiledArtifact,
stream,
workspace=None,
) -> None:
stream = to_cuda_stream(stream)
if args.epilogue is not None:
epilogue_params = args.epilogue.parameters
else:
epilogue_params = [args.out]
compiled_gemm = compiled_artifact.compiled_obj
self.cute_run(compiled_gemm, args.A, args.B, stream, *epilogue_params)
@staticmethod
def generate_kernels(
metadata_filter: Callable[[KernelMetadata], bool],
epilogue_args: EpilogueArguments = None,
cc: int = None,
) -> list["PersistentDenseGemmEFCKernel"]:
"""
Returns a list of all possible configurations of PersistentDenseGemmEFCKernel that
adhere to constraints passed in under kwargs.
"""
if cc is not None and cc not in [100, 101, 103]:
return []
design_params = {
"use_2cta_mma": [True, False],
"tile_shape": [
(M, N, 256) for M in [64, 128, 256] for N in [32, 64, 128, 256]
],
"cluster_shape": [
(M, N, 1) for M in [1, 2, 4, 8, 16] for N in [1, 2, 4, 8, 16]
],
"use_tma_store": [True],
}
if epilogue_args is not None:
if not isinstance(epilogue_args, EpilogueArguments):
return []
epilogue_metadata = EpilogueMetadata.from_args(epilogue_args)
if not PersistentDenseGemmEFCKernel._valid_fusion(epilogue_metadata):
return []
else:
epilogue_metadata = None
from itertools import product
param_names = list(design_params.keys())
param_values = [design_params[name] for name in param_names]
kernel_list = []
for operands in PersistentDenseGemmEFCKernel._metadata_operand_combinations():
for values in product(*param_values):
design = Sm100DesignMetadata(**dict(zip(param_names, values)))
kernel_name = "cutedsl.PersistentDenseGemmEFCKernel_sm100_{layout}_A{A}_B{B}_out{out}_acc{acc}_{num_cta}cta_cluster{cluster}_tile{tile}{_tma_store}".format(
layout=strides_to_layout_string(
operands.A.stride, operands.B.stride, operands.out.stride
),
A=operands.A.dtype,
B=operands.B.dtype,
out=operands.out.dtype,
acc=operands.accumulator_type,
num_cta=("2" if design.use_2cta_mma else "1"),
cluster=tuple_to_string(design.cluster_shape),
tile=tuple_to_string(design.tile_shape),
_tma_store="_tma_store" if design.use_tma_store else "",
)
metadata = KernelMetadata(
operands=operands,
design=design,
kernel_name=kernel_name,
kernel_class=PersistentDenseGemmEFCKernel,
min_cc=100,
epilogue=epilogue_metadata,
)
if PersistentDenseGemmEFCKernel._valid_metadata(
metadata
) and metadata_filter(metadata):
kernel_list.append(
PersistentDenseGemmEFCKernel(metadata)
)
return kernel_list

View File

@@ -0,0 +1,73 @@
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
from cutlass_api.config import GlobalOptions
from cutlass_api.kernel import Kernel
from cutlass_api.utils import TensorWrapper
import cutlass.cute as cute
class CuteDslKernel(Kernel):
"""
Base class for all CuTe DSL kernels
"""
def cute_compile(self, entry_point_fn, *fn_args):
"""
Compiles a kernel using CuTe DSL compile.
This method is intended to provider a provider-wide consistent implementation of compile() for
all CuTe DSL kernels, respecting GlobalOptions.
:param entry_point_fn: The kernel function to compile
:param fn_args: All arguments to pass to the kernel compilation
:return: The compiled kernel object
"""
options = None
if GlobalOptions().use_tvm_ffi:
options = "--enable-tvm-ffi"
compile_args = [
x.compile_time_tensor if isinstance(x, TensorWrapper) else x
for x in fn_args
]
return cute.compile(entry_point_fn, *compile_args, options=options)
def cute_run(self, entry_point_fn, *fn_args):
"""
Extracts runtime tensors from TensorWrappers and runs the kernel.
:param entry_point_fn: The kernel function to run
:param fn_args: All arguments to pass to the kernel run
:return: The result of the kernel run
"""
run_args = [
x.runtime_tensor if isinstance(x, TensorWrapper) else x for x in fn_args
]
return entry_point_fn(*run_args)

View File

@@ -0,0 +1,43 @@
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
import cutlass.utils as utils
def get_max_active_clusters(cluster_shape: tuple[int, int]) -> int:
"""
:param cluster_shape: The shape of the cluster.
:type cluster_shape: tuple[int, int]
:returns: The maximum number clusters of the provided shape that can fit on the current GPU
:rtype: int
"""
return utils.HardwareInfo().get_max_active_clusters(
cluster_shape[0] * cluster_shape[1]
)

View File

@@ -0,0 +1,72 @@
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
from collections.abc import Callable
from typing import Protocol, runtime_checkable
from cutlass_api.arguments import EpilogueArguments
from cutlass_api.kernel import Kernel
from cutlass_api.metadata import KernelMetadata
@runtime_checkable
class ProviderBase(Protocol):
@classmethod
def generate_kernels(
cls,
metadata_filter: Callable[[KernelMetadata], bool] | None = None,
epilogue_args: EpilogueArguments | None = None,
cc: int | None = None,
) -> list[Kernel]:
"""
Return a list of kernels that support the type of the provided metadata.
:param metadata_filter: boolean filter function to apply to the metadata
:type metadata_filter: Callable[[KernelMetadata], bool]
:param epilogue_args: the epilogue arguments
:type epilogue_args: EpilogueArguments
:param cc: compute capability of device for which kernels should be compiled. For example, if running on H100, this should be set to 90
:type cc: int
:return: list of all supported kernel configurations
:rtype: list[Kernel]
"""
raise NotImplementedError
@classmethod
def register(cls, kernel_class: type[Kernel]) -> type[Kernel]:
"""
Registers a Kernel class as being able to be discovered through the given provider.
:param kernel_class: the Kernel class to register
:type kernel_class: Type[Kernel]
:return: the registered kernel class
:rtype: Type[Kernel]
"""
raise NotImplementedError

View File

@@ -0,0 +1,61 @@
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
from __future__ import annotations
from dataclasses import dataclass
@dataclass
class Status:
"""
A simple status class to wrap an optional exception.
"""
error: Exception | None = None
def __bool__(self) -> bool:
return self.error is None
@classmethod
def success(cls) -> Status:
"""Create a successful status."""
return cls()
@classmethod
def fail(cls, error: str | Exception) -> Status:
"""Create a failed status with an error."""
if isinstance(error, str):
error = ValueError(error)
return cls(error=error)
def raise_on_error(self) -> None:
"""Raise the stored exception if this status represents a failure."""
if self.error is not None:
raise self.error

View File

@@ -0,0 +1,60 @@
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
"""Type markers for CUTLASS API field annotations."""
from __future__ import annotations
from typing import Any, Protocol, runtime_checkable
@runtime_checkable
class TensorLike(Protocol):
"""Protocol for Tensor-like objects that support the DLPack protocol.
CUTLASS API supports tensor-like objects (e.g., torch.Tensor, cute.Tensor) that
implement __dlpack__ and __dlpack_device__.
"""
def __dlpack__(self, *, stream: Any = None) -> Any:
"""Return a DLPack capsule representing the tensor data."""
...
def __dlpack_device__(self) -> tuple[int, int]:
"""Return the device type and device id as a tuple."""
...
class NumericLike(Protocol):
"""Type marker for fields that accept numeric-like types.
Fields annotated with NumericLike accept the following:
- cutlass.Numeric
- torch.dtype
"""

View File

@@ -0,0 +1,425 @@
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
"""
Utilities for CUTLASS API.
"""
from __future__ import annotations
import importlib
from typing import TYPE_CHECKING, Any
import cutlass
import cutlass.cute as cute
from cutlass.cute.runtime import from_dlpack
from cutlass_api.config import GlobalOptions
from cutlass_api.status import Status
if TYPE_CHECKING:
import cuda
import torch
def is_numpy_available() -> bool:
"""Check if numpy is available."""
return importlib.util.find_spec("numpy") is not None
def is_torch_available() -> bool:
"""Check if torch is available."""
return importlib.util.find_spec("torch") is not None
def is_numpy_tensor(inp) -> bool:
"""Check if the input is a numpy tensor."""
if is_numpy_available():
import numpy as np
return isinstance(inp, np.ndarray)
return False
def is_torch_tensor(inp) -> bool:
"""Check if the input is a torch tensor."""
if is_torch_available():
import torch
return isinstance(inp, torch.Tensor)
return False
def _lazy_import(mod_name: str) -> Any:
"""Internal utility to lazily import a module only when needed."""
class Lazy:
def __getattr__(self, name: str) -> Any:
module = importlib.import_module(mod_name)
return getattr(module, name)
return Lazy()
def check_cuda_errors(result: list):
"""
Checks whether `result` contains a CUDA error raises the error as an exception, if so. Otherwise,
returns the result contained in the remaining fields of `result`.
:param result: the results of the `cuda` method, consisting of an error code and any method results
:type result: list
:return: non-error-code results from the `results` parameter
"""
# `result` is of the format : (cudaError_t, result...)
err = result[0]
if err.value:
_lazy_import("cuda.cuda")
cuda_bindings_runtime = _lazy_import("cuda.bindings.runtime")
raise RuntimeError(
f"CUDA error: {cuda_bindings_runtime.cudaGetErrorString(err)[1].decode('utf-8')}"
)
if len(result) == 1:
return None
elif len(result) == 2:
return result[1]
else:
return result[1:]
def device_cc(device: int = 0) -> int:
"""
Returns the compute capability of the device with ID `device`.
:param device: ID of the device to query
:type device: int
:return: compute capability of the queried device (e.g., 80 for SM80)
:rtype: int
"""
_lazy_import("cuda.cuda")
cuda_bindings_runtime = _lazy_import("cuda.bindings.runtime")
deviceProp = check_cuda_errors(
cuda_bindings_runtime.cudaGetDeviceProperties(device)
)
major = str(deviceProp.major)
minor = str(deviceProp.minor)
return int(major + minor)
def is_device_cc_supported(supported_ccs: set[int]) -> Status:
"""
Fetch the device compute capability, and check if it is in supported ccs
:return: Status indicating success if device CC is in supported_ccs
:rtype: Status
"""
try:
cc = device_cc()
if cc not in supported_ccs:
return Status.fail(
f"Compute capability {cc} is not in supported set {supported_ccs}."
)
return Status.success()
except Exception as e:
return Status.fail(e)
def cutlass_type_from_torch_type(dtype) -> type[cutlass.Numeric]:
"""
Convert a torch dtype to a cutlass dtype.
:param dtype: The torch dtype to convert.
:return: The cutlass dtype.
:rtype: type[cutlass.Numeric]
"""
import torch
torch_dtype_map = {
torch.float64: cutlass.Float64,
torch.float32: cutlass.Float32,
torch.float16: cutlass.Float16,
torch.bfloat16: cutlass.BFloat16,
torch.int32: cutlass.Int32,
torch.int8: cutlass.Int8,
torch.uint8: cutlass.Uint8,
torch.float8_e5m2: cutlass.Float8E5M2,
torch.float8_e4m3fn: cutlass.Float8E4M3FN,
torch.float8_e4m3fnuz: cutlass.Float8E4M3B11FNUZ,
}
try:
return torch_dtype_map[dtype]
except KeyError:
raise KeyError(f"Unsupported dtype: {dtype}")
def to_cutlass_type(dtype) -> type[cutlass.Numeric]:
"""
Convert a dtype to a cutlass dtype.
:param dtype: The dtype to convert (e.g., torch.float32)
:return: The cutlass dtype.
:rtype: type[cutlass.Numeric]
"""
if isinstance(dtype, type) and issubclass(dtype, cutlass.Numeric):
return dtype
converters = []
if is_torch_available():
converters.append(cutlass_type_from_torch_type)
# Iterate through the available converters and return the first one that succeeds.
for converter in converters:
try:
return converter(dtype)
except KeyError:
continue
raise KeyError(f"Unsupported dtype: {dtype}")
def to_cuda_stream(
stream: cuda.bindings.driver.CUstream | torch.cuda.Stream,
skip_if_ffi: bool = True,
) -> cuda.bindings.driver.CUstream:
"""
Convert provided stream to a cuda.CUstream.
:param stream: The stream to convert.
:type stream: Union[cuda.bindings.driver.CUstream, torch.cuda.Stream]
:param skip_if_ffi: Skip the conversion if True and if TVM-FFI is enabled in GlobalOptions().
:type skip_if_ffi: bool
:return: The converted stream.
:rtype: cuda.bindings.driver.CUstream
"""
if skip_if_ffi and GlobalOptions().use_tvm_ffi:
# TVM-FFI can directly handle streams of various types, including raw int handles
return stream
cuda = _lazy_import("cuda.bindings.driver")
if isinstance(stream, cuda.CUstream):
return stream
if is_torch_available():
import torch
if isinstance(stream, torch.cuda.Stream):
return cuda.CUstream(stream.cuda_stream)
raise ValueError(f"Unsupported stream type: {type(stream)}")
def add_batch_mode(
tensor: cute.Tensor | torch.Tensor,
) -> cute.Tensor | torch.Tensor:
"""
Adds a batch mode to the tensor.
If the tensor is a torch.Tensor and has rank 2,
it will be unsqueezed along the first dimension.
:param tensor: The tensor to add batch mode to.
:type tensor: Union[cute.Tensor, "torch.Tensor"]
:return: The tensor with batch mode added.
:rtype: Union[cute.Tensor, "torch.Tensor"]
"""
if is_torch_tensor(tensor):
if tensor.dim() == 2:
return tensor.unsqueeze(0)
elif tensor.dim() < 2 or tensor.dim() > 3:
raise ValueError(f"Expected 2-3 dimensions, got {tensor.dim()}")
return tensor
return tensor
def permute_batch_mode(
tensor: cute.Tensor | torch.Tensor,
) -> cute.Tensor | torch.Tensor:
"""
Permute the batch mode of the tensor.
If the tensor is a torch.Tensor and has rank 3,
it will be permuted along the first dimension.
:param tensor: The tensor to permute.
:type tensor: Union[cute.Tensor, "torch.Tensor"]
:return: The tensor with batch mode permuted.
:rtype: Union[cute.Tensor, "torch.Tensor"]
"""
if is_torch_tensor(tensor):
if tensor.dim() != 3:
raise ValueError(f"Expected 3 dimensions, got {tensor.dim()}")
return tensor.permute([1, 2, 0])
else:
raise ValueError(f"Unsupported type: {type(tensor)}")
def leading_dim(tensor) -> int:
"""
Get the leading dimension of a tensor. This is the first mode with
stride of 1 and non-unit shape. If the has a single element, the leading
dimension is 0. Modes with both stride of 1 and shape of 1 are treated
as though they have stride of 0.
:param tensor: The tensor to get the leading dimension of.
:type tensor: Union[cute.Tensor, "torch.Tensor"]
:return: The leading dimension of the tensor.
:rtype: int
"""
if is_torch_tensor(tensor):
if tensor.numel() == 1:
return 0
updated_stride = [
s if sz != 1 else 0 for s, sz in zip(tensor.stride(), tensor.shape)
]
return updated_stride.index(1)
else:
raise ValueError(f"Unsupported tensor type: {type(tensor)}")
def get_stride_order(stride: tuple[int, ...]) -> tuple[int, ...]:
"""
Returns the order of the stride of a tensor. For a stride of rank N,
the dimension with the smallest stride will have stride order 0 and the
dimension with the largest stride will have stride order N-1.
:param stride: The stride of the tensor.
:type stride: tuple[int, ...]
:return: The order of the stride of the tensor.
:rtype: tuple[int, ...]
"""
# The code below performs an argsort on the stride:
# indices = range(len(stride))
# Sort indices using comparison between stride[i] and stride[j] when
# sorting indices i and j.
return tuple(sorted(range(len(stride)), key=stride.__getitem__))
class TensorWrapper:
"""
Wrapper class for supporting compilation and execution both with
and without TVM-FFI.
When using TVM-FFI, one can pass a framework-level tensor (e.g., torch.Tensor)
to the JIT function at run time, but not at compile time. At compile time, one
must use a `_FakeTensor` to specify the tensor.
When not using TVM-FFI, one passes in a cute.Tensor at both compile and run time.
This class contains two key members:
- runtime_tensor: The tensor to use at run time.
- compile_time_tensor: The tensor to use at compile time.
Users of this class should access each of these underlying members for execution
and runtime, respectively. Users of this class do not need to know whether TVM-FFI
is enabled.
"""
def __init__(self, tensor: Any):
if isinstance(tensor, cute.Tensor):
# Regardless of whether TVM-FFI is enabled, if the tensor passed in is a cute.Tensor,
# it can be used as the runtime tensor and compile time tensor.
self.runtime_tensor = tensor
self.compile_time_tensor = tensor
self._shape = tensor.shape
self._stride = tensor.stride
elif GlobalOptions().use_tvm_ffi:
# If TVM-FFI is enabled, runtime tensor is set simply as the tensor passed in, but
# we must make a fake tensor for compilation.
self.runtime_tensor = tensor
if is_torch_tensor(self.runtime_tensor):
dtype = cutlass_type_from_torch_type(self.runtime_tensor.dtype)
rank = self.runtime_tensor.dim()
self._stride = self.runtime_tensor.stride()
stride_order = get_stride_order(self._stride)
shape = [cute.SymInt() for _ in range(rank)]
shape[stride_order.index(0)] = cute.SymInt(divisibility=16)
self._shape = tuple(self.runtime_tensor.shape)
else:
raise ValueError(
f"Unsupported tensor type: {type(self.runtime_tensor)}"
)
self.compile_time_tensor = cute.runtime.make_fake_compact_tensor(
dtype, shape, stride_order=stride_order, assumed_align=16
)
else:
# TVM-FFI is disabled and the tensor passed in is not a cute.Tensor,
# We must convert it to a cute.Tensor
self.runtime_tensor = from_dlpack(
tensor, assumed_align=16
).mark_layout_dynamic(leading_dim(tensor))
self._shape = self.runtime_tensor.shape
self._stride = self.runtime_tensor.stride
# Since the runtime tensor is now a cute.Tensor, we can use it at
# compile time as well
self.compile_time_tensor = self.runtime_tensor
@property
def element_type(self) -> type[cutlass.Numeric]:
return self.compile_time_tensor.element_type
@property
def shape(self) -> tuple[int, ...]:
return self._shape
@property
def stride(self) -> tuple[int, ...]:
return self._stride
def strides_to_layout_string(*strides: list[tuple[int, ...]]) -> str:
"""
Convert list[stride tuples] to layout string ('t' for row-major, 'n' for col-major).
Example: ((0, 1, 0), (0, 0, 1), (0, 0, 1)) -> "ntt"
"""
row_major_stride = (0, 0, 1)
col_major_stride = (0, 1, 0)
layout_string_map = {
row_major_stride: "t",
col_major_stride: "n",
}
return "".join(layout_string_map[s] for s in strides)
def tuple_to_string(tuple: tuple[int, ...]) -> str:
"""
Convert a tuple of integers to an 'x'-separated string (e.g., (2, 3, 4) -> '2x3x4').
"""
return "x".join(str(x) for x in tuple) if len(tuple) > 1 else str(tuple[0])

View File

@@ -0,0 +1,558 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "3dd45ef2",
"metadata": {},
"source": [
"# Basic GEMM using CUTLASS Python API"
]
},
{
"cell_type": "markdown",
"id": "4709aa60",
"metadata": {},
"source": [
"The CUTLASS API provides a consistent, uniform interface for discovering, compiling, and running GPU kernels from various DSL sources.\n",
"\n",
"This notebook walks through a minimal GEMM (Generalized Matrix-Matrix Multiplication) example, and introduces the core concepts of the API."
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "f878d960-d175-4d84-b978-88afbd318850",
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"\n",
"import cutlass\n",
"\n",
"import cutlass_api\n",
"\n",
"if not (status := cutlass_api.utils.is_device_cc_supported({100, 103})):\n",
" print(\n",
" f\"This notebook requires a GPU with compute capability 100 or 103.\\n{status.error}\"\n",
" )\n",
" import sys\n",
"\n",
" sys.exit(0)"
]
},
{
"cell_type": "markdown",
"id": "db91dab6",
"metadata": {},
"source": [
"## Running your first kernel"
]
},
{
"cell_type": "markdown",
"id": "7b7b87b0",
"metadata": {},
"source": [
"### Setting up arguments\n",
"\n",
"CUTLASS API has first-class support for PyTorch tensors. We start by creating torch tensors that will be operands to a matrix multiplication."
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "f550c4ea",
"metadata": {},
"outputs": [],
"source": [
"M, N, K, L = 128, 256, 64, 2\n",
"ab_type = torch.float16\n",
"out_type = torch.float32\n",
"acc_type = torch.float32\n",
"\n",
"A = torch.randint(-1, 2, (L, M, K), device=\"cuda\", dtype=ab_type)\n",
"B = torch.randint(-1, 2, (L, K, N), device=\"cuda\", dtype=ab_type)\n",
"out = torch.empty((L, M, N), device=\"cuda\", dtype=out_type)\n",
"\n",
"reference = (A @ B).to(out.dtype)"
]
},
{
"cell_type": "markdown",
"id": "6b6cb805",
"metadata": {},
"source": [
"We then create a `GemmArguments` object. This object specifies:\n",
"1. what logical operation do we want to perform (a GEMM)\n",
"2. on which operands we want to perform that operation (`A`, `B`, `out` as declared above)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "b57690df",
"metadata": {},
"outputs": [],
"source": [
"args = cutlass_api.arguments.GemmArguments(A=A, B=B, out=out, accumulator_type=acc_type)"
]
},
{
"cell_type": "markdown",
"id": "67e5ddcf",
"metadata": {},
"source": [
"### Kernel discovery\n",
"\n",
"We now need to find kernels that can perform the operation we expressed in `args`.\n",
"\n",
"The simplest way to do so is to use `get_kernels(args)`. It searches a set of kernels pre-registered in the library, and returns the subset of those kernels which can successfully run our given `args`.\n",
"\n",
"Any of these kernels will be functionally equivalent -- they may have different design or performance characteristics. We arbitrarily pick the first of the returned kernels to execute here"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "9872ad66",
"metadata": {},
"outputs": [],
"source": [
"kernels = cutlass_api.get_kernels(args)\n",
"assert kernels, \"No kernels found for the given arguments!\"\n",
"\n",
"kernel = kernels[0]"
]
},
{
"cell_type": "markdown",
"id": "4c17693e",
"metadata": {},
"source": [
"#### Run the kernel\n",
"\n",
"Running the kernel is as simple as `kernel.run(args)`.\n",
"\n",
"This implicitly JIT-compiles the kernel, and launches it on the GPU device using our given arguments."
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "baf4588a",
"metadata": {},
"outputs": [],
"source": [
"kernel.run(args)\n",
"\n",
"torch.testing.assert_close(out, reference)"
]
},
{
"cell_type": "markdown",
"id": "4d7ad85b",
"metadata": {},
"source": [
"One can also explicitly compile the kernel and pass this in to `kernel.run` to avoid\n",
"JIT compilation on future invocations. Additional details related to this will be\n",
"described below."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "06f9f844",
"metadata": {},
"outputs": [],
"source": [
"artifact = kernel.compile(args)\n",
"kernel.run(args, compiled_artifact=artifact)\n",
"torch.testing.assert_close(out, reference)"
]
},
{
"cell_type": "markdown",
"id": "630e9e4b",
"metadata": {},
"source": [
"---\n",
"\n",
"---\n",
"\n",
"### Understanding the core interfaces"
]
},
{
"cell_type": "markdown",
"id": "2d8b8e94",
"metadata": {},
"source": [
"#### 1. `RuntimeArguments` / `GemmArguments`\n",
"\n",
"`RuntimeArguments` describe the operation a user wants to perform, and all the runtime operands or other runtime parameters needed for it. \n",
"This includes primary runtime operands to the operation, as well as any custom epilogue fusions and runtime performance knobs.\n",
"\n",
"We provide builtin subtypes of `RuntimeArguments` for common operations (e.g. GEMM, Elementwise ops; more later).\n",
"\n",
"For instance, `GemmArguments` is a type of `RuntimeArguments`:\n",
"\n",
"```python\n",
"@dataclass\n",
"class GemmArguments(RuntimeArguments):\n",
" A: TensorLike\n",
" B: TensorLike\n",
" out: TensorLike\n",
" accumulator_type: NumericLike\n",
"```\n",
"\n",
"`GemmArguments` conveys:\n",
"* We want to perform a dense GEMM operation (`out = A @ B`)\n",
"* We want to perform it for operands in `A, B, out`, with intermediate results stored as `accumulator_type`\n",
"* We can optionally set a custom epilogue that is fused on top of the base GEMM. Some kernels also support some runtime performance controls which can be specified here. These will be discussed in detail in other tutorials.\n",
"\n",
"It is a kernel-agnostic way to specify the desired functionality.\n",
"\n",
"`RuntimeArguments` can be constructed from any `TensorLike` object. This includes `torch.Tensor`, `cute.Tensor`, or any other DLPack-compatible tensors."
]
},
{
"cell_type": "markdown",
"id": "e7eda0dd",
"metadata": {},
"source": [
"#### 2. Kernel Discovery\n",
"\n",
"There are several kernels available in CUTLASS DSLs that are registered with, and discoverable via, the CUTLASS API.\n",
"\n",
"This includes kernels for various operations (GEMM, Elementwise operations, ...), which implement various algorithms & architecture features. Within the same implementation, there are several instances or configurations of it with different combinations of operand types, layouts, tile sizes, etc.\n",
"\n",
"In the previous step, we used `GemmArguments` to specify our desired GEMM in a kernel-agnostic way. Now we find kernels that can fulfill that functionality. A subset of the available kernels will perform GEMM, and a subset of _those_ will support the properties of specific operands we are currently using."
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "3b737131",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"A total of 107616 kernel instances are available.\n",
"Of these, 350 support the given arguments.\n",
"Picked kernel with name: cutedsl.PersistentDenseGemmKernel_sm100_ttt_AFloat16_BFloat16_outFloat32_accFloat32_2cta_cluster2x1x1_tile128x32x256_tma_store\n"
]
}
],
"source": [
"# get_kernels() fetches all kernels when called without args\n",
"all_kernels = cutlass_api.get_kernels()\n",
"print(f\"A total of {len(all_kernels)} kernel instances are available.\")\n",
"\n",
"# we can limit the search to kernels supporting given args\n",
"kernels = cutlass_api.get_kernels(args)\n",
"print(f\"Of these, {len(kernels)} support the given arguments.\")\n",
"\n",
"kernel = kernels[0]\n",
"print(f\"Picked kernel with name: {kernel.metadata.kernel_name}\")"
]
},
{
"cell_type": "markdown",
"id": "252a4d38",
"metadata": {},
"source": [
"#### 3. `Kernel` execution"
]
},
{
"cell_type": "markdown",
"id": "574d004b",
"metadata": {},
"source": [
"Once we have selected a kernel, we are now ready to execute it. We previously showed the simplest way to do this is `kernel.run(args)`.\n",
"\n",
"This method does the following:\n",
"* verify that the kernel supports the given `args`\n",
"* JIT-compile the kernel\n",
"* launch the compiled kernel function\n",
"\n",
"Users can do these steps individually for more control:"
]
},
{
"cell_type": "markdown",
"id": "e8945aa6",
"metadata": {},
"source": [
"* `kernel.supports(args)` checks if the kernel supports the given `args`\n",
" * this is relevant if the kernel was not picked just for these `args`"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "159fd610",
"metadata": {},
"outputs": [],
"source": [
"supported = kernel.supports(args)\n",
"assert supported"
]
},
{
"cell_type": "markdown",
"id": "948689a8",
"metadata": {},
"source": [
"If the arguments are not supported by this kernel, `supports` returns a `Status` object explaining the error."
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "2cfc9ea7",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Operand `A` is unsupported: Expected element type Float16, got BFloat16\n"
]
}
],
"source": [
"unsupported_args = cutlass_api.arguments.GemmArguments(\n",
" A=A.to(torch.bfloat16), B=B, out=out, accumulator_type=acc_type\n",
")\n",
"if not (status := kernel.supports(unsupported_args)):\n",
" print(status.error)\n",
"\n",
"assert not status"
]
},
{
"cell_type": "markdown",
"id": "c2db8f20",
"metadata": {},
"source": [
"* `kernel.compile(args)` compiles the kernel, and returns a `CompiledArtifact`\n",
"\n",
"This compiled artifact is a lightweight wrapper over the result of compiling a kernel (e.g., via `cute.compile()`).\n",
"\n",
"For just-in-time compilation, we can use the compiled artifact straightaway.\n",
"In the future, we will support optionally serializing it for ahead-of-time compilation and deserialized in a different context.\n"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "02e79eb8",
"metadata": {},
"outputs": [],
"source": [
"compiled_artifact = kernel.compile(args)"
]
},
{
"cell_type": "markdown",
"id": "4dfb8d51",
"metadata": {},
"source": [
"* `kernel.run(args)` launches the compiled kernel function. This example uses:\n",
" * the precompiled artifact\n",
" * a custom stream to launch to\n",
" * bypasses the supports check already performed above (`assume_supported_args=True`)."
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "02398bf0",
"metadata": {},
"outputs": [],
"source": [
"# zero the output to avoid testing stale output\n",
"out.zero_()\n",
"\n",
"kernel.run(\n",
" args,\n",
" compiled_artifact,\n",
" stream=torch.cuda.Stream(),\n",
" assume_supported_args=True,\n",
")\n",
"torch.testing.assert_close(out, reference)"
]
},
{
"cell_type": "markdown",
"id": "f67eeb8f",
"metadata": {},
"source": [
"Some kernels may also require a device \"workspace\". This is an additional buffer needed by some kernels for book-keeping, temporary results, etc.\n",
"Its size can be queried using `kernel.get_workspace_size(args)`. Most kernels will have a workspace size of 0.\n",
"If a kernel does have a non-zero workspace size, an additional buffer of at least that size must be provided. Without it, the kernel behavior is undefined."
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "5b2e2d56",
"metadata": {},
"outputs": [],
"source": [
"workspace_size = kernel.get_workspace_size(args)\n",
"workspace = torch.empty(workspace_size, device=\"cuda\", dtype=torch.int8)\n",
"\n",
"out.zero_()\n",
"kernel.run(args, compiled_artifact, stream=torch.cuda.Stream(), workspace=workspace)\n",
"torch.testing.assert_close(out, reference)"
]
},
{
"cell_type": "markdown",
"id": "baffaf12",
"metadata": {},
"source": [
"### Advanced: Filtering on Metadata"
]
},
{
"cell_type": "markdown",
"id": "86dd521a",
"metadata": {},
"source": [
"Using `RuntimeArguments` to search for supporting kernels is a convenient way to discover kernels: users directly specify their desired functionality, and `get_kernels()` finds the supporting kernels.\n",
"It covers all logical operands of any operation, as well as (in later examples) epilogue fusions, and performance controls.\n",
"\n",
"However, there may be cases where users want more advanced ways to query kernels. These could be:\n",
"* when the desired properties may not be expressed in runtime controls\n",
" * the simplest scenario may be if you're searching searching for a kernel with a specific name, a specific class, etc.\n",
" * searching for kernel's static properties such as tile size, cluster size, etc.\n",
"* when the `RuntimeArguments` are not available or you want to generate & pre-compile a broader set of kernels\n",
"\n",
"For such cases, we provide a more advanced filtering based on `KernelMetadata`"
]
},
{
"cell_type": "markdown",
"id": "ef54ae50",
"metadata": {},
"source": [
"`KernelMetadata` captures a wide variety of properties of a `Kernel`.\n",
"\n",
"These are properties of a kernel's functional support (like operand types, layouts, alignments), as well as architectural/design choices & performance characteristics (like tilze size, scheduling characteristics).\n",
"\n",
"Different kernels may use different sub-classes of `metadata.operands`, `metadata.design`, `metadata.epilogue` for flexibility, which can also identify their characteristics.\n",
"\n",
"```python\n",
"@dataclass\n",
"class KernelMetadata:\n",
" kernel_name: str\n",
" kernel_class: type[\"Kernel\"]\n",
" min_cc: int\n",
" operands: OperandsMetadata\n",
" design: DesignMetadata | None = None\n",
" epilogue: EpilogueMetadata | None = None\n",
"```"
]
},
{
"cell_type": "markdown",
"id": "f9943888",
"metadata": {},
"source": [
"Every unique kernel instance can be distinguished by its metadata.\n",
"It can be used in filtering for kernels in addition to the `RuntimeArguments`, by providing a custom `metadata_filter`.\n",
"\n",
"Here, we get all kernels that support `args`, and have `metadata.design` of type `Sm100DesignMetadata`.\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "8717ac89",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Found 350 kernels which support args & have Sm100DesignMetadata\n"
]
}
],
"source": [
"kernels = cutlass_api.get_kernels(\n",
" args,\n",
" metadata_filter=lambda metadata: isinstance(\n",
" metadata.design, cutlass_api.metadata.Sm100DesignMetadata\n",
" ),\n",
")\n",
"print(f\"Found {len(kernels)} kernels which support args & have Sm100DesignMetadata\")"
]
},
{
"cell_type": "markdown",
"id": "1d1f9124",
"metadata": {},
"source": [
"We can construct more advanced filters by leveraging duck-typing.\n",
"Additionally, we can get all the kernels that match our filter, rather than supporting a fully-defined set of arguments.\n",
"This could be useful to pre-generate large set of kernels not targeted to any one problem."
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "a76ec20f",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Found 9400 matching kernels\n"
]
}
],
"source": [
"def a_more_complex_filter(metadata: cutlass_api.metadata.KernelMetadata) -> bool:\n",
" \"\"\"\n",
" Find all GEMM kernels that support Float16 A and 2-CTA MMA\n",
" \"\"\"\n",
" # Only look at GEMM kernels\n",
" if not isinstance(metadata.operands, cutlass_api.metadata.GemmOperandsMetadata):\n",
" return False\n",
" # Only look at kernels with A-type F16\n",
" if metadata.operands.A.dtype != cutlass.Float16:\n",
" return False\n",
" # Only look at kernels with tile_shape[0] == 128\n",
" if getattr(metadata.design, \"tile_shape\", [None])[0] != 128:\n",
" return False\n",
" return True\n",
"\n",
"\n",
"# Look ma, no args! Fetch all kernels that match the filter,\n",
"# instead of supporting a complete set of args\n",
"kernels = cutlass_api.get_kernels(\n",
" args=None,\n",
" metadata_filter=a_more_complex_filter,\n",
")\n",
"print(f\"Found {len(kernels)} matching kernels\")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

View File

@@ -0,0 +1,518 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "a31330d3",
"metadata": {},
"source": [
"# Custom epilogue fusions for GEMMs\n",
"\n",
"Note: this notebook requires a GPU with compute capability 100 or 103:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "bb450878",
"metadata": {},
"outputs": [],
"source": [
"import cutlass_api\n",
"\n",
"if not (status := cutlass_api.utils.is_device_cc_supported({100, 103})):\n",
" print(f\"This notebook requires a GPU with compute capability 100 or 103.\\n{status.error}\")\n",
" import sys\n",
"\n",
" sys.exit(0)"
]
},
{
"cell_type": "markdown",
"id": "154e9d59",
"metadata": {},
"source": [
"The CUTLASS API provides flexible epilogue fusion support by allowing for the specification of an epilogue via high-level tensor operations that one would like to compose with an operation.\n",
"\n",
"For those familiar with the legacy CUTLASS Python API's [epilogue visitor tree frontend](https://github.com/NVIDIA/cutlass/blob/a2439551c765c5393aebe557ee75d3a0412d2211/examples/python/deprecated/04_epilogue_visitor.ipynb), much of the interface is shared.\n",
"\n",
"The CUTLASS API enables one to express an epilogue using a function operating at the `torch.Tensor`-level, and has tooling to automatically add this to kernels supporting the provided function. \n",
"\n",
"For example, in PyTorch one might write the following to compute a GEMM + epilogue:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e6d77d53",
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"\n",
"torch.manual_seed(2025)\n",
"\n",
"L, M, N, K = 1, 1024, 1024, 1024\n",
"A = torch.randn(L, M, K, device=\"cuda\", dtype=torch.float16)\n",
"B = torch.randn(L, K, N, device=\"cuda\", dtype=torch.float16)\n",
"C = torch.randn(L, M, N, device=\"cuda\", dtype=torch.float16)\n",
"\n",
"def my_epilogue(accum, C, alpha, beta, extra_scalar):\n",
" Aux = (alpha * accum) + (beta * C)\n",
" D = extra_scalar * Aux\n",
" return D, Aux\n",
"\n",
"alpha, beta, extra_scalar = 1.0, 2.0, 0.5\n",
"D, Aux = my_epilogue(A @ B, C, alpha, beta, extra_scalar)\n"
]
},
{
"cell_type": "markdown",
"id": "66ee4dd1",
"metadata": {},
"source": [
"The CUTLASS API allows the same epilogue function `my_epilogue` to be used in GEMMs provided by the API.\n",
"\n",
"To do so, one defines `EpilogueArguments` consisting of the epilogue function to compute (or a string representation of it) along with arguments corresponding to each input and output of the function (except for `accum`):"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "f079d9d6",
"metadata": {},
"outputs": [],
"source": [
"import cutlass_api\n",
"from cutlass_api.arguments import GemmArguments, EpilogueArguments\n",
"\n",
"# Allocate buffers for D and Aux\n",
"D_, Aux_ = [torch.empty((L, M, N), device=\"cuda\", dtype=torch.float16) for _ in range(2)]\n",
"\n",
"epi_args = EpilogueArguments(my_epilogue, C=C, alpha=alpha, beta=beta, extra_scalar=extra_scalar, D=D_, Aux=Aux_)\n"
]
},
{
"cell_type": "markdown",
"id": "97ef8e8a",
"metadata": {},
"source": [
"These arguments can be added to `GemmArguments` and passed in to `get_kernels()` for use when retrieving compatible kernels:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "60215c4e",
"metadata": {},
"outputs": [],
"source": [
"args = GemmArguments(A=A, B=B, out=D, accumulator_type=torch.float32, epilogue=epi_args)\n",
"kernels = cutlass_api.get_kernels(args)\n",
"assert len(kernels) > 0\n"
]
},
{
"cell_type": "markdown",
"id": "b0a7f9a2",
"metadata": {},
"source": [
"Each of the kernels returned by `get_kernels` can be compiled and executed just the same with these new arguments, as it was in examples without\n",
"epilogue fusion. For example, using the first kernel:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "150f3296",
"metadata": {},
"outputs": [],
"source": [
"kernels[0].run(args)\n",
"\n",
"torch.testing.assert_close(D, D_)\n",
"torch.testing.assert_close(Aux, Aux_)\n"
]
},
{
"cell_type": "markdown",
"id": "f1a826e3",
"metadata": {},
"source": [
"## How the epilogue fusion API works\n",
"To support specifying an epilogue via a Python function, a kernel needs some mechanism to:\n",
"1. Detect the operations in the epilogue function\n",
"2. Determine if the kernel can support the operations\n",
"3. Emit code to perform these operations within the kernel\n",
"\n",
"Step 1 listed above does not depend on the kernel and its implementation (e.g., DSL), while steps 2 and 3 depend on the kernel and/or its implementation.\n",
"\n",
"Thus, the CUTLASS API separates these components so that step 1 takes place at the API level and steps 2 and 3 take place in the kernel. This process is visualized below. We will walk through each step in greater detail.\n",
"\n",
"```python\n",
" +------------------------------------+\n",
" | def epi(accum, alpha, beta, C): |\n",
" | D = (accum * alpha) + (beta * C) | 1. Define epilogue via a Python function\n",
" | return D |\n",
" +------------------------------------+\n",
" |\n",
" |\n",
" |\n",
" GemmArguments(..., 2. Pass epilogue function, operands, and outputs\n",
" epilogue=EpilogueArguments( to EpilogueArguments constructor,\n",
" epi, alpha=alpha, beta=beta, C=C)) and add this to the GemmArguments. Under the\n",
" | hood, this parses the Python AST of the\n",
" | epilogue function to produce a DAG of load,\n",
" | store, and compute nodes.\n",
" V\n",
" +-----------------------------------------+ \n",
" | Intermediate DAG representation |\n",
" | =============================== |\n",
" | |\n",
" | Store() |\n",
" | | |\n",
" | Add() |\n",
" | / \\ |\n",
" | / \\ |\n",
" | / \\ |\n",
" | Mul() Mul() |\n",
" | / \\ / \\ |\n",
" | AccFetch() | Load(C) \\ |\n",
" | | \\ |\n",
" | Load(alpha) Load(beta) |\n",
" | |\n",
" +-----------------------------------------+\n",
" / | \\\n",
" / | \\ 3. Individual kernel classes use the DAG representation\n",
" / | \\ to determine if the kernel class supports the DAG.\n",
" Kernel 0 Kernel 1 Kernel 2 If so, the kernel class emits DSL-level operations\n",
" epilogue epilogue epilogue needed to compute the epilogue DAG alongside the\n",
" emitter emitter emitter basic operation of the kernel (e.g., GEMM).\n",
" | | |\n",
" | | |\n",
" V V V\n",
"```\n",
"\n",
"### Defining an epilogue via a Python function\n",
"Epilogue fusion patterns are defined by users in Python functions that perform Tensor-level operations -- using a `torch.Tensor` (for example) resulting from matrix multiplication, the function must be able to compute the desired results of the epilogue.\n",
"\n",
"The structure of these functions is as follows:\n",
"```python\n",
"def custom_epi_name(accum, *args) -> Union[TensorType, tuple[TensorType]]:\n",
" \"\"\"\n",
" :param accum: result of matrix multiplication, convolution, etc. before the epilogue\n",
" :type accum: TensorType\n",
" :param args: additional arguments to be used in the epilogue (e.g., aux tensors)\n",
" :type args: list[Union[TensorType, ScalarType]]\n",
"\n",
" :returns: at least one tensor resulting from the operation of the epilogue\n",
" :rtype: Union[TensorType, tuple[TensorType]]\n",
" \"\"\"\n",
" # Do some compute\n",
" return D # and potentially other values\n",
"```\n",
"\n",
"The user defines a custom epilogue via a Python function that **must** do at least the following:\n",
"1. Take in a first positional argument named `accum` that represents the result of operation just before the epilogue is to be performed. For example, in a GEMM, `accum = A @ B`.\n",
"2. Return at least one tensor that results from computing the epilogue. Currently, the return list must contain at least one output named `D`, though this constraint may be loosened in the future.\n",
"\n",
"Each additional argument following `accum` in the function definition is expected to be either a Tensor or scalar to be loaded. Each variable in the return statement represents a Tensor or scalar to be stored. The underlying implementation of the epilogue in the kernel will determine how operands are loaded and stored.\n",
"\n",
"Compute operations are represented in static single assignment (SSA) form.\n",
"This means that each variable can be assigned exactly once.\n",
"Operations currently supported ares:\n",
"* Tensor-tensor elementwise addition, subtraction, multiplication, and division\n",
"* Scalar broadcasts via addition, subtraction, multiplication, and division\n",
"* Predefined elementwise activation functions (e.g., ReLU, sigmoid, tanh)\n",
"\n",
"Operations that are not yet supported include:\n",
"* Row/column broadcasts (planned to be added soon)\n",
"* Reductions (planned to be added soon)\n",
"* Binary minimum and maximum functions (planned to be added soon)\n",
"If attempting to use these operations will result in no kernels being found in the call to `get_kernels`.\n",
"\n",
"Violations to SSA or use of unexpected operators will be flagged with an exception when parsing the AST of the custom epilogue.\n",
"\n",
"Examples of epilogues fitting these patterns are given below. We will show full, runnable examples at the end of this notebook.\n",
"```python\n",
"def relu_aux_store(accum, alpha, C):\n",
" # Note that the function definition itself does not indicate the types and\n",
" # ranks of alpha and C. Thus, one cannot tell whether the epilogue is performing\n",
" # broadcasts or elementwise operations until actual arguments or metadata are\n",
" # provided to the epilogue. See below for details.\n",
" F = (accum * alpha) + (C * 2.0) # Constant beta of 2.0\n",
" D = relu(F)\n",
" return D, F\n",
"\n",
"def aux_normalize(accum, aux):\n",
" D = accum / aux\n",
" return D\n",
"```\n",
"\n",
"Additional information about each operand and output must be provided by the user when constructing `EpilogueArguments`, as we will discuss below. This additional information is necessary for fully defining the operations being performed -- without knowledge of whether `alpha` is a scalar or a Tensor, we cannot determine whether multiplication by `alpha` is a broadcasted or elementwise operation.\n",
"\n",
"### Constructing epilogue arguments\n",
"`EpilogueArguments` encapsulate the arguments needed to determine the functional operation of a fused epilogue.\n",
"\n",
"A user must provide in the construction of `EpilogueArguments` tensors for all operands and outputs of the epilogue. However, unlike arguments for basic operations (e.g., GEMM), the full set of operands needed to be specified for an epilogue pattern depends upon the custom epilogue defined by the user.\n",
"\n",
"Therefore, `EpilogueArguments` is defined generically as taking in an `epilogue_fn` and additional `kwargs`. Under the hood, the AST for `epilogue_fn` is parsed to determine the operands and outputs of the epilogue. The user is required to provide in `kwargs` Tensors or scalars for all operands and outputs in the provided epilogue.\n",
"\n",
"For example, with an epilogue of:\n",
"```python\n",
"def my_epi(accum, alpha, C, beta):\n",
" F = (accum * alpha) + (C * beta)\n",
" D = relu(F)\n",
" return D, F\n",
"```\n",
"A user would need to construct epilogue arguments as follows:\n",
"```python\n",
"epi_args = EpilogueArguments(my_epi, alpha=..., C=..., beta=..., D=..., F=...)\n",
"```\n",
"\n",
"After verifying that all required operands and outputs are present, the constructor to `EpilogueArguments` will perform additional passes on the AST of `epilogue_fn` using the provided inputs to generate an internal DAG representing the epilogue. This DAG structure is attached to `EpilogueArguments` for use as they are passed through a call to `get_kernels`.\n",
"\n",
"### Discovering kernels that support the epilogue pattern\n",
"\n",
"The call to `get_kernels(args)` will return any kernels that support the provided `GemmArguments`.\n",
"Since the `GemmArguments` constructed above now include `EpilogueArguments`, returned kernels must support the provided epilogue.\n",
"\n",
"Under the hood of `get_kernels()`, each `Kernel` class will determine in its `generate_kernels()` method whether it supports the provided `EpilogueArguments`.\n",
"It can do so by traversing the DAG that resulted from the construction of `EpilogueArguments` to find the operations that compose the epilogue.\n",
"Assuming that the `Kernel` can support the DAG, it must then add to the source for the kernel any operations needed to support the DAG.\n",
"An example of how this is done generically for an SM100 CuTe DSL GEMM is provided in `sm100_static_persistent_efc.py`.\n",
"\n",
"## Example epilogues\n",
"We now provide various examples of adding custom epilogues to GEMM kernels targeting SM100. A broader set of epilogue examples are available in `test_gemm_epilogue_fusion.py`.\n",
"\n",
"### Auxiliary input and output tensors"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "171ac178",
"metadata": {},
"outputs": [],
"source": [
"from cutlass_api.fusion.activation import relu\n",
"\n",
"def relu_aux_store(accum, alpha, C):\n",
" F = (accum * alpha) + (C * 2.0) # Constant beta\n",
" D = relu(F)\n",
" return D, F\n",
"\n",
"C = torch.randn((L, M, N), device=\"cuda\", dtype=torch.float16)\n",
"alpha = 3.0\n",
"D = torch.empty((L, M, N), device=\"cuda\", dtype=torch.float16)\n",
"F = torch.empty((L, M, N), device=\"cuda\", dtype=torch.float16)\n",
"\n",
"epi_args = EpilogueArguments(relu_aux_store, alpha=alpha, C=C, D=D, F=F)\n",
"args = GemmArguments(A=A, B=B, out=D, accumulator_type=torch.float32, epilogue=epi_args)\n",
"kernels = cutlass_api.get_kernels(args, cc=100)\n",
"assert len(kernels) > 0\n",
"kernels[0].run(args)\n",
"\n",
"D_ref, F_ref = relu_aux_store(A @ B, alpha, C)\n",
"\n",
"torch.testing.assert_close(D, D_ref)\n",
"torch.testing.assert_close(F, F_ref)\n"
]
},
{
"cell_type": "markdown",
"id": "f947b403",
"metadata": {},
"source": [
"### Keyword functions and returning accumulator"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "62c2b49b",
"metadata": {},
"outputs": [],
"source": [
"def relu_scale_return_acc(accum, alpha, beta, C, scale):\n",
" F = relu((accum * alpha) + (C * beta))\n",
" D = F * scale\n",
" return D, F, accum\n",
"\n",
"C = torch.randn((L, M, N), device=\"cuda\", dtype=torch.float16)\n",
"alpha = 1.0\n",
"beta = 2.0\n",
"scale = 0.5\n",
"D = torch.empty((L, M, N), device=\"cuda\", dtype=torch.float16)\n",
"F = torch.empty((L, M, N), device=\"cuda\", dtype=torch.float16)\n",
"accum = torch.empty((L, M, N), device=\"cuda\", dtype=torch.float32)\n",
"\n",
"epi_args = EpilogueArguments(relu_scale_return_acc, alpha=alpha, beta=beta, C=C, scale=scale, D=D, F=F, accum=accum)\n",
"args = GemmArguments(A=A, B=B, out=D, accumulator_type=torch.float32, epilogue=epi_args)\n",
"kernels = cutlass_api.get_kernels(args, cc=100)\n",
"assert len(kernels) > 0\n",
"kernels[0].run(args)\n",
"\n",
"D_ref, F_ref, accum_ref = relu_scale_return_acc(A @ B, alpha, beta, C, scale)\n",
"\n",
"torch.testing.assert_close(D, D_ref)\n",
"torch.testing.assert_close(F, F_ref)\n",
"torch.testing.assert_close(accum, accum_ref.to(accum.dtype))\n"
]
},
{
"cell_type": "markdown",
"id": "c641911f",
"metadata": {},
"source": [
"### Passing a string representation of the function\n",
"`EpilogueArguments` can additionally be constructed using a string representation of the epilogue function:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "5987bf44",
"metadata": {},
"outputs": [],
"source": [
"epi_str = \"def epi(accum, alpha, beta, C): F = (accum * alpha) + (C * beta); D = relu(F); return D, F\"\n",
"\n",
"C = torch.randn((L, M, N), device=\"cuda\", dtype=torch.float16)\n",
"alpha = 1.0\n",
"beta = 0.5\n",
"D = torch.empty((L, M, N), device=\"cuda\", dtype=torch.float16)\n",
"F = torch.empty((L, M, N), device=\"cuda\", dtype=torch.float16)\n",
"\n",
"epi_args = EpilogueArguments(epi_str, alpha=alpha, beta=beta, C=C, D=D, F=F)\n",
"args = GemmArguments(A=A, B=B, out=D, accumulator_type=torch.float32, epilogue=epi_args)\n",
"kernels = cutlass_api.get_kernels(args, cc=100)\n",
"assert len(kernels) > 0\n",
"kernels[0].run(args)\n",
"\n",
"F_ref = (A @ B) * alpha + (C * beta)\n",
"D_ref = torch.relu(F_ref)\n",
"\n",
"torch.testing.assert_close(D, D_ref)\n",
"torch.testing.assert_close(F, F_ref)\n"
]
},
{
"cell_type": "markdown",
"id": "e26a58a2",
"metadata": {},
"source": [
"### Failure examples\n",
"The following are examples of constructing `EpilogueArguments` that are expected to fail."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "1e3d0c89",
"metadata": {},
"outputs": [],
"source": [
"####################################################\n",
"# Epilogues must take in an accumulator\n",
"####################################################\n",
"def fail_missing_accum(alpha, beta, C):\n",
" D = (C * beta)\n",
" return D\n",
"\n",
"try:\n",
" epi_args = EpilogueArguments(fail_missing_accum, alpha=alpha, beta=beta, C=C, D=D)\n",
" args = GemmArguments(A=A, B=B, out=D, accumulator_type=torch.float32, epilogue=epi_args)\n",
"except Exception as e:\n",
" # \"accum must be an input to the epilogue function\"\n",
" print(e)\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "48a359f7",
"metadata": {},
"outputs": [],
"source": [
"####################################################\n",
"# Epilogues must return an output named D\n",
"####################################################\n",
"def fail_missing_D(accum, alpha, beta, C):\n",
" F = (accum * alpha) + (C * beta)\n",
" return F\n",
"\n",
"try:\n",
" epi_args = EpilogueArguments(fail_missing_D, alpha=alpha, beta=beta, C=C, F=F)\n",
" args = GemmArguments(A=A, B=B, out=D, accumulator_type=torch.float32, epilogue=epi_args)\n",
"except Exception as e:\n",
" # \"On SM90 or higher, D is expected to be a output node with 0 users to enable smem reuse between C and D, but got []\"\n",
" print(e)\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "49d9ee94",
"metadata": {},
"outputs": [],
"source": [
"####################################################\n",
"# Epilogues must use single-static assignment (SSA)\n",
"####################################################\n",
"def fail_ssa(accum):\n",
" tmp = accum * 2.0\n",
" # Redefine tmp, which violates SSA form.\n",
" tmp = tmp - 1.0\n",
" D = tmp / 4.0\n",
" return D, tmp\n",
"\n",
"try:\n",
" epi_args = EpilogueArguments(fail_ssa, D=D, tmp=F)\n",
" args = GemmArguments(A=A, B=B, out=D, accumulator_type=torch.float32, epilogue=epi_args)\n",
"except Exception as e:\n",
" # \"Variable 'tmp' cannot be defined twice.\"\n",
" print(e)\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "871bb727",
"metadata": {},
"outputs": [],
"source": [
"####################################################\n",
"# Must provide all operands and outputs to\n",
"# EpilogueArguments\n",
"####################################################\n",
"def my_epi(accum, alpha, beta, C):\n",
" F = (accum * alpha) + (C * beta)\n",
" D = relu(F)\n",
" return D\n",
"\n",
"try:\n",
" # Missing D\n",
" epi_args = EpilogueArguments(my_epi, alpha=alpha, beta=beta, C=C)\n",
" args = GemmArguments(A=A, B=B, out=D, accumulator_type=torch.float32, epilogue=epi_args)\n",
"except Exception as e:\n",
" # \"Argument D is not provided in the kwargs of the EpilogueArguments constructor\"\n",
" print(e)\n",
"\n",
"try:\n",
" # Missing alpha\n",
" epi_args = EpilogueArguments(my_epi, beta=beta, C=C, D=D)\n",
" args = GemmArguments(A=A, B=B, out=D, accumulator_type=torch.float32, epilogue=epi_args)\n",
"except Exception as e:\n",
" # \"Argument alpha is not provided in the kwargs of the EpilogueArguments constructor\"\n",
" print(e)\n"
]
}
],
"metadata": {
"language_info": {
"name": "python"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

View File

@@ -0,0 +1,548 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "578f2730",
"metadata": {},
"source": [
"# Adding a kernel to the CUTLASS API\n",
"The CUTLASS API is designed to make it easy for users to add their own kernel\n",
"so that it can be discovered and run under the uniform API. We welcome contributions\n",
"toward the API by \"bringing your own kernel.\"\n",
"\n",
"This example shows how to add a CuTe DSL kernel to the CUTLASS API.\n",
"\n",
"## Bring your own implementation\n",
"Individuals wishing to add a CuTe DSL kernel to the CUTLASS API likely already\n",
"have the kernel written in CuTe DSL, but have not yet implemented the API's needed\n",
"interface. Within the API, we separate these components into the \"implementation\" --\n",
"the kernel written in CuTe DSL -- and the \"interface\" -- the definition of methods\n",
"a kernel needs to be used within the CUTLASS API.\n",
"\n",
"For example, consider the following implementation of a simple FP64 GEMM kernel implementation:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "5a64b0be",
"metadata": {},
"outputs": [],
"source": [
"from typing import Callable\n",
"\n",
"import cuda.bindings.driver as cuda\n",
"\n",
"import cutlass\n",
"import cutlass.cute as cute\n",
"\n",
"\n",
"class F64GemmKernelImplementation:\n",
" def __init__(self, cta_tile_shape_mn: tuple[int, int]):\n",
" self.cta_tile_shape_mn = cta_tile_shape_mn\n",
"\n",
" @cute.jit\n",
" def __call__(\n",
" self, a: cute.Tensor, b: cute.Tensor, out: cute.Tensor, stream: cuda.CUstream\n",
" ):\n",
" l, m, n = out.shape\n",
" m_tiles = (m + self.cta_tile_shape_mn[0] - 1) // self.cta_tile_shape_mn[0]\n",
" n_tiles = (n + self.cta_tile_shape_mn[1] - 1) // self.cta_tile_shape_mn[1]\n",
"\n",
" grid = (m_tiles, n_tiles, l)\n",
" block = [self.cta_tile_shape_mn[0], self.cta_tile_shape_mn[1], 1]\n",
" self.kernel(a, b, out).launch(grid=grid, block=block, stream=stream)\n",
"\n",
" @cute.kernel\n",
" def kernel(self, a: cute.Tensor, b: cute.Tensor, out: cute.Tensor):\n",
" l, m, n = out.shape\n",
" k = a.shape[-1]\n",
" m_tile, n_tile, l_idx = cute.arch.block_idx()\n",
" tidx, tidy, _ = cute.arch.thread_idx()\n",
"\n",
" m_idx = m_tile * self.cta_tile_shape_mn[0] + tidx\n",
" n_idx = n_tile * self.cta_tile_shape_mn[1] + tidy\n",
"\n",
" if m_idx < m and n_idx < n:\n",
" out[l_idx, m_idx, n_idx] = cutlass.Float64(0)\n",
" for k_idx in range(k):\n",
" out[l_idx, m_idx, n_idx] += (\n",
" a[l_idx, m_idx, k_idx] * b[l_idx, k_idx, n_idx]\n",
" )"
]
},
{
"cell_type": "markdown",
"id": "36a08d4b",
"metadata": {},
"source": [
"The implementation is configurable via a `cta_tile_shape_mn` argument, which\n",
"controls the size of blocks and tiles in the M and N modes. A simple `cute.jit` function\n",
"computes the grid and block size for the input problem based on `cta_tile_shape_mn`,\n",
"and launches the kernel. The `cute.kernel` itself simply has each thread compute a single\n",
"output element of the matrix by taking a dot product.\n",
"\n",
"This implementation is not performant, but is kept simple for illustrative purposes."
]
},
{
"cell_type": "markdown",
"id": "a5d0e661",
"metadata": {},
"source": [
"## Defining interface methods\n",
"As it currently stands, this GEMM kernel implementation cannot be used via the\n",
"CUTLASS API because it does not implement interface methods. Specifically, kernels\n",
"within the CUTLASS API must inherit from and implement the `cutlass_api.Kernel`\n",
"abstract class. This class has methods needed for many common operations\n",
"performed when compiling and executing DSL kernels.\n",
"\n",
"Certain providers (i.e., DSLs), such as CuTe DSL, provide an additional layer atop the\n",
"`cutlass_api.Kernel` class to add utilities for kernels being written\n",
"via that provider. For example, the CuTe DSL provider in the CUTLASS API\n",
"defines `cutlass_api.providers.cutedsl.kernel.CuteDslKernel`, which adds utilities surrounding\n",
"`cute.compile()` to add compile-time arguments needed for using TVM-FFI when\n",
"it is enabled.\n",
"\n",
"We will next walk through the steps in defining interface methods for this\n",
"implementation"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "1a2da869",
"metadata": {},
"outputs": [],
"source": [
"import itertools\n",
"\n",
"import cutlass_api\n",
"from cutlass_api.arguments import GemmArguments\n",
"from cutlass_api.metadata import KernelMetadata\n",
"from cutlass_api.status import Status"
]
},
{
"cell_type": "markdown",
"id": "86ae75cc",
"metadata": {},
"source": [
"We begin by defining a class to represent the kernel's interface.\n",
"As mentioned above, since this is a CuTe DSL kernel, our interface must\n",
"inherit from and implement `cutlass_api.providers.cutedsl.kernel.CuteDslKernel`.\n",
"\n",
"The class must additionally be registered with the CuTe DSL provider\n",
"via the `@CuTeDSLProvider.register` decorator so that the class\n",
"can be considered when discovering kernels."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "3a86d138",
"metadata": {},
"outputs": [],
"source": [
"@cutlass_api.providers.cutedsl.CuTeDSLProvider.register\n",
"class F64GemmKernel(cutlass_api.providers.cutedsl.kernel.CuteDslKernel):\n",
" # Empty versions of interface methods. These will be implemented later, interspersed\n",
" # with notebook markdown. Normally, one would define them inline with the class definition.\n",
" def __init__(self, metadata: KernelMetadata): pass\n",
"\n",
" def _run(self, args: GemmArguments, artifact: cutlass_api.artifact.CompiledArtifact, stream, workspace=None): pass\n",
"\n",
" def compile(self, args: GemmArguments, cc: int = None) -> cutlass_api.artifact.CompiledArtifact: pass\n",
"\n",
" @staticmethod\n",
" def generate_kernels(metadata_filter, epilogue_args=None, cc=None) -> list[\"F64GemmKernel\"]: pass\n",
"\n",
" def _supports(self, args: GemmArguments) -> Status: pass\n",
"\n",
" def get_workspace_size(self, args: GemmArguments) -> int: pass"
]
},
{
"cell_type": "markdown",
"id": "327e9e7c",
"metadata": {},
"source": [
"The `__init__` method of the class takes in a `KernelMetadata` object\n",
"from which it extracts the `cta_tile_shape_mn`. This is used to construct\n",
"the kernel implementation object. We will discuss later how the `KernelMetadata`\n",
"object passed in here is constructed:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "785d1882",
"metadata": {},
"outputs": [],
"source": [
"def __init__(self, metadata: KernelMetadata):\n",
" self.metadata = metadata\n",
" cta_tile_shape_mn = metadata.design.tile_shape[:2]\n",
" self.impl = F64GemmKernelImplementation(cta_tile_shape_mn)"
]
},
{
"cell_type": "markdown",
"id": "500a0030",
"metadata": {},
"source": [
"### Defining interfaces for compilation and execution\n",
"The interfaces needed for compilation and execution are simple.\n",
"\n",
"The `compile` method simply constructs a placeholder stream object\n",
"and passes that and relevant arguments to `self.cute_compile`. This\n",
"is a utility defined in the `CuteDSLKernel` abstract class that\n",
"passes in compilation flags needed for certain options to `cute.compile`\n",
"(e.g., TVM-FFI). The result is wrapped as a `CompiledArtifact`."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "63b4a129",
"metadata": {},
"outputs": [],
"source": [
"def compile(self, args: GemmArguments, cc: int = None) -> cutlass_api.artifact.CompiledArtifact:\n",
" stream = cutlass.cute.runtime.make_fake_stream()\n",
" compiled_gemm = self.cute_compile(self.impl, args.A, args.B, args.out, stream)\n",
" return cutlass_api.artifact.CompiledArtifact(compiled_gemm, self)"
]
},
{
"cell_type": "markdown",
"id": "023127fd",
"metadata": {},
"source": [
"Users define the `_run` method rather than the top-level `run` method\n",
"(no leading underscore) that is used in interacting with kernels. `_run` (1) extracts from `args`\n",
"the arguments needed to run the JIT function, and (2) calls the JIT function\n",
"passed in via `artifact` with these arguments."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "2ae7c009",
"metadata": {},
"outputs": [],
"source": [
"def _run(self, args: GemmArguments, artifact: cutlass_api.artifact.CompiledArtifact, stream, workspace=None):\n",
" stream = cutlass_api.utils.to_cuda_stream(stream)\n",
" compiled_gemm = artifact.compiled_obj\n",
" self.cute_run(compiled_gemm, args.A, args.B, args.out, stream)"
]
},
{
"cell_type": "markdown",
"id": "4052e5a0",
"metadata": {},
"source": [
"Finally, since this kernel does not require any device workspace,\n",
"we give it a simple `get_workspace_size` method that always returns 0."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "968906ea",
"metadata": {},
"outputs": [],
"source": [
"def get_workspace_size(self, args: GemmArguments) -> int:\n",
" return 0"
]
},
{
"cell_type": "markdown",
"id": "e245a319",
"metadata": {},
"source": [
"### Defining interfaces for kernel generation\n",
"We have implemented the interfaces needed for constructing the kernel\n",
"interface, compiling it, and running it. We now must implement methods for\n",
"generating the possible configurations of this kernel that the kernel\n",
"class itself supports. This will be used in kernel discovery (e.g., via\n",
"`cutlass_api.get_kernels()`).\n",
"\n",
"To do so, we write the `generate_kernels` method. This takes in a\n",
"binary function `metadata_filter`, epilogue arguments `epilogue_args`,\n",
"and a compute capability `cc`. It returns a list of all instances\n",
"of the kernel interface that support the `epilogue_args`, are compatible\n",
"with the given `cc`, and which pass the `metadata_filter`.\n",
"\n",
"The `Kernel` class is responsible for defining what valid possible configurations (instances) of it can exist.\n",
"In this example, the valid configurations involve a cross-product of row/column-major strides and two preset tile shapes.\n",
"We create a nested loop over these knobs and create a `KernelMetadata` corresponding to each unique configuration.\n",
"\n",
"The `generate_kernels` method must additionally filter the generated kernels by passing it through a `metadata_filter`.\n",
"This is a user-provided custom filter to filter generated metadata combinations. More information on `metadata_filter` is provided in other examples."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "47dc2f20",
"metadata": {},
"outputs": [],
"source": [
"@staticmethod\n",
"def generate_kernels(\n",
" metadata_filter: Callable[[KernelMetadata], bool],\n",
" epilogue_args: cutlass_api.arguments.EpilogueArguments = None,\n",
" cc: int = None,\n",
") -> list[\"F64GemmKernel\"]:\n",
"\n",
" # The tile shapes this kernel supports/exposes\n",
" supported_tile_shapes = [(32, 32, 1), (16, 16, 1)]\n",
"\n",
" if epilogue_args is not None:\n",
" return []\n",
"\n",
" row_major_stride = (0, 0, 1)\n",
" col_major_stride = (0, 1, 0)\n",
" stride_combos = list(itertools.product([row_major_stride, col_major_stride], repeat=3))\n",
" alignment = 1\n",
"\n",
" def stride_name(stride): \n",
" return \"T\" if stride == row_major_stride else \"N\"\n",
"\n",
" kernels = []\n",
" for tile_shape in supported_tile_shapes:\n",
" design_metadata = cutlass_api.metadata.BLASDesignMetadata(tile_shape, (1, 1, 1))\n",
" for stride_A, stride_B, stride_out in stride_combos:\n",
" # Create TensorAttributes for A, B, and out tensors\n",
" a_attrs = cutlass_api.metadata.TensorAttributes(cutlass.Float64, stride_A, alignment)\n",
" b_attrs = cutlass_api.metadata.TensorAttributes(cutlass.Float64, stride_B, alignment)\n",
" out_attrs = cutlass_api.metadata.TensorAttributes(cutlass.Float64, stride_out, alignment)\n",
" layout_str = cutlass_api.utils.strides_to_layout_string(stride_A, stride_B, stride_out)\n",
"\n",
" name = f\"F64GemmKernel_tile{tile_shape[0]}x{tile_shape[1]}_{layout_str}\"\n",
"\n",
" metadata = KernelMetadata(\n",
" kernel_name=name,\n",
" kernel_class=F64GemmKernel,\n",
" operands=cutlass_api.metadata.GemmOperandsMetadata(\n",
" a_attrs, b_attrs, out_attrs, accumulator_type=cutlass.Float64\n",
" ),\n",
" design=design_metadata,\n",
" min_cc=0,\n",
" )\n",
"\n",
" if metadata_filter(metadata):\n",
" kernels.append(F64GemmKernel(metadata))\n",
"\n",
" return kernels"
]
},
{
"cell_type": "markdown",
"id": "c7cdbc66",
"metadata": {},
"source": [
"We also add a method for indicating whether a kernel instance in question\n",
"supports a set of arguments. The top-level `Kernel.supports` method will\n",
"already verify that the `args` passed in match the metadata with which\n",
"this `Kernel` instance was constructed. Here, we define additional\n",
"checks specific to this kernel, such as that the kernel expects\n",
"all operands to be of rank 3:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "54067d47",
"metadata": {},
"outputs": [],
"source": [
"def _supports(self, args: GemmArguments) -> Status:\n",
" if not (\n",
" len(args.A.shape) == 3 and # A should be (L, M, K)\n",
" len(args.B.shape) == 3 and # B should be (L, K, N)\n",
" len(args.out.shape) == 3 # out should be (L, M, N)\n",
" ):\n",
" return Status.fail(\"All operands must be rank 3.\")\n",
" return Status.success()\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "edaf2cba",
"metadata": {},
"outputs": [],
"source": [
"# Assign methods to the class because we interspersed notebook markdown\n",
"# with the class definition. This is not needed in a real implementation.\n",
"F64GemmKernel.__init__ = __init__\n",
"F64GemmKernel.compile = compile\n",
"F64GemmKernel._run = _run\n",
"F64GemmKernel._supports = _supports\n",
"F64GemmKernel.generate_kernels = generate_kernels\n",
"F64GemmKernel.get_workspace_size = get_workspace_size"
]
},
{
"cell_type": "markdown",
"id": "c8fc84e9",
"metadata": {},
"source": [
"## Discovering instances of the kernel and using them\n",
"The CUTLASS API is now prepared to discover instances of this\n",
"kernel interface just as was done in previous examples.\n",
"\n",
"We add a small modification of using a `metadata_filter`\n",
"to ensure that all returned kernels are instances of the\n",
"`F64GemmKernel` class we just implemented. This is needed\n",
"only for example/testing purposes."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "cec5431d",
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"\n",
"torch.manual_seed(2025)\n",
"\n",
"L, M, N, K = 1, 256, 1024, 128\n",
"A = torch.randn(L, M, K, device=\"cuda\", dtype=torch.float64)\n",
"B = torch.randn(L, K, N, device=\"cuda\", dtype=torch.float64)\n",
"out = torch.empty(L, M, N, device=\"cuda\", dtype=torch.float64)\n",
"\n",
"args = GemmArguments(A, B, out, accumulator_type=torch.float64)\n",
"\n",
"def is_f64gemm_kernel(metadata):\n",
" return metadata.kernel_class == F64GemmKernel\n",
"\n",
"kernels = cutlass_api.get_kernels(args, metadata_filter=is_f64gemm_kernel)"
]
},
{
"cell_type": "markdown",
"id": "50e81a7d",
"metadata": {},
"source": [
"We can print off the names of the first few kernels to see that\n",
"they come from our recently-added kernel."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "cdb92b5e",
"metadata": {},
"outputs": [],
"source": [
"print(kernels[0].metadata.kernel_name)\n",
"print(kernels[1].metadata.kernel_name)"
]
},
{
"cell_type": "markdown",
"id": "697ee3c3",
"metadata": {},
"source": [
"We can evaluate and test the correctness of an instance of our kernel:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "f5486244",
"metadata": {},
"outputs": [],
"source": [
"kernels[0].run(args)\n",
"torch.testing.assert_close(out, A @ B)"
]
},
{
"cell_type": "markdown",
"id": "8de96f7e",
"metadata": {},
"source": [
"We can also test the limits of our kernel's design space by providing a\n",
"metadata filter that expects a CTA tile size M of 256, which is not exposed\n",
"in the `generate_kernels` method of our recently-added kernel. We expect\n",
"no kernels of type `F64GemmKernel` to be returned."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "917c74e3",
"metadata": {},
"outputs": [],
"source": [
"def my_filter(metadata):\n",
" return (\n",
" is_f64gemm_kernel(metadata) and\n",
" isinstance(metadata.design, cutlass_api.metadata.BLASDesignMetadata) and\n",
" metadata.design.tile_shape[0] == 256\n",
" )\n",
"kernels_ctam256 = cutlass_api.get_kernels(args, metadata_filter=my_filter)\n",
"\n",
"# No kernels should be found\n",
"assert len(kernels_ctam256) == 0"
]
},
{
"cell_type": "markdown",
"id": "caa80a7d",
"metadata": {},
"source": [
"## A note on contributing kernels to directory structure\n",
"This example showed how to define a kernel inline and add it to the\n",
"API for example purposes. This kernel doesn't necessarily need to live\n",
"within the API's source code.\n",
"\n",
"We welcome contributions of kernels that do live within the CUTLASS\n",
"API's repository as well.\n",
"\n",
"Kernels in the repository are organized based on the \"provider\" in which they are\n",
"authored (i.e., the DSL). All kernels corresponding to a given\n",
"provider live a directory corresponding to that provider under\n",
"`cutlass_api/providers`. For example, CuTe DSL kernels live\n",
"under `cutlass_api/providers/cutedsl`.\n",
"\n",
"Each provider can organize kernels differently. For CuTe DSL,\n",
"kernels are further split based on their logical operation,\n",
"with GEMM kernels under the `cutlass_api/providers/cutedsl/gemm`\n",
"directory.\n",
"\n",
"We recommend separating the implementation of the kernel from\n",
"its interface not just by using separate classes, as done in\n",
"this example, but also by separating the implementation and\n",
"interface into separate files. This makes it easier to update\n",
"each without affecting the other.\n",
"\n",
"For example, CuTe DSL GEMM kernels have the following organization:\n",
"```text\n",
"cutlass_api/\n",
" providers/\n",
" cutedsl/\n",
" gemm/\n",
" sm100_static_persistent.py\n",
" implementations/\n",
" sm100_static_persistent_impl.py\n",
"```"
]
}
],
"metadata": {
"language_info": {
"name": "python"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

View File

@@ -0,0 +1,521 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "f97e61c9",
"metadata": {},
"source": [
"# Best practices for reducing host-side latency"
]
},
{
"cell_type": "markdown",
"id": "a7a9c63c",
"metadata": {},
"source": [
"Overall performance depends on both device performance (i.e., that of the kernel) and host performance (i.e., that of the runtime).\n",
"This notebook focuses on the latter: techniques to minimize any overheads incurred from the CUTLASS API and underlying\n",
"DSL runtimes.\n",
"\n",
"This notebook does not discuss techniques for improving device-side performance. A future notebook may cover this topic."
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "e3ca9e40",
"metadata": {},
"outputs": [],
"source": [
"import time\n",
"import torch\n",
"import cutlass_api"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "efaac09c",
"metadata": {},
"outputs": [],
"source": [
"if not (status := cutlass_api.utils.is_device_cc_supported({100, 103})):\n",
" print(f\"This notebook requires a GPU with compute capability 100 or 103.\\n{status.error}\")\n",
" import sys\n",
"\n",
" sys.exit(0)"
]
},
{
"cell_type": "markdown",
"id": "40de11ce",
"metadata": {},
"source": [
"We start with boilerplate initial setup to create tensors and pick a kernel.\n",
"\n",
"For the purposes of this notebook, we use a very small GEMM size of M=N=K=128\n",
"and L=1. This small size is chosen to magnify the impact of host latency on\n",
"end-to-end performance so as to better illustrate the effect of the techniques\n",
"described below."
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "b8c44947",
"metadata": {},
"outputs": [],
"source": [
"warmup_iterations = 10\n",
"profiling_iterations = 100\n",
"total_iterations = warmup_iterations + profiling_iterations\n",
"\n",
"# Use a small problem size to showcase host overheads\n",
"L, M, N, K = 1, 128, 128, 128\n",
"\n",
"# We use different operands in each iteration. Though not particularly relevant for\n",
"# host latency, this is a best practice when benchmarking GPU kernels to avoid\n",
"# unrealistic caching effects.\n",
"As = [torch.randint(-1, 2, (M, K), device=\"cuda\", dtype=torch.float16) for _ in range(total_iterations)]\n",
"Bs = [torch.randint(-1, 2, (K, N), device=\"cuda\", dtype=torch.float16) for _ in range(total_iterations)]\n",
"outs = [torch.empty((M, N), device=\"cuda\", dtype=torch.float16) for _ in range(total_iterations)]\n",
"\n",
"# Construct arguments outside of the benchmarking loop. We will later also consider\n",
"# cases in which they are constructed inside the benchmarking loop.\n",
"args = [cutlass_api.arguments.GemmArguments(A=As[i], B=Bs[i], out=outs[i], accumulator_type=torch.float16) for i in range(total_iterations)]\n",
"\n",
"references = [(As[i] @ Bs[i]).to(outs[i].dtype) for i in range(total_iterations)]\n",
"\n",
"kernels = cutlass_api.get_kernels(args[0], cc=100)\n",
"\n",
"assert len(kernels) > 0\n",
"kernel = kernels[0]"
]
},
{
"cell_type": "markdown",
"id": "f2e7eece",
"metadata": {},
"source": [
"We next set up a basic benchmarking routine."
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "2472eafa",
"metadata": {},
"outputs": [],
"source": [
"def benchmark(label, code, warmup_it=warmup_iterations, profiling_it=profiling_iterations):\n",
" total_it = warmup_it + profiling_it\n",
" assert total_it <= total_iterations, f\"Benchmark-local iteration count must be less than or equal to total iterations: {total_it} > {total_iterations}\"\n",
" # warmup\n",
" rets = [None] * total_it\n",
" for i in range(warmup_it):\n",
" rets[i] = code(i)\n",
" torch.cuda.synchronize()\n",
"\n",
" start = time.time()\n",
" for i in range(profiling_it):\n",
" idx = warmup_it + i\n",
" rets[idx] = code(idx)\n",
" torch.cuda.synchronize()\n",
" end = time.time()\n",
"\n",
" avg_time = (end - start) / profiling_it\n",
" print(f\"[{label:<30}] avg of {profiling_it} iterations: {avg_time:1.3e} seconds\")\n",
" return avg_time, rets"
]
},
{
"cell_type": "markdown",
"id": "4909a76b",
"metadata": {},
"source": [
"We now describe techniques for reducing host latency:\n",
"* Compile once, run many times\n",
"* Bypassing checks for argument-kernel compatibility\n",
"* Using [CUDA Graphs](https://developer.nvidia.com/blog/cuda-graphs/)\n",
"* Using [TVM FFI](https://tvm.apache.org/ffi/)\n",
"\n",
"These techniques are complementary and should be used together when applicable\n",
"for an application."
]
},
{
"cell_type": "markdown",
"id": "06495033",
"metadata": {},
"source": [
"### Compile once, run many times\n",
"The `kernel.run` method takes in an optional `compiled_artifact` argument of type\n",
"`cutlass_api.artifact.CompiledArtifact`. When this argument is set, the kernel\n",
"will directly use the precompiled function within `compiled_argument`. When\n",
"it is not set, the call to `kernel.run` will JIT compile the kernel on each\n",
"invocation.\n",
"\n",
"Precompiling the kernel is critical to achieving good performance."
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "6de11f56",
"metadata": {},
"outputs": [],
"source": [
"stream = torch.cuda.current_stream()\n",
"def no_compiled_artifact(i: int):\n",
" return kernel.run(args[i], stream=stream)\n",
"\n",
"# Compile the kernel once, reuse for each iterations\n",
"compiled_artifact = kernel.compile(args[0])\n",
"\n",
"def with_compiled_artifact(i: int):\n",
" return kernel.run(args[i], stream=stream, compiled_artifact=compiled_artifact)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "350c9bd6",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[Without compiled artifact ] avg of 5 iterations: 1.376e+00 seconds\n",
"[With compiled artifact ] avg of 5 iterations: 1.016e-05 seconds\n"
]
}
],
"source": [
"time_no_artifact, _ = benchmark(f\"Without compiled artifact\", no_compiled_artifact, warmup_it=2, profiling_it=5)\n",
"time_w_artifact, _ = benchmark(f\"With compiled artifact\", with_compiled_artifact, warmup_it=2, profiling_it=5)"
]
},
{
"cell_type": "markdown",
"id": "5cfbc2d2",
"metadata": {},
"source": [
"### Bypassing checks for argument-kernel compatibility\n",
"By default, the call to `kernel.run` will check if the kernel supports the provided arguments.\n",
"Under the hood, this invokes `kernel.supports(args)`.\n",
"\n",
"While these checks are helpful for catching incompatible arguments, they are performed\n",
"in Python, and thus can add to host overhead.\n",
"\n",
"When confident that arguments will be compatible with a kernel, one should bypass\n",
"the `supports` check in `kernel.run` by setting the optional `assume_supported_args`\n",
"argument to `True`."
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "5b93dfae",
"metadata": {},
"outputs": [],
"source": [
"def with_supports_check(i: int):\n",
" return kernel.run(args[i], compiled_artifact=compiled_artifact, stream=stream, assume_supported_args=False)\n",
"\n",
"def without_supports_check(i: int):\n",
" return kernel.run(args[i], compiled_artifact=compiled_artifact, stream=stream, assume_supported_args=True)"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "b282f437",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[With supports check ] avg of 100 iterations: 1.463e-05 seconds\n",
"[Bypass supports check ] avg of 100 iterations: 6.239e-06 seconds\n",
"Speedup with skip supports: 2.34x\n"
]
}
],
"source": [
"time_w_supports, _ = benchmark(\"With supports check\", with_supports_check)\n",
"time_wo_supports, _ = benchmark(\"Bypass supports check\", without_supports_check)\n",
"print(f\"Speedup with skip supports: {time_w_supports / time_wo_supports:.2f}x\")"
]
},
{
"cell_type": "markdown",
"id": "d74cb3e7",
"metadata": {},
"source": [
"### CUDA Graphs"
]
},
{
"cell_type": "markdown",
"id": "656d5e2c",
"metadata": {},
"source": [
"CUTLASS API supports [CUDA Graphs](https://developer.nvidia.com/blog/cuda-graphs/) usage with PyTorch as usual.\n",
"\n",
"The kernel compilation must happen outside the CUDA graph. Then, we create a graph using usual PyTorch idioms to launch a kernel several times on the graph's stream."
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "e614509f",
"metadata": {},
"outputs": [],
"source": [
"num_launches = 20\n",
"\n",
"# Create a CUDA Graph to run our compiled kernel N times\n",
"g = torch.cuda.CUDAGraph()\n",
"with torch.cuda.graph(g):\n",
" # Run N iterations of our compiled kernel on the current stream\n",
" for i in range(num_launches):\n",
" kernel.run(\n",
" args[i],\n",
" compiled_artifact=compiled_artifact,\n",
" stream=torch.cuda.current_stream(),\n",
" assume_supported_args=True,\n",
" )\n",
"\n",
"# Zero the output so we don't refcheck stale results\n",
"_ = outs[0].zero_()"
]
},
{
"cell_type": "markdown",
"id": "8fc69c6e",
"metadata": {},
"source": [
"Once captured, we can replay the graph. This will only replay the kernel launches placed on the CUDA stream.\n",
"Any other prepratory work on the host and arguments passed in from python are cached during the capture."
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "d9c5d5c5",
"metadata": {},
"outputs": [],
"source": [
"# Replay captured graph and check first result\n",
"g.replay()\n",
"\n",
"torch.testing.assert_close(outs[0], references[0])"
]
},
{
"cell_type": "markdown",
"id": "388c8e02",
"metadata": {},
"source": [
"Let's compare the timing:"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "45d4e739",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[20 launches without CUDA Graph] avg of 1 iterations: 4.699e-04 seconds\n",
"[20 launches with CUDA Graph ] avg of 1 iterations: 9.084e-05 seconds\n",
"Speedup with CUDA Graph: 5.17x\n"
]
}
],
"source": [
"def without_cuda_graph(x: int):\n",
" for i in range(num_launches):\n",
" kernel.run(\n",
" args[i],\n",
" compiled_artifact=compiled_artifact,\n",
" stream=torch.cuda.current_stream(),\n",
" assume_supported_args=True,\n",
" )\n",
"\n",
"def with_cuda_graph(x: int):\n",
" g.replay()\n",
"\n",
"\n",
"time_wo_cuda_graph, _ = benchmark(f\"{num_launches} launches without CUDA Graph\", without_cuda_graph, warmup_it=0, profiling_it=1)\n",
"time_w_cuda_graph, _ = benchmark(f\"{num_launches} launches with CUDA Graph\", with_cuda_graph, warmup_it=0, profiling_it=1)\n",
"\n",
"print(f\"Speedup with CUDA Graph: {time_wo_cuda_graph / time_w_cuda_graph:.2f}x\")"
]
},
{
"cell_type": "markdown",
"id": "fe5c3168",
"metadata": {},
"source": [
"### TVM FFI"
]
},
{
"cell_type": "markdown",
"id": "ee7f9fd2",
"metadata": {},
"source": [
"When applicable, CUTLASS API uses [Apache TVM FFI](https://tvm.apache.org/ffi/) under the hood for invoking compiled DSL kernels from Python.\n",
"Apache TVM FFI is an open ABI and FFI for machine learning systems.\n",
"\n",
"TVM FFI is enabled by default in CUTLASS API, and is recommended for best performance."
]
},
{
"cell_type": "markdown",
"id": "1690bbed",
"metadata": {},
"source": [
"`cutlass_api.config.GlobalOptions().use_tvm_ffi` controls whether or not TVM-FFI will be used by CUTLASS API."
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "993c60ae",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"True\n"
]
}
],
"source": [
"print(cutlass_api.config.GlobalOptions().use_tvm_ffi)"
]
},
{
"cell_type": "markdown",
"id": "00ed9a40",
"metadata": {},
"source": [
"If for some reason you do not wish to use it, this section demonstrates how, you can set this to False. No other change is needed. The below code compares the performance with and without TVM-FFI."
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "e8f56be3",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[[TVM-FFI ON ] Create args ] avg of 100 iterations: 8.367e-05 seconds\n",
"[[TVM-FFI ON ] Compile kernel ] avg of 5 iterations: 1.352e+00 seconds\n",
"[[TVM-FFI ON ] Run kernel ] avg of 100 iterations: 6.509e-06 seconds\n"
]
}
],
"source": [
"cutlass_api.config.GlobalOptions().use_tvm_ffi = True\n",
"\n",
"def run_iteration(i):\n",
" args = cutlass_api.arguments.GemmArguments(A=As[i], B=Bs[i], out=outs[i], accumulator_type=torch.float16)\n",
" return kernel.run(\n",
" args,\n",
" compiled_artifact=compiled_artifact,\n",
" stream=torch.cuda.current_stream(),\n",
" assume_supported_args=True,\n",
" )\n",
"\n",
"def create_arguments(i: int):\n",
" return cutlass_api.arguments.GemmArguments(A=As[i], B=Bs[i], out=outs[i], accumulator_type=torch.float16)\n",
"\n",
"args_creation_on, args = benchmark(\"[TVM-FFI ON ] Create args\", create_arguments)\n",
"compilation_on, compiled = benchmark(\"[TVM-FFI ON ] Compile kernel\", lambda i: kernel.compile(args[i]), warmup_it=2, profiling_it=5)\n",
"compiled_artifact = compiled[0]\n",
"run_on, _ = benchmark(\"[TVM-FFI ON ] Run kernel\", lambda i: kernel.run(args[i], compiled_artifact=compiled_artifact, assume_supported_args=True, stream=stream))"
]
},
{
"cell_type": "code",
"execution_count": 14,
"id": "5a4c2db4",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[[TVM-FFI OFF ] Create args ] avg of 100 iterations: 1.255e-04 seconds\n",
"[[TVM-FFI OFF ] Compile kernel ] avg of 5 iterations: 1.278e+00 seconds\n",
"[[TVM-FFI OFF ] Run kernel ] avg of 100 iterations: 4.519e-05 seconds\n"
]
}
],
"source": [
"cutlass_api.config.GlobalOptions().use_tvm_ffi = False\n",
"args_creation_off, args = benchmark(\"[TVM-FFI OFF ] Create args\", create_arguments)\n",
"compilation_off, compiled = benchmark(\"[TVM-FFI OFF ] Compile kernel\", lambda i: kernel.compile(args[i]), warmup_it=2, profiling_it=5)\n",
"compiled_artifact = compiled[0]\n",
"run_off, _ = benchmark(\"[TVM-FFI OFF ] Run kernel\", lambda i: kernel.run(args[i], compiled_artifact=compiled_artifact, assume_supported_args=True, stream=stream))"
]
},
{
"cell_type": "code",
"execution_count": 15,
"id": "17b43718",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Speedups with TVM-FFI: \n",
"Arg creation: 1.50x\n",
"Compilation: 0.95x\n",
"Run: 6.94x\n"
]
}
],
"source": [
"print(\"Speedups with TVM-FFI: \")\n",
"print(f\"Arg creation: {args_creation_off / args_creation_on:.2f}x\")\n",
"print(f\"Compilation: {compilation_off / compilation_on:.2f}x\")\n",
"print(f\"Run: {run_off / run_on:.2f}x\")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.5"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

View File

@@ -0,0 +1,76 @@
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
[build-system]
requires = ["setuptools>=61.0"]
build-backend = "setuptools.build_meta"
[project]
name = "cutlass_api"
dynamic = ["version"]
requires-python = ">=3.12"
dependencies = [
"nvidia-cutlass-dsl",
"apache-tvm-ffi", # Required for optimal performance. To opt-out: uninstall or set cutlass_api.config.GlobalOptions().use_tvm_ffi=False
]
[project.optional-dependencies]
torch = [
"torch",
"torch-c-dlpack-ext",
]
test = [
"jupyter",
"pytest",
"cutlass_api[torch]",
]
dev = [
"cutlass_api[test]",
]
[tool.setuptools.packages.find]
where = ["."]
[tool.setuptools.dynamic]
version = {attr = "cutlass_api.__version__"}
[tool.ruff]
target-version = "py312"
src = ["cutlass_api"]
# Exclude fusion/ from all ruff checks
extend-exclude = [
"cutlass_api/fusion",
]
[tool.ruff.lint]
select = [
"F", # pyflakes
"E", # pycodestyle errors
"W", # pycodestyle warnings
"I", # isort
"UP", # pyupgrade
"B", # flake8-bugbear
"SIM", # flake8-simplify
"C4", # flake8-comprehensions
"PTH", # flake8-use-pathlib
"FA", # flake8-future-annotations
"PERF", # perflint - performance anti-patterns
"FURB", # refurb - modern Python idioms
"PIE", # flake8-pie - misc improvements
"TCH", # flake8-type-checking
"PLE", # pylint errors
"RSE", # flake8-raise
]
ignore = [
"SIM103", # nedless-bool - can be noisy
"B905", # zip(strict=) - too strict
]
[tool.ruff.lint.isort]
section-order = ["future", "standard-library", "third-party", "cutlass", "first-party", "local-folder"]
known-first-party = ["cutlass_api"]
[tool.ruff.lint.isort.sections]
cutlass = ["cutlass"]

View File

@@ -0,0 +1,166 @@
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
import os
import pytest
import torch
from torch.cuda import current_stream
import cutlass_api
from cutlass_api.utils import is_device_cc_supported
@pytest.mark.parametrize(
"M, N, K",
[
(256, 512, 1024),
],
)
@pytest.mark.parametrize(
"ab_dtype, c_dtype, accumulator_type",
[
(torch.float16, torch.float16, torch.float16),
],
)
@pytest.mark.parametrize(
"n_iterations",
[
20,
],
)
@pytest.mark.skipif(
not is_device_cc_supported({100})
or (os.getenv("CUTE_DSL_ARCH", "") not in ["", "sm_100a", "sm_100f"]),
reason="Requires compute capability 100 and to be compiled with sm_100a or sm_100f",
)
def test_gemm_sm100(
M: int,
N: int,
K: int,
ab_dtype: torch.dtype,
c_dtype: torch.dtype,
accumulator_type: torch.dtype,
n_iterations: int,
):
A = torch.randint(-1, 2, (M, K), device="cuda").to(ab_dtype)
B = torch.randint(-1, 2, (K, N), device="cuda").to(ab_dtype)
D = torch.randint(-1, 2, (M, N), device="cuda").to(c_dtype)
args = cutlass_api.arguments.GemmArguments(
A=A, B=B, out=D, accumulator_type=accumulator_type
)
kernels = cutlass_api.get_kernels(args, cc=100)
assert len(kernels) > 0
kernel = kernels[0]
"""
Compile the kernel and capture CUDA Graph.
The kernel needs to be compiled outside the CUDA graph.
"""
assert kernel.supports(args)
compiled_artifact = kernel.compile(args)
stream = torch.cuda.Stream()
# Create a CUDA Graph to run our compiled kernel N times
g = torch.cuda.CUDAGraph()
with torch.cuda.graph(g):
# Run N iterations of our compiled kernel on the current stream
for _ in range(n_iterations):
kernel.run(
args,
compiled_artifact=compiled_artifact,
stream=current_stream(),
assume_supported_args=True,
)
# Zero the output so we don't refcheck stale results
D.zero_()
# Replay captured graph & check result
g.replay()
torch.cuda.synchronize()
reference = A @ B
assert torch.allclose(D, reference.to(D.dtype)), "Refcheck failed!"
"""
Run with & without graph capture to compare overhead
"""
# Create CUDA events for measuring performance
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
# Warmup the GPU
for _ in range(n_iterations):
kernel.run(
args,
compiled_artifact=compiled_artifact,
stream=stream,
assume_supported_args=True,
)
# Run without CUDA graph and time it
start.record()
for _ in range(n_iterations):
kernel.run(
args,
compiled_artifact=compiled_artifact,
stream=stream,
assume_supported_args=True,
)
end.record()
torch.cuda.synchronize()
without_graph_time = start.elapsed_time(end)
# Warmup again
for _ in range(n_iterations):
kernel.run(
args,
compiled_artifact=compiled_artifact,
stream=stream,
assume_supported_args=True,
)
# Run with CUDA graph and time it
start.record()
g.replay()
end.record()
torch.cuda.synchronize()
with_graph_time = start.elapsed_time(end)
percent_speedup = (without_graph_time - with_graph_time) / with_graph_time
print("-" * 80)
print(f"Number of launches : {n_iterations}")
print(f"Time without CUDA graph: {without_graph_time:.2f} ms")
print(f"Time with CUDA graph: {with_graph_time:.2f} ms")
print(f"Speedup : {percent_speedup * 100.0:.2f}%")

View File

@@ -0,0 +1,69 @@
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
import pytest
import torch
import cutlass_api
from cutlass_api.utils import device_cc
@pytest.mark.parametrize(
"M, N",
[
(256, 512),
(1024, 8192),
],
)
@pytest.mark.parametrize(
"dtype",
[
torch.float32,
torch.float16,
],
)
def test_elementwise_add(
M: int,
N: int,
dtype: torch.dtype,
):
A = torch.randint(-1, 2, (M, N), device="cuda", dtype=dtype)
B = torch.randint(-1, 2, (M, N), device="cuda", dtype=dtype)
D = torch.empty((M, N), device="cuda", dtype=dtype)
args = cutlass_api.arguments.ElementwiseArguments(A=A, B=B, out=D)
kernels = cutlass_api.get_kernels(args)
assert len(kernels) > 0
kernel = kernels[0]
kernel.run(args)
reference = A + B
assert torch.allclose(D, reference)

View File

@@ -0,0 +1,211 @@
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
import logging
import os
from importlib.util import find_spec
from pprint import pformat
import pytest
import torch
import cutlass
import cutlass_api
from cutlass_api.config import GlobalOptions
from cutlass_api.utils import is_device_cc_supported
torch.manual_seed(2025)
logger = logging.getLogger(__name__)
@pytest.mark.parametrize(
"M, N, K, L",
[
(256, 512, 1024, 1),
(256, 512, 64, 1),
(256, 512, 64, 2),
],
)
@pytest.mark.parametrize(
"ab_dtype, c_dtype, accumulator_type",
[
(torch.float16, torch.float32, torch.float32),
(torch.float16, torch.float16, torch.float16),
(torch.bfloat16, torch.bfloat16, torch.float32),
],
)
@pytest.mark.parametrize(
"use_tvm_ffi",
[True, False],
)
@pytest.mark.skipif(
not is_device_cc_supported({100})
or (os.getenv("CUTE_DSL_ARCH", "") not in ["", "sm_100a", "sm_100f"]),
reason="Requires compute capability 100 and to be compiled with sm_100a or sm_100f",
)
def test_gemm_sm100(
M: int,
N: int,
K: int,
L: int,
ab_dtype: torch.dtype,
c_dtype: torch.dtype,
accumulator_type: torch.dtype,
use_tvm_ffi: bool,
):
A = torch.randint(-1, 2, (L, M, K), device="cuda", dtype=ab_dtype)
B = torch.randint(-1, 2, (L, K, N), device="cuda", dtype=ab_dtype)
D = torch.empty((L, M, N), device="cuda", dtype=c_dtype)
GlobalOptions().use_tvm_ffi = use_tvm_ffi
args = cutlass_api.arguments.GemmArguments(A, B, D, accumulator_type)
kernels = cutlass_api.get_kernels(args, cc=100)
assert len(kernels) > 0
kernel = kernels[0]
logger.debug(f"Picked kernel: {kernel.metadata.kernel_name}")
logger.debug(f"Kernel metadata:\n{pformat(kernel.metadata)}")
kernel.run(args)
reference = A @ B
torch.testing.assert_close(D, reference.to(D.dtype))
@pytest.mark.skipif(
not is_device_cc_supported({100})
or (os.getenv("CUTE_DSL_ARCH", "") not in ["", "sm_100a"]),
reason="Requires compute capability 100 and to be compiled with sm_100a",
)
def test_gemm_sm100_int8():
M, N, K = 256, 512, 128
A = torch.randint(-1, 2, (M, K), device="cuda", dtype=torch.int8)
B = torch.randint(-1, 2, (K, N), device="cuda", dtype=torch.int8)
D = torch.empty((M, N), device="cuda", dtype=torch.int32)
args = cutlass_api.arguments.GemmArguments(A, B, D, accumulator_type=torch.int32)
kernels = cutlass_api.get_kernels(args, cc=100)
assert len(kernels) > 0
kernel = kernels[0]
assert kernel.supports(args)
compiled_artifact = kernel.compile(args)
kernel.run(args, compiled_artifact=compiled_artifact, assume_supported_args=True)
reference = torch._int_mm(A, B)
torch.testing.assert_close(D, reference.to(D.dtype))
@pytest.mark.skipif(
find_spec("tvm_ffi") is None,
reason="FP8 currently requires TVM FFI to be installed",
)
@pytest.mark.skipif(
not is_device_cc_supported({100})
or (os.getenv("CUTE_DSL_ARCH", "") not in ["", "sm_100a"]),
reason="Requires compute capability 100 and to be compiled with sm_100a",
)
def test_gemm_sm100_fp8():
# Currently, FP8 is only supported via TVM FFI.
GlobalOptions().use_tvm_ffi = True
M, N, K = 256, 512, 128
# Create torch fp8 tensors for A and B
A = torch.randint(-1, 2, (M, K), device="cuda").to(torch.float8_e4m3fn)
D = torch.empty((M, N), device="cuda", dtype=torch.float32)
# Transpose B because torch._scaled_mm expects B to be column-major
B = torch.randint(-1, 2, (N, K), device="cuda").to(torch.float8_e4m3fn).T
args = cutlass_api.arguments.GemmArguments(A, B, D, accumulator_type=torch.float32)
kernels = cutlass_api.get_kernels(args, cc=100)
assert len(kernels) > 0
kernel = kernels[0]
assert kernel.supports(args)
compiled_artifact = kernel.compile(args)
kernel.run(args, compiled_artifact=compiled_artifact, assume_supported_args=True)
identity_scale = torch.ones(1, device="cuda", dtype=torch.float32)
reference = torch._scaled_mm(
A, B, identity_scale, identity_scale, out_dtype=torch.float32
)
torch.testing.assert_close(D, reference)
def test_no_gemms_available():
M = N = K = 128
L = 1
A = torch.empty((L, M, K)).to(torch.float32)
B = torch.empty((L, K, N)).to(torch.float32)
D = torch.empty((L, M, N)).to(torch.float32)
args = cutlass_api.arguments.GemmArguments(A, B, D, accumulator_type=torch.float32)
kernels = cutlass_api.get_kernels(args, cc=70)
# There are currenlty no kernels available for compute capability 70.
assert len(kernels) == 0
@pytest.mark.skipif(
not is_device_cc_supported({100}),
reason="Requires compute capability 100",
)
def test_metadata_filter():
# Test supplying metadata filter only
def tile_size_m_filter(metadata: cutlass_api.metadata.KernelMetadata) -> bool:
if not isinstance(metadata.design, cutlass_api.metadata.Sm100DesignMetadata):
return False
return metadata.design.tile_shape[0] == 64
kernels = cutlass_api.get_kernels(cc=100, metadata_filter=tile_size_m_filter)
for kernel in kernels:
assert kernel.metadata.design.tile_shape[0] == 64, (
f"Kernel {kernel.metadata.kernel_name} has tile shape {kernel.metadata.design.tile_shape}"
)
# Test supplying metadata filter and arguments
A = torch.randint(-1, 2, (1, 256, 256), device="cuda").to(torch.float16)
B = torch.randint(-1, 2, (1, 256, 256), device="cuda").to(torch.float16)
D = torch.empty((1, 256, 256), device="cuda").to(torch.float16)
args = cutlass_api.arguments.GemmArguments(A, B, D, accumulator_type=torch.float16)
kernels = cutlass_api.get_kernels(
args=args, cc=100, metadata_filter=tile_size_m_filter
)
for kernel in kernels:
assert kernel.metadata.design.tile_shape[0] == 64, (
f"Kernel {kernel.metadata.kernel_name} has tile shape {kernel.metadata.design.tile_shape}"
)
assert kernel.metadata.operands.A.dtype == cutlass.Float16
assert kernel.metadata.operands.B.dtype == cutlass.Float16
assert kernel.metadata.operands.out.dtype == cutlass.Float16
assert kernel.metadata.operands.accumulator_type == cutlass.Float16

View File

@@ -0,0 +1,833 @@
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
import os
import pytest
import torch
import cutlass_api
from cutlass_api.utils import device_cc
from cutlass_api.config import GlobalOptions
torch.manual_seed(2025)
def problem_sizes():
"""
Problem sizes for tests
"""
return [
(256, 512, 1024, 1),
(256, 512, 128, 1),
(256, 512, 128, 2),
]
def base_data_types():
"""
Data types for (ab, c, d, accumulator)
"""
return [
(torch.float16, torch.float32, torch.float32, torch.float32),
(torch.float16, torch.float16, torch.float16, torch.float16),
(torch.bfloat16, torch.bfloat16, torch.bfloat16, torch.float32),
]
def supports_sm100af():
return device_cc() == 100 and (
os.getenv("CUTE_DSL_ARCH", "") in ["", "sm_100a", "sm_100f"]
)
# Unary operation strings and functions
identity = ("", lambda x: x)
relu = ("relu", torch.relu)
tanh = ("tanh", torch.tanh)
sigmoid = ("sigmoid", torch.sigmoid)
exp = ("exp", torch.exp)
unary_ops = [identity, relu, tanh, sigmoid, exp]
# Binary operation strings and functions
add = (lambda a, b: f"{a} + {b}", lambda a, b: a + b)
sub = (lambda a, b: f"{a} - {b}", lambda a, b: a - b)
mul = (lambda a, b: f"{a} * {b}", lambda a, b: a * b)
# Don't include divide in main binary ops due to issues with division by zero in refchecks
binary_ops = [add, sub, mul]
@pytest.mark.parametrize("M, N, K, L", problem_sizes())
# Restrict to D of float16 for now to avoid rounding error when converting torch f16 output to f32
@pytest.mark.parametrize(
"ab_dtype, d_dtype, accumulator_type",
[(torch.float16, torch.float16, torch.float16)],
)
@pytest.mark.parametrize("unary_str, unary_op", unary_ops)
@pytest.mark.skipif(
not supports_sm100af(),
reason="Requires compute capability 100 and to be compiled with sm_100a or sm_100f",
)
def test_gemm_fusion_unary(
M, N, K, L, ab_dtype, d_dtype, accumulator_type, unary_str, unary_op
):
A = torch.randint(-1, 2, (L, M, K), device="cuda", dtype=ab_dtype)
B = torch.randint(-1, 2, (L, K, N), device="cuda", dtype=ab_dtype)
D = torch.empty((L, M, N), device="cuda", dtype=d_dtype)
def epi(accum):
D = unary_op(accum)
return D
epi_str = f"def epi(accum): D = {unary_str}(accum); return D"
epi_args = cutlass_api.arguments.EpilogueArguments(epi_str, D=D)
args = cutlass_api.arguments.GemmArguments(
A=A, B=B, out=D, accumulator_type=accumulator_type, epilogue=epi_args
)
kernels = cutlass_api.get_kernels(args, cc=100)
assert len(kernels) > 0
kernels[0].run(args)
reference = epi(A @ B)
torch.testing.assert_close(D, reference.to(D.dtype))
@pytest.mark.parametrize("M, N, K, L", [(256, 512, 128, 2)])
# Restrict to D of float16 for now to avoid rounding error when converting torch f16 output to f32
@pytest.mark.parametrize(
"ab_dtype, d_dtype, accumulator_type",
[(torch.float16, torch.float16, torch.float16)],
)
@pytest.mark.parametrize("unary_str, unary_op", [relu])
@pytest.mark.parametrize("unary_str2, unary_op2", [sigmoid, tanh])
@pytest.mark.skipif(
not supports_sm100af(),
reason="Requires compute capability 100 and to be compiled with sm_100a or sm_100f",
)
def test_gemm_fusion_unary_composition(
M,
N,
K,
L,
ab_dtype,
d_dtype,
accumulator_type,
unary_str,
unary_op,
unary_str2,
unary_op2,
):
A = torch.randint(-1, 2, (L, M, K), device="cuda", dtype=ab_dtype)
B = torch.randint(-1, 2, (L, K, N), device="cuda", dtype=ab_dtype)
D = torch.empty((L, M, N), device="cuda", dtype=d_dtype)
def epi(accum):
D = unary_op2(unary_op(accum))
return D
epi_str = f"def epi(accum): D = {unary_str2}({unary_str}(accum)); return D"
epi_args = cutlass_api.arguments.EpilogueArguments(epi_str, D=D)
args = cutlass_api.arguments.GemmArguments(
A=A, B=B, out=D, accumulator_type=accumulator_type, epilogue=epi_args
)
kernels = cutlass_api.get_kernels(args, cc=100)
assert len(kernels) > 0
kernels[0].run(args)
reference = epi(A @ B)
torch.testing.assert_close(D, reference.to(D.dtype))
@pytest.mark.parametrize("M, N, K, L", problem_sizes())
# Restrict to D of float16 for now to avoid rounding error when converting torch f16 output to f32
@pytest.mark.parametrize(
"ab_dtype, d_dtype, accumulator_type",
[(torch.float16, torch.float16, torch.float16)],
)
# Restrict unary to identity and relu to avoid rounding errors
@pytest.mark.parametrize("unary_str, unary_op", [identity, relu])
@pytest.mark.skipif(
not supports_sm100af(),
reason="Requires compute capability 100 and to be compiled with sm_100a or sm_100f",
)
def test_gemm_fusion_unary_literal(
M, N, K, L, ab_dtype, d_dtype, accumulator_type, unary_str, unary_op
):
A = torch.randint(-1, 2, (L, M, K), device="cuda", dtype=ab_dtype)
B = torch.randint(-1, 2, (L, K, N), device="cuda", dtype=ab_dtype)
D = torch.empty((L, M, N), device="cuda", dtype=d_dtype)
def epi(accum):
D = unary_op(accum) * 3.0 - 1.234
return D
epi_str = f"def epi(accum): D = {unary_str}(accum) * 3.0 - 1.234; return D"
epi_args = cutlass_api.arguments.EpilogueArguments(epi_str, D=D)
args = cutlass_api.arguments.GemmArguments(
A=A, B=B, out=D, accumulator_type=accumulator_type, epilogue=epi_args
)
kernels = cutlass_api.get_kernels(args, cc=100)
assert len(kernels) > 0
kernels[0].run(args)
reference = epi(A @ B)
torch.testing.assert_close(D, reference.to(D.dtype))
@pytest.mark.parametrize("M, N, K, L", problem_sizes())
@pytest.mark.parametrize(
"ab_dtype, c_dtype, d_dtype, accumulator_type", base_data_types()
)
@pytest.mark.parametrize("unary_str, unary_op", [identity, relu])
@pytest.mark.parametrize("binary_str, binary_op", binary_ops)
@pytest.mark.skipif(
not supports_sm100af(),
reason="Requires compute capability 100 and to be compiled with sm_100a or sm_100f",
)
def test_gemm_fusion_unary_binary_composition(
M,
N,
K,
L,
ab_dtype,
c_dtype,
d_dtype,
accumulator_type,
unary_str,
unary_op,
binary_str,
binary_op,
):
A = torch.randint(-1, 2, (L, M, K), device="cuda", dtype=ab_dtype)
B = torch.randint(-1, 2, (L, K, N), device="cuda", dtype=ab_dtype)
C = torch.randint(-1, 2, (L, M, N), device="cuda", dtype=c_dtype)
D = torch.empty((L, M, N), device="cuda", dtype=d_dtype)
def epi(accum, C):
z = unary_op(accum)
D = binary_op(z, C)
return D
epi_str = f"def epi(accum, C): z = {unary_str}(accum); D = {binary_str('z', 'C')}; return D"
epi_args = cutlass_api.arguments.EpilogueArguments(epi_str, C=C, D=D)
args = cutlass_api.arguments.GemmArguments(
A=A, B=B, out=D, accumulator_type=accumulator_type, epilogue=epi_args
)
kernels = cutlass_api.get_kernels(args, cc=100)
assert len(kernels) > 0
kernels[0].run(args)
reference = epi(A @ B, C)
torch.testing.assert_close(D, reference.to(D.dtype))
@pytest.mark.parametrize("M, N, K, L", problem_sizes())
# Restrict to D of float16 for now to avoid rounding error when converting torch f16 output to f32
@pytest.mark.parametrize(
"ab_dtype, d_dtype, accumulator_type",
[(torch.float16, torch.float16, torch.float16)],
)
@pytest.mark.parametrize("c0_dtype", [torch.float16, torch.float32])
@pytest.mark.parametrize("c1_dtype", [torch.float16, torch.float32])
@pytest.mark.parametrize("binary_str0, binary_op0", [add, sub])
@pytest.mark.parametrize("binary_str1, binary_op1", [add, sub])
@pytest.mark.skipif(
not supports_sm100af(),
reason="Requires compute capability 100 and to be compiled with sm_100a or sm_100f",
)
def test_gemm_fusion_binary_binary_composition(
M,
N,
K,
L,
ab_dtype,
d_dtype,
accumulator_type,
c0_dtype,
c1_dtype,
binary_str0,
binary_op0,
binary_str1,
binary_op1,
):
A = torch.randint(-1, 2, (L, M, K), device="cuda", dtype=ab_dtype)
B = torch.randint(-1, 2, (L, K, N), device="cuda", dtype=ab_dtype)
C0 = torch.randint(-1, 2, (L, M, N), device="cuda", dtype=c0_dtype)
C1 = torch.randint(-1, 2, (L, M, N), device="cuda", dtype=c1_dtype)
D = torch.empty((L, M, N), device="cuda", dtype=d_dtype)
def epi(accum, C0, C1):
z = torch.relu(accum)
z1 = binary_op0(z, C0)
D = binary_op1(z1, C1)
return D
epi_str = f"def epi(accum, C0, C1): z = relu(accum); z1 = {binary_str0('z', 'C0')}; D = {binary_str1('z1', 'C1')}; return D"
epi_args = cutlass_api.arguments.EpilogueArguments(epi_str, C0=C0, C1=C1, D=D)
args = cutlass_api.arguments.GemmArguments(
A=A, B=B, out=D, accumulator_type=accumulator_type, epilogue=epi_args
)
kernels = cutlass_api.get_kernels(args, cc=100)
assert len(kernels) > 0
kernels[0].run(args)
reference = epi(A @ B, C0, C1)
torch.testing.assert_close(D, reference.to(D.dtype))
@pytest.mark.skipif(
not supports_sm100af(),
reason="Requires compute capability 100 and to be compiled with sm_100a or sm_100f",
)
def test_gemm_fusion_division():
M, N, K, L = 256, 512, 128, 2
ab_dtype = torch.float16
d_dtype = torch.float16
accumulator_type = torch.float16
# Specifically initialize A and B with ones to avoid division by zero in refchecks
A = torch.ones((L, M, K), device="cuda", dtype=ab_dtype)
B = torch.ones((L, K, N), device="cuda", dtype=ab_dtype)
D = torch.empty((L, M, N), device="cuda", dtype=d_dtype)
#########################################################
# Test division by a literal
#########################################################
def epi(accum):
D = accum / 2.0
return D
epi_str = "def epi(accum): D = accum / 2.0; return D"
epi_args = cutlass_api.arguments.EpilogueArguments(epi_str, D=D)
args = cutlass_api.arguments.GemmArguments(
A=A, B=B, out=D, accumulator_type=accumulator_type, epilogue=epi_args
)
kernels = cutlass_api.get_kernels(args, cc=100)
assert len(kernels) > 0
kernels[0].run(args)
reference = epi(A @ B)
torch.testing.assert_close(D, reference.to(D.dtype))
#########################################################
# Test division by an input
#########################################################
def epi(accum, scalar):
D = accum / scalar
return D
epi_str = "def epi(accum, scalar): D = accum / scalar; return D"
scalar = 4.0
epi_args = cutlass_api.arguments.EpilogueArguments(epi_str, scalar=scalar, D=D)
args = cutlass_api.arguments.GemmArguments(
A=A, B=B, out=D, accumulator_type=accumulator_type, epilogue=epi_args
)
kernels = cutlass_api.get_kernels(args, cc=100)
assert len(kernels) > 0
kernels[0].run(args)
reference = epi(A @ B, scalar)
torch.testing.assert_close(D, reference.to(D.dtype))
@pytest.mark.parametrize("M, N, K, L", problem_sizes())
# Restrict to D of float16 for now to avoid rounding error when converting torch f16 output to f32
@pytest.mark.parametrize(
"ab_dtype, d_dtype, accumulator_type",
[(torch.float16, torch.float16, torch.float16)],
)
@pytest.mark.parametrize("unary_str, unary_op", [sigmoid, tanh])
@pytest.mark.skipif(
not supports_sm100af(),
reason="Requires compute capability 100 and to be compiled with sm_100a or sm_100f",
)
def test_gemm_fusion_unary_multi_output(
M, N, K, L, ab_dtype, d_dtype, accumulator_type, unary_str, unary_op
):
A = torch.randint(-1, 2, (L, M, K), device="cuda", dtype=ab_dtype)
B = torch.randint(-1, 2, (L, K, N), device="cuda", dtype=ab_dtype)
D = torch.empty((L, M, N), device="cuda", dtype=d_dtype)
z = torch.empty((L, M, N), device="cuda", dtype=d_dtype)
def epi(accum):
z0 = torch.relu(accum)
z = unary_op(z0)
D = z + z0
return D, z
epi_str = f"def epi(accum): z0 = relu(accum); z = {unary_str}(z0); D = z + z0; return D, z"
epi_args = cutlass_api.arguments.EpilogueArguments(epi_str, z=z, D=D)
args = cutlass_api.arguments.GemmArguments(
A=A, B=B, out=D, accumulator_type=accumulator_type, epilogue=epi_args
)
kernels = cutlass_api.get_kernels(args, cc=100)
assert len(kernels) > 0
kernels[0].run(args)
ref_D, ref_z = epi(A @ B)
torch.testing.assert_close(D, ref_D.to(D.dtype))
torch.testing.assert_close(z, ref_z.to(z.dtype))
@pytest.mark.parametrize("M, N, K, L", problem_sizes())
# Restrict to D of float16 for now to avoid rounding error when converting torch f16 output to f32
@pytest.mark.parametrize(
"ab_dtype, d_dtype, accumulator_type",
[(torch.float16, torch.float16, torch.float16)],
)
@pytest.mark.parametrize("c_dtype", [torch.float16, torch.float32])
@pytest.mark.parametrize("binary_str, binary_op", binary_ops)
@pytest.mark.skipif(
not supports_sm100af(),
reason="Requires compute capability 100 and to be compiled with sm_100a or sm_100f",
)
def test_gemm_fusion_binary_multi_output(
M, N, K, L, ab_dtype, d_dtype, accumulator_type, c_dtype, binary_str, binary_op
):
A = torch.randint(-1, 2, (L, M, K), device="cuda", dtype=ab_dtype)
B = torch.randint(-1, 2, (L, K, N), device="cuda", dtype=ab_dtype)
C = torch.randint(-1, 2, (L, M, N), device="cuda", dtype=c_dtype)
D = torch.empty((L, M, N), device="cuda", dtype=d_dtype)
z = torch.empty((L, M, N), device="cuda", dtype=d_dtype)
def epi(accum, C):
z0 = torch.relu(accum)
z = binary_op(z0, C)
D = z + z0
return D, z
epi_str = f"def epi(accum, C): z0 = relu(accum); z = {binary_str('z0', 'C')}; D = z + z0; return D, z"
epi_args = cutlass_api.arguments.EpilogueArguments(epi_str, C=C, z=z, D=D)
args = cutlass_api.arguments.GemmArguments(
A=A, B=B, out=D, accumulator_type=accumulator_type, epilogue=epi_args
)
kernels = cutlass_api.get_kernels(args, cc=100)
assert len(kernels) > 0
kernels[0].run(args)
ref_D, ref_z = epi(A @ B, C)
torch.testing.assert_close(D, ref_D.to(D.dtype))
torch.testing.assert_close(z, ref_z.to(z.dtype))
@pytest.mark.skipif(
not supports_sm100af(),
reason="Requires compute capability 100 and to be compiled with sm_100a or sm_100f",
)
def test_gemm_fusion_return_acc():
M, N, K, L = 256, 512, 128, 2
ab_dtype = torch.float16
c_dtype = torch.float32
d_dtype = torch.float16
accumulator_type = torch.float16
A = torch.randint(-1, 2, (L, M, K), device="cuda", dtype=ab_dtype)
B = torch.randint(-1, 2, (L, K, N), device="cuda", dtype=ab_dtype)
C = torch.randint(-1, 2, (L, M, N), device="cuda", dtype=c_dtype)
D = torch.empty((L, M, N), device="cuda", dtype=d_dtype)
accum = torch.empty((L, M, N), device="cuda", dtype=accumulator_type)
def epi(accum, C):
D = torch.relu(accum) + C
return D, accum
epi_str = "def epi(accum, C): D = relu(accum) + C; return D, accum"
epi_args = cutlass_api.arguments.EpilogueArguments(
epi_str, C=C, D=D, accum=accum
)
args = cutlass_api.arguments.GemmArguments(
A=A, B=B, out=D, accumulator_type=accumulator_type, epilogue=epi_args
)
kernels = cutlass_api.get_kernels(args, cc=100)
assert len(kernels) > 0
kernels[0].run(args)
reference, ref_accum = epi(A @ B, C)
torch.testing.assert_close(D, reference.to(D.dtype))
torch.testing.assert_close(accum, ref_accum.to(accum.dtype))
@pytest.mark.skipif(
not supports_sm100af(),
reason="Requires compute capability 100 and to be compiled with sm_100a or sm_100f",
)
def test_gemm_fusion_acc_as_multiple_input():
M, N, K, L = 256, 512, 128, 2
ab_dtype = torch.float16
c_dtype = torch.float32
d_dtype = torch.float16
accumulator_type = torch.float16
A = torch.randint(-1, 2, (L, M, K), device="cuda", dtype=ab_dtype)
B = torch.randint(-1, 2, (L, K, N), device="cuda", dtype=ab_dtype)
C = torch.randint(-1, 2, (L, M, N), device="cuda", dtype=c_dtype)
D = torch.empty((L, M, N), device="cuda", dtype=d_dtype)
#########################################################
# Test binary op inside
#########################################################
def epi(accum, C):
D = torch.relu(torch.relu(accum) * C) + accum
return D
epi_str = "def epi(accum, C): D = relu(relu(accum) * C) + accum; return D"
epi_args = cutlass_api.arguments.EpilogueArguments(epi_str, C=C, D=D)
args = cutlass_api.arguments.GemmArguments(
A=A, B=B, out=D, accumulator_type=accumulator_type, epilogue=epi_args
)
kernels = cutlass_api.get_kernels(args, cc=100)
assert len(kernels) > 0
kernels[0].run(args)
reference = epi(A @ B, C)
torch.testing.assert_close(D, reference.to(D.dtype))
#########################################################
# Test unary op inside
#########################################################
def epi(accum):
D = torch.relu(torch.sigmoid(torch.relu(accum))) + accum
return D
epi_str = "def epi(accum): D = relu(sigmoid(relu(accum))) + accum; return D"
epi_args = cutlass_api.arguments.EpilogueArguments(epi_str, D=D)
args = cutlass_api.arguments.GemmArguments(
A=A, B=B, out=D, accumulator_type=accumulator_type, epilogue=epi_args
)
kernels = cutlass_api.get_kernels(args, cc=100)
assert len(kernels) > 0
kernels[0].run(args)
reference = epi(A @ B)
torch.testing.assert_close(D, reference.to(D.dtype))
@pytest.mark.skipif(
not supports_sm100af(),
reason="Requires compute capability 100 and to be compiled with sm_100a or sm_100f",
)
def test_gemm_fusion_matmul_input_as_aux():
M, N, K, L = 1024, 1024, 1024, 2
ab_dtype = torch.float16
c_dtype = torch.float32
d_dtype = torch.float16
accumulator_type = torch.float16
A = torch.randint(-1, 2, (L, M, K), device="cuda", dtype=ab_dtype)
B = torch.randint(-1, 2, (L, K, N), device="cuda", dtype=ab_dtype)
C = torch.randint(-1, 2, (L, M, N), device="cuda", dtype=c_dtype)
D = torch.empty((L, M, N), device="cuda", dtype=d_dtype)
#########################################################
# Test binary op inside
#########################################################
def epi(accum, C, A):
D = torch.sigmoid(torch.relu(accum) * C) + A
return D
epi_str = "def epi(accum, C, A): D = sigmoid(relu(accum) * C) + A; return D"
epi_args = cutlass_api.arguments.EpilogueArguments(epi_str, C=C, A=A, D=D)
args = cutlass_api.arguments.GemmArguments(
A=A, B=B, out=D, accumulator_type=accumulator_type, epilogue=epi_args
)
kernels = cutlass_api.get_kernels(args, cc=100)
assert len(kernels) > 0
kernels[0].run(args)
reference = epi(A @ B, C, A)
torch.testing.assert_close(D, reference.to(D.dtype))
#########################################################
# Test unary op inside
#########################################################
def epi(accum, A):
D = torch.tanh(torch.relu(accum)) + A
return D
epi_str = "def epi(accum, A): D = tanh(relu(accum)) + A; return D"
epi_args = cutlass_api.arguments.EpilogueArguments(epi_str, A=A, D=D)
args = cutlass_api.arguments.GemmArguments(
A=A, B=B, out=D, accumulator_type=accumulator_type, epilogue=epi_args
)
kernels = cutlass_api.get_kernels(args, cc=100)
assert len(kernels) > 0
kernels[0].run(args)
reference = epi(A @ B, A)
torch.testing.assert_close(D, reference.to(D.dtype))
@pytest.mark.parametrize("M, N, K, L", problem_sizes())
@pytest.mark.parametrize(
"ab_dtype, c_dtype, d_dtype, accumulator_type", base_data_types()
)
@pytest.mark.parametrize(
"use_tvm_ffi",
[True, False],
)
@pytest.mark.skipif(
not supports_sm100af(),
reason="Requires compute capability 100 and to be compiled with sm_100a or sm_100f",
)
def test_gemm_alpha_beta(
M, N, K, L, ab_dtype, c_dtype, d_dtype, accumulator_type, use_tvm_ffi
):
GlobalOptions().use_tvm_ffi = use_tvm_ffi
A = torch.randint(-1, 2, (L, M, K), device="cuda", dtype=ab_dtype)
B = torch.randint(-1, 2, (L, K, N), device="cuda", dtype=ab_dtype)
C = torch.randint(-1, 2, (L, M, N), device="cuda", dtype=c_dtype)
D = torch.empty((L, M, N), device="cuda", dtype=d_dtype)
def epi(accum, C, alpha, beta):
D = alpha * accum + beta * C
return D
alpha = 0.5
beta = 0.5
epi_args = cutlass_api.arguments.EpilogueArguments(
epi, C=C, alpha=alpha, beta=beta, D=D
)
args = cutlass_api.arguments.GemmArguments(
A=A, B=B, out=D, accumulator_type=accumulator_type, epilogue=epi_args
)
kernels = cutlass_api.get_kernels(args, cc=100)
assert len(kernels) > 0
kernel = kernels[0]
for a, b in [(0.5, 0.5), (1.0, 0.0), (0.0, 1.0)]:
epi_args = cutlass_api.arguments.EpilogueArguments(
epi, C=C, alpha=a, beta=b, D=D
)
args = cutlass_api.arguments.GemmArguments(
A=A, B=B, out=D, accumulator_type=accumulator_type, epilogue=epi_args
)
kernel.run(args)
reference = epi(A @ B, C, a, b)
torch.testing.assert_close(D, reference.to(D.dtype))
@pytest.mark.skipif(
not supports_sm100af(),
reason="Requires compute capability 100 and to be compiled with sm_100a or sm_100f",
)
def test_gemm_big_epi():
M, N, K, L = 256, 512, 128, 2
ab_dtype = torch.float16
c_dtype = torch.float32
d_dtype = torch.bfloat16
accumulator_type = torch.float16
A = torch.randint(-1, 2, (L, M, K), device="cuda", dtype=ab_dtype)
B = torch.randint(-1, 2, (L, K, N), device="cuda", dtype=ab_dtype)
In0 = torch.randint(-1, 2, (L, M, N), device="cuda", dtype=c_dtype)
In1 = torch.randint(-1, 2, (L, M, N), device="cuda", dtype=c_dtype)
In2 = torch.randint(-1, 2, (L, M, N), device="cuda", dtype=c_dtype)
In3 = torch.randint(-1, 2, (L, M, N), device="cuda", dtype=c_dtype)
In4 = torch.randint(-1, 2, (L, M, N), device="cuda", dtype=c_dtype)
In5 = torch.randint(-1, 2, (L, M, N), device="cuda", dtype=c_dtype)
In6 = torch.randint(-1, 2, (L, M, N), device="cuda", dtype=c_dtype)
In7 = torch.randint(-1, 2, (L, M, N), device="cuda", dtype=c_dtype)
Out0 = torch.empty((L, M, N), device="cuda", dtype=d_dtype)
Out1 = torch.empty((L, M, N), device="cuda", dtype=d_dtype)
Out2 = torch.empty((L, M, N), device="cuda", dtype=d_dtype)
Out3 = torch.empty((L, M, N), device="cuda", dtype=d_dtype)
Out4 = torch.empty((L, M, N), device="cuda", dtype=d_dtype)
Out5 = torch.empty((L, M, N), device="cuda", dtype=d_dtype)
Out6 = torch.empty((L, M, N), device="cuda", dtype=d_dtype)
D = torch.empty((L, M, N), device="cuda", dtype=d_dtype)
sc0 = 1.0
sc1 = 2.0
sc2 = 3.0
sc3 = 4.0
sc4 = 5.0
sc5 = 6.0
sc6 = 7.0
sc7 = 8.0
def epi(
accum,
In0,
In1,
In2,
In3,
In4,
In5,
In6,
In7,
sc0,
sc1,
sc2,
sc3,
sc4,
sc5,
sc6,
sc7,
):
Out0 = accum * sc0 + In0
Out1 = Out0 + In1 * sc1
Out2 = Out1 - In2 * sc2
Out3 = Out2 + In3 * sc3
Out4 = Out3 - In4 * sc4
Out5 = Out4 + In5 * sc5
Out6 = Out5 - In6 * sc6
D = Out6 + In7 * sc7
return Out0, Out1, Out2, Out3, Out4, Out5, Out6, D
epi_args = cutlass_api.arguments.EpilogueArguments(
epi,
In0=In0,
In1=In1,
In2=In2,
In3=In3,
In4=In4,
In5=In5,
In6=In6,
In7=In7,
Out0=Out0,
Out1=Out1,
Out2=Out2,
Out3=Out3,
Out4=Out4,
Out5=Out5,
Out6=Out6,
D=D,
sc0=sc0,
sc1=sc1,
sc2=sc2,
sc3=sc3,
sc4=sc4,
sc5=sc5,
sc6=sc6,
sc7=sc7,
)
args = cutlass_api.arguments.GemmArguments(
A=A, B=B, out=D, accumulator_type=accumulator_type, epilogue=epi_args
)
kernels = cutlass_api.get_kernels(args, cc=100)
assert len(kernels) > 0
kernels[0].run(args)
reference = epi(
A @ B,
In0,
In1,
In2,
In3,
In4,
In5,
In6,
In7,
sc0,
sc1,
sc2,
sc3,
sc4,
sc5,
sc6,
sc7,
)
for out, ref in zip([Out0, Out1, Out2, Out3, Out4, Out5, Out6, D], reference):
torch.testing.assert_close(out, ref.to(out.dtype))
@pytest.mark.skipif(
not supports_sm100af(),
reason="Requires compute capability 100 and to be compiled with sm_100a or sm_100f",
)
def test_gemm_fusion_not_available():
M = 256
N = 512
K = 1024
A = torch.randint(-1, 2, (M, K), device="cuda", dtype=torch.float16)
B = torch.randint(-1, 2, (K, N), device="cuda", dtype=torch.float16)
D = torch.empty((M, N), device="cuda", dtype=torch.float16)
# Non-scalar broadcasts are currently not supported
bias = torch.randint(-1, 2, (M, 1), device="cuda", dtype=torch.float16)
def epi(accum, bias):
D = accum + bias
return D
epi_args = cutlass_api.arguments.EpilogueArguments(epi, bias=bias, D=D)
args = cutlass_api.arguments.GemmArguments(
A=A, B=B, out=D, accumulator_type=torch.float16, epilogue=epi_args
)
kernels = cutlass_api.get_kernels(args, cc=100)
assert len(kernels) == 0

View File

@@ -0,0 +1,86 @@
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
import os
import pytest
import cutlass_api
@pytest.mark.parametrize(
"notebook_name, supported_ccs",
[
("000_gemm.ipynb", [100, 103]),
("001_gemm_with_fused_epilogue.ipynb", [100, 103]),
("002_bring_your_own_kernel.ipynb", [80, 89, 90, 100, 103, 120, 121]),
("003_host_latency_best_practices.ipynb", [100, 103]),
],
)
def test_notebooks(notebook_name, supported_ccs):
possible_cc_strs = []
for cc in supported_ccs:
arch_conditional_ccs = [90, 100, 101, 103, 120, 121]
family_conditional_ccs = [100, 103]
if cc in family_conditional_ccs:
possible_cc_strs.extend([f"sm_{cc}f", f"sm_{cc}a"])
elif cc in arch_conditional_ccs:
possible_cc_strs.append(f"sm_{cc}a")
else:
possible_cc_strs.append(f"sm_{cc}")
cute_dsl_arch = os.environ.get("CUTE_DSL_ARCH", "")
# Add empty string to allow running without macro set
possible_cc_strs.append("")
if (
cutlass_api.utils.device_cc() not in supported_ccs
or cute_dsl_arch not in possible_cc_strs
):
# Each test should gracefully exit(0) for CCs that are not supported, but
# the nbconvert-based runner below has issues with this. Thus, we manually
# skip here.
pytest.skip(
f"This notebook requires a GPU with compute capability {supported_ccs}"
)
notebook_dir = os.path.join(os.path.dirname(__file__), "..", "..", "examples")
full_notebook_path = os.path.join(notebook_dir, notebook_name)
import nbconvert
import nbformat
with open(full_notebook_path, "r") as file:
notebook = nbformat.read(file, as_version=4)
ep = nbconvert.preprocessors.ExecutePreprocessor(timeout=600, kernel_name="python3")
# Execute the notebook. This call will error out on any assertions or errors in the
# notebook itself. Allow these to propagate up so the test will fail on notebook failure.
ep.preprocess(notebook, {"metadata": {"path": notebook_dir}})

View File

@@ -0,0 +1,178 @@
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
import os
import cuda.bindings.driver as cuda
import pytest
import torch
from torch.cuda import current_stream
import cutlass_api
from cutlass_api.config import GlobalOptions
from cutlass_api.utils import is_device_cc_supported
def benchmark(label, code, n_iterations):
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
for _ in range(n_iterations):
code()
start.record()
for _ in range(n_iterations):
code()
end.record()
torch.cuda.synchronize()
avg_time = start.elapsed_time(end) / n_iterations
print(f"[{label:20}] avg of {n_iterations} iterations: {avg_time:1.3e} ms")
return avg_time
@pytest.mark.parametrize(
"M, N, K",
[
(256, 512, 1024),
],
)
@pytest.mark.parametrize(
"ab_dtype, c_dtype, accumulator_type",
[
(torch.float16, torch.float16, torch.float16),
],
)
@pytest.mark.parametrize(
"n_iterations",
[
50,
],
)
@pytest.mark.skipif(
not is_device_cc_supported({100})
or (os.getenv("CUTE_DSL_ARCH", "") not in ["", "sm_100a", "sm_100f"]),
reason="Requires compute capability 100 and to be compiled with sm_100a or sm_100f",
)
def test_gemm_sm100(
M: int,
N: int,
K: int,
ab_dtype: torch.dtype,
c_dtype: torch.dtype,
accumulator_type: torch.dtype,
n_iterations: int,
):
print()
A = torch.randint(-1, 2, (M, K), device="cuda").to(ab_dtype)
B = torch.randint(-1, 2, (K, N), device="cuda").to(ab_dtype)
D = torch.empty((M, N), device="cuda").to(c_dtype)
GlobalOptions().use_tvm_ffi = True
args = cutlass_api.arguments.GemmArguments(
A=A, B=B, out=D, accumulator_type=accumulator_type
)
kernels = cutlass_api.get_kernels(args, cc=100)
assert len(kernels) > 0
kernel = kernels[0]
"""
Compile & run the kernel with TVM-FFI.
"""
assert kernel.supports(args)
compiled_artifact_with_tvm_ffi = kernel.compile(args)
kernel.run(
args,
compiled_artifact=compiled_artifact_with_tvm_ffi,
stream=current_stream(),
assume_supported_args=True,
)
reference = A @ B
assert torch.allclose(D, reference.to(D.dtype)), "Refcheck failed!"
# Also works with CUDA graphs
g = torch.cuda.CUDAGraph()
with torch.cuda.graph(g):
kernel.run(
args,
compiled_artifact=compiled_artifact_with_tvm_ffi,
stream=current_stream(),
assume_supported_args=True,
)
D.zero_()
g.replay()
torch.cuda.synchronize()
assert torch.allclose(D, reference.to(D.dtype)), "Refcheck failed!"
"""
Create args & run with & without TVM-FFI to compare overhead
"""
args = cutlass_api.arguments.GemmArguments(
A=A, B=B, out=D, accumulator_type=accumulator_type
)
compiled_artifact_with_tvm_ffi = kernel.compile(args)
# Run with TVM-FFI and time it
avg_time_with_tvm_ffi = benchmark(
"Run with TVM-FFI",
lambda: kernel.run(
args,
compiled_artifact=compiled_artifact_with_tvm_ffi,
stream=current_stream(),
assume_supported_args=True,
),
n_iterations,
)
# Run without TVM-FFI and time it
GlobalOptions().use_tvm_ffi = False
args_without_tvm_ffi = cutlass_api.arguments.GemmArguments(
A=A, B=B, out=D, accumulator_type=accumulator_type
)
compiled_artifact_without_tvm_ffi = kernel.compile(args_without_tvm_ffi)
stream = cuda.CUstream(current_stream().cuda_stream)
avg_time_without_tvm_ffi = benchmark(
"Run without TVM-FFI",
lambda: kernel.run(
args_without_tvm_ffi,
compiled_artifact=compiled_artifact_without_tvm_ffi,
stream=stream,
assume_supported_args=True,
),
n_iterations,
)
speedup = avg_time_without_tvm_ffi / avg_time_with_tvm_ffi
print(f"Speedup with TVM-FFI: {speedup:.3f}x")

View File

@@ -0,0 +1,95 @@
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
import pytest
import torch
import cutlass_api
pytestmark = pytest.mark.arch("80")
def test_basic_keywords():
def epi(accum, C, alpha, beta):
D = (alpha * accum) + (beta * C)
return D
cutlass_api.arguments.EpilogueArguments(
epilogue_fn=epi,
D=torch.randn(10, 10),
C=torch.randn(10, 10),
alpha=1.0,
beta=1.0,
)
def test_missing_keywords():
epifn = """
def epi(accum, C, alpha, beta):
D = (alpha * accum) + (beta * C)
F = relu(D)
return D, F
"""
try:
# Missing F
cutlass_api.arguments.EpilogueArguments(
epilogue_fn=epifn,
D=torch.randn(10, 10),
C=torch.randn(10, 10),
alpha=1.0,
beta=1.0,
)
except ValueError as e:
assert "F" in str(e)
else:
assert False, "Failed to catch missing keyword"
def test_extra_keywords():
epifn = """
def epi(accum, C, alpha, beta):
D = (alpha * accum) + (beta * C)
F = relu(D)
return D, F
"""
try:
cutlass_api.arguments.EpilogueArguments(
epilogue_fn=epifn,
D=torch.randn(10, 10),
C=torch.randn(10, 10),
alpha=1.0,
beta=1.0,
F=torch.randn(10, 10),
gamma=3.0,
)
except ValueError as e:
assert "gamma" in str(e)
else:
assert False, "Failed to catch extra keyword"